1
use finance::{commodity::Commodity, tag::Tag};
2
use num_rational::Rational64;
3
use sqlx::types::Uuid;
4
use std::{collections::HashMap, fmt::Debug};
5
use supp_macro::command;
6

            
7
use super::{CmdError, CmdResult};
8
use crate::{command::FinanceEntity, config::ConfigError, user::User};
9

            
10
command! {
11
    GetCommodity {
12
        #[required]
13
        user_id: Uuid,
14
        #[required]
15
        commodity_id: Uuid,
16
    } => {
17
        let user = User { id: user_id };
18

            
19
6
        let mut conn = user.get_connection().await.map_err(|err| {
20
6
            log::error!("{}", t!("Database error: %{err}", err = err : {:?}));
21
6
            ConfigError::DB
22
6
        })?;
23

            
24
        let comm = sqlx::query_file_as!(Commodity, "sql/select/commodities/by_id.sql", &commodity_id)
25
            .fetch_one(&mut *conn)
26
            .await?;
27

            
28
        // For each commodity, get its tags
29
        let mut tagged_entities = Vec::new();
30
        let tags: HashMap<String, FinanceEntity> =
31
            sqlx::query_file!("sql/select/tags/by_commodity.sql", &commodity_id)
32
                .fetch_all(&mut *conn)
33
                .await?
34
                .into_iter()
35
4
                .map(|row| {
36
4
                    (
37
4
                        row.tag_name.clone(),
38
4
                        FinanceEntity::Tag(Tag {
39
4
                            id: row.id,
40
4
                            tag_name: row.tag_name,
41
4
                            tag_value: row.tag_value,
42
4
                            description: row.description,
43
4
                        }),
44
4
                    )
45
4
                })
46
                .collect();
47

            
48
        tagged_entities.push((FinanceEntity::Commodity(comm), tags));
49
        Ok(Some(CmdResult::TaggedEntities(tagged_entities)))
50
    }
51
76
}
52

            
53
command! {
54
    CreateCommodity {
55
        #[required]
56
        fraction: Rational64,
57
        #[required]
58
        symbol: String,
59
        #[required]
60
        name: String,
61
        #[required]
62
        user_id: Uuid,
63
    } => {
64
        let user = User { id: user_id };
65

            
66
        Ok(Some(
67
            user.create_commodity(fraction.to_integer(), symbol, name)
68
                .await?
69
                .id
70
                .to_string()
71
                .into(),
72
        ))
73
    }
74
432
}
75

            
76
command! {
77
    ListCommodities {
78
        #[required]
79
        user_id: Uuid,
80
    } => {
81
        let user = User { id: user_id };
82
6
        let mut conn = user.get_connection().await.map_err(|err| {
83
6
            log::error!("{}", t!("Database error: %{err}", err = err : {:?}));
84
6
            ConfigError::DB
85
6
        })?;
86

            
87
        // Get all commodities
88
        let commodities: Vec<Commodity> = sqlx::query_file!("sql/select/commodities/all.sql")
89
            .fetch_all(&mut *conn)
90
            .await?
91
            .into_iter()
92
            .map(|row| Commodity {
93
2
                id: row.id,
94
2
                fraction: row.fraction,
95
2
            })
96
            .collect();
97

            
98
        // For each commodity, get its tags
99
        let mut tagged_entities = Vec::new();
100
        for commodity in commodities {
101
            let tags: HashMap<String, FinanceEntity> =
102
                sqlx::query_file!("sql/select/tags/by_commodity.sql", &commodity.id)
103
                    .fetch_all(&mut *conn)
104
                    .await?
105
                    .into_iter()
106
4
                    .map(|row| {
107
4
                        (
108
4
                            row.tag_name.clone(),
109
4
                            FinanceEntity::Tag(Tag {
110
4
                                id: row.id,
111
4
                                tag_name: row.tag_name,
112
4
                                tag_value: row.tag_value,
113
4
                                description: row.description,
114
4
                            }),
115
4
                        )
116
4
                    })
117
                    .collect();
118

            
119
            tagged_entities.push((FinanceEntity::Commodity(commodity), tags));
120
        }
121
        Ok(Some(CmdResult::TaggedEntities(tagged_entities)))
122
    }
123
54
}
124

            
125
#[cfg(test)]
126
mod command_tests {
127
    use super::*;
128
    use crate::db::DB_POOL;
129
    use sqlx::PgPool;
130
    use supp_macro::local_db_sqlx_test;
131
    use tokio::sync::OnceCell;
132

            
133
    /// Context for keeping environment intact
134
    static CONTEXT: OnceCell<()> = OnceCell::const_new();
135
    static USER: OnceCell<User> = OnceCell::const_new();
136

            
137
9
    async fn setup() {
138
6
        CONTEXT
139
6
            .get_or_init(|| async {
140
                #[cfg(feature = "testlog")]
141
2
                let _ = env_logger::builder()
142
2
                    .is_test(true)
143
2
                    .filter_level(log::LevelFilter::Trace)
144
2
                    .try_init();
145
4
            })
146
6
            .await;
147
6
        USER.get_or_init(|| async { User { id: Uuid::new_v4() } })
148
6
            .await;
149
6
    }
150

            
151
    #[local_db_sqlx_test]
152
    async fn test_list_commodities_empty(pool: PgPool) -> anyhow::Result<()> {
153
        let user = USER.get().unwrap();
154
        user.commit()
155
            .await
156
            .expect("Failed to commit user to database");
157

            
158
        if let Some(CmdResult::TaggedEntities(entities)) =
159
            ListCommodities::new().user_id(user.id).run().await?
160
        {
161
            assert!(
162
                entities.is_empty(),
163
                "Expected no commodities in empty database"
164
            );
165
        } else {
166
            panic!("Expected TaggedEntities result");
167
        }
168
    }
169

            
170
    #[local_db_sqlx_test]
171
    async fn test_list_commodities_with_data(pool: PgPool) -> anyhow::Result<()> {
172
        let user = USER.get().unwrap();
173
        user.commit()
174
            .await
175
            .expect("Failed to commit user to database");
176

            
177
        // Create a test commodity with tags
178
        CreateCommodity::new()
179
            .fraction(1.into())
180
            .symbol("TST".to_string())
181
            .name("Test Commodity".to_string())
182
            .user_id(user.id)
183
            .run()
184
            .await?;
185

            
186
        // List commodities
187
        if let Some(CmdResult::TaggedEntities(entities)) =
188
            ListCommodities::new().user_id(user.id).run().await?
189
        {
190
            assert_eq!(entities.len(), 1, "Expected one commodity");
191

            
192
            let (entity, tags) = &entities[0];
193
            if let FinanceEntity::Commodity(c) = entity {
194
                assert_eq!(c.fraction, 1);
195

            
196
                // Check tags
197
                assert_eq!(tags.len(), 2); // symbol and name tags
198
                for tag in tags.values() {
199
                    if let FinanceEntity::Tag(t) = tag {
200
                        match t.tag_name.as_str() {
201
                            "symbol" => assert_eq!(t.tag_value, "TST"),
202
                            "name" => assert_eq!(t.tag_value, "Test Commodity"),
203
                            _ => panic!("Unexpected tag: {}", t.tag_name),
204
                        }
205
                    }
206
                }
207
            } else {
208
                panic!("Expected Commodity entity");
209
            }
210
        } else {
211
            panic!("Expected TaggedEntities result");
212
        }
213
    }
214

            
215
    #[local_db_sqlx_test]
216
    async fn test_get_commodity(pool: PgPool) -> anyhow::Result<()> {
217
        let user = USER.get().unwrap();
218
        user.commit()
219
            .await
220
            .expect("Failed to commit user to database");
221

            
222
        // First create a commodity
223
        let commodity_result = CreateCommodity::new()
224
            .fraction(1.into())
225
            .symbol("TST".to_string())
226
            .name("Test Commodity".to_string())
227
            .user_id(user.id)
228
            .run()
229
            .await?;
230

            
231
        // Get the commodity ID
232
        let commodity_id = if let Some(CmdResult::String(id)) = commodity_result {
233
            uuid::Uuid::parse_str(&id)?
234
        } else {
235
            panic!("Expected commodity ID string result");
236
        };
237

            
238
        // Test GetCommodity command
239
        if let Some(CmdResult::TaggedEntities(entities)) = GetCommodity::new()
240
            .user_id(user.id)
241
            .commodity_id(commodity_id)
242
            .run()
243
            .await?
244
        {
245
            assert_eq!(entities.len(), 1, "Expected one commodity");
246

            
247
            let (entity, tags) = &entities[0];
248
            if let FinanceEntity::Commodity(c) = entity {
249
                assert_eq!(c.id, commodity_id);
250
                assert_eq!(c.fraction, 1);
251

            
252
                // Check tags
253
                assert_eq!(tags.len(), 2); // symbol and name tags
254
                for tag in tags.values() {
255
                    if let FinanceEntity::Tag(t) = tag {
256
                        match t.tag_name.as_str() {
257
                            "symbol" => assert_eq!(t.tag_value, "TST"),
258
                            "name" => assert_eq!(t.tag_value, "Test Commodity"),
259
                            _ => panic!("Unexpected tag: {}", t.tag_name),
260
                        }
261
                    }
262
                }
263
            } else {
264
                panic!("Expected Commodity entity");
265
            }
266
        } else {
267
            panic!("Expected TaggedEntities result");
268
        }
269

            
270
        // Test with non-existent commodity ID
271
        let result = GetCommodity::new()
272
            .user_id(user.id)
273
            .commodity_id(Uuid::new_v4())
274
            .run()
275
            .await;
276
        assert!(result.is_err(), "Expected error for non-existent commodity");
277
    }
278
}