1
use chrono::{DateTime, Utc};
2
use num_rational::Rational64;
3
use sqlx::types::Uuid;
4
use std::collections::HashMap;
5
use supp_macro::command;
6

            
7
use crate::{config::ConfigError, user::User};
8

            
9
use super::super::{
10
    BreakdownData, BreakdownPeriod, BreakdownRow, BreakdownSort, CmdError, CmdResult,
11
    CommodityAmount, FilterEntity, PeriodGrouping, ReportFilter, ReportMeta, UNCATEGORIZED_KEY,
12
};
13
use super::fetch::{
14
    BreakdownSplit, fetch_date_range_breakdown_filtered_no_conversion,
15
    fetch_date_range_breakdown_filtered_with_conversion, fetch_target_symbol,
16
};
17
use super::period::{
18
    generate_month_boundaries, generate_quarter_boundaries, generate_year_boundaries,
19
};
20

            
21
const DEFAULT_TAG_NAME: &str = "category";
22

            
23
/// Default scope filter: only count splits that live on income or expense
24
/// accounts. Without this, the balancing asset/liability leg of every
25
/// transaction lands in `Uncategorized` and every balanced transaction
26
/// sums to zero across the breakdown.
27
///
28
// TODO(script-classifier): mirror ActivityReport's future script hook —
29
// swap this fixed TagIn for a user-supplied classifier expression.
30
12
fn default_scope_filter() -> ReportFilter {
31
12
    ReportFilter::TagIn {
32
12
        entity: FilterEntity::Account,
33
12
        name: "type".to_owned(),
34
12
        values: vec!["income".to_owned(), "expense".to_owned()],
35
12
    }
36
12
}
37

            
38
12
fn compose_scope(user_filter: Option<&ReportFilter>) -> ReportFilter {
39
12
    match user_filter {
40
1
        Some(f) => ReportFilter::And(vec![default_scope_filter(), f.clone()]),
41
11
        None => default_scope_filter(),
42
    }
43
12
}
44

            
45
1
fn boundaries_for(
46
1
    grouping: PeriodGrouping,
47
1
    from: DateTime<Utc>,
48
1
    to: DateTime<Utc>,
49
1
) -> Vec<(String, DateTime<Utc>, DateTime<Utc>)> {
50
1
    match grouping {
51
1
        PeriodGrouping::Month => generate_month_boundaries(from, to),
52
        PeriodGrouping::Quarter => generate_quarter_boundaries(from, to),
53
        PeriodGrouping::Year => generate_year_boundaries(from, to),
54
    }
55
1
}
56

            
57
13
async fn fetch_splits(
58
13
    conn: &mut sqlx::PgConnection,
59
13
    target_commodity_id: Option<Uuid>,
60
13
    target_symbol: Option<&str>,
61
13
    from: DateTime<Utc>,
62
13
    to: DateTime<Utc>,
63
13
    tag_name: &str,
64
13
    filter: &ReportFilter,
65
13
) -> Result<Vec<BreakdownSplit>, CmdError> {
66
13
    match (target_commodity_id, target_symbol) {
67
2
        (Some(tid), Some(sym)) => {
68
2
            fetch_date_range_breakdown_filtered_with_conversion(
69
2
                conn, tid, sym, from, to, tag_name, filter,
70
2
            )
71
2
            .await
72
        }
73
        _ => {
74
11
            fetch_date_range_breakdown_filtered_no_conversion(conn, from, to, tag_name, filter)
75
11
                .await
76
        }
77
    }
78
13
}
79

            
80
type AmountByCommodity = HashMap<Uuid, (Rational64, String)>;
81

            
82
13
fn aggregate(splits: Vec<BreakdownSplit>, include_uncategorized: bool) -> Vec<BreakdownRow> {
83
13
    let mut by_tag: HashMap<String, AmountByCommodity> = HashMap::new();
84

            
85
23
    for s in splits {
86
23
        let key = s.category.unwrap_or_else(|| UNCATEGORIZED_KEY.to_owned());
87
23
        by_tag
88
23
            .entry(key)
89
23
            .or_default()
90
23
            .entry(s.commodity_id)
91
23
            .and_modify(|(sum, _)| *sum += s.value)
92
23
            .or_insert_with(|| (s.value, s.commodity_symbol));
93
    }
94

            
95
13
    by_tag
96
13
        .into_iter()
97
22
        .filter(|(k, _)| include_uncategorized || k != UNCATEGORIZED_KEY)
98
21
        .map(|(tag_value, amounts)| {
99
21
            let is_uncategorized = tag_value == UNCATEGORIZED_KEY;
100
21
            let mut amounts: Vec<CommodityAmount> = amounts
101
21
                .into_iter()
102
21
                .map(
103
                    |(commodity_id, (amount, commodity_symbol))| CommodityAmount {
104
21
                        commodity_id,
105
21
                        commodity_symbol,
106
21
                        amount,
107
21
                    },
108
                )
109
21
                .collect();
110
21
            amounts.sort_by_key(|a| a.commodity_id);
111
21
            BreakdownRow {
112
21
                tag_value,
113
21
                is_uncategorized,
114
21
                amounts,
115
21
            }
116
21
        })
117
13
        .collect()
118
13
}
119

            
120
16
fn row_sort_key(row: &BreakdownRow, target_commodity_id: Option<Uuid>) -> Rational64 {
121
16
    match target_commodity_id {
122
12
        Some(tid) => row
123
12
            .amounts
124
12
            .iter()
125
12
            .find(|a| a.commodity_id == tid)
126
12
            .map_or(Rational64::new(0, 1), |a| a.amount),
127
4
        None => row
128
4
            .amounts
129
4
            .iter()
130
4
            .map(|a| {
131
4
                if a.amount < Rational64::new(0, 1) {
132
                    -a.amount
133
                } else {
134
4
                    a.amount
135
                }
136
4
            })
137
4
            .fold(Rational64::new(0, 1), |acc, v| acc + v),
138
    }
139
16
}
140

            
141
13
fn sort_rows(rows: &mut [BreakdownRow], order: BreakdownSort, target_commodity_id: Option<Uuid>) {
142
13
    match order {
143
11
        BreakdownSort::AmountDesc => rows.sort_by(|a, b| {
144
5
            row_sort_key(b, target_commodity_id).cmp(&row_sort_key(a, target_commodity_id))
145
5
        }),
146
3
        BreakdownSort::AmountAsc => rows.sort_by(|a, b| {
147
3
            row_sort_key(a, target_commodity_id).cmp(&row_sort_key(b, target_commodity_id))
148
3
        }),
149
3
        BreakdownSort::NameAsc => rows.sort_by(|a, b| a.tag_value.cmp(&b.tag_value)),
150
        BreakdownSort::NameDesc => rows.sort_by(|a, b| b.tag_value.cmp(&a.tag_value)),
151
    }
152
13
}
153

            
154
command! {
155
    CategoryBreakdown {
156
        #[required]
157
        user_id: Uuid,
158
        #[required]
159
        date_from: DateTime<Utc>,
160
        #[required]
161
        date_to: DateTime<Utc>,
162
        #[optional]
163
        tag_name: String,
164
        #[optional]
165
        target_commodity_id: Uuid,
166
        #[optional]
167
        period_grouping: PeriodGrouping,
168
        #[optional]
169
        report_filter: ReportFilter,
170
        #[optional]
171
        sort_order: BreakdownSort,
172
        #[optional]
173
        include_uncategorized: bool,
174
        // TODO(result-script): accept a nomiscript expression to filter/reshape
175
        // BreakdownRow entries before returning. The format is not yet
176
        // designed; see doc/reporting.org "Future work" for the intent.
177
        #[optional]
178
        result_script: String,
179
    } => {
180
        let user = User { id: user_id };
181
        let mut conn = user.get_connection().await.map_err(|err| {
182
            log::error!("{}", t!("Database error: %{err}", err = err : {:?}));
183
            ConfigError::DB
184
        })?;
185

            
186
11
        let tag_name = tag_name.unwrap_or_else(|| DEFAULT_TAG_NAME.to_owned());
187
        let include_uncategorized = include_uncategorized.unwrap_or(true);
188
        let sort_order = sort_order.unwrap_or_default();
189
        let _ = result_script;
190
        let scoped_filter = compose_scope(report_filter.as_ref());
191

            
192
        let target_symbol = match target_commodity_id {
193
            Some(tid) => Some(fetch_target_symbol(&mut conn, tid).await?),
194
            None => None,
195
        };
196

            
197
        let period_boundaries = period_grouping
198
1
            .map(|g| boundaries_for(g, date_from, date_to))
199
            .unwrap_or_default();
200

            
201
        let mut periods = if period_boundaries.is_empty() {
202
            let splits = fetch_splits(
203
                &mut conn,
204
                target_commodity_id,
205
                target_symbol.as_deref(),
206
                date_from,
207
                date_to,
208
                &tag_name,
209
                &scoped_filter,
210
            )
211
            .await?;
212
            let rows = aggregate(splits, include_uncategorized);
213
            vec![BreakdownPeriod { label: None, rows }]
214
        } else {
215
            let mut out = Vec::with_capacity(period_boundaries.len());
216
            for (label, pfrom, pto) in &period_boundaries {
217
                let splits = fetch_splits(
218
                    &mut conn,
219
                    target_commodity_id,
220
                    target_symbol.as_deref(),
221
                    *pfrom,
222
                    *pto,
223
                    &tag_name,
224
                    &scoped_filter,
225
                )
226
                .await?;
227
                let rows = aggregate(splits, include_uncategorized);
228
                out.push(BreakdownPeriod {
229
                    label: Some(label.clone()),
230
                    rows,
231
                });
232
            }
233
            out
234
        };
235

            
236
        for period in &mut periods {
237
            sort_rows(&mut period.rows, sort_order, target_commodity_id);
238
        }
239

            
240
        Ok(Some(CmdResult::Breakdown(BreakdownData {
241
            meta: ReportMeta {
242
                date_from: Some(date_from),
243
                date_to: Some(date_to),
244
                target_commodity_id,
245
            },
246
            tag_name,
247
            periods,
248
        })))
249
    }
250
78
}