1use std::sync::Arc;
2
3use axum::{
4 Extension, Json,
5 body::Body,
6 extract::State,
7 http::{HeaderValue, Request, StatusCode, header},
8 middleware::Next,
9 response::{IntoResponse, Response},
10};
11
12use axum_extra::extract::{
13 CookieJar,
14 cookie::{Cookie, SameSite},
15};
16use serde::{Deserialize, Serialize};
17
18use crate::{AppState, model::User, token};
19use redis::AsyncCommands;
20use server::db::get_connection;
21
22#[derive(Debug, Serialize)]
23pub struct ErrorResponse {
24 pub status: &'static str,
25 pub message: String,
26}
27
28#[derive(Debug, Serialize, Deserialize, Clone)]
29pub struct JWTAuthMiddleware {
30 pub user: User,
31 pub access_token_uuid: uuid::Uuid,
32}
33
34type AuthError = (StatusCode, Json<ErrorResponse>);
35
36fn unauthorized(message: impl Into<String>) -> AuthError {
37 (
38 StatusCode::UNAUTHORIZED,
39 Json(ErrorResponse {
40 status: "fail",
41 message: message.into(),
42 }),
43 )
44}
45
46fn internal(message: impl Into<String>) -> AuthError {
47 (
48 StatusCode::INTERNAL_SERVER_ERROR,
49 Json(ErrorResponse {
50 status: "error",
51 message: message.into(),
52 }),
53 )
54}
55
56fn extract_access_token(cookie_jar: &CookieJar, req: &Request<Body>) -> Option<String> {
57 cookie_jar
58 .get("access_token")
59 .map(|cookie| cookie.value().to_string())
60 .or_else(|| {
61 req.headers()
62 .get(header::AUTHORIZATION)
63 .and_then(|auth_header| auth_header.to_str().ok())
64 .and_then(|auth_value| {
65 auth_value
66 .strip_prefix("Bearer ")
67 .map(std::borrow::ToOwned::to_owned)
68 })
69 })
70}
71
72async fn resolve_access(
73 data: &Arc<AppState>,
74 access_token: Option<&str>,
75) -> Option<(uuid::Uuid, uuid::Uuid)> {
76 let token_str = access_token?;
77 let details = token::verify_jwt_token(&data.conf.access_token_public_key, token_str).ok()?;
78 let mut redis = data
79 .redis_client
80 .get_multiplexed_async_connection()
81 .await
82 .ok()?;
83 let user_id_str: String = redis.get(details.token_uuid.to_string()).await.ok()?;
84 let user_id = uuid::Uuid::parse_str(&user_id_str).ok()?;
85 Some((user_id, details.token_uuid))
86}
87
88struct RefreshOutcome {
89 user_id: uuid::Uuid,
90 access_token_uuid: uuid::Uuid,
91 set_cookies: Vec<HeaderValue>,
92}
93
94async fn try_refresh_session(
95 data: &Arc<AppState>,
96 cookie_jar: &CookieJar,
97) -> Result<RefreshOutcome, AuthError> {
98 let refresh_token = cookie_jar
99 .get("refresh_token")
100 .map(|c| c.value().to_string())
101 .ok_or_else(|| unauthorized("You are not logged in, please provide token"))?;
102
103 let refresh_details =
104 token::verify_jwt_token(&data.conf.refresh_token_public_key, &refresh_token)
105 .map_err(|e| unauthorized(format!("{e:?}")))?;
106
107 let mut redis_client = data
108 .redis_client
109 .get_multiplexed_async_connection()
110 .await
111 .map_err(|e| internal(format!("Redis error: {e}")))?;
112
113 let user_id_str: String = redis_client
114 .get(refresh_details.token_uuid.to_string())
115 .await
116 .map_err(|_| unauthorized("Token is invalid or session has expired"))?;
117
118 let user_id = uuid::Uuid::parse_str(&user_id_str)
119 .map_err(|_| unauthorized("Token is invalid or session has expired"))?;
120
121 let access_details = token::generate_jwt_token(
122 user_id,
123 data.conf.access_token_max_age,
124 &data.conf.access_token_private_key,
125 )
126 .map_err(|e| internal(format!("error generating token: {e}")))?;
127
128 let new_refresh_details = token::generate_jwt_token_with_uuid(
129 user_id,
130 refresh_details.token_uuid,
131 data.conf.refresh_token_max_age,
132 &data.conf.refresh_token_private_key,
133 )
134 .map_err(|e| internal(format!("error generating token: {e}")))?;
135
136 let _: bool = redis_client
137 .set_ex(
138 access_details.token_uuid.to_string(),
139 user_id.to_string(),
140 (data.conf.access_token_max_age * 60) as u64,
141 )
142 .await
143 .map_err(|e| internal(format!("Redis error: {e}")))?;
144
145 let _: bool = redis_client
146 .expire(
147 refresh_details.token_uuid.to_string(),
148 data.conf.refresh_token_max_age * 60,
149 )
150 .await
151 .map_err(|e| internal(format!("Redis error: {e}")))?;
152
153 let access_cookie = Cookie::build((
154 "access_token",
155 access_details.token.clone().unwrap_or_default(),
156 ))
157 .path("/")
158 .max_age(time::Duration::minutes(data.conf.access_token_max_age * 60))
159 .same_site(SameSite::Lax)
160 .http_only(true)
161 .build();
162
163 let refresh_cookie = Cookie::build((
164 "refresh_token",
165 new_refresh_details.token.unwrap_or_default(),
166 ))
167 .path("/")
168 .max_age(time::Duration::minutes(
169 data.conf.refresh_token_max_age * 60,
170 ))
171 .same_site(SameSite::Lax)
172 .http_only(true)
173 .build();
174
175 let logged_in_cookie = Cookie::build(("logged_in", "true"))
176 .path("/")
177 .max_age(time::Duration::minutes(data.conf.access_token_max_age * 60))
178 .same_site(SameSite::Lax)
179 .http_only(false)
180 .build();
181
182 let set_cookies = [access_cookie, refresh_cookie, logged_in_cookie]
183 .iter()
184 .map(|c| c.to_string().parse().unwrap())
185 .collect();
186
187 Ok(RefreshOutcome {
188 user_id,
189 access_token_uuid: access_details.token_uuid,
190 set_cookies,
191 })
192}
193
194pub async fn auth(
195 cookie_jar: CookieJar,
196 State(data): State<Arc<AppState>>,
197 mut req: Request<Body>,
198 next: Next,
199) -> Result<Response, AuthError> {
200 let access_token = extract_access_token(&cookie_jar, &req);
201
202 let (user_id, access_token_uuid, refreshed_cookies) =
203 match resolve_access(&data, access_token.as_deref()).await {
204 Some((uid, atid)) => (uid, atid, Vec::new()),
205 None => {
206 let outcome = try_refresh_session(&data, &cookie_jar).await?;
207 (
208 outcome.user_id,
209 outcome.access_token_uuid,
210 outcome.set_cookies,
211 )
212 }
213 };
214
215 let mut conn = get_connection()
216 .await
217 .map_err(|e| internal(format!("DB error: {e}")))?;
218
219 let user = sqlx::query_as!(User, "SELECT id, user_name as name, email, user_password as password, user_role as role, photo, verified, db_name as database, created_at, updated_at FROM users WHERE id = $1", user_id)
220 .fetch_optional(&mut *conn)
221 .await
222 .map_err(|e| internal(format!("Error fetching user from database: {e}")))?
223 .ok_or_else(|| unauthorized("The user belonging to this token no longer exists"))?;
224
225 if !user.verified {
226 return Err(unauthorized("The user is not verified yet"));
227 }
228
229 req.extensions_mut().insert(JWTAuthMiddleware {
230 user,
231 access_token_uuid,
232 });
233
234 let mut response = next.run(req).await;
235 for cookie in refreshed_cookies {
236 response.headers_mut().append(header::SET_COOKIE, cookie);
237 }
238 Ok(response)
239}
240
241pub async fn admin(
242 Extension(jwtauth): Extension<JWTAuthMiddleware>,
243 req: Request<Body>,
244 next: Next,
245) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
246 if jwtauth.user.role != "admin" {
247 let error_response = ErrorResponse {
248 status: "fail",
249 message: "The user does not have admin permissions".to_string(),
250 };
251 return Err((StatusCode::UNAUTHORIZED, Json(error_response)));
252 }
253
254 Ok(next.run(req).await)
255}