Lines
24.73 %
Functions
6.67 %
Branches
100 %
use std::sync::Arc;
use axum::{
Extension, Json,
body::Body,
extract::State,
http::{HeaderValue, Request, StatusCode, header},
middleware::Next,
response::{IntoResponse, Response},
};
use axum_extra::extract::{
CookieJar,
cookie::{Cookie, SameSite},
use serde::{Deserialize, Serialize};
use crate::{AppState, model::User, token};
use redis::AsyncCommands;
use server::db::get_connection;
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
pub status: &'static str,
pub message: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct JWTAuthMiddleware {
pub user: User,
pub access_token_uuid: uuid::Uuid,
type AuthError = (StatusCode, Json<ErrorResponse>);
fn unauthorized(message: impl Into<String>) -> AuthError {
(
StatusCode::UNAUTHORIZED,
Json(ErrorResponse {
status: "fail",
message: message.into(),
}),
)
fn internal(message: impl Into<String>) -> AuthError {
StatusCode::INTERNAL_SERVER_ERROR,
status: "error",
fn extract_access_token(cookie_jar: &CookieJar, req: &Request<Body>) -> Option<String> {
cookie_jar
.get("access_token")
.map(|cookie| cookie.value().to_string())
.or_else(|| {
req.headers()
.get(header::AUTHORIZATION)
.and_then(|auth_header| auth_header.to_str().ok())
.and_then(|auth_value| {
auth_value
.strip_prefix("Bearer ")
.map(std::borrow::ToOwned::to_owned)
})
async fn resolve_access(
data: &Arc<AppState>,
access_token: Option<&str>,
) -> Option<(uuid::Uuid, uuid::Uuid)> {
let token_str = access_token?;
let details = token::verify_jwt_token(&data.conf.access_token_public_key, token_str).ok()?;
let mut redis = data
.redis_client
.get_multiplexed_async_connection()
.await
.ok()?;
let user_id_str: String = redis.get(details.token_uuid.to_string()).await.ok()?;
let user_id = uuid::Uuid::parse_str(&user_id_str).ok()?;
Some((user_id, details.token_uuid))
struct RefreshOutcome {
user_id: uuid::Uuid,
access_token_uuid: uuid::Uuid,
set_cookies: Vec<HeaderValue>,
async fn try_refresh_session(
cookie_jar: &CookieJar,
) -> Result<RefreshOutcome, AuthError> {
let refresh_token = cookie_jar
.get("refresh_token")
.map(|c| c.value().to_string())
.ok_or_else(|| unauthorized("You are not logged in, please provide token"))?;
let refresh_details =
token::verify_jwt_token(&data.conf.refresh_token_public_key, &refresh_token)
.map_err(|e| unauthorized(format!("{e:?}")))?;
let mut redis_client = data
.map_err(|e| internal(format!("Redis error: {e}")))?;
let user_id_str: String = redis_client
.get(refresh_details.token_uuid.to_string())
.map_err(|_| unauthorized("Token is invalid or session has expired"))?;
let user_id = uuid::Uuid::parse_str(&user_id_str)
let access_details = token::generate_jwt_token(
user_id,
data.conf.access_token_max_age,
&data.conf.access_token_private_key,
.map_err(|e| internal(format!("error generating token: {e}")))?;
let new_refresh_details = token::generate_jwt_token_with_uuid(
refresh_details.token_uuid,
data.conf.refresh_token_max_age,
&data.conf.refresh_token_private_key,
let _: bool = redis_client
.set_ex(
access_details.token_uuid.to_string(),
user_id.to_string(),
(data.conf.access_token_max_age * 60) as u64,
.expire(
refresh_details.token_uuid.to_string(),
data.conf.refresh_token_max_age * 60,
let access_cookie = Cookie::build((
"access_token",
access_details.token.clone().unwrap_or_default(),
))
.path("/")
.max_age(time::Duration::minutes(data.conf.access_token_max_age * 60))
.same_site(SameSite::Lax)
.http_only(true)
.build();
let refresh_cookie = Cookie::build((
"refresh_token",
new_refresh_details.token.unwrap_or_default(),
.max_age(time::Duration::minutes(
let logged_in_cookie = Cookie::build(("logged_in", "true"))
.http_only(false)
let set_cookies = [access_cookie, refresh_cookie, logged_in_cookie]
.iter()
.map(|c| c.to_string().parse().unwrap())
.collect();
Ok(RefreshOutcome {
access_token_uuid: access_details.token_uuid,
set_cookies,
pub async fn auth(
cookie_jar: CookieJar,
State(data): State<Arc<AppState>>,
mut req: Request<Body>,
next: Next,
) -> Result<Response, AuthError> {
let access_token = extract_access_token(&cookie_jar, &req);
let (user_id, access_token_uuid, refreshed_cookies) =
match resolve_access(&data, access_token.as_deref()).await {
Some((uid, atid)) => (uid, atid, Vec::new()),
None => {
let outcome = try_refresh_session(&data, &cookie_jar).await?;
outcome.user_id,
outcome.access_token_uuid,
outcome.set_cookies,
let mut conn = get_connection()
.map_err(|e| internal(format!("DB error: {e}")))?;
let user = sqlx::query_as!(User, "SELECT id, user_name as name, email, user_password as password, user_role as role, photo, verified, db_name as database, created_at, updated_at FROM users WHERE id = $1", user_id)
.fetch_optional(&mut *conn)
.map_err(|e| internal(format!("Error fetching user from database: {e}")))?
.ok_or_else(|| unauthorized("The user belonging to this token no longer exists"))?;
if !user.verified {
return Err(unauthorized("The user is not verified yet"));
req.extensions_mut().insert(JWTAuthMiddleware {
user,
access_token_uuid,
});
let mut response = next.run(req).await;
for cookie in refreshed_cookies {
response.headers_mut().append(header::SET_COOKIE, cookie);
Ok(response)
pub async fn admin(
Extension(jwtauth): Extension<JWTAuthMiddleware>,
req: Request<Body>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
if jwtauth.user.role != "admin" {
let error_response = ErrorResponse {
message: "The user does not have admin permissions".to_string(),
return Err((StatusCode::UNAUTHORIZED, Json(error_response)));
Ok(next.run(req).await)