use crate::auth::{OAuthClient, UserSession, refresh_access_token_internal}; use crate::config::Config; use crate::gsd::service::GSDService; use crate::repository::RedisRepository; use crate::util::set_and_log_session; use axum::extract::{Request, State}; use axum::http::{HeaderValue, StatusCode}; use axum::middleware::Next; use axum::response::{IntoResponse, Response}; use axum_extra::extract::CookieJar; use log::{error, info, warn}; use std::sync::Arc; #[derive(Clone)] pub struct AppState { pub config: Config, pub repository: RedisRepository, pub gsd_service: GSDService, pub oauth_client: OAuthClient, pub frontend_url: String, } /// Middleware to validate session and refresh tokens if needed pub async fn session_auth_middleware( jar: CookieJar, State(state): State>, mut request: Request, next: Next, ) -> Response { // 1. Extract session ID from cookie let session_id = match jar.get("session_id") { Some(cookie) => cookie.value().to_string(), None => { warn!("No session cookie found"); return (StatusCode::UNAUTHORIZED, "No session cookie").into_response(); } }; // 2. Find session in Redis let session_key = format!("user_session:{}", session_id); let session_json = match state.repository.get(&session_key).await { Ok(Some(json)) => json, Ok(None) => { warn!("Session not found in Redis: {}", session_id); return (StatusCode::UNAUTHORIZED, "Session expired or invalid").into_response(); } Err(e) => { error!("Redis error while fetching session: {:?}", e); return (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response(); } }; // 3. Parse session data let mut user_session: UserSession = match serde_json::from_str(&session_json) { Ok(session) => session, Err(e) => { error!("Failed to parse session JSON: {:?}", e); return (StatusCode::INTERNAL_SERVER_ERROR, "Invalid session data").into_response(); } }; // 4. Check if access token is expired let now = chrono::Utc::now().timestamp(); if user_session.expires_at <= now { info!( "Access token expired for session {}, attempting refresh", session_id ); // 5. Refresh the access token using refresh token match refresh_access_token_internal( &state.oauth_client, &state.repository, &session_id, &mut user_session, ) .await { Ok(new_access_token) => { info!( "Successfully refreshed access token for session {}", session_id ); user_session.access_token = new_access_token; } Err(e) => { error!("Failed to refresh access token: {}", e); // Clean up invalid session let _ = state.repository.delete(&session_key).await; return ( StatusCode::UNAUTHORIZED, "Session expired, please login again", ) .into_response(); } } } else { info!( "Access token still valid for session {} (expires in {} seconds)", session_id, user_session.expires_at - now ); } // 6. Attach validated access token to request for downstream handlers match HeaderValue::from_str(format!("Bearer {}", &user_session.access_token).as_str()) { Ok(header_value) => { request.headers_mut().insert("authorization", header_value); } Err(e) => { error!("Failed to create authorization header: {:?}", e); return (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response(); } } // 7. Pass the request to the next handler next.run(request).await } pub async fn gsd_decorate_header( State(state): State>, mut request: Request, next: Next, ) -> Response { let state_cloned = state.clone(); let session = state_cloned.repository.get_session().await; match session { Ok(session) => { let session_value; if session.is_none() { match state_cloned.gsd_service.get_session().await { Ok(session) => { session_value = session.clone(); set_and_log_session(&state_cloned, session.clone()).await; } Err(error) => { error!("Error getting session: {:?}", error); return StatusCode::UNAUTHORIZED.into_response(); } } } else { session_value = session.unwrap(); } request.headers_mut().insert( "sessionId", HeaderValue::from_str(session_value.as_str()).unwrap(), ); } Err(error) => { error!( "Redis error occured during fetching current session id. Error: {}", error ); return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } } request.headers_mut().insert( "appkey", HeaderValue::from_str(state_cloned.config.gsd_app_key.as_str()).unwrap(), ); next.run(request).await }