1
use finance::{commodity::Commodity, tag::Tag};
2
use num_rational::Rational64;
3
use sqlx::types::Uuid;
4
use std::{collections::HashMap, fmt::Debug};
5
use supp_macro::command;
6

            
7
use super::{CmdError, CmdResult};
8
use crate::{command::FinanceEntity, config::ConfigError, user::User};
9

            
10
command! {
11
    GetCommodity {
12
        #[required]
13
        user_id: Uuid,
14
        #[required]
15
        commodity_id: Uuid,
16
    } => {
17
        let user = User { id: user_id };
18

            
19
        let mut conn = user.get_connection().await.map_err(|err| {
20
            log::error!("{}", t!("Database error: %{err}", err = err : {:?}));
21
            ConfigError::DB
22
        })?;
23

            
24
        let comm = sqlx::query_file_as!(Commodity, "sql/select/commodities/by_id.sql", &commodity_id)
25
            .fetch_one(&mut *conn)
26
            .await?;
27

            
28
        // For each commodity, get its tags
29
        let mut tagged_entities = Vec::new();
30
        let tags: HashMap<String, FinanceEntity> =
31
            sqlx::query_file!("sql/select/tags/by_commodity.sql", &commodity_id)
32
                .fetch_all(&mut *conn)
33
                .await?
34
                .into_iter()
35
2
                .map(|row| {
36
2
                    (
37
2
                        row.tag_name.clone(),
38
2
                        FinanceEntity::Tag(Tag {
39
2
                            id: row.id,
40
2
                            tag_name: row.tag_name,
41
2
                            tag_value: row.tag_value,
42
2
                            description: row.description,
43
2
                        }),
44
2
                    )
45
2
                })
46
                .collect();
47

            
48
        tagged_entities.push((FinanceEntity::Commodity(comm), tags));
49
        Ok(Some(CmdResult::TaggedEntities {
50
            entities: tagged_entities,
51
            pagination: None,
52
        }))
53
    }
54
107
}
55

            
56
command! {
57
    CreateCommodity {
58
        #[required]
59
        symbol: String,
60
        #[required]
61
        name: String,
62
        #[required]
63
        user_id: Uuid,
64
    } => {
65
        let user = User { id: user_id };
66

            
67
        Ok(Some(
68
            user.create_commodity(symbol, name)
69
                .await?
70
                .id
71
                .to_string()
72
                .into(),
73
        ))
74
    }
75
2841
}
76

            
77
command! {
78
    ListCommodities {
79
        #[required]
80
        user_id: Uuid,
81
    } => {
82
        let user = User { id: user_id };
83
31
        let mut conn = user.get_connection().await.map_err(|err| {
84
31
            log::error!("{}", t!("Database error: %{err}", err = err : {:?}));
85
31
            ConfigError::DB
86
31
        })?;
87

            
88
        // Get all commodities
89
        let commodities: Vec<Commodity> = sqlx::query_file!("sql/select/commodities/all.sql")
90
            .fetch_all(&mut *conn)
91
            .await?
92
            .into_iter()
93
199
            .map(|row| Commodity { id: row.id })
94
            .collect();
95

            
96
        // For each commodity, get its tags
97
        let mut tagged_entities = Vec::new();
98
        for commodity in commodities {
99
            let tags: HashMap<String, FinanceEntity> =
100
                sqlx::query_file!("sql/select/tags/by_commodity.sql", &commodity.id)
101
                    .fetch_all(&mut *conn)
102
                    .await?
103
                    .into_iter()
104
398
                    .map(|row| {
105
398
                        (
106
398
                            row.tag_name.clone(),
107
398
                            FinanceEntity::Tag(Tag {
108
398
                                id: row.id,
109
398
                                tag_name: row.tag_name,
110
398
                                tag_value: row.tag_value,
111
398
                                description: row.description,
112
398
                            }),
113
398
                        )
114
398
                    })
115
                    .collect();
116

            
117
            tagged_entities.push((FinanceEntity::Commodity(commodity), tags));
118
        }
119
        Ok(Some(CmdResult::TaggedEntities {
120
            entities: tagged_entities,
121
            pagination: None,
122
        }))
123
    }
124
1117
}
125

            
126
// Converts a source-commodity amount into a target-commodity amount
127
// using the most recent Price row that links the two. Looks up the
128
// direct `source -> target` row first; on miss, tries the inverse
129
// `target -> source` row and inverts the ratio.
130
command! {
131
    ConvertCommodity {
132
        #[required]
133
        user_id: Uuid,
134
        #[required]
135
        amount_num: i64,
136
        #[required]
137
        amount_denom: i64,
138
        #[required]
139
        source_commodity_id: Uuid,
140
        #[required]
141
        target_commodity_id: Uuid,
142
    } => {
143
        let user = User { id: user_id };
144
        let mut conn = user.get_connection().await.map_err(|err| {
145
            log::error!("{}", t!("Database error: %{err}", err = err : {:?}));
146
            ConfigError::DB
147
        })?;
148

            
149
        if amount_denom == 0 {
150
            return Err(CmdError::Args(
151
                "convert-commodity: amount has zero denominator".to_string(),
152
            ));
153
        }
154
        let amount = Rational64::new(amount_num, amount_denom);
155

            
156
        if source_commodity_id == target_commodity_id {
157
            return Ok(Some(CmdResult::Rational(amount)));
158
        }
159

            
160
        if let Some(row) = sqlx::query_file!(
161
            "sql/select/prices/latest_between.sql",
162
            &source_commodity_id,
163
            &target_commodity_id,
164
        )
165
        .fetch_optional(&mut *conn)
166
        .await? {
167
            let price = Rational64::new(row.value_num, row.value_denom);
168
            return Ok(Some(CmdResult::Rational(amount * price)));
169
        }
170

            
171
        if let Some(row) = sqlx::query_file!(
172
            "sql/select/prices/latest_between.sql",
173
            &target_commodity_id,
174
            &source_commodity_id,
175
        )
176
        .fetch_optional(&mut *conn)
177
        .await? {
178
            if row.value_num == 0 {
179
                return Err(CmdError::Args(
180
                    "convert-commodity: inverse price has zero numerator".to_string(),
181
                ));
182
            }
183
            let inverse = Rational64::new(row.value_denom, row.value_num);
184
            return Ok(Some(CmdResult::Rational(amount * inverse)));
185
        }
186

            
187
        Err(CmdError::Args(format!(
188
            "convert-commodity: no Price row between {source_commodity_id} and {target_commodity_id}"
189
        )))
190
    }
191
360
}
192

            
193
#[cfg(test)]
194
mod command_tests {
195
    use super::*;
196
    use crate::db::DB_POOL;
197
    use sqlx::PgPool;
198
    use supp_macro::local_db_sqlx_test;
199
    use tokio::sync::OnceCell;
200

            
201
    /// Context for keeping environment intact
202
    static CONTEXT: OnceCell<()> = OnceCell::const_new();
203
    static USER: OnceCell<User> = OnceCell::const_new();
204

            
205
3
    async fn setup() {
206
3
        CONTEXT
207
3
            .get_or_init(|| async {
208
                #[cfg(feature = "testlog")]
209
1
                let _ = env_logger::builder()
210
1
                    .is_test(true)
211
1
                    .filter_level(log::LevelFilter::Trace)
212
1
                    .try_init();
213
2
            })
214
3
            .await;
215
3
        USER.get_or_init(|| async { User { id: Uuid::new_v4() } })
216
3
            .await;
217
3
    }
218

            
219
    #[local_db_sqlx_test]
220
    async fn test_list_commodities_empty(pool: PgPool) -> anyhow::Result<()> {
221
        let user = USER.get().unwrap();
222
        user.commit()
223
            .await
224
            .expect("Failed to commit user to database");
225

            
226
        if let Some(CmdResult::TaggedEntities { entities, .. }) =
227
            ListCommodities::new().user_id(user.id).run().await?
228
        {
229
            assert!(
230
                entities.is_empty(),
231
                "Expected no commodities in empty database"
232
            );
233
        } else {
234
            panic!("Expected TaggedEntities result");
235
        }
236
    }
237

            
238
    #[local_db_sqlx_test]
239
    async fn test_list_commodities_with_data(pool: PgPool) -> anyhow::Result<()> {
240
        let user = USER.get().unwrap();
241
        user.commit()
242
            .await
243
            .expect("Failed to commit user to database");
244

            
245
        // Create a test commodity with tags
246
        CreateCommodity::new()
247
            .symbol("TST".to_string())
248
            .name("Test Commodity".to_string())
249
            .user_id(user.id)
250
            .run()
251
            .await?;
252

            
253
        // List commodities
254
        if let Some(CmdResult::TaggedEntities { entities, .. }) =
255
            ListCommodities::new().user_id(user.id).run().await?
256
        {
257
            assert_eq!(entities.len(), 1, "Expected one commodity");
258

            
259
            let (entity, tags) = &entities[0];
260
            if let FinanceEntity::Commodity(_c) = entity {
261
                // Check tags
262
                assert_eq!(tags.len(), 2); // symbol and name tags
263
                for tag in tags.values() {
264
                    if let FinanceEntity::Tag(t) = tag {
265
                        match t.tag_name.as_str() {
266
                            "symbol" => assert_eq!(t.tag_value, "TST"),
267
                            "name" => assert_eq!(t.tag_value, "Test Commodity"),
268
                            _ => panic!("Unexpected tag: {}", t.tag_name),
269
                        }
270
                    }
271
                }
272
            } else {
273
                panic!("Expected Commodity entity");
274
            }
275
        } else {
276
            panic!("Expected TaggedEntities result");
277
        }
278
    }
279

            
280
    #[local_db_sqlx_test]
281
    async fn test_get_commodity(pool: PgPool) -> anyhow::Result<()> {
282
        let user = USER.get().unwrap();
283
        user.commit()
284
            .await
285
            .expect("Failed to commit user to database");
286

            
287
        // First create a commodity
288
        let commodity_result = CreateCommodity::new()
289
            .symbol("TST".to_string())
290
            .name("Test Commodity".to_string())
291
            .user_id(user.id)
292
            .run()
293
            .await?;
294

            
295
        // Get the commodity ID
296
        let commodity_id = if let Some(CmdResult::String(id)) = commodity_result {
297
            uuid::Uuid::parse_str(&id)?
298
        } else {
299
            panic!("Expected commodity ID string result");
300
        };
301

            
302
        // Test GetCommodity command
303
        if let Some(CmdResult::TaggedEntities { entities, .. }) = GetCommodity::new()
304
            .user_id(user.id)
305
            .commodity_id(commodity_id)
306
            .run()
307
            .await?
308
        {
309
            assert_eq!(entities.len(), 1, "Expected one commodity");
310

            
311
            let (entity, tags) = &entities[0];
312
            if let FinanceEntity::Commodity(c) = entity {
313
                assert_eq!(c.id, commodity_id);
314

            
315
                // Check tags
316
                assert_eq!(tags.len(), 2); // symbol and name tags
317
                for tag in tags.values() {
318
                    if let FinanceEntity::Tag(t) = tag {
319
                        match t.tag_name.as_str() {
320
                            "symbol" => assert_eq!(t.tag_value, "TST"),
321
                            "name" => assert_eq!(t.tag_value, "Test Commodity"),
322
                            _ => panic!("Unexpected tag: {}", t.tag_name),
323
                        }
324
                    }
325
                }
326
            } else {
327
                panic!("Expected Commodity entity");
328
            }
329
        } else {
330
            panic!("Expected TaggedEntities result");
331
        }
332

            
333
        // Test with non-existent commodity ID
334
        let result = GetCommodity::new()
335
            .user_id(user.id)
336
            .commodity_id(Uuid::new_v4())
337
            .run()
338
            .await;
339
        assert!(result.is_err(), "Expected error for non-existent commodity");
340
    }
341
}