1
use finance::{commodity::Commodity, tag::Tag};
2
use num_rational::Rational64;
3
use sqlx::types::Uuid;
4
use std::{collections::HashMap, fmt::Debug};
5
use supp_macro::command;
6

            
7
use super::{CmdError, CmdResult};
8
use crate::{command::FinanceEntity, config::ConfigError, user::User};
9

            
10
command! {
11
    GetCommodity {
12
        #[required]
13
        user_id: Uuid,
14
        #[required]
15
        commodity_id: Uuid,
16
    } => {
17
        let user = User { id: user_id };
18

            
19
3
        let mut conn = user.get_connection().await.map_err(|err| {
20
3
            log::error!("{}", t!("Database error: %{err}", err = err : {:?}));
21
3
            ConfigError::DB
22
3
        })?;
23

            
24
        let comm = sqlx::query_file_as!(Commodity, "sql/select/commodities/by_id.sql", &commodity_id)
25
            .fetch_one(&mut *conn)
26
            .await?;
27

            
28
        // For each commodity, get its tags
29
        let mut tagged_entities = Vec::new();
30
        let tags: HashMap<String, FinanceEntity> =
31
            sqlx::query_file!("sql/select/tags/by_commodity.sql", &commodity_id)
32
                .fetch_all(&mut *conn)
33
                .await?
34
                .into_iter()
35
2
                .map(|row| {
36
2
                    (
37
2
                        row.tag_name.clone(),
38
2
                        FinanceEntity::Tag(Tag {
39
2
                            id: row.id,
40
2
                            tag_name: row.tag_name,
41
2
                            tag_value: row.tag_value,
42
2
                            description: row.description,
43
2
                        }),
44
2
                    )
45
2
                })
46
                .collect();
47

            
48
        tagged_entities.push((FinanceEntity::Commodity(comm), tags));
49
        Ok(Some(CmdResult::TaggedEntities {
50
            entities: tagged_entities,
51
            pagination: None,
52
        }))
53
    }
54
65
}
55

            
56
command! {
57
    CreateCommodity {
58
        #[required]
59
        fraction: Rational64,
60
        #[required]
61
        symbol: String,
62
        #[required]
63
        name: String,
64
        #[required]
65
        user_id: Uuid,
66
    } => {
67
        let user = User { id: user_id };
68

            
69
        Ok(Some(
70
            user.create_commodity(fraction.to_integer(), symbol, name)
71
                .await?
72
                .id
73
                .to_string()
74
                .into(),
75
        ))
76
    }
77
297
}
78

            
79
command! {
80
    ListCommodities {
81
        #[required]
82
        user_id: Uuid,
83
    } => {
84
        let user = User { id: user_id };
85
3
        let mut conn = user.get_connection().await.map_err(|err| {
86
3
            log::error!("{}", t!("Database error: %{err}", err = err : {:?}));
87
3
            ConfigError::DB
88
3
        })?;
89

            
90
        // Get all commodities
91
        let commodities: Vec<Commodity> = sqlx::query_file!("sql/select/commodities/all.sql")
92
            .fetch_all(&mut *conn)
93
            .await?
94
            .into_iter()
95
            .map(|row| Commodity {
96
1
                id: row.id,
97
1
                fraction: row.fraction,
98
1
            })
99
            .collect();
100

            
101
        // For each commodity, get its tags
102
        let mut tagged_entities = Vec::new();
103
        for commodity in commodities {
104
            let tags: HashMap<String, FinanceEntity> =
105
                sqlx::query_file!("sql/select/tags/by_commodity.sql", &commodity.id)
106
                    .fetch_all(&mut *conn)
107
                    .await?
108
                    .into_iter()
109
2
                    .map(|row| {
110
2
                        (
111
2
                            row.tag_name.clone(),
112
2
                            FinanceEntity::Tag(Tag {
113
2
                                id: row.id,
114
2
                                tag_name: row.tag_name,
115
2
                                tag_value: row.tag_value,
116
2
                                description: row.description,
117
2
                            }),
118
2
                        )
119
2
                    })
120
                    .collect();
121

            
122
            tagged_entities.push((FinanceEntity::Commodity(commodity), tags));
123
        }
124
        Ok(Some(CmdResult::TaggedEntities {
125
            entities: tagged_entities,
126
            pagination: None,
127
        }))
128
    }
129
45
}
130

            
131
#[cfg(test)]
132
mod command_tests {
133
    use super::*;
134
    use crate::db::DB_POOL;
135
    use sqlx::PgPool;
136
    use supp_macro::local_db_sqlx_test;
137
    use tokio::sync::OnceCell;
138

            
139
    /// Context for keeping environment intact
140
    static CONTEXT: OnceCell<()> = OnceCell::const_new();
141
    static USER: OnceCell<User> = OnceCell::const_new();
142

            
143
3
    async fn setup() {
144
3
        CONTEXT
145
3
            .get_or_init(|| async {
146
                #[cfg(feature = "testlog")]
147
1
                let _ = env_logger::builder()
148
1
                    .is_test(true)
149
1
                    .filter_level(log::LevelFilter::Trace)
150
1
                    .try_init();
151
2
            })
152
3
            .await;
153
3
        USER.get_or_init(|| async { User { id: Uuid::new_v4() } })
154
3
            .await;
155
3
    }
156

            
157
    #[local_db_sqlx_test]
158
    async fn test_list_commodities_empty(pool: PgPool) -> anyhow::Result<()> {
159
        let user = USER.get().unwrap();
160
        user.commit()
161
            .await
162
            .expect("Failed to commit user to database");
163

            
164
        if let Some(CmdResult::TaggedEntities { entities, .. }) =
165
            ListCommodities::new().user_id(user.id).run().await?
166
        {
167
            assert!(
168
                entities.is_empty(),
169
                "Expected no commodities in empty database"
170
            );
171
        } else {
172
            panic!("Expected TaggedEntities result");
173
        }
174
    }
175

            
176
    #[local_db_sqlx_test]
177
    async fn test_list_commodities_with_data(pool: PgPool) -> anyhow::Result<()> {
178
        let user = USER.get().unwrap();
179
        user.commit()
180
            .await
181
            .expect("Failed to commit user to database");
182

            
183
        // Create a test commodity with tags
184
        CreateCommodity::new()
185
            .fraction(1.into())
186
            .symbol("TST".to_string())
187
            .name("Test Commodity".to_string())
188
            .user_id(user.id)
189
            .run()
190
            .await?;
191

            
192
        // List commodities
193
        if let Some(CmdResult::TaggedEntities { entities, .. }) =
194
            ListCommodities::new().user_id(user.id).run().await?
195
        {
196
            assert_eq!(entities.len(), 1, "Expected one commodity");
197

            
198
            let (entity, tags) = &entities[0];
199
            if let FinanceEntity::Commodity(c) = entity {
200
                assert_eq!(c.fraction, 1);
201

            
202
                // Check tags
203
                assert_eq!(tags.len(), 2); // symbol and name tags
204
                for tag in tags.values() {
205
                    if let FinanceEntity::Tag(t) = tag {
206
                        match t.tag_name.as_str() {
207
                            "symbol" => assert_eq!(t.tag_value, "TST"),
208
                            "name" => assert_eq!(t.tag_value, "Test Commodity"),
209
                            _ => panic!("Unexpected tag: {}", t.tag_name),
210
                        }
211
                    }
212
                }
213
            } else {
214
                panic!("Expected Commodity entity");
215
            }
216
        } else {
217
            panic!("Expected TaggedEntities result");
218
        }
219
    }
220

            
221
    #[local_db_sqlx_test]
222
    async fn test_get_commodity(pool: PgPool) -> anyhow::Result<()> {
223
        let user = USER.get().unwrap();
224
        user.commit()
225
            .await
226
            .expect("Failed to commit user to database");
227

            
228
        // First create a commodity
229
        let commodity_result = CreateCommodity::new()
230
            .fraction(1.into())
231
            .symbol("TST".to_string())
232
            .name("Test Commodity".to_string())
233
            .user_id(user.id)
234
            .run()
235
            .await?;
236

            
237
        // Get the commodity ID
238
        let commodity_id = if let Some(CmdResult::String(id)) = commodity_result {
239
            uuid::Uuid::parse_str(&id)?
240
        } else {
241
            panic!("Expected commodity ID string result");
242
        };
243

            
244
        // Test GetCommodity command
245
        if let Some(CmdResult::TaggedEntities { entities, .. }) = GetCommodity::new()
246
            .user_id(user.id)
247
            .commodity_id(commodity_id)
248
            .run()
249
            .await?
250
        {
251
            assert_eq!(entities.len(), 1, "Expected one commodity");
252

            
253
            let (entity, tags) = &entities[0];
254
            if let FinanceEntity::Commodity(c) = entity {
255
                assert_eq!(c.id, commodity_id);
256
                assert_eq!(c.fraction, 1);
257

            
258
                // Check tags
259
                assert_eq!(tags.len(), 2); // symbol and name tags
260
                for tag in tags.values() {
261
                    if let FinanceEntity::Tag(t) = tag {
262
                        match t.tag_name.as_str() {
263
                            "symbol" => assert_eq!(t.tag_value, "TST"),
264
                            "name" => assert_eq!(t.tag_value, "Test Commodity"),
265
                            _ => panic!("Unexpected tag: {}", t.tag_name),
266
                        }
267
                    }
268
                }
269
            } else {
270
                panic!("Expected Commodity entity");
271
            }
272
        } else {
273
            panic!("Expected TaggedEntities result");
274
        }
275

            
276
        // Test with non-existent commodity ID
277
        let result = GetCommodity::new()
278
            .user_id(user.id)
279
            .commodity_id(Uuid::new_v4())
280
            .run()
281
            .await;
282
        assert!(result.is_err(), "Expected error for non-existent commodity");
283
    }
284
}