1
use axum::{
2
    body::Body,
3
    http::{Request, StatusCode, header},
4
    middleware::Next,
5
    response::{IntoResponse, Redirect, Response},
6
};
7

            
8
12
pub async fn redirect_on_auth_error(req: Request<Body>, next: Next) -> Response {
9
12
    let accepts_html = req
10
12
        .headers()
11
12
        .get(header::ACCEPT)
12
12
        .and_then(|h| h.to_str().ok())
13
12
        .is_some_and(|accept| accept.contains("text/html"));
14

            
15
12
    let response = next.run(req).await;
16

            
17
    // Check if this is an authentication error that should redirect HTML requests
18
12
    let should_redirect = accepts_html
19
4
        && (response.status() == StatusCode::UNAUTHORIZED
20
2
            || response.status() == StatusCode::FORBIDDEN);
21

            
22
12
    if should_redirect {
23
2
        Redirect::to("/").into_response()
24
    } else {
25
10
        response
26
    }
27
12
}
28

            
29
#[cfg(test)]
30
mod tests {
31
    use super::*;
32
    use axum::{Router, http::Method, middleware, response::Json, routing::get};
33
    use serde_json::json;
34
    use tower::ServiceExt;
35

            
36
4
    async fn mock_handler_401() -> (StatusCode, Json<serde_json::Value>) {
37
4
        (
38
4
            StatusCode::UNAUTHORIZED,
39
4
            Json(json!({"status": "fail", "message": "Unauthorized"})),
40
4
        )
41
4
    }
42

            
43
2
    async fn mock_handler_200() -> &'static str {
44
2
        "OK"
45
2
    }
46

            
47
    #[tokio::test]
48
2
    async fn test_redirect_on_html_401() {
49
2
        let app = Router::new()
50
2
            .route("/test", get(mock_handler_401))
51
2
            .layer(middleware::from_fn(redirect_on_auth_error));
52

            
53
2
        let request = Request::builder()
54
2
            .method(Method::GET)
55
2
            .uri("/test")
56
2
            .header(header::ACCEPT, "text/html,application/xhtml+xml")
57
2
            .body(Body::empty())
58
2
            .unwrap();
59

            
60
2
        let response = app.oneshot(request).await.unwrap();
61

            
62
2
        assert_eq!(response.status(), StatusCode::SEE_OTHER);
63
2
        assert_eq!(response.headers().get(header::LOCATION).unwrap(), "/");
64
2
    }
65

            
66
    #[tokio::test]
67
2
    async fn test_preserve_json_401() {
68
2
        let app = Router::new()
69
2
            .route("/test", get(mock_handler_401))
70
2
            .layer(middleware::from_fn(redirect_on_auth_error));
71

            
72
2
        let request = Request::builder()
73
2
            .method(Method::GET)
74
2
            .uri("/test")
75
2
            .header(header::ACCEPT, "application/json")
76
2
            .body(Body::empty())
77
2
            .unwrap();
78

            
79
2
        let response = app.oneshot(request).await.unwrap();
80

            
81
2
        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
82
2
        assert!(response.headers().get(header::LOCATION).is_none());
83
2
    }
84

            
85
    #[tokio::test]
86
2
    async fn test_passthrough_200() {
87
2
        let app = Router::new()
88
2
            .route("/test", get(mock_handler_200))
89
2
            .layer(middleware::from_fn(redirect_on_auth_error));
90

            
91
2
        let request = Request::builder()
92
2
            .method(Method::GET)
93
2
            .uri("/test")
94
2
            .header(header::ACCEPT, "text/html")
95
2
            .body(Body::empty())
96
2
            .unwrap();
97

            
98
2
        let response = app.oneshot(request).await.unwrap();
99

            
100
2
        assert_eq!(response.status(), StatusCode::OK);
101
2
        assert!(response.headers().get(header::LOCATION).is_none());
102
2
    }
103
}