1use std::sync::Arc;
2
3use axum::{
4 Extension, Json,
5 body::Body,
6 extract::State,
7 http::{HeaderValue, Request, StatusCode, header},
8 middleware::Next,
9 response::{IntoResponse, Response},
10};
11
12use axum_extra::extract::{
13 CookieJar,
14 cookie::{Cookie, SameSite},
15};
16use serde::{Deserialize, Serialize};
17
18use crate::{AppState, model::User};
19use redis::AsyncCommands;
20use server::db::get_connection;
21
22#[derive(Debug, Serialize)]
23pub struct ErrorResponse {
24 pub status: &'static str,
25 pub message: String,
26}
27
28#[derive(Debug, Serialize, Deserialize, Clone)]
29pub struct JWTAuthMiddleware {
30 pub user: User,
31 pub access_token_uuid: uuid::Uuid,
32}
33
34type AuthError = (StatusCode, Json<ErrorResponse>);
35
36fn unauthorized(message: impl Into<String>) -> AuthError {
37 (
38 StatusCode::UNAUTHORIZED,
39 Json(ErrorResponse {
40 status: "fail",
41 message: message.into(),
42 }),
43 )
44}
45
46fn internal(message: impl Into<String>) -> AuthError {
47 (
48 StatusCode::INTERNAL_SERVER_ERROR,
49 Json(ErrorResponse {
50 status: "error",
51 message: message.into(),
52 }),
53 )
54}
55
56fn extract_access_token(cookie_jar: &CookieJar, req: &Request<Body>) -> Option<String> {
57 cookie_jar
58 .get("access_token")
59 .map(|cookie| cookie.value().to_string())
60 .or_else(|| {
61 req.headers()
62 .get(header::AUTHORIZATION)
63 .and_then(|auth_header| auth_header.to_str().ok())
64 .and_then(|auth_value| {
65 auth_value
66 .strip_prefix("Bearer ")
67 .map(std::borrow::ToOwned::to_owned)
68 })
69 })
70}
71
72async fn resolve_access(
73 data: &Arc<AppState>,
74 access_token: Option<&str>,
75) -> Option<(uuid::Uuid, uuid::Uuid)> {
76 let token_str = access_token?;
77 let details = crate::auth_keys::verify(token_str, crate::token::TokenType::Access).await?;
78 let mut redis = data
79 .redis_client
80 .get_multiplexed_async_connection()
81 .await
82 .ok()?;
83 let user_id_str: String = redis.get(details.token_uuid.to_string()).await.ok()?;
84 let user_id = uuid::Uuid::parse_str(&user_id_str).ok()?;
85 if user_id != details.user_id {
89 return None;
90 }
91 Some((user_id, details.token_uuid))
92}
93
94struct RefreshOutcome {
95 user_id: uuid::Uuid,
96 access_token_uuid: uuid::Uuid,
97 set_cookies: Vec<HeaderValue>,
98}
99
100async fn try_refresh_session(
101 data: &Arc<AppState>,
102 cookie_jar: &CookieJar,
103) -> Result<RefreshOutcome, AuthError> {
104 let refresh_token = cookie_jar
105 .get("refresh_token")
106 .map(|c| c.value().to_string())
107 .ok_or_else(|| unauthorized("You are not logged in, please provide token"))?;
108
109 let refresh_details =
110 crate::auth_keys::verify(&refresh_token, crate::token::TokenType::Refresh)
111 .await
112 .ok_or_else(|| unauthorized("Token is invalid or session has expired"))?;
113
114 let mut redis_client = data
115 .redis_client
116 .get_multiplexed_async_connection()
117 .await
118 .map_err(|e| internal(format!("Redis error: {e}")))?;
119
120 let user_id_str: Option<String> = redis_client
126 .get_del(refresh_details.token_uuid.to_string())
127 .await
128 .map_err(|e| internal(format!("Redis error: {e}")))?;
129 let user_id_str =
130 user_id_str.ok_or_else(|| unauthorized("Token is invalid or session has expired"))?;
131
132 let user_id = uuid::Uuid::parse_str(&user_id_str)
133 .map_err(|_| unauthorized("Token is invalid or session has expired"))?;
134
135 if user_id != refresh_details.user_id {
137 return Err(unauthorized("Token is invalid or session has expired"));
138 }
139
140 let access_details = crate::auth_keys::mint(
141 user_id,
142 data.conf.access_token_max_age,
143 crate::token::TokenType::Access,
144 )
145 .await
146 .map_err(|e| internal(format!("error generating token: {e}")))?;
147
148 let new_refresh_details = crate::auth_keys::mint(
152 user_id,
153 data.conf.refresh_token_max_age,
154 crate::token::TokenType::Refresh,
155 )
156 .await
157 .map_err(|e| internal(format!("error generating token: {e}")))?;
158
159 let _: bool = redis_client
160 .set_ex(
161 access_details.token_uuid.to_string(),
162 user_id.to_string(),
163 (data.conf.access_token_max_age * 60) as u64,
164 )
165 .await
166 .map_err(|e| internal(format!("Redis error: {e}")))?;
167
168 let _: bool = redis_client
171 .set_ex(
172 new_refresh_details.token_uuid.to_string(),
173 user_id.to_string(),
174 (data.conf.refresh_token_max_age * 60) as u64,
175 )
176 .await
177 .map_err(|e| internal(format!("Redis error: {e}")))?;
178
179 let access_cookie = Cookie::build((
180 "access_token",
181 access_details.token.clone().unwrap_or_default(),
182 ))
183 .path("/")
184 .max_age(time::Duration::minutes(data.conf.access_token_max_age * 60))
185 .same_site(SameSite::Lax)
186 .secure(crate::auth_keys::secure_cookies())
187 .http_only(true)
188 .build();
189
190 let refresh_cookie = Cookie::build((
191 "refresh_token",
192 new_refresh_details.token.unwrap_or_default(),
193 ))
194 .path("/")
195 .max_age(time::Duration::minutes(
196 data.conf.refresh_token_max_age * 60,
197 ))
198 .same_site(SameSite::Lax)
199 .secure(crate::auth_keys::secure_cookies())
200 .http_only(true)
201 .build();
202
203 let logged_in_cookie = Cookie::build(("logged_in", "true"))
204 .path("/")
205 .max_age(time::Duration::minutes(data.conf.access_token_max_age * 60))
206 .same_site(SameSite::Lax)
207 .http_only(false)
208 .build();
209
210 let set_cookies = [access_cookie, refresh_cookie, logged_in_cookie]
211 .iter()
212 .map(|c| c.to_string().parse().unwrap())
213 .collect();
214
215 Ok(RefreshOutcome {
216 user_id,
217 access_token_uuid: access_details.token_uuid,
218 set_cookies,
219 })
220}
221
222pub async fn auth(
223 cookie_jar: CookieJar,
224 State(data): State<Arc<AppState>>,
225 mut req: Request<Body>,
226 next: Next,
227) -> Result<Response, AuthError> {
228 let access_token = extract_access_token(&cookie_jar, &req);
229
230 let (user_id, access_token_uuid, refreshed_cookies) =
231 match resolve_access(&data, access_token.as_deref()).await {
232 Some((uid, atid)) => (uid, atid, Vec::new()),
233 None => {
234 let outcome = try_refresh_session(&data, &cookie_jar).await?;
235 (
236 outcome.user_id,
237 outcome.access_token_uuid,
238 outcome.set_cookies,
239 )
240 }
241 };
242
243 let mut conn = get_connection()
244 .await
245 .map_err(|e| internal(format!("DB error: {e}")))?;
246
247 let user = sqlx::query_as!(User, "SELECT id, user_name as name, email, user_password as password, user_role as role, photo, verified, db_name as database, created_at, updated_at FROM users WHERE id = $1", user_id)
248 .fetch_optional(&mut *conn)
249 .await
250 .map_err(|e| internal(format!("Error fetching user from database: {e}")))?
251 .ok_or_else(|| unauthorized("The user belonging to this token no longer exists"))?;
252
253 if !user.verified {
254 return Err(unauthorized("The user is not verified yet"));
255 }
256
257 req.extensions_mut().insert(JWTAuthMiddleware {
258 user,
259 access_token_uuid,
260 });
261
262 let mut response = next.run(req).await;
263 for cookie in refreshed_cookies {
264 response.headers_mut().append(header::SET_COOKIE, cookie);
265 }
266 Ok(response)
267}
268
269pub async fn admin(
270 Extension(jwtauth): Extension<JWTAuthMiddleware>,
271 req: Request<Body>,
272 next: Next,
273) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
274 if jwtauth.user.role != "admin" {
275 let error_response = ErrorResponse {
276 status: "fail",
277 message: "The user does not have admin permissions".to_string(),
278 };
279 return Err((StatusCode::UNAUTHORIZED, Json(error_response)));
280 }
281
282 Ok(next.run(req).await)
283}