Skip to main content

web/
redirect_middleware.rs

1use axum::{
2    body::Body,
3    http::{Request, StatusCode, header},
4    middleware::Next,
5    response::{IntoResponse, Redirect, Response},
6};
7
8pub async fn redirect_on_auth_error(req: Request<Body>, next: Next) -> Response {
9    let accepts_html = req
10        .headers()
11        .get(header::ACCEPT)
12        .and_then(|h| h.to_str().ok())
13        .is_some_and(|accept| accept.contains("text/html"));
14
15    let response = next.run(req).await;
16
17    // Check if this is an authentication error that should redirect HTML requests
18    let should_redirect = accepts_html
19        && (response.status() == StatusCode::UNAUTHORIZED
20            || response.status() == StatusCode::FORBIDDEN);
21
22    if should_redirect {
23        Redirect::to("/").into_response()
24    } else {
25        response
26    }
27}
28
29#[cfg(test)]
30mod tests {
31    use super::*;
32    use axum::{Router, http::Method, middleware, response::Json, routing::get};
33    use serde_json::json;
34    use tower::ServiceExt;
35
36    async fn mock_handler_401() -> (StatusCode, Json<serde_json::Value>) {
37        (
38            StatusCode::UNAUTHORIZED,
39            Json(json!({"status": "fail", "message": "Unauthorized"})),
40        )
41    }
42
43    async fn mock_handler_200() -> &'static str {
44        "OK"
45    }
46
47    #[tokio::test]
48    async fn test_redirect_on_html_401() {
49        let app = Router::new()
50            .route("/test", get(mock_handler_401))
51            .layer(middleware::from_fn(redirect_on_auth_error));
52
53        let request = Request::builder()
54            .method(Method::GET)
55            .uri("/test")
56            .header(header::ACCEPT, "text/html,application/xhtml+xml")
57            .body(Body::empty())
58            .unwrap();
59
60        let response = app.oneshot(request).await.unwrap();
61
62        assert_eq!(response.status(), StatusCode::SEE_OTHER);
63        assert_eq!(response.headers().get(header::LOCATION).unwrap(), "/");
64    }
65
66    #[tokio::test]
67    async fn test_preserve_json_401() {
68        let app = Router::new()
69            .route("/test", get(mock_handler_401))
70            .layer(middleware::from_fn(redirect_on_auth_error));
71
72        let request = Request::builder()
73            .method(Method::GET)
74            .uri("/test")
75            .header(header::ACCEPT, "application/json")
76            .body(Body::empty())
77            .unwrap();
78
79        let response = app.oneshot(request).await.unwrap();
80
81        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
82        assert!(response.headers().get(header::LOCATION).is_none());
83    }
84
85    #[tokio::test]
86    async fn test_passthrough_200() {
87        let app = Router::new()
88            .route("/test", get(mock_handler_200))
89            .layer(middleware::from_fn(redirect_on_auth_error));
90
91        let request = Request::builder()
92            .method(Method::GET)
93            .uri("/test")
94            .header(header::ACCEPT, "text/html")
95            .body(Body::empty())
96            .unwrap();
97
98        let response = app.oneshot(request).await.unwrap();
99
100        assert_eq!(response.status(), StatusCode::OK);
101        assert!(response.headers().get(header::LOCATION).is_none());
102    }
103}