1
use std::sync::Arc;
2

            
3
use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier, password_hash::SaltString};
4
use axum::{
5
    Extension, Json,
6
    extract::State,
7
    http::{HeaderMap, Response, StatusCode, header},
8
    response::IntoResponse,
9
};
10
use axum_extra::extract::{
11
    CookieJar,
12
    cookie::{Cookie, SameSite},
13
};
14
use serde_json::json;
15

            
16
use crate::{
17
    AppState,
18
    jwt_auth::JWTAuthMiddleware,
19
    model::{LoginUserSchema, RegisterUserSchema, User},
20
    response::FilteredUser,
21
    token::{self, TokenDetails},
22
};
23
use rust_i18n::t;
24

            
25
use redis::AsyncCommands;
26
use server::{command::CmdResult, db::get_connection};
27

            
28
pub async fn get_salt() -> SaltString {
29
    SaltString::from_rng(&mut rand::rng())
30
}
31

            
32
pub async fn register_user_handler(
33
    State(_data): State<Arc<AppState>>,
34
    Json(body): Json<RegisterUserSchema>,
35
) -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
36
    let mut conn = get_connection().await.map_err(|e| {
37
        let msg = format!("{} {}", t!("Database error:"), e);
38
        log::error!("{msg}");
39
        let error_response = serde_json::json!({
40
            "status": "fail",
41
            "message": msg,
42
        });
43
        (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
44
    })?;
45

            
46
    let user_exists: Option<bool> =
47
        sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)")
48
            .bind(body.email.clone().to_ascii_lowercase())
49
            .fetch_one(&mut *conn)
50
            .await
