Lines
100 %
Functions
75 %
Branches
use axum::{
body::Body,
http::{Request, StatusCode, header},
middleware::Next,
response::{IntoResponse, Redirect, Response},
};
pub async fn redirect_on_auth_error(req: Request<Body>, next: Next) -> Response {
let accepts_html = req
.headers()
.get(header::ACCEPT)
.and_then(|h| h.to_str().ok())
.is_some_and(|accept| accept.contains("text/html"));
let response = next.run(req).await;
// Check if this is an authentication error that should redirect HTML requests
let should_redirect = accepts_html
&& (response.status() == StatusCode::UNAUTHORIZED
|| response.status() == StatusCode::FORBIDDEN);
if should_redirect {
Redirect::to("/").into_response()
} else {
response
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{Router, http::Method, middleware, response::Json, routing::get};
use serde_json::json;
use tower::ServiceExt;
async fn mock_handler_401() -> (StatusCode, Json<serde_json::Value>) {
(
StatusCode::UNAUTHORIZED,
Json(json!({"status": "fail", "message": "Unauthorized"})),
)
async fn mock_handler_200() -> &'static str {
"OK"
#[tokio::test]
async fn test_redirect_on_html_401() {
let app = Router::new()
.route("/test", get(mock_handler_401))
.layer(middleware::from_fn(redirect_on_auth_error));
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.header(header::ACCEPT, "text/html,application/xhtml+xml")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::SEE_OTHER);
assert_eq!(response.headers().get(header::LOCATION).unwrap(), "/");
async fn test_preserve_json_401() {
.header(header::ACCEPT, "application/json")
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert!(response.headers().get(header::LOCATION).is_none());
async fn test_passthrough_200() {
.route("/test", get(mock_handler_200))
.header(header::ACCEPT, "text/html")
assert_eq!(response.status(), StatusCode::OK);