Skip to main content

web/
jwt_auth.rs

1use std::sync::Arc;
2
3use axum::{
4    Extension, Json,
5    body::Body,
6    extract::State,
7    http::{HeaderValue, Request, StatusCode, header},
8    middleware::Next,
9    response::{IntoResponse, Response},
10};
11
12use axum_extra::extract::{
13    CookieJar,
14    cookie::{Cookie, SameSite},
15};
16use serde::{Deserialize, Serialize};
17
18use crate::{AppState, model::User};
19use redis::AsyncCommands;
20use server::db::get_connection;
21
22#[derive(Debug, Serialize)]
23pub struct ErrorResponse {
24    pub status: &'static str,
25    pub message: String,
26}
27
28#[derive(Debug, Serialize, Deserialize, Clone)]
29pub struct JWTAuthMiddleware {
30    pub user: User,
31    pub access_token_uuid: uuid::Uuid,
32}
33
34type AuthError = (StatusCode, Json<ErrorResponse>);
35
36fn unauthorized(message: impl Into<String>) -> AuthError {
37    (
38        StatusCode::UNAUTHORIZED,
39        Json(ErrorResponse {
40            status: "fail",
41            message: message.into(),
42        }),
43    )
44}
45
46fn internal(message: impl Into<String>) -> AuthError {
47    (
48        StatusCode::INTERNAL_SERVER_ERROR,
49        Json(ErrorResponse {
50            status: "error",
51            message: message.into(),
52        }),
53    )
54}
55
56fn extract_access_token(cookie_jar: &CookieJar, req: &Request<Body>) -> Option<String> {
57    cookie_jar
58        .get("access_token")
59        .map(|cookie| cookie.value().to_string())
60        .or_else(|| {
61            req.headers()
62                .get(header::AUTHORIZATION)
63                .and_then(|auth_header| auth_header.to_str().ok())
64                .and_then(|auth_value| {
65                    auth_value
66                        .strip_prefix("Bearer ")
67                        .map(std::borrow::ToOwned::to_owned)
68                })
69        })
70}
71
72async fn resolve_access(
73    data: &Arc<AppState>,
74    access_token: Option<&str>,
75) -> Option<(uuid::Uuid, uuid::Uuid)> {
76    let token_str = access_token?;
77    let details = crate::auth_keys::verify(token_str, crate::token::TokenType::Access).await?;
78    let mut redis = data
79        .redis_client
80        .get_multiplexed_async_connection()
81        .await
82        .ok()?;
83    let user_id_str: String = redis.get(details.token_uuid.to_string()).await.ok()?;
84    let user_id = uuid::Uuid::parse_str(&user_id_str).ok()?;
85    // Bind the verified JWT subject to the Redis session owner: the token was
86    // signed by `details.user_id`'s key, so a mismatch means the token_uuid was
87    // paired with a different user's session — reject it.
88    if user_id != details.user_id {
89        return None;
90    }
91    Some((user_id, details.token_uuid))
92}
93
94struct RefreshOutcome {
95    user_id: uuid::Uuid,
96    access_token_uuid: uuid::Uuid,
97    set_cookies: Vec<HeaderValue>,
98}
99
100async fn try_refresh_session(
101    data: &Arc<AppState>,
102    cookie_jar: &CookieJar,
103) -> Result<RefreshOutcome, AuthError> {
104    let refresh_token = cookie_jar
105        .get("refresh_token")
106        .map(|c| c.value().to_string())
107        .ok_or_else(|| unauthorized("You are not logged in, please provide token"))?;
108
109    let refresh_details =
110        crate::auth_keys::verify(&refresh_token, crate::token::TokenType::Refresh)
111            .await
112            .ok_or_else(|| unauthorized("Token is invalid or session has expired"))?;
113
114    let mut redis_client = data
115        .redis_client
116        .get_multiplexed_async_connection()
117        .await
118        .map_err(|e| internal(format!("Redis error: {e}")))?;
119
120    // Atomically CONSUME the old refresh session (GETDEL): the first request to
121    // redeem this token gets the user id and removes the entry; a concurrent
122    // replay of the same token gets `None` and is rejected. This makes refresh
123    // rotation one-time-use and race-safe — set-new-then-del-old would let two
124    // concurrent refreshes both win.
125    let user_id_str: Option<String> = redis_client
126        .get_del(refresh_details.token_uuid.to_string())
127        .await
128        .map_err(|e| internal(format!("Redis error: {e}")))?;
129    let user_id_str =
130        user_id_str.ok_or_else(|| unauthorized("Token is invalid or session has expired"))?;
131
132    let user_id = uuid::Uuid::parse_str(&user_id_str)
133        .map_err(|_| unauthorized("Token is invalid or session has expired"))?;
134
135    // Bind the verified subject to the Redis session owner (see resolve_access).
136    if user_id != refresh_details.user_id {
137        return Err(unauthorized("Token is invalid or session has expired"));
138    }
139
140    let access_details = crate::auth_keys::mint(
141        user_id,
142        data.conf.access_token_max_age,
143        crate::token::TokenType::Access,
144    )
145    .await
146    .map_err(|e| internal(format!("error generating token: {e}")))?;
147
148    // Rotate the refresh token: mint a FRESH uuid (not the old one) so the
149    // previous refresh token can be invalidated. Reusing the uuid would let a
150    // replayed pre-refresh token keep working until its own exp.
151    let new_refresh_details = crate::auth_keys::mint(
152        user_id,
153        data.conf.refresh_token_max_age,
154        crate::token::TokenType::Refresh,
155    )
156    .await
157    .map_err(|e| internal(format!("error generating token: {e}")))?;
158
159    let _: bool = redis_client
160        .set_ex(
161            access_details.token_uuid.to_string(),
162            user_id.to_string(),
163            (data.conf.access_token_max_age * 60) as u64,
164        )
165        .await
166        .map_err(|e| internal(format!("Redis error: {e}")))?;
167
168    // Register the new refresh session. The old one was already consumed
169    // atomically by the GETDEL above, so no separate delete is needed.
170    let _: bool = redis_client
171        .set_ex(
172            new_refresh_details.token_uuid.to_string(),
173            user_id.to_string(),
174            (data.conf.refresh_token_max_age * 60) as u64,
175        )
176        .await
177        .map_err(|e| internal(format!("Redis error: {e}")))?;
178
179    let access_cookie = Cookie::build((
180        "access_token",
181        access_details.token.clone().unwrap_or_default(),
182    ))
183    .path("/")
184    .max_age(time::Duration::minutes(data.conf.access_token_max_age * 60))
185    .same_site(SameSite::Lax)
186    .secure(crate::auth_keys::secure_cookies())
187    .http_only(true)
188    .build();
189
190    let refresh_cookie = Cookie::build((
191        "refresh_token",
192        new_refresh_details.token.unwrap_or_default(),
193    ))
194    .path("/")
195    .max_age(time::Duration::minutes(
196        data.conf.refresh_token_max_age * 60,
197    ))
198    .same_site(SameSite::Lax)
199    .secure(crate::auth_keys::secure_cookies())
200    .http_only(true)
201    .build();
202
203    let logged_in_cookie = Cookie::build(("logged_in", "true"))
204        .path("/")
205        .max_age(time::Duration::minutes(data.conf.access_token_max_age * 60))
206        .same_site(SameSite::Lax)
207        .http_only(false)
208        .build();
209
210    let set_cookies = [access_cookie, refresh_cookie, logged_in_cookie]
211        .iter()
212        .map(|c| c.to_string().parse().unwrap())
213        .collect();
214
215    Ok(RefreshOutcome {
216        user_id,
217        access_token_uuid: access_details.token_uuid,
218        set_cookies,
219    })
220}
221
222pub async fn auth(
223    cookie_jar: CookieJar,
224    State(data): State<Arc<AppState>>,
225    mut req: Request<Body>,
226    next: Next,
227) -> Result<Response, AuthError> {
228    let access_token = extract_access_token(&cookie_jar, &req);
229
230    let (user_id, access_token_uuid, refreshed_cookies) =
231        match resolve_access(&data, access_token.as_deref()).await {
232            Some((uid, atid)) => (uid, atid, Vec::new()),
233            None => {
234                let outcome = try_refresh_session(&data, &cookie_jar).await?;
235                (
236                    outcome.user_id,
237                    outcome.access_token_uuid,
238                    outcome.set_cookies,
239                )
240            }
241        };
242
243    let mut conn = get_connection()
244        .await
245        .map_err(|e| internal(format!("DB error: {e}")))?;
246
247    let user = sqlx::query_as!(User, "SELECT id, user_name as name, email, user_password as password, user_role as role, photo, verified, db_name as database, created_at, updated_at FROM users WHERE id = $1", user_id)
248        .fetch_optional(&mut *conn)
249        .await
250        .map_err(|e| internal(format!("Error fetching user from database: {e}")))?
251        .ok_or_else(|| unauthorized("The user belonging to this token no longer exists"))?;
252
253    if !user.verified {
254        return Err(unauthorized("The user is not verified yet"));
255    }
256
257    req.extensions_mut().insert(JWTAuthMiddleware {
258        user,
259        access_token_uuid,
260    });
261
262    let mut response = next.run(req).await;
263    for cookie in refreshed_cookies {
264        response.headers_mut().append(header::SET_COOKIE, cookie);
265    }
266    Ok(response)
267}
268
269pub async fn admin(
270    Extension(jwtauth): Extension<JWTAuthMiddleware>,
271    req: Request<Body>,
272    next: Next,
273) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
274    if jwtauth.user.role != "admin" {
275        let error_response = ErrorResponse {
276            status: "fail",
277            message: "The user does not have admin permissions".to_string(),
278        };
279        return Err((StatusCode::UNAUTHORIZED, Json(error_response)));
280    }
281
282    Ok(next.run(req).await)
283}