51
            .map_err(|e| {
52
                let msg = format!("{} {}", t!("Database error:"), e);
53
                log::error!("{msg}");
54
                let error_response = serde_json::json!({
55
                    "status": "fail",
56
                    "message": msg,
57
                });
58
                (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
59
            })?;
60

            
61
    if let Some(exists) = user_exists
62
        && exists
63
    {
64
        let msg = format!("{}", t!("User with that email already exists"));
65
        log::error!("{msg}");
66
        let error_response = serde_json::json!({
67
            "status": "fail",
68
            "message": msg,
69
        });
70
        return Err((StatusCode::CONFLICT, Json(error_response)));
71
    }
72

            
73
    let salt = get_salt().await;
74
    let hashed_password = Argon2::default()
75
        .hash_password(body.password.as_bytes(), &salt)
76
        .map_err(|e| {
77
            let msg = format!("{} {}", t!("Error while hashing password:"), e);
78
            log::error!("{msg}");
79
            let error_response = serde_json::json!({
80
                "status": "fail",
81
                "message": msg,
82
            });
83
            (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
84
        })
85
        .map(|hash| hash.to_string())?;
86
    let mut conn = get_connection().await.map_err(|e| {
87
        let msg = format!("{} {}", t!("Database error:"), e);
88
        log::error!("{msg}");
89
        let error_response = serde_json::json!({
90
            "status": "fail",
91
            "message": msg,
92
        });
93
        (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
94
    })?;
95

            
96
    let user = sqlx::query_as!(
97
        User,
98
        "INSERT INTO users (user_name,email,user_password) VALUES ($1, $2, $3) RETURNING id, user_name as name, email, user_password as password, user_role as role, photo, verified, db_name as database, created_at, updated_at",
99
        body.name.to_string(),
100
        body.email.to_string().to_ascii_lowercase(),
101
        hashed_password
102
    )
103
    .fetch_one(&mut *conn)
104
    .await
105
    .map_err(|e| {
106
        let msg = format!("{} {}", t!("Database error:"), e);
107
        log::error!("{msg}");
108
        let error_response = serde_json::json!({
109
            "status": "fail",
110
            "message": msg,
111
        });
112
        (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
113
    })?;
114

            
115
    let user_response = serde_json::json!({"status": "success","data": serde_json::json!({
116
        "user": filter_user_record(&user)
117
    })});
118

            
119
    Ok(Json(user_response))
120
}
121

            
122
pub async fn login_user_handler(
123
    State(data): State<Arc<AppState>>,
124
    Json(body): Json<LoginUserSchema>,
125
) -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
126
    let mut conn = get_connection().await.map_err(|e| {
127
        let msg = format!("{} {}", t!("Database error:"), e);
128
        log::error!("{msg}");
129
        let error_response = serde_json::json!({
130
            "status": "fail",
131
            "message": msg,
132
        });
133
        (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
134
    })?;
135

            
136
    let user = sqlx::query_as!(
137
        User,
138
        "SELECT id, user_name as name, email, user_password as password, user_role as role, photo, verified, db_name as database, created_at, updated_at FROM users WHERE email = $1",
139
        body.email.to_ascii_lowercase()
140
    )
141
    .fetch_optional(&mut *conn)
142
    .await
143
    .map_err(|e| {
144
        let msg = format!("{} {}", t!("Database error:"), e);
145
        log::error!("{msg}");
146
        let error_response = serde_json::json!({
147
            "status": "error",
148
            "message": msg,
149
        });
150
        (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
151
    })?
152
    .ok_or_else(|| {
153
        let msg = format!("{}", t!("Invalid email or password"));
154
        log::error!("{msg}");
155
        let error_response = serde_json::json!({
156
            "status": "fail",
157
            "message": msg,
158
        });
159
        (StatusCode::BAD_REQUEST, Json(error_response))
160
    })?;
161

            
162
    let is_valid = match PasswordHash::new(&user.password) {
163
        Ok(parsed_hash) => Argon2::default()
164
            .verify_password(body.password.as_bytes(), &parsed_hash)
165
            .is_ok_and(|()| true),
166
        Err(_) => false,
167
    };
168

            
169
    if !is_valid {
170
        let msg = format!("{}", t!("Invalid email or password"));
171
        log::error!("{msg}");
172
        let error_response = serde_json::json!({
173
            "status": "fail",
174
            "message": msg,
175
        });
176
        return Err((StatusCode::BAD_REQUEST, Json(error_response)));
177
    }
178

            
179
    let access_token_details = generate_token(
180
        user.id,
181
        data.conf.access_token_max_age,
182
        &data.conf.access_token_private_key,
183
    )?;
184
    let refresh_token_details = generate_token(
185
        user.id,
186
        data.conf.refresh_token_max_age,
187
        &data.conf.refresh_token_private_key,
188
    )?;
189

            
190
    save_token_data_to_redis(&data, &access_token_details, data.conf.access_token_max_age).await?;
191
    save_token_data_to_redis(
192
        &data,
193
        &refresh_token_details,
194
        data.conf.refresh_token_max_age,
195
    )
196
    .await?;
197

            
198
    let access_cookie = Cookie::build((
199
        "access_token",
200
        access_token_details.token.clone().unwrap_or_default(),
201
    ))
202
    .path("/")
203
    .max_age(time::Duration::minutes(data.conf.access_token_max_age * 60))
204
    .same_site(SameSite::Lax)
205
    .http_only(true);
206

            
207
    let refresh_cookie = Cookie::build((
208
        "refresh_token",
209
        refresh_token_details.token.unwrap_or_default(),
210
    ))
211
    .path("/")
212
    .max_age(time::Duration::minutes(
213
        data.conf.refresh_token_max_age * 60,
214
    ))
215
    .same_site(SameSite::Lax)
216
    .http_only(true);
217

            
218
    let logged_in_cookie = Cookie::build(("logged_in", "true"))
219
        .path("/")
220
        .max_age(time::Duration::minutes(data.conf.access_token_max_age * 60))
221
        .same_site(SameSite::Lax)
222
        .http_only(false);
223

            
224
    let mut response = Response::new(
225
        json!({"status": "success", "access_token": access_token_details.token.unwrap()})
226
            .to_string(),
227
    );
228
    let mut headers = HeaderMap::new();
229
    headers.append(
230
        header::SET_COOKIE,
231
        access_cookie.to_string().parse().unwrap(),
232
    );
233
    headers.append(
234
        header::SET_COOKIE,
235
        refresh_cookie.to_string().parse().unwrap(),
236
    );
237
    headers.append(
238
        header::SET_COOKIE,
239
        logged_in_cookie.to_string().parse().unwrap(),
240
    );
241
    headers.append("HX-Redirect", "/".to_string().parse().unwrap());
242

            
243
    response.headers_mut().extend(headers);
244
    Ok(response)
245
}
246

            
247
6
pub async fn refresh_access_token_handler(
248
6
    cookie_jar: CookieJar,
249
6
    State(data): State<Arc<AppState>>,
250
9
) -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
251
6
    let message = t!("could not refresh access token");
252

            
253
6
    let refresh_token = cookie_jar
254
6
        .get("refresh_token")
255
6
        .map(|cookie| cookie.value().to_string())
256
6
        .ok_or_else(|| {
257
4
            let error_response = serde_json::json!({
258
4
                "status": "fail",
259
4
                "message": message
260
            });
261
4
            (StatusCode::FORBIDDEN, Json(error_response))
262
4
        })?;
263

            
264
    let refresh_token_details =
265
2
        match token::verify_jwt_token(&data.conf.refresh_token_public_key, &refresh_token) {
266
            Ok(token_details) => token_details,
267
2
            Err(e) => {
268
2
                let error_response = serde_json::json!({
269
2
                    "status": "fail",
270
2
                    "message": format_args!("{:?}", e)
271
                });
272
2
                return Err((StatusCode::UNAUTHORIZED, Json(error_response)));
273
            }
274
        };
275

            
276
    let mut redis_client = data
277
        .redis_client
278
        .get_multiplexed_async_connection()
279
        .await
280
        .map_err(|e| {
281
            let msg = format!("{} {}", t!("Redis error:"), e);
282
            log::error!("{msg}");
283
            let error_response = serde_json::json!({
284
                "status": "error",
285
                "message": msg,
286
            });
287
            (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
288
        })?;
289

            
290
    let redis_token_user_id = redis_client
291
        .get::<_, String>(refresh_token_details.token_uuid.to_string())
292
        .await
293
        .map_err(|_| {
294
            let msg = format!("{}", t!("Token is invalid or session has expired"));
295
            log::error!("{msg}");
296
            let error_response = serde_json::json!({
297
                "status": "error",
298
                "message": msg,
299
            });
300
            (StatusCode::UNAUTHORIZED, Json(error_response))
301
        })?;
302

            
303
    let user_id_uuid = uuid::Uuid::parse_str(&redis_token_user_id).map_err(|_| {
304
        let msg = format!("{}", t!("Token is invalid or session has expired"));
305
        log::error!("{msg}");
306
        let error_response = serde_json::json!({
307
            "status": "error",
308
            "message": msg,
309
        });
310
        (StatusCode::UNAUTHORIZED, Json(error_response))
311
    })?;
312

            
313
    let mut conn = get_connection().await.map_err(|e| {
314
        let msg = format!("{} {}", t!("Database error:"), e);
315
        log::error!("{msg}");
316
        let error_response = serde_json::json!({
317
            "status": "fail",
318
            "message": msg,
319
        });
320
        (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
321
    })?;
322

            
323
    let user = sqlx::query_as!(User, "SELECT id, user_name as name, email, user_password as password, user_role as role, photo, verified, db_name as database, created_at, updated_at FROM users WHERE id = $1", user_id_uuid)
324
        .fetch_optional(&mut *conn)
325
        .await
326
        .map_err(|e| {
327
            let msg = format!("{} {}", t!("Error fetching user from database:"), e);
328
            log::error!("{msg}");
329
            let error_response = serde_json::json!({
330
                "status": "fail",
331
                "message": msg,
332
            });
333
            (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
334
        })?;
335

            
336
    let user = user.ok_or_else(|| {
337
        let msg = format!(
338
            "{}",
339
            t!("The user belonging to this token no longer exists")
340
        );
341
        log::error!("{msg}");
342
        let error_response = serde_json::json!({
343
            "status": "fail",
344
            "message": msg,
345
        });
346
        (StatusCode::UNAUTHORIZED, Json(error_response))
347
    })?;
348

            
349
    let access_token_details = generate_token(
350
        user.id,
351
        data.conf.access_token_max_age,
352
        &data.conf.access_token_private_key,
353
    )?;
354

            
355
    save_token_data_to_redis(&data, &access_token_details, data.conf.access_token_max_age).await?;
356

            
357
    let access_cookie = Cookie::build((
358
        "access_token",
359
        access_token_details.token.clone().unwrap_or_default(),
360
    ))
361
    .path("/")
362
    .max_age(time::Duration::minutes(data.conf.access_token_max_age * 60))
363
    .same_site(SameSite::Lax)
364
    .http_only(true);
365

            
366
    let logged_in_cookie = Cookie::build(("logged_in", "true"))
367
        .path("/")
368
        .max_age(time::Duration::minutes(data.conf.access_token_max_age * 60))
369
        .same_site(SameSite::Lax)
370
        .http_only(false);
371

            
372
    let mut response = Response::new(
373
        json!({"status": "success", "access_token": access_token_details.token.unwrap()})
374
            .to_string(),
375
    );
376
    let mut headers = HeaderMap::new();
377
    headers.append(
378
        header::SET_COOKIE,
379
        access_cookie.to_string().parse().unwrap(),
380
    );
381
    headers.append(
382
        header::SET_COOKIE,
383
        logged_in_cookie.to_string().parse().unwrap(),
384
    );
385

            
386
    response.headers_mut().extend(headers);
387
    Ok(response)
388
6
}
389

            
390
pub async fn logout_handler(
391
    cookie_jar: CookieJar,
392
    Extension(auth_guard): Extension<JWTAuthMiddleware>,
393
    State(data): State<Arc<AppState>>,
394
) -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
395
    let message = t!("Token is invalid or session has expired");
396

            
397
    let refresh_token = cookie_jar
398
        .get("refresh_token")
399
        .map(|cookie| cookie.value().to_string())
400
        .ok_or_else(|| {
401
            let error_response = serde_json::json!({
402
                "status": "fail",
403
                "message": message
404
            });
405
            (StatusCode::FORBIDDEN, Json(error_response))
406
        })?;
407

            
408
    let refresh_token_details =
409
        match token::verify_jwt_token(&data.conf.refresh_token_public_key, &refresh_token) {
410
            Ok(token_details) => token_details,
411
            Err(e) => {
412
                let error_response = serde_json::json!({
413
                    "status": "fail",
414
                    "message": format_args!("{:?}", e)
415
                });
416
                return Err((StatusCode::UNAUTHORIZED, Json(error_response)));
417
            }
418
        };
419

            
420
    let mut redis_client = data
421
        .redis_client
422
        .get_multiplexed_async_connection()
423
        .await
424
        .map_err(|e| {
425
            let msg = format!("{} {}", t!("Redis error:"), e);
426
            log::error!("{msg}");
427
            let error_response = serde_json::json!({
428
                "status": "error",
429
                "message": msg,
430
            });
431
            (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
432
        })?;
433

            
434
    let _: bool = redis_client
435
        .del(&[
436
            refresh_token_details.token_uuid.to_string(),
437
            auth_guard.access_token_uuid.to_string(),
438
        ])
439
        .await
440
        .map_err(|e| {
441
            let error_response = serde_json::json!({
442
                "status": "error",
443
                "message": format_args!("{:?}", e)
444
            });
445
            (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
446
        })?;
447

            
448
    let access_cookie = Cookie::build(("access_token", ""))
449
        .path("/")
450
        .max_age(time::Duration::minutes(-1))
451
        .same_site(SameSite::Lax)
452
        .http_only(true);
453
    let refresh_cookie = Cookie::build(("refresh_token", ""))
454
        .path("/")
455
        .max_age(time::Duration::minutes(-1))
456
        .same_site(SameSite::Lax)
457
        .http_only(true);
458

            
459
    let logged_in_cookie = Cookie::build(("logged_in", "true"))
460
        .path("/")
461
        .max_age(time::Duration::minutes(-1))
462
        .same_site(SameSite::Lax)
463
        .http_only(false);
464

            
465
    let mut headers = HeaderMap::new();
466
    headers.append(
467
        header::SET_COOKIE,
468
        access_cookie.to_string().parse().unwrap(),
469
    );
470
    headers.append(
471
        header::SET_COOKIE,
472
        refresh_cookie.to_string().parse().unwrap(),
473
    );
474
    headers.append(
475
        header::SET_COOKIE,
476
        logged_in_cookie.to_string().parse().unwrap(),
477
    );
478
    headers.append("HX-Redirect", "/".to_string().parse().unwrap());
479

            
480
    let mut response = Response::new(json!({"status": "success"}).to_string());
481
    response.headers_mut().extend(headers);
482
    Ok(response)
483
}
484

            
485
pub async fn get_me_handler(
486
    Extension(jwtauth): Extension<JWTAuthMiddleware>,
487
) -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
488
    let json_response = serde_json::json!({
489
        "status":  "success",
490
        "data": serde_json::json!({
491
            "user": filter_user_record(&jwtauth.user)
492
        })
493
    });
494

            
495
    Ok(Json(json_response))
496
}
497

            
498
6
pub async fn get_version() -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
499
4
    if let Some(CmdResult::String(version)) = server::command::config::GetVersion::new()
500
4
        .run()
501
4
        .await
502
4
        .map_err(|_| {
503
            (
504
                StatusCode::INTERNAL_SERVER_ERROR,
505
                Json(json!("Can't get version")),
506
            )
507
        })?
508
4
        && let Some(CmdResult::String(build_date)) = server::command::config::GetBuildDate::new()
509
4
            .run()
510
4
            .await
511
4
            .map_err(|_| {
512
                (
513
                    StatusCode::INTERNAL_SERVER_ERROR,
514
                    Json(json!("Can't get version")),
515
                )
516
            })?
517
    {
518
4
        Ok(Response::new(format!(
519
4
            "<span class=\"version\">{version}</span><span class=\"build_date\" data-iso=\"{build_date}\"><script>
520
4
    const el = document.currentScript.parentElement;
521
4
    el.textContent = new Date(el.dataset.iso.trim()).toLocaleString();
522
4
  </script></span>"
523
4
        )))
524
    } else {
525
        Ok(Response::new("Unversioned".to_string()))
526
    }
527
4
}
528

            
529
3
pub async fn get_logout_link() -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
530
2
    Ok(Response::new(
531
2
        "<a hx-get=\"/api/auth/logout\">Logout</a>".to_string(),
532
2
    ))
533
2
}
534

            
535
pub async fn get_home_link() -> Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
536
    Ok(Response::new("<a href='/'>Back to main</a>".to_string()))
537
}
538

            
539
fn filter_user_record(user: &User) -> FilteredUser {
540
    FilteredUser {
541
        id: user.id.to_string(),
542
        email: user.email.clone(),
543
        name: user.name.clone(),
544
        photo: user.photo.clone(),
545
        role: user.role.clone(),
546
        verified: user.verified,
547
        createdAt: user.created_at.unwrap(),
548
        updatedAt: user.updated_at.unwrap(),
549
    }
550
}
551

            
552
fn generate_token(
553
    user_id: uuid::Uuid,
554
    max_age: i64,
555
    private_key: &str,
556
) -> Result<TokenDetails, (StatusCode, Json<serde_json::Value>)> {
557
    token::generate_jwt_token(user_id, max_age, private_key).map_err(|e| {
558
        let msg = format!("{} {}", t!("error generating token:"), e);
559
        log::error!("{msg}");
560
        let error_response = serde_json::json!({
561
            "status": "error",
562
            "message": msg,
563
        });
564
        (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
565
    })
566
}
567

            
568
async fn save_token_data_to_redis(
569
    data: &Arc<AppState>,
570
    token_details: &TokenDetails,
571
    max_age: i64,
572
) -> Result<(), (StatusCode, Json<serde_json::Value>)> {
573
    let mut redis_client = data
574
        .redis_client
575
        .get_multiplexed_async_connection()
576
        .await
577
        .map_err(|e| {
578
            let msg = format!("{} {}", t!("Redis error:"), e);
579
            log::error!("{msg}");
580
            let error_response = serde_json::json!({
581
                "status": "error",
582
                "message": msg,
583
            });
584
            (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response))
585
        })?;
586
    let _: bool = redis_client
587
        .set_ex(
588
            token_details.token_uuid.to_string(),
589
            token_details.user_id.to_string(),
590
            (max_age * 60) as u64,
591
        )
592
        .await
593
        .map_err(|e| {
594
            let error_response = serde_json::json!({
595
                "status": "error",
596
                "message": format_args!("{}", e),
597
            });
598
            (StatusCode::UNPROCESSABLE_ENTITY, Json(error_response))
599
        })?;
600
    Ok(())
601
}