Skip to main content

web/
jwt_auth.rs

1use std::sync::Arc;
2
3use axum::{
4    Extension, Json,
5    body::Body,
6    extract::State,
7    http::{HeaderValue, Request, StatusCode, header},
8    middleware::Next,
9    response::{IntoResponse, Response},
10};
11
12use axum_extra::extract::{
13    CookieJar,
14    cookie::{Cookie, SameSite},
15};
16use serde::{Deserialize, Serialize};
17
18use crate::{AppState, model::User, token};
19use redis::AsyncCommands;
20use server::db::get_connection;
21
22#[derive(Debug, Serialize)]
23pub struct ErrorResponse {
24    pub status: &'static str,
25    pub message: String,
26}
27
28#[derive(Debug, Serialize, Deserialize, Clone)]
29pub struct JWTAuthMiddleware {
30    pub user: User,
31    pub access_token_uuid: uuid::Uuid,
32}
33
34type AuthError = (StatusCode, Json<ErrorResponse>);
35
36fn unauthorized(message: impl Into<String>) -> AuthError {
37    (
38        StatusCode::UNAUTHORIZED,
39        Json(ErrorResponse {
40            status: "fail",
41            message: message.into(),
42        }),
43    )
44}
45
46fn internal(message: impl Into<String>) -> AuthError {
47    (
48        StatusCode::INTERNAL_SERVER_ERROR,
49        Json(ErrorResponse {
50            status: "error",
51            message: message.into(),
52        }),
53    )
54}
55
56fn extract_access_token(cookie_jar: &CookieJar, req: &Request<Body>) -> Option<String> {
57    cookie_jar
58        .get("access_token")
59        .map(|cookie| cookie.value().to_string())
60        .or_else(|| {
61            req.headers()
62                .get(header::AUTHORIZATION)
63                .and_then(|auth_header| auth_header.to_str().ok())
64                .and_then(|auth_value| {
65                    auth_value
66                        .strip_prefix("Bearer ")
67                        .map(std::borrow::ToOwned::to_owned)
68                })
69        })
70}
71
72async fn resolve_access(
73    data: &Arc<AppState>,
74    access_token: Option<&str>,
75) -> Option<(uuid::Uuid, uuid::Uuid)> {
76    let token_str = access_token?;
77    let details = token::verify_jwt_token(&data.conf.access_token_public_key, token_str).ok()?;
78    let mut redis = data
79        .redis_client
80        .get_multiplexed_async_connection()
81        .await
82        .ok()?;
83    let user_id_str: String = redis.get(details.token_uuid.to_string()).await.ok()?;
84    let user_id = uuid::Uuid::parse_str(&user_id_str).ok()?;
85    Some((user_id, details.token_uuid))
86}
87
88struct RefreshOutcome {
89    user_id: uuid::Uuid,
90    access_token_uuid: uuid::Uuid,
91    set_cookies: Vec<HeaderValue>,
92}
93
94async fn try_refresh_session(
95    data: &Arc<AppState>,
96    cookie_jar: &CookieJar,
97) -> Result<RefreshOutcome, AuthError> {
98    let refresh_token = cookie_jar
99        .get("refresh_token")
100        .map(|c| c.value().to_string())
101        .ok_or_else(|| unauthorized("You are not logged in, please provide token"))?;
102
103    let refresh_details =
104        token::verify_jwt_token(&data.conf.refresh_token_public_key, &refresh_token)
105            .map_err(|e| unauthorized(format!("{e:?}")))?;
106
107    let mut redis_client = data
108        .redis_client
109        .get_multiplexed_async_connection()
110        .await
111        .map_err(|e| internal(format!("Redis error: {e}")))?;
112
113    let user_id_str: String = redis_client
114        .get(refresh_details.token_uuid.to_string())
115        .await
116        .map_err(|_| unauthorized("Token is invalid or session has expired"))?;
117
118    let user_id = uuid::Uuid::parse_str(&user_id_str)
119        .map_err(|_| unauthorized("Token is invalid or session has expired"))?;
120
121    let access_details = token::generate_jwt_token(
122        user_id,
123        data.conf.access_token_max_age,
124        &data.conf.access_token_private_key,
125    )
126    .map_err(|e| internal(format!("error generating token: {e}")))?;
127
128    let new_refresh_details = token::generate_jwt_token_with_uuid(
129        user_id,
130        refresh_details.token_uuid,
131        data.conf.refresh_token_max_age,
132        &data.conf.refresh_token_private_key,
133    )
134    .map_err(|e| internal(format!("error generating token: {e}")))?;
135
136    let _: bool = redis_client
137        .set_ex(
138            access_details.token_uuid.to_string(),
139            user_id.to_string(),
140            (data.conf.access_token_max_age * 60) as u64,
141        )
142        .await
143        .map_err(|e| internal(format!("Redis error: {e}")))?;
144
145    let _: bool = redis_client
146        .expire(
147            refresh_details.token_uuid.to_string(),
148            data.conf.refresh_token_max_age * 60,
149        )
150        .await
151        .map_err(|e| internal(format!("Redis error: {e}")))?;
152
153    let access_cookie = Cookie::build((
154        "access_token",
155        access_details.token.clone().unwrap_or_default(),
156    ))
157    .path("/")
158    .max_age(time::Duration::minutes(data.conf.access_token_max_age * 60))
159    .same_site(SameSite::Lax)
160    .http_only(true)
161    .build();
162
163    let refresh_cookie = Cookie::build((
164        "refresh_token",
165        new_refresh_details.token.unwrap_or_default(),
166    ))
167    .path("/")
168    .max_age(time::Duration::minutes(
169        data.conf.refresh_token_max_age * 60,
170    ))
171    .same_site(SameSite::Lax)
172    .http_only(true)
173    .build();
174
175    let logged_in_cookie = Cookie::build(("logged_in", "true"))
176        .path("/")
177        .max_age(time::Duration::minutes(data.conf.access_token_max_age * 60))
178        .same_site(SameSite::Lax)
179        .http_only(false)
180        .build();
181
182    let set_cookies = [access_cookie, refresh_cookie, logged_in_cookie]
183        .iter()
184        .map(|c| c.to_string().parse().unwrap())
185        .collect();
186
187    Ok(RefreshOutcome {
188        user_id,
189        access_token_uuid: access_details.token_uuid,
190        set_cookies,
191    })
192}
193
194pub async fn auth(
195    cookie_jar: CookieJar,
196    State(data): State<Arc<AppState>>,
197    mut req: Request<Body>,
198    next: Next,
199) -> Result<Response, AuthError> {
200    let access_token = extract_access_token(&cookie_jar, &req);
201
202    let (user_id, access_token_uuid, refreshed_cookies) =
203        match resolve_access(&data, access_token.as_deref()).await {
204            Some((uid, atid)) => (uid, atid, Vec::new()),
205            None => {
206                let outcome = try_refresh_session(&data, &cookie_jar).await?;
207                (
208                    outcome.user_id,
209                    outcome.access_token_uuid,
210                    outcome.set_cookies,
211                )
212            }
213        };
214
215    let mut conn = get_connection()
216        .await
217        .map_err(|e| internal(format!("DB error: {e}")))?;
218
219    let user = sqlx::query_as!(User, "SELECT id, user_name as name, email, user_password as password, user_role as role, photo, verified, db_name as database, created_at, updated_at FROM users WHERE id = $1", user_id)
220        .fetch_optional(&mut *conn)
221        .await
222        .map_err(|e| internal(format!("Error fetching user from database: {e}")))?
223        .ok_or_else(|| unauthorized("The user belonging to this token no longer exists"))?;
224
225    if !user.verified {
226        return Err(unauthorized("The user is not verified yet"));
227    }
228
229    req.extensions_mut().insert(JWTAuthMiddleware {
230        user,
231        access_token_uuid,
232    });
233
234    let mut response = next.run(req).await;
235    for cookie in refreshed_cookies {
236        response.headers_mut().append(header::SET_COOKIE, cookie);
237    }
238    Ok(response)
239}
240
241pub async fn admin(
242    Extension(jwtauth): Extension<JWTAuthMiddleware>,
243    req: Request<Body>,
244    next: Next,
245) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
246    if jwtauth.user.role != "admin" {
247        let error_response = ErrorResponse {
248            status: "fail",
249            message: "The user does not have admin permissions".to_string(),
250        };
251        return Err((StatusCode::UNAUTHORIZED, Json(error_response)));
252    }
253
254    Ok(next.run(req).await)
255}