1
use std::sync::Arc;
2

            
3
use axum::{
4
    Extension, Json,
5
    body::Body,
6
    extract::State,
7
    http::{Request, StatusCode, header},
8
    middleware::Next,
9
    response::IntoResponse,
10
};
11

            
12
use axum_extra::extract::cookie::CookieJar;
13
use serde::{Deserialize, Serialize};
14

            
15
use crate::{AppState, model::User, token};
16
use redis::AsyncCommands;
17
use server::db::get_connection;
18

            
19
#[derive(Debug, Serialize)]
20
pub struct ErrorResponse {
21
    pub status: &'static str,
22
    pub message: String,
23
}
24

            
25
#[derive(Debug, Serialize, Deserialize, Clone)]
26
pub struct JWTAuthMiddleware {
27
    pub user: User,
28
    pub access_token_uuid: uuid::Uuid,
29
}
30

            
31
22
pub async fn auth(
32
22
    cookie_jar: CookieJar,
33
22
    State(data): State<Arc<AppState>>,
34
22
    mut req: Request<Body>,
35
22
    next: Next,
36
33
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
37
22
    let access_token = cookie_jar
38
22
        .get("access_token")
39
22
        .map(|cookie| cookie.value().to_string())
40
22
        .or_else(|| {
41
22
            req.headers()
42
22
                .get(header::AUTHORIZATION)
43
22
                .and_then(|auth_header| auth_header.to_str().ok())
44
22
                .and_then(|auth_value| {
45
                    auth_value
46
                        .strip_prefix("Bearer ")
47
                        .map(std::borrow::ToOwned::to_owned)
48
                })
49
22
        });
50

            
51
22
    let access_token = access_token.ok_or_else(|| {
52
22
        let error_response = ErrorResponse {
53
22
            status: "fail",
54
22
            message: "You are not logged in, please provide token".to_string(),
55
22
        };
56
22
        (StatusCode::UNAUTHORIZED, Json(error_response))
57
22
    })?;
58

            
59
    let access_token_details =
60
        match token::verify_jwt_token(&data.conf.access_token_public_key, &access_token) {
61
            Ok(token_details) => token_details,
62
            Err(e) => {
63
                let error_response = ErrorResponse {
64
                    status: "fail",
65
                    message: format!("{e:?}"),
66
                };
67
                return Err((StatusCode::UNAUTHORIZED, Json(error_response)));
68
            }
69
        };
70

            
71
    let access_token_uuid = uuid::Uuid::parse_str(&access_token_details.token_uuid.to_string())
72
        .map_err(|_| {
73
            let error_response = ErrorResponse {
74
                status: "fail",
75
                message: "Invalid token".to_string(),
76
            };
77
            (StatusCode::UNAUTHORIZED, Json(error_response))
78
        })?;
79

            
80
    let mut redis_client = data
81
        .redis_client
82
        .get_multiplexed_async_connection()
83
        .await
84
        .map_err(|e| {
85
            let error_response = ErrorResponse {
86
                status: "error",
87
                message: format!("Redis error: {e}"),
88
            };
89
            (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
90
        })?;
91

            
92
    let redis_token_user_id = redis_client
93
        .get::<_, String>(access_token_uuid.clone().to_string())
94
        .await
95
        .map_err(|_| {
96
            let error_response = ErrorResponse {
97
                status: "error",
98
                message: "Token is invalid or session has expired".to_string(),
99
            };
100
            (StatusCode::UNAUTHORIZED, Json(error_response))
101
        })?;
102

            
103
    let user_id_uuid = uuid::Uuid::parse_str(&redis_token_user_id).map_err(|_| {
104
        let error_response = ErrorResponse {
105
            status: "fail",
106
            message: "Token is invalid or session has expired".to_string(),
107
        };
108
        (StatusCode::UNAUTHORIZED, Json(error_response))
109
    })?;
110

            
111
    let mut conn = get_connection().await.map_err(|e| {
112
        let error_response = ErrorResponse {
113
            status: "error",
114
            message: format!("DB error: {e}"),
115
        };
116
        (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
117
    })?;
118

            
119
    let user = sqlx::query_as!(User, "SELECT id, user_name as name, email, user_password as password, user_role as role, photo, verified, db_name as database, created_at, updated_at FROM users WHERE id = $1", user_id_uuid)
120
        .fetch_optional(&mut *conn)
121
        .await
122
        .map_err(|e| {
123
            let error_response = ErrorResponse {
124
                status: "fail",
125
                message: format!("Error fetching user from database: {e}"),
126
            };
127
            (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
128
        })?;
129

            
130
    let user = user.ok_or_else(|| {
131
        let error_response = ErrorResponse {
132
            status: "fail",
133
            message: "The user belonging to this token no longer exists".to_string(),
134
        };
135
        (StatusCode::UNAUTHORIZED, Json(error_response))
136
    })?;
137

            
138
    if !user.verified {
139
        let error_response = ErrorResponse {
140
            status: "fail",
141
            message: "The user is not verified yet".to_string(),
142
        };
143
        return Err((StatusCode::UNAUTHORIZED, Json(error_response)));
144
    }
145

            
146
    req.extensions_mut().insert(JWTAuthMiddleware {
147
        user,
148
        access_token_uuid,
149
    });
150

            
151
    Ok(next.run(req).await)
152
22
}
153

            
154
pub async fn admin(
155
    Extension(jwtauth): Extension<JWTAuthMiddleware>,
156
    req: Request<Body>,
157
    next: Next,
158
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
159
    if jwtauth.user.role != "admin" {
160
        let error_response = ErrorResponse {
161
            status: "fail",
162
            message: "The user does not have admin permissions".to_string(),
163
        };
164
        return Err((StatusCode::UNAUTHORIZED, Json(error_response)));
165
    }
166

            
167
    Ok(next.run(req).await)
168
}