1
#[macro_use]
2
extern crate rust_i18n;
3

            
4
i18n!("locales", fallback = "en");
5

            
6
mod config;
7
mod files;
8
mod handler;
9
mod jwt_auth;
10
mod model;
11
mod pages;
12
mod redirect_middleware;
13
mod response;
14
mod route;
15
mod token;
16

            
17
use axum::{
18
    Router,
19
    extract::{MatchedPath, Request},
20
    middleware::{self, Next},
21
    response::IntoResponse,
22
    routing::get,
23
};
24

            
25
use axum::http::{
26
    HeaderValue, Method,
27
    header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE},
28
};
29

            
30
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
31
use redis::Client;
32
use std::sync::Arc;
33
use std::{future::ready, time::Instant};
34
use tower_http::{cors::CorsLayer, services::ServeDir};
35
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
36

            
37
use config::Config;
38
use route::{
39
    create_accounts_router, create_api_router, create_pages_router, create_tags_router,
40
    create_transactions_router,
41
};
42

            
43
pub struct AppState {
44
    conf: Config,
45
    redis_client: Client,
46
    frac: i64,
47
}
48

            
49
fn metrics_app() -> Router {
50
    let recorder_handle = setup_metrics_recorder();
51
    Router::new().route("/metrics", get(move || ready(recorder_handle.render())))
52
}
53

            
54
fn setup_metrics_recorder() -> PrometheusHandle {
55
    const EXPONENTIAL_SECONDS: &[f64] = &[
56
        0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0,
57
    ];
58

            
59
    PrometheusBuilder::new()
60
        .set_buckets_for_metric(
61
            Matcher::Full("http_requests_duration_seconds".to_string()),
62
            EXPONENTIAL_SECONDS,
63
        )
64
        .unwrap()
65
        .install_recorder()
66
        .unwrap()
67
}
68

            
69
async fn track_metrics(req: Request, next: Next) -> impl IntoResponse {
70
    let start = Instant::now();
71
    let path = if let Some(matched_path) = req.extensions().get::<MatchedPath>() {
72
        matched_path.as_str().to_owned()
73
    } else {
74
        req.uri().path().to_owned()
75
    };
76
    let method = req.method().clone();
77

            
78
    let response = next.run(req).await;
79

            
80
    let latency = start.elapsed().as_secs_f64();
81
    let status = response.status().as_u16().to_string();
82

            
83
    let labels = [
84
        ("method", method.to_string()),
85
        ("path", path),
86
        ("status", status),
87
    ];
88

            
89
    metrics::counter!("http_requests_total", &labels).increment(1);
90
    metrics::histogram!("http_requests_duration_seconds", &labels).record(latency);
91

            
92
    response
93
}
94

            
95
#[tokio::main]
96
async fn main() -> anyhow::Result<()> {
97
    tracing_subscriber::registry()
98
        .with(
99
            tracing_subscriber::EnvFilter::try_from_default_env()
100
                .unwrap_or_else(|_| "with_axum_htmx_askama=debug".into()),
101
        )
102
        .with(tracing_subscriber::fmt::layer())
103
        .try_init()
104
        .expect("Failed to set tracing subscriber");
105

            
106
    let conf = Config::init().await.unwrap();
107
    log::debug!("Config ready");
108

            
109
    let redis_client = Client::open(conf.redis_url.clone())
110
        .and_then(|client| {
111
            client.get_connection().map(|_| {
112
                log::debug!(
113
                    "Connection to the redis db {} successful!",
114
                    client.get_connection_info().redis.db
115
                );
116
                client
117
            })
118
        })
119
        .unwrap_or_else(|e| {
120
            log::error!("Error connecting to Redis: {e}");
121
            std::process::exit(1);
122
        });
123

            
124
    let cors = CorsLayer::new()
125
        .allow_origin(conf.site_url.parse::<HeaderValue>()?)
126
        .allow_methods([Method::GET, Method::POST, Method::PATCH, Method::DELETE])
127
        .allow_credentials(true)
128
        .allow_headers([AUTHORIZATION, ACCEPT, CONTENT_TYPE]);
129

            
130
    let state = Arc::new(AppState {
131
        conf,
132
        redis_client: redis_client.clone(),
133
        frac: 0,
134
    });
135

            
136
    let router = Router::new()
137
        .route("/", get(pages::index))
138
        .route("/register", get(pages::register))
139
        .route("/login", get(pages::login))
140
        .merge(create_pages_router(state.clone()))
141
        .merge(create_accounts_router(state.clone()))
142
        .merge(create_transactions_router(state.clone()))
143
        .merge(create_tags_router(state.clone()))
144
        .nest("/api", create_api_router(state.clone()))
145
        .nest_service(
146
            "/static",
147
            ServeDir::new(std::env::var("STATIC_PATH").unwrap_or("web/static".to_string())),
148
        )
149
        .with_state(state)
150
        .layer(cors)
151
        .route_layer(middleware::from_fn(track_metrics));
152

            
153
    let app = metrics_app();
154

            
155
    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
156
    let metrics_listener = tokio::net::TcpListener::bind("0.0.0.0:3001").await?;
157
    let (_main_server, _metrics_server) = tokio::join!(
158
        run_server(metrics_listener, app),
159
        run_server(listener, router)
160
    );
161
    Ok(())
162
}
163

            
164
async fn run_server(listener: tokio::net::TcpListener, app: Router) {
165
    axum::serve(listener, app).await.unwrap();
166
}