1
use std::sync::Arc;
2

            
3
use askama::Template;
4
use axum::Json;
5
use axum::extract::Query;
6
use axum::http::HeaderMap;
7
use axum::{Extension, extract::State, http::StatusCode, response::IntoResponse};
8
use num_rational::Rational64;
9
use serde::Deserialize;
10
use server::command::{
11
    CmdResult, ReportData, ReportNode, commodity::ListCommodities, report::TrialBalance,
12
};
13
use sqlx::types::Uuid;
14
use sqlx::types::chrono::NaiveDate;
15

            
16
use crate::{AppState, jwt_auth::JWTAuthMiddleware, pages::HtmlTemplate};
17

            
18
use super::balance::{self, CommodityOption};
19
use super::{build_report_filter, empty_string_as_none};
20

            
21
#[derive(Template)]
22
#[template(path = "pages/report/trial_balance.html")]
23
struct TrialBalanceReportPage;
24

            
25
pub async fn trial_balance_report_page() -> impl IntoResponse {
26
    HtmlTemplate(TrialBalanceReportPage)
27
}
28

            
29
struct ReportRowView {
30
    account_name: String,
31
    depth: usize,
32
    amounts: Vec<AmountView>,
33
}
34

            
35
struct AmountView {
36
    commodity_symbol: String,
37
    amount: Rational64,
38
}
39

            
40
fn flatten_nodes(nodes: &[ReportNode]) -> Vec<ReportRowView> {
41
    let mut rows = Vec::new();
42
    for node in nodes {
43
        rows.push(ReportRowView {
44
            account_name: node.account_name.clone(),
45
            depth: node.depth,
46
            amounts: node
47
                .amounts
48
                .iter()
49
                .map(|a| AmountView {
50
                    commodity_symbol: a.commodity_symbol.clone(),
51
                    amount: a.amount,
52
                })
53
                .collect(),
54
        });
55
        rows.extend(flatten_nodes(&node.children));
56
    }
57
    rows
58
}
59

            
60
#[derive(Template)]
61
#[template(path = "components/report/trial_balance_table.html")]
62
struct TrialBalanceTableTemplate {
63
    commodities: Vec<CommodityOption>,
64
    rows: Vec<ReportRowView>,
65
    date_from: Option<String>,
66
    date_to: Option<String>,
67
    target_commodity_id: Option<String>,
68
    tag_filters: String,
69
    tag_filter_mode: String,
70
    scripting_enabled: bool,
71
}
72

            
73
#[derive(Deserialize)]
74
pub struct TrialBalanceParams {
75
    #[serde(default, deserialize_with = "empty_string_as_none")]
76
    date_from: Option<String>,
77
    #[serde(default, deserialize_with = "empty_string_as_none")]
78
    date_to: Option<String>,
79
    #[serde(default, deserialize_with = "empty_string_as_none")]
80
    target_commodity_id: Option<String>,
81
    #[serde(default, deserialize_with = "empty_string_as_none")]
82
    tag_filters: Option<String>,
83
    #[serde(default, deserialize_with = "empty_string_as_none")]
84
    tag_filter_mode: Option<String>,
85
}
86

            
87
pub async fn trial_balance_report_table(
88
    Query(params): Query<TrialBalanceParams>,
89
    State(_data): State<Arc<AppState>>,
90
    Extension(jwt_auth): Extension<JWTAuthMiddleware>,
91
    headers: HeaderMap,
92
) -> Result<impl IntoResponse, StatusCode> {
93
    let commodity_entities = ListCommodities::new()
94
        .user_id(jwt_auth.user.id)
95
        .run()
96
        .await
97
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
98

            
99
    let commodities = commodity_entities
100
        .and_then(|r| {
101
            if let CmdResult::TaggedEntities { entities, .. } = r {
102
                Some(balance::fetch_commodity_list(entities))
103
            } else {
104
                None
105
            }
106
        })
107
        .unwrap_or_default();
108

            
109
    let date_from = params
110
        .date_from
111
        .as_deref()
112
        .and_then(|s| NaiveDate::parse_from_str(s, "%Y-%m-%d").ok())
113
        .map(|d| d.and_hms_opt(0, 0, 0).unwrap().and_utc());
114

            
115
    let date_to = params
116
        .date_to
117
        .as_deref()
118
        .and_then(|s| NaiveDate::parse_from_str(s, "%Y-%m-%d").ok())
119
        .map(|d| d.and_hms_opt(23, 59, 59).unwrap().and_utc());
120

            
121
    let tag_filter_mode = params
122
        .tag_filter_mode
123
        .clone()
124
        .unwrap_or_else(|| "visual".to_string());
125

            
126
    let (Some(df), Some(dt)) = (date_from, date_to) else {
127
        return Ok(HtmlTemplate(TrialBalanceTableTemplate {
128
            commodities,
129
            rows: vec![],
130
            date_from: params.date_from,
131
            date_to: params.date_to,
132
            target_commodity_id: params.target_commodity_id,
133
            tag_filters: params.tag_filters.unwrap_or_default(),
134
            tag_filter_mode,
135
            scripting_enabled: cfg!(feature = "scripting"),
136
        })
137
        .into_response());
138
    };
139

            
140
    let mut cmd = TrialBalance::new()
141
        .user_id(jwt_auth.user.id)
142
        .date_from(df)
143
        .date_to(dt);
144

            
145
    if let Some(ref tid_str) = params.target_commodity_id
146
        && let Ok(tid) = tid_str.parse::<Uuid>()
147
    {
148
        cmd = cmd.target_commodity_id(tid);
149
    }
150

            
151
    if let Some(filter) = build_report_filter(
152
        params.tag_filters.as_deref(),
153
        Some(tag_filter_mode.as_str()),
154
    ) {
155
        cmd = cmd.report_filter(filter);
156
    }
157

            
158
    let result = cmd
159
        .run()
160
        .await
161
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
162

            
163
    let report_data = match result {
164
        Some(CmdResult::Report(data)) => data,
165
        _ => ReportData {
166
            meta: server::command::ReportMeta {
167
                date_from: None,
168
                date_to: None,
169
                target_commodity_id: None,
170
            },
171
            periods: vec![],
172
        },
173
    };
174

            
175
    if headers
176
        .get("accept")
177
        .and_then(|v| v.to_str().ok())
178
        .is_some_and(|v| v.contains("application/json"))
179
    {
180
        return Ok(Json(report_data).into_response());
181
    }
182

            
183
    let rows: Vec<ReportRowView> = report_data
184
        .periods
185
        .iter()
186
        .flat_map(|p| flatten_nodes(&p.roots))
187
        .collect();
188

            
189
    Ok(HtmlTemplate(TrialBalanceTableTemplate {
190
        commodities,
191
        rows,
192
        date_from: params.date_from,
193
        date_to: params.date_to,
194
        target_commodity_id: params.target_commodity_id,
195
        tag_filters: params.tag_filters.unwrap_or_default(),
196
        tag_filter_mode,
197
        scripting_enabled: cfg!(feature = "scripting"),
198
    })
199
    .into_response())
200
}