1
use std::sync::Arc;
2

            
3
use axum::{
4
    Extension, Json,
5
    body::Body,
6
    extract::State,
7
    http::{HeaderValue, Request, StatusCode, header},
8
    middleware::Next,
9
    response::{IntoResponse, Response},
10
};
11

            
12
use axum_extra::extract::{
13
    CookieJar,
14
    cookie::{Cookie, SameSite},
15
};
16
use serde::{Deserialize, Serialize};
17

            
18
use crate::{AppState, model::User, token};
19
use redis::AsyncCommands;
20
use server::db::get_connection;
21

            
22
#[derive(Debug, Serialize)]
23
pub struct ErrorResponse {
24
    pub status: &'static str,
25
    pub message: String,
26
}
27

            
28
#[derive(Debug, Serialize, Deserialize, Clone)]
29
pub struct JWTAuthMiddleware {
30
    pub user: User,
31
    pub access_token_uuid: uuid::Uuid,
32
}
33

            
34
type AuthError = (StatusCode, Json<ErrorResponse>);
35

            
36
12
fn unauthorized(message: impl Into<String>) -> AuthError {
37
12
    (
38
12
        StatusCode::UNAUTHORIZED,
39
12
        Json(ErrorResponse {
40
12
            status: "fail",
41
12
            message: message.into(),
42
12
        }),
43
12
    )
44
12
}
45

            
46
fn internal(message: impl Into<String>) -> AuthError {
47
    (
48
        StatusCode::INTERNAL_SERVER_ERROR,
49
        Json(ErrorResponse {
50
            status: "error",
51
            message: message.into(),
52
        }),
53
    )
54
}
55

            
56
12
fn extract_access_token(cookie_jar: &CookieJar, req: &Request<Body>) -> Option<String> {
57
12
    cookie_jar
58
12
        .get("access_token")
59
12
        .map(|cookie| cookie.value().to_string())
60
12
        .or_else(|| {
61
12
            req.headers()
62
12
                .get(header::AUTHORIZATION)
63
12
                .and_then(|auth_header| auth_header.to_str().ok())
64
12
                .and_then(|auth_value| {
65
                    auth_value
66
                        .strip_prefix("Bearer ")
67
                        .map(std::borrow::ToOwned::to_owned)
68
                })
69
12
        })
70
12
}
71

            
72
12
async fn resolve_access(
73
12
    data: &Arc<AppState>,
74
12
    access_token: Option<&str>,
75
12
) -> Option<(uuid::Uuid, uuid::Uuid)> {
76
12
    let token_str = access_token?;
77
    let details = token::verify_jwt_token(&data.conf.access_token_public_key, token_str).ok()?;
78
    let mut redis = data
79
        .redis_client
80
        .get_multiplexed_async_connection()
81
        .await
82
        .ok()?;
83
    let user_id_str: String = redis.get(details.token_uuid.to_string()).await.ok()?;
84
    let user_id = uuid::Uuid::parse_str(&user_id_str).ok()?;
85
    Some((user_id, details.token_uuid))
86
12
}
87

            
88
struct RefreshOutcome {
89
    user_id: uuid::Uuid,
90
    access_token_uuid: uuid::Uuid,
91
    set_cookies: Vec<HeaderValue>,
92
}
93

            
94
12
async fn try_refresh_session(
95
12
    data: &Arc<AppState>,
96
12
    cookie_jar: &CookieJar,
97
12
) -> Result<RefreshOutcome, AuthError> {
98
12
    let refresh_token = cookie_jar
99
12
        .get("refresh_token")
100
12
        .map(|c| c.value().to_string())
101
12
        .ok_or_else(|| unauthorized("You are not logged in, please provide token"))?;
102

            
103
    let refresh_details =
104
        token::verify_jwt_token(&data.conf.refresh_token_public_key, &refresh_token)
105
            .map_err(|e| unauthorized(format!("{e:?}")))?;
106

            
107
    let mut redis_client = data
108
        .redis_client
109
        .get_multiplexed_async_connection()
110
        .await
111
        .map_err(|e| internal(format!("Redis error: {e}")))?;
112

            
113
    let user_id_str: String = redis_client
114
        .get(refresh_details.token_uuid.to_string())
115
        .await
116
        .map_err(|_| unauthorized("Token is invalid or session has expired"))?;
117

            
118
    let user_id = uuid::Uuid::parse_str(&user_id_str)
119
        .map_err(|_| unauthorized("Token is invalid or session has expired"))?;
120

            
121
    let access_details = token::generate_jwt_token(
122
        user_id,
123
        data.conf.access_token_max_age,
124
        &data.conf.access_token_private_key,
125
    )
126
    .map_err(|e| internal(format!("error generating token: {e}")))?;
127

            
128
    let new_refresh_details = token::generate_jwt_token_with_uuid(
129
        user_id,
130
        refresh_details.token_uuid,
131
        data.conf.refresh_token_max_age,
132
        &data.conf.refresh_token_private_key,
133
    )
134
    .map_err(|e| internal(format!("error generating token: {e}")))?;
135

            
136
    let _: bool = redis_client
137
        .set_ex(
138
            access_details.token_uuid.to_string(),
139
            user_id.to_string(),
140
            (data.conf.access_token_max_age * 60) as u64,
141
        )
142
        .await
143
        .map_err(|e| internal(format!("Redis error: {e}")))?;
144

            
145
    let _: bool = redis_client
146
        .expire(
147
            refresh_details.token_uuid.to_string(),
148
            data.conf.refresh_token_max_age * 60,
149
        )
150
        .await
151
        .map_err(|e| internal(format!("Redis error: {e}")))?;
152

            
153
    let access_cookie = Cookie::build((
154
        "access_token",
155
        access_details.token.clone().unwrap_or_default(),
156
    ))
157
    .path("/")
158
    .max_age(time::Duration::minutes(data.conf.access_token_max_age * 60))
159
    .same_site(SameSite::Lax)
160
    .http_only(true)
161
    .build();
162

            
163
    let refresh_cookie = Cookie::build((
164
        "refresh_token",
165
        new_refresh_details.token.unwrap_or_default(),
166
    ))
167
    .path("/")
168
    .max_age(time::Duration::minutes(
169
        data.conf.refresh_token_max_age * 60,
170
    ))
171
    .same_site(SameSite::Lax)
172
    .http_only(true)
173
    .build();
174

            
175
    let logged_in_cookie = Cookie::build(("logged_in", "true"))
176
        .path("/")
177
        .max_age(time::Duration::minutes(data.conf.access_token_max_age * 60))
178
        .same_site(SameSite::Lax)
179
        .http_only(false)
180
        .build();
181

            
182
    let set_cookies = [access_cookie, refresh_cookie, logged_in_cookie]
183
        .iter()
184
        .map(|c| c.to_string().parse().unwrap())
185
        .collect();
186

            
187
    Ok(RefreshOutcome {
188
        user_id,
189
        access_token_uuid: access_details.token_uuid,
190
        set_cookies,
191
    })
192
12
}
193

            
194
12
pub async fn auth(
195
12
    cookie_jar: CookieJar,
196
12
    State(data): State<Arc<AppState>>,
197
12
    mut req: Request<Body>,
198
12
    next: Next,
199
12
) -> Result<Response, AuthError> {
200
12
    let access_token = extract_access_token(&cookie_jar, &req);
201

            
202
    let (user_id, access_token_uuid, refreshed_cookies) =
203
12
        match resolve_access(&data, access_token.as_deref()).await {
204
            Some((uid, atid)) => (uid, atid, Vec::new()),
205
            None => {
206
12
                let outcome = try_refresh_session(&data, &cookie_jar).await?;
207
                (
208
                    outcome.user_id,
209
                    outcome.access_token_uuid,
210
                    outcome.set_cookies,
211
                )
212
            }
213
        };
214

            
215
    let mut conn = get_connection()
216
        .await
217
        .map_err(|e| internal(format!("DB error: {e}")))?;
218

            
219
    let user = sqlx::query_as!(User, "SELECT id, user_name as name, email, user_password as password, user_role as role, photo, verified, db_name as database, created_at, updated_at FROM users WHERE id = $1", user_id)
220
        .fetch_optional(&mut *conn)
221
        .await
222
        .map_err(|e| internal(format!("Error fetching user from database: {e}")))?
223
        .ok_or_else(|| unauthorized("The user belonging to this token no longer exists"))?;
224

            
225
    if !user.verified {
226
        return Err(unauthorized("The user is not verified yet"));
227
    }
228

            
229
    req.extensions_mut().insert(JWTAuthMiddleware {
230
        user,
231
        access_token_uuid,
232
    });
233

            
234
    let mut response = next.run(req).await;
235
    for cookie in refreshed_cookies {
236
        response.headers_mut().append(header::SET_COOKIE, cookie);
237
    }
238
    Ok(response)
239
12
}
240

            
241
pub async fn admin(
242
    Extension(jwtauth): Extension<JWTAuthMiddleware>,
243
    req: Request<Body>,
244
    next: Next,
245
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
246
    if jwtauth.user.role != "admin" {
247
        let error_response = ErrorResponse {
248
            status: "fail",
249
            message: "The user does not have admin permissions".to_string(),
250
        };
251
        return Err((StatusCode::UNAUTHORIZED, Json(error_response)));
252
    }
253

            
254
    Ok(next.run(req).await)
255
}