1
use finance::{commodity::Commodity, tag::Tag};
2
use sqlx::types::Uuid;
3
use std::{collections::HashMap, fmt::Debug};
4
use supp_macro::command;
5

            
6
use super::{CmdError, CmdResult};
7
use crate::{command::FinanceEntity, config::ConfigError, user::User};
8

            
9
command! {
10
    GetCommodity {
11
        #[required]
12
        user_id: Uuid,
13
        #[required]
14
        commodity_id: Uuid,
15
    } => {
16
        let user = User { id: user_id };
17

            
18
        let mut conn = user.get_connection().await.map_err(|err| {
19
            log::error!("{}", t!("Database error: %{err}", err = err : {:?}));
20
            ConfigError::DB
21
        })?;
22

            
23
        let comm = sqlx::query_file_as!(Commodity, "sql/select/commodities/by_id.sql", &commodity_id)
24
            .fetch_one(&mut *conn)
25
            .await?;
26

            
27
        // For each commodity, get its tags
28
        let mut tagged_entities = Vec::new();
29
        let tags: HashMap<String, FinanceEntity> =
30
            sqlx::query_file!("sql/select/tags/by_commodity.sql", &commodity_id)
31
                .fetch_all(&mut *conn)
32
                .await?
33
                .into_iter()
34
2
                .map(|row| {
35
2
                    (
36
2
                        row.tag_name.clone(),
37
2
                        FinanceEntity::Tag(Tag {
38
2
                            id: row.id,
39
2
                            tag_name: row.tag_name,
40
2
                            tag_value: row.tag_value,
41
2
                            description: row.description,
42
2
                        }),
43
2
                    )
44
2
                })
45
                .collect();
46

            
47
        tagged_entities.push((FinanceEntity::Commodity(comm), tags));
48
        Ok(Some(CmdResult::TaggedEntities {
49
            entities: tagged_entities,
50
            pagination: None,
51
        }))
52
    }
53
8
}
54

            
55
command! {
56
    CreateCommodity {
57
        #[required]
58
        symbol: String,
59
        #[required]
60
        name: String,
61
        #[required]
62
        user_id: Uuid,
63
    } => {
64
        let user = User { id: user_id };
65

            
66
        Ok(Some(
67
            user.create_commodity(symbol, name)
68
                .await?
69
                .id
70
                .to_string()
71
                .into(),
72
        ))
73
    }
74
335
}
75

            
76
command! {
77
    ListCommodities {
78
        #[required]
79
        user_id: Uuid,
80
    } => {
81
        let user = User { id: user_id };
82
3
        let mut conn = user.get_connection().await.map_err(|err| {
83
3
            log::error!("{}", t!("Database error: %{err}", err = err : {:?}));
84
3
            ConfigError::DB
85
3
        })?;
86

            
87
        // Get all commodities
88
        let commodities: Vec<Commodity> = sqlx::query_file!("sql/select/commodities/all.sql")
89
            .fetch_all(&mut *conn)
90
            .await?
91
            .into_iter()
92
1
            .map(|row| Commodity { id: row.id })
93
            .collect();
94

            
95
        // For each commodity, get its tags
96
        let mut tagged_entities = Vec::new();
97
        for commodity in commodities {
98
            let tags: HashMap<String, FinanceEntity> =
99
                sqlx::query_file!("sql/select/tags/by_commodity.sql", &commodity.id)
100
                    .fetch_all(&mut *conn)
101
                    .await?
102
                    .into_iter()
103
2
                    .map(|row| {
104
2
                        (
105
2
                            row.tag_name.clone(),
106
2
                            FinanceEntity::Tag(Tag {
107
2
                                id: row.id,
108
2
                                tag_name: row.tag_name,
109
2
                                tag_value: row.tag_value,
110
2
                                description: row.description,
111
2
                            }),
112
2
                        )
113
2
                    })
114
                    .collect();
115

            
116
            tagged_entities.push((FinanceEntity::Commodity(commodity), tags));
117
        }
118
        Ok(Some(CmdResult::TaggedEntities {
119
            entities: tagged_entities,
120
            pagination: None,
121
        }))
122
    }
123
45
}
124

            
125
#[cfg(test)]
126
mod command_tests {
127
    use super::*;
128
    use crate::db::DB_POOL;
129
    use sqlx::PgPool;
130
    use supp_macro::local_db_sqlx_test;
131
    use tokio::sync::OnceCell;
132

            
133
    /// Context for keeping environment intact
134
    static CONTEXT: OnceCell<()> = OnceCell::const_new();
135
    static USER: OnceCell<User> = OnceCell::const_new();
136

            
137
3
    async fn setup() {
138
3
        CONTEXT
139
3
            .get_or_init(|| async {
140
                #[cfg(feature = "testlog")]
141
1
                let _ = env_logger::builder()
142
1
                    .is_test(true)
143
1
                    .filter_level(log::LevelFilter::Trace)
144
1
                    .try_init();
145
2
            })
146
3
            .await;
147
3
        USER.get_or_init(|| async { User { id: Uuid::new_v4() } })
148
3
            .await;
149
3
    }
150

            
151
    #[local_db_sqlx_test]
152
    async fn test_list_commodities_empty(pool: PgPool) -> anyhow::Result<()> {
153
        let user = USER.get().unwrap();
154
        user.commit()
155
            .await
156
            .expect("Failed to commit user to database");
157

            
158
        if let Some(CmdResult::TaggedEntities { entities, .. }) =
159
            ListCommodities::new().user_id(user.id).run().await?
160
        {
161
            assert!(
162
                entities.is_empty(),
163
                "Expected no commodities in empty database"
164
            );
165
        } else {
166
            panic!("Expected TaggedEntities result");
167
        }
168
    }
169

            
170
    #[local_db_sqlx_test]
171
    async fn test_list_commodities_with_data(pool: PgPool) -> anyhow::Result<()> {
172
        let user = USER.get().unwrap();
173
        user.commit()
174
            .await
175
            .expect("Failed to commit user to database");
176

            
177
        // Create a test commodity with tags
178
        CreateCommodity::new()
179
            .symbol("TST".to_string())
180
            .name("Test Commodity".to_string())
181
            .user_id(user.id)
182
            .run()
183
            .await?;
184

            
185
        // List commodities
186
        if let Some(CmdResult::TaggedEntities { entities, .. }) =
187
            ListCommodities::new().user_id(user.id).run().await?
188
        {
189
            assert_eq!(entities.len(), 1, "Expected one commodity");
190

            
191
            let (entity, tags) = &entities[0];
192
            if let FinanceEntity::Commodity(_c) = entity {
193
                // Check tags
194
                assert_eq!(tags.len(), 2); // symbol and name tags
195
                for tag in tags.values() {
196
                    if let FinanceEntity::Tag(t) = tag {
197
                        match t.tag_name.as_str() {
198
                            "symbol" => assert_eq!(t.tag_value, "TST"),
199
                            "name" => assert_eq!(t.tag_value, "Test Commodity"),
200
                            _ => panic!("Unexpected tag: {}", t.tag_name),
201
                        }
202
                    }
203
                }
204
            } else {
205
                panic!("Expected Commodity entity");
206
            }
207
        } else {
208
            panic!("Expected TaggedEntities result");
209
        }
210
    }
211

            
212
    #[local_db_sqlx_test]
213
    async fn test_get_commodity(pool: PgPool) -> anyhow::Result<()> {
214
        let user = USER.get().unwrap();
215
        user.commit()
216
            .await
217
            .expect("Failed to commit user to database");
218

            
219
        // First create a commodity
220
        let commodity_result = CreateCommodity::new()
221
            .symbol("TST".to_string())
222
            .name("Test Commodity".to_string())
223
            .user_id(user.id)
224
            .run()
225
            .await?;
226

            
227
        // Get the commodity ID
228
        let commodity_id = if let Some(CmdResult::String(id)) = commodity_result {
229
            uuid::Uuid::parse_str(&id)?
230
        } else {
231
            panic!("Expected commodity ID string result");
232
        };
233

            
234
        // Test GetCommodity command
235
        if let Some(CmdResult::TaggedEntities { entities, .. }) = GetCommodity::new()
236
            .user_id(user.id)
237
            .commodity_id(commodity_id)
238
            .run()
239
            .await?
240
        {
241
            assert_eq!(entities.len(), 1, "Expected one commodity");
242

            
243
            let (entity, tags) = &entities[0];
244
            if let FinanceEntity::Commodity(c) = entity {
245
                assert_eq!(c.id, commodity_id);
246

            
247
                // Check tags
248
                assert_eq!(tags.len(), 2); // symbol and name tags
249
                for tag in tags.values() {
250
                    if let FinanceEntity::Tag(t) = tag {
251
                        match t.tag_name.as_str() {
252
                            "symbol" => assert_eq!(t.tag_value, "TST"),
253
                            "name" => assert_eq!(t.tag_value, "Test Commodity"),
254
                            _ => panic!("Unexpected tag: {}", t.tag_name),
255
                        }
256
                    }
257
                }
258
            } else {
259
                panic!("Expected Commodity entity");
260
            }
261
        } else {
262
            panic!("Expected TaggedEntities result");
263
        }
264

            
265
        // Test with non-existent commodity ID
266
        let result = GetCommodity::new()
267
            .user_id(user.id)
268
            .commodity_id(Uuid::new_v4())
269
            .run()
270
            .await;
271
        assert!(result.is_err(), "Expected error for non-existent commodity");
272
    }
273
}