web/
redirect_middleware.rs1use axum::{
2 body::Body,
3 http::{Request, StatusCode, header},
4 middleware::Next,
5 response::{IntoResponse, Redirect, Response},
6};
7
8pub async fn redirect_on_auth_error(req: Request<Body>, next: Next) -> Response {
9 let accepts_html = req
10 .headers()
11 .get(header::ACCEPT)
12 .and_then(|h| h.to_str().ok())
13 .is_some_and(|accept| accept.contains("text/html"));
14
15 let response = next.run(req).await;
16
17 let should_redirect = accepts_html
19 && (response.status() == StatusCode::UNAUTHORIZED
20 || response.status() == StatusCode::FORBIDDEN);
21
22 if should_redirect {
23 Redirect::to("/").into_response()
24 } else {
25 response
26 }
27}
28
29#[cfg(test)]
30mod tests {
31 use super::*;
32 use axum::{Router, http::Method, middleware, response::Json, routing::get};
33 use serde_json::json;
34 use tower::ServiceExt;
35
36 async fn mock_handler_401() -> (StatusCode, Json<serde_json::Value>) {
37 (
38 StatusCode::UNAUTHORIZED,
39 Json(json!({"status": "fail", "message": "Unauthorized"})),
40 )
41 }
42
43 async fn mock_handler_200() -> &'static str {
44 "OK"
45 }
46
47 #[tokio::test]
48 async fn test_redirect_on_html_401() {
49 let app = Router::new()
50 .route("/test", get(mock_handler_401))
51 .layer(middleware::from_fn(redirect_on_auth_error));
52
53 let request = Request::builder()
54 .method(Method::GET)
55 .uri("/test")
56 .header(header::ACCEPT, "text/html,application/xhtml+xml")
57 .body(Body::empty())
58 .unwrap();
59
60 let response = app.oneshot(request).await.unwrap();
61
62 assert_eq!(response.status(), StatusCode::SEE_OTHER);
63 assert_eq!(response.headers().get(header::LOCATION).unwrap(), "/");
64 }
65
66 #[tokio::test]
67 async fn test_preserve_json_401() {
68 let app = Router::new()
69 .route("/test", get(mock_handler_401))
70 .layer(middleware::from_fn(redirect_on_auth_error));
71
72 let request = Request::builder()
73 .method(Method::GET)
74 .uri("/test")
75 .header(header::ACCEPT, "application/json")
76 .body(Body::empty())
77 .unwrap();
78
79 let response = app.oneshot(request).await.unwrap();
80
81 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
82 assert!(response.headers().get(header::LOCATION).is_none());
83 }
84
85 #[tokio::test]
86 async fn test_passthrough_200() {
87 let app = Router::new()
88 .route("/test", get(mock_handler_200))
89 .layer(middleware::from_fn(redirect_on_auth_error));
90
91 let request = Request::builder()
92 .method(Method::GET)
93 .uri("/test")
94 .header(header::ACCEPT, "text/html")
95 .body(Body::empty())
96 .unwrap();
97
98 let response = app.oneshot(request).await.unwrap();
99
100 assert_eq!(response.status(), StatusCode::OK);
101 assert!(response.headers().get(header::LOCATION).is_none());
102 }
103}