1
use std::collections::HashMap;
2

            
3
use finance::split::Split;
4
use finance::transaction::Transaction;
5

            
6
use crate::format::{
7
    AccountData, BASE_OFFSET, CommodityData, ContextType, ENTITY_HEADER_SIZE, EntityFlags,
8
    EntityHeader, EntityType, GLOBAL_HEADER_SIZE, GlobalHeader, OUTPUT_HEADER_SIZE, Operation,
9
    OutputHeader, SplitData, TAG_DATA_SIZE, TagData, TransactionData,
10
};
11

            
12
48
fn transaction_to_data(tx: &Transaction) -> TransactionData {
13
48
    TransactionData {
14
48
        post_date: tx.post_date.timestamp_millis(),
15
48
        enter_date: tx.enter_date.timestamp_millis(),
16
48
        split_count: 0,
17
48
        tag_count: 0,
18
48
        is_multi_currency: 0,
19
48
        reserved: [0; 23],
20
48
    }
21
48
}
22

            
23
48
fn split_to_data(split: &Split) -> SplitData {
24
    SplitData {
25
48
        account_id: *split.account_id.as_bytes(),
26
48
        commodity_id: *split.commodity_id.as_bytes(),
27
48
        value_num: split.value_num,
28
48
        value_denom: split.value_denom,
29
48
        reconcile_state: split.reconcile_state.map_or(0, u8::from),
30
48
        reserved: [0; 7],
31
48
        reconcile_date: split
32
48
            .reconcile_date
33
48
            .map_or(0, |d: chrono::DateTime<chrono::Utc>| d.timestamp_millis()),
34
    }
35
48
}
36

            
37
pub struct MemorySerializer {
38
    context_type: ContextType,
39
    primary_entity_type: EntityType,
40
    primary_entity_idx: u32,
41
    entities: Vec<SerializedEntity>,
42
    strings_pool: Vec<u8>,
43
    string_cache: HashMap<String, (u32, u16)>,
44
}
45

            
46
struct SerializedEntity {
47
    header: EntityHeader,
48
    data: Vec<u8>,
49
}
50

            
51
impl Default for MemorySerializer {
52
    fn default() -> Self {
53
        Self::new()
54
    }
55
}
56

            
57
impl MemorySerializer {
58
    #[must_use]
59
58
    pub fn new() -> Self {
60
58
        Self {
61
58
            context_type: ContextType::EntityCreate,
62
58
            primary_entity_type: EntityType::Transaction,
63
58
            primary_entity_idx: 0,
64
58
            entities: Vec::new(),
65
58
            strings_pool: Vec::new(),
66
58
            string_cache: HashMap::new(),
67
58
        }
68
58
    }
69

            
70
57
    pub fn set_context(&mut self, context_type: ContextType, primary_entity_type: EntityType) {
71
57
        self.context_type = context_type;
72
57
        self.primary_entity_type = primary_entity_type;
73
57
    }
74

            
75
57
    pub fn set_primary(&mut self, entity_idx: u32) {
76
57
        self.primary_entity_idx = entity_idx;
77
57
    }
78

            
79
85
    pub fn add_string(&mut self, s: &str) -> (u32, u16) {
80
85
        if let Some(&cached) = self.string_cache.get(s) {
81
1
            return cached;
82
84
        }
83
84
        let offset = self.strings_pool.len() as u32;
84
84
        let len = s.len() as u16;
85
84
        self.strings_pool.extend_from_slice(s.as_bytes());
86
84
        self.string_cache.insert(s.to_string(), (offset, len));
87
84
        (offset, len)
88
85
    }
89

            
90
1
    pub fn add_transaction(
91
1
        &mut self,
92
1
        id: [u8; 16],
93
1
        parent_idx: i32,
94
1
        is_primary: bool,
95
1
        is_context: bool,
96
1
        post_date: i64,
97
1
        enter_date: i64,
98
1
        split_count: u32,
99
1
        tag_count: u32,
100
1
        is_multi_currency: bool,
101
1
    ) -> u32 {
102
1
        let flags = EntityFlags::make(is_primary, is_context);
103
1
        let data = TransactionData {
104
1
            post_date,
105
1
            enter_date,
106
1
            split_count,
107
1
            tag_count,
108
1
            is_multi_currency: u8::from(is_multi_currency),
109
1
            reserved: [0; 23],
110
1
        };
111
1
        let header = EntityHeader::new(
112
1
            EntityType::Transaction,
113
1
            Operation::Nop,
114
1
            flags,
115
1
            id,
116
1
            parent_idx,
117
            0,
118
1
            data.to_bytes().len() as u32,
119
        );
120
1
        let idx = self.entities.len() as u32;
121
1
        self.entities.push(SerializedEntity {
122
1
            header,
123
1
            data: data.to_bytes().to_vec(),
124
1
        });
125
1
        idx
126
1
    }
127

            
128
1
    pub fn add_split(
129
1
        &mut self,
130
1
        id: [u8; 16],
131
1
        parent_idx: i32,
132
1
        is_primary: bool,
133
1
        is_context: bool,
134
1
        account_id: [u8; 16],
135
1
        commodity_id: [u8; 16],
136
1
        value_num: i64,
137
1
        value_denom: i64,
138
1
        reconcile_state: u8,
139
1
        reconcile_date: i64,
140
1
    ) -> u32 {
141
1
        let flags = EntityFlags::make(is_primary, is_context);
142
1
        let data = SplitData {
143
1
            account_id,
144
1
            commodity_id,
145
1
            value_num,
146
1
            value_denom,
147
1
            reconcile_state,
148
1
            reserved: [0; 7],
149
1
            reconcile_date,
150
1
        };
151
1
        let header = EntityHeader::new(
152
1
            EntityType::Split,
153
1
            Operation::Nop,
154
1
            flags,
155
1
            id,
156
1
            parent_idx,
157
            0,
158
1
            data.to_bytes().len() as u32,
159
        );
160
1
        let idx = self.entities.len() as u32;
161
1
        self.entities.push(SerializedEntity {
162
1
            header,
163
1
            data: data.to_bytes().to_vec(),
164
1
        });
165
1
        idx
166
1
    }
167

            
168
33
    pub fn add_tag(
169
33
        &mut self,
170
33
        id: [u8; 16],
171
33
        parent_idx: i32,
172
33
        is_primary: bool,
173
33
        is_context: bool,
174
33
        name: &str,
175
33
        value: &str,
176
33
    ) -> u32 {
177
33
        let flags = EntityFlags::make(is_primary, is_context);
178
33
        let (name_offset, name_len) = self.add_string(name);
179
33
        let (value_offset, value_len) = self.add_string(value);
180
33
        let data = TagData {
181
33
            name_offset,
182
33
            value_offset,
183
33
            name_len,
184
33
            value_len,
185
33
            reserved: [0; 4],
186
33
        };
187
33
        let header = EntityHeader::new(
188
33
            EntityType::Tag,
189
33
            Operation::Nop,
190
33
            flags,
191
33
            id,
192
33
            parent_idx,
193
            0,
194
33
            TAG_DATA_SIZE as u32,
195
        );
196
33
        let idx = self.entities.len() as u32;
197
33
        self.entities.push(SerializedEntity {
198
33
            header,
199
33
            data: data.to_bytes().to_vec(),
200
33
        });
201
33
        idx
202
33
    }
203

            
204
8
    pub fn add_account(
205
8
        &mut self,
206
8
        id: [u8; 16],
207
8
        parent_idx: i32,
208
8
        is_primary: bool,
209
8
        is_context: bool,
210
8
        parent_account_id: [u8; 16],
211
8
        name: &str,
212
8
        path: &str,
213
8
        tag_count: u32,
214
8
    ) -> u32 {
215
8
        let flags = EntityFlags::make(is_primary, is_context);
216
8
        let (name_offset, name_len) = self.add_string(name);
217
8
        let (path_offset, path_len) = self.add_string(path);
218
8
        let data = AccountData {
219
8
            parent_account_id,
220
8
            name_offset,
221
8
            path_offset,
222
8
            tag_count,
223
8
            name_len,
224
8
            path_len,
225
8
            reserved: [0; 16],
226
8
        };
227
8
        let header = EntityHeader::new(
228
8
            EntityType::Account,
229
8
            Operation::Nop,
230
8
            flags,
231
8
            id,
232
8
            parent_idx,
233
            0,
234
8
            data.to_bytes().len() as u32,
235
        );
236
8
        let idx = self.entities.len() as u32;
237
8
        self.entities.push(SerializedEntity {
238
8
            header,
239
8
            data: data.to_bytes().to_vec(),
240
8
        });
241
8
        idx
242
8
    }
243

            
244
    pub fn add_commodity(
245
        &mut self,
246
        id: [u8; 16],
247
        parent_idx: i32,
248
        is_primary: bool,
249
        is_context: bool,
250
        symbol: &str,
251
        name: &str,
252
        fraction: u32,
253
        tag_count: u32,
254
    ) -> u32 {
255
        let flags = EntityFlags::make(is_primary, is_context);
256
        let (symbol_offset, symbol_len) = self.add_string(symbol);
257
        let (name_offset, name_len) = self.add_string(name);
258
        let data = CommodityData {
259
            symbol_offset,
260
            name_offset,
261
            fraction,
262
            tag_count,
263
            symbol_len,
264
            name_len,
265
            reserved: [0; 12],
266
        };
267
        let header = EntityHeader::new(
268
            EntityType::Commodity,
269
            Operation::Nop,
270
            flags,
271
            id,
272
            parent_idx,
273
            0,
274
            data.to_bytes().len() as u32,
275
        );
276
        let idx = self.entities.len() as u32;
277
        self.entities.push(SerializedEntity {
278
            header,
279
            data: data.to_bytes().to_vec(),
280
        });
281
        idx
282
    }
283

            
284
    #[must_use]
285
1
    pub fn entity_count(&self) -> u32 {
286
1
        self.entities.len() as u32
287
1
    }
288

            
289
48
    pub fn add_transaction_from(
290
48
        &mut self,
291
48
        tx: &Transaction,
292
48
        is_primary: bool,
293
48
        split_count: u32,
294
48
        tag_count: u32,
295
48
        is_multi_currency: bool,
296
48
    ) -> u32 {
297
48
        let flags = EntityFlags::make(is_primary, false);
298
48
        let mut data = transaction_to_data(tx);
299
48
        data.split_count = split_count;
300
48
        data.tag_count = tag_count;
301
48
        data.is_multi_currency = u8::from(is_multi_currency);
302
48
        let header = EntityHeader::new(
303
48
            EntityType::Transaction,
304
48
            Operation::Nop,
305
48
            flags,
306
48
            *tx.id.as_bytes(),
307
            -1,
308
            0,
309
48
            data.to_bytes().len() as u32,
310
        );
311
48
        let idx = self.entities.len() as u32;
312
48
        self.entities.push(SerializedEntity {
313
48
            header,
314
48
            data: data.to_bytes().to_vec(),
315
48
        });
316
48
        idx
317
48
    }
318

            
319
48
    pub fn add_split_from(&mut self, split: &Split, parent_idx: i32) -> u32 {
320
48
        let data = split_to_data(split);
321
48
        let header = EntityHeader::new(
322
48
            EntityType::Split,
323
48
            Operation::Nop,
324
            0,
325
48
            *split.id.as_bytes(),
326
48
            parent_idx,
327
            0,
328
48
            data.to_bytes().len() as u32,
329
        );
330
48
        let idx = self.entities.len() as u32;
331
48
        self.entities.push(SerializedEntity {
332
48
            header,
333
48
            data: data.to_bytes().to_vec(),
334
48
        });
335
48
        idx
336
48
    }
337

            
338
    #[must_use]
339
57
    pub fn finalize(mut self, output_size: u32) -> Vec<u8> {
340
57
        let entity_count = self.entities.len() as u32;
341
57
        let entities_offset = BASE_OFFSET + GLOBAL_HEADER_SIZE as u32;
342

            
343
57
        let mut entities_total_size = 0u32;
344
139
        for entity in &self.entities {
345
139
            entities_total_size += ENTITY_HEADER_SIZE as u32 + entity.data.len() as u32;
346
139
        }
347

            
348
57
        let strings_pool_offset = entities_offset + entities_total_size;
349
57
        let strings_pool_size = self.strings_pool.len() as u32;
350
57
        let output_offset = strings_pool_offset + strings_pool_size;
351

            
352
57
        let output_header = OutputHeader::new(entity_count);
353

            
354
57
        let mut global_header = GlobalHeader::new(
355
57
            self.context_type,
356
57
            self.primary_entity_type,
357
57
            entity_count,
358
57
            self.primary_entity_idx,
359
        );
360
57
        global_header.entities_offset = entities_offset;
361
57
        global_header.strings_pool_offset = strings_pool_offset;
362
57
        global_header.strings_pool_size = strings_pool_size;
363
57
        global_header.output_offset = output_offset;
364
57
        global_header.output_size = output_size;
365

            
366
57
        let total_size = GLOBAL_HEADER_SIZE
367
57
            + entities_total_size as usize
368
57
            + strings_pool_size as usize
369
57
            + output_size as usize;
370
57
        let mut buffer = vec![0u8; total_size];
371

            
372
57
        buffer[..GLOBAL_HEADER_SIZE].copy_from_slice(global_header.as_bytes());
373

            
374
57
        let mut current_offset = entities_offset;
375
57
        let mut write_pos = GLOBAL_HEADER_SIZE;
376

            
377
139
        for entity in &mut self.entities {
378
139
            entity.header.data_offset = current_offset + ENTITY_HEADER_SIZE as u32;
379
139
            let header_bytes = entity.header.to_bytes();
380
139
            buffer[write_pos..write_pos + ENTITY_HEADER_SIZE].copy_from_slice(&header_bytes);
381
139
            write_pos += ENTITY_HEADER_SIZE;
382
139
            buffer[write_pos..write_pos + entity.data.len()].copy_from_slice(&entity.data);
383
139
            write_pos += entity.data.len();
384
139
            current_offset += ENTITY_HEADER_SIZE as u32 + entity.data.len() as u32;
385
139
        }
386

            
387
57
        buffer[write_pos..write_pos + self.strings_pool.len()].copy_from_slice(&self.strings_pool);
388
57
        write_pos += self.strings_pool.len();
389

            
390
57
        buffer[write_pos..write_pos + OUTPUT_HEADER_SIZE]
391
57
            .copy_from_slice(&output_header.to_bytes());
392

            
393
57
        buffer
394
57
    }
395
}
396

            
397
#[cfg(test)]
398
mod tests {
399
    use super::*;
400
    use crate::format::MAGIC_NOMI;
401

            
402
    #[test]
403
1
    fn test_serializer_basic() {
404
1
        let mut serializer = MemorySerializer::new();
405
1
        serializer.set_context(ContextType::EntityCreate, EntityType::Transaction);
406

            
407
1
        let tx_id = [1u8; 16];
408
1
        let tx_idx = serializer.add_transaction(tx_id, -1, true, false, 1000, 2000, 2, 1, false);
409
1
        serializer.set_primary(tx_idx);
410

            
411
1
        let split_id = [2u8; 16];
412
1
        let account_id = [3u8; 16];
413
1
        let commodity_id = [4u8; 16];
414
1
        serializer.add_split(
415
1
            split_id,
416
1
            tx_idx as i32,
417
            false,
418
            false,
419
1
            account_id,
420
1
            commodity_id,
421
            -5000,
422
            100,
423
            0,
424
            0,
425
        );
426

            
427
1
        let tag_id = [5u8; 16];
428
1
        serializer.add_tag(
429
1
            tag_id,
430
1
            tx_idx as i32,
431
            false,
432
            false,
433
1
            "note",
434
1
            "test transaction",
435
        );
436

            
437
1
        assert_eq!(serializer.entity_count(), 3);
438

            
439
1
        let buffer = serializer.finalize(1024);
440

            
441
1
        let header = GlobalHeader::from_bytes(&buffer).unwrap();
442
1
        assert_eq!(header.magic, MAGIC_NOMI);
443
1
        assert_eq!(header.input_entity_count, 3);
444
1
        assert_eq!(header.context_type, ContextType::EntityCreate as u8);
445
1
        assert_eq!(header.primary_entity_type, EntityType::Transaction as u8);
446
1
    }
447

            
448
    #[test]
449
1
    fn test_string_deduplication() {
450
1
        let mut serializer = MemorySerializer::new();
451
1
        let (offset1, len1) = serializer.add_string("test");
452
1
        let (offset2, len2) = serializer.add_string("test");
453
1
        let (offset3, _) = serializer.add_string("other");
454

            
455
1
        assert_eq!(offset1, offset2);
456
1
        assert_eq!(len1, len2);
457
1
        assert_ne!(offset1, offset3);
458
1
    }
459
}