1
use std::sync::Arc;
2

            
3
use askama::Template;
4
use axum::Json;
5
use axum::extract::Query;
6
use axum::http::HeaderMap;
7
use axum::{Extension, extract::State, http::StatusCode, response::IntoResponse};
8
use num_rational::Rational64;
9
use serde::Deserialize;
10
use server::command::{
11
    CmdResult, FinanceEntity, ReportData, ReportNode, commodity::ListCommodities,
12
    report::BalanceReport,
13
};
14
use sqlx::types::Uuid;
15
use sqlx::types::chrono::NaiveDate;
16

            
17
use crate::{AppState, jwt_auth::JWTAuthMiddleware, pages::HtmlTemplate};
18

            
19
use super::{build_report_filter, empty_string_as_none};
20

            
21
#[derive(Template)]
22
#[template(path = "pages/report/balance.html")]
23
struct BalanceReportPage;
24

            
25
pub async fn balance_report_page() -> impl IntoResponse {
26
    HtmlTemplate(BalanceReportPage)
27
}
28

            
29
pub(super) struct CommodityOption {
30
    pub(super) id: Uuid,
31
    pub(super) symbol: String,
32
    pub(super) name: String,
33
}
34

            
35
struct ReportRowView {
36
    account_name: String,
37
    depth: usize,
38
    amounts: Vec<AmountView>,
39
}
40

            
41
struct AmountView {
42
    commodity_symbol: String,
43
    amount: Rational64,
44
}
45

            
46
fn flatten_nodes(nodes: &[ReportNode]) -> Vec<ReportRowView> {
47
    let mut rows = Vec::new();
48
    for node in nodes {
49
        rows.push(ReportRowView {
50
            account_name: node.account_name.clone(),
51
            depth: node.depth,
52
            amounts: node
53
                .amounts
54
                .iter()
55
                .map(|a| AmountView {
56
                    commodity_symbol: a.commodity_symbol.clone(),
57
                    amount: a.amount,
58
                })
59
                .collect(),
60
        });
61
        rows.extend(flatten_nodes(&node.children));
62
    }
63
    rows
64
}
65

            
66
pub(super) fn fetch_commodity_list(
67
    entities: Vec<(
68
        FinanceEntity,
69
        std::collections::HashMap<String, FinanceEntity>,
70
    )>,
71
) -> Vec<CommodityOption> {
72
    let mut commodities = Vec::new();
73
    for (entity, tags) in entities {
74
        if let FinanceEntity::Commodity(commodity) = entity
75
            && let (FinanceEntity::Tag(s), FinanceEntity::Tag(n)) = (&tags["symbol"], &tags["name"])
76
        {
77
            commodities.push(CommodityOption {
78
                id: commodity.id,
79
                symbol: s.tag_value.clone(),
80
                name: n.tag_value.clone(),
81
            });
82
        }
83
    }
84
    commodities
85
}
86

            
87
#[derive(Template)]
88
#[template(path = "components/report/balance_table.html")]
89
struct BalanceTableTemplate {
90
    commodities: Vec<CommodityOption>,
91
    rows: Vec<ReportRowView>,
92
    as_of: Option<String>,
93
    target_commodity_id: Option<String>,
94
    tag_filters: String,
95
    tag_filter_mode: String,
96
    scripting_enabled: bool,
97
}
98

            
99
#[derive(Deserialize)]
100
pub struct BalanceParams {
101
    #[serde(default, deserialize_with = "empty_string_as_none")]
102
    target_commodity_id: Option<String>,
103
    #[serde(default, deserialize_with = "empty_string_as_none")]
104
    as_of: Option<String>,
105
    #[serde(default, deserialize_with = "empty_string_as_none")]
106
    tag_filters: Option<String>,
107
    #[serde(default, deserialize_with = "empty_string_as_none")]
108
    tag_filter_mode: Option<String>,
109
}
110

            
111
pub async fn balance_report_table(
112
    Query(params): Query<BalanceParams>,
113
    State(_data): State<Arc<AppState>>,
114
    Extension(jwt_auth): Extension<JWTAuthMiddleware>,
115
    headers: HeaderMap,
116
) -> Result<impl IntoResponse, StatusCode> {
117
    let commodity_entities = ListCommodities::new()
118
        .user_id(jwt_auth.user.id)
119
        .run()
120
        .await
121
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
122

            
123
    let commodities = commodity_entities
124
        .and_then(|r| {
125
            if let CmdResult::TaggedEntities { entities, .. } = r {
126
                Some(fetch_commodity_list(entities))
127
            } else {
128
                None
129
            }
130
        })
131
        .unwrap_or_default();
132

            
133
    let mut cmd = BalanceReport::new().user_id(jwt_auth.user.id);
134

            
135
    if let Some(ref tid_str) = params.target_commodity_id
136
        && let Ok(tid) = tid_str.parse::<Uuid>()
137
    {
138
        cmd = cmd.target_commodity_id(tid);
139
    }
140

            
141
    if let Some(ref as_of_str) = params.as_of
142
        && let Ok(date) = NaiveDate::parse_from_str(as_of_str, "%Y-%m-%d")
143
    {
144
        cmd = cmd.as_of(date.and_hms_opt(23, 59, 59).unwrap().and_utc());
145
    }
146

            
147
    if let Some(filter) = build_report_filter(
148
        params.tag_filters.as_deref(),
149
        params.tag_filter_mode.as_deref(),
150
    ) {
151
        cmd = cmd.report_filter(filter);
152
    }
153

            
154
    let result = cmd
155
        .run()
156
        .await
157
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
158

            
159
    let report_data = match result {
160
        Some(CmdResult::Report(data)) => data,
161
        _ => ReportData {
162
            meta: server::command::ReportMeta {
163
                date_from: None,
164
                date_to: None,
165
                target_commodity_id: None,
166
            },
167
            periods: vec![],
168
        },
169
    };
170

            
171
    if headers
172
        .get("accept")
173
        .and_then(|v| v.to_str().ok())
174
        .is_some_and(|v| v.contains("application/json"))
175
    {
176
        return Ok(Json(report_data).into_response());
177
    }
178

            
179
    let rows: Vec<ReportRowView> = report_data
180
        .periods
181
        .iter()
182
        .flat_map(|p| flatten_nodes(&p.roots))
183
        .collect();
184

            
185
    let tag_filter_mode = params
186
        .tag_filter_mode
187
        .clone()
188
        .unwrap_or_else(|| "visual".to_string());
189

            
190
    Ok(HtmlTemplate(BalanceTableTemplate {
191
        commodities,
192
        rows,
193
        as_of: params.as_of,
194
        target_commodity_id: params.target_commodity_id,
195
        tag_filters: params.tag_filters.unwrap_or_default(),
196
        tag_filter_mode,
197
        scripting_enabled: cfg!(feature = "scripting"),
198
    })
199
    .into_response())
200
}