Implemented authentication with Keycloak

This commit is contained in:
Dennis Nemec
2025-10-02 20:22:11 +02:00
parent e8954ba5c1
commit b87d7e0268
15 changed files with 1697 additions and 94 deletions

4
.env.example Normal file
View File

@ -0,0 +1,4 @@
POSTGRES_USER="admin"
POSTGRES_PASSWORD="admin"
KEYCLOAK_ADMIN_PASSWORD="admin"
KC_HOSTNAME="localhost"

874
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -5,8 +5,8 @@ edition = "2024"
[dependencies] [dependencies]
axum = "0.8.6" axum = "0.8.6"
axum-keycloak-auth = "0.8.3"
chrono = "0.4.42" chrono = "0.4.42"
http = "1.3.1"
log = "0.4.28" log = "0.4.28"
redis = { version = "0.32.6", features = ["connection-manager", "tokio-comp"] } redis = { version = "0.32.6", features = ["connection-manager", "tokio-comp"] }
reqwest = { version = "0.12.23", features = ["json"] } reqwest = { version = "0.12.23", features = ["json"] }
@ -15,3 +15,6 @@ serde_json = "1.0.145"
simplelog = "0.12.2" simplelog = "0.12.2"
tokio = { version = "1.47.1", features = ["full"] } tokio = { version = "1.47.1", features = ["full"] }
toml = "0.9.7" toml = "0.9.7"
oauth2 = "5.0.0"
uuid = "1.18.1"
axum-extra = { version = "0.10.3", features = ["cookie"] }

View File

@ -3,7 +3,16 @@ host_ip = "127.0.0.1"
host_port = 3000 host_port = 3000
redis_url = "redis://127.0.0.1:6379" redis_url = "redis://127.0.0.1:6379"
gsd_app_key = "GSD-RestApi" gsd_app_key = "GSD-RestApi"
frontend_url = "http://127.0.0.1:3000"
gsd_rest_url = "http://192.168.1.9:8334" gsd_rest_url = "http://192.168.1.9:8334"
gsd_user = "GSDWebServiceTmp" gsd_user = "GSDWebServiceTmp"
gsd_password = "<PASSWORD>" gsd_password = "<PASSWORD>"
gsd_app_names = ["GSD-RestApi"] gsd_app_names = ["GSD-RestApi"]
[keycloak]
realm_url = "http://localhost:8080/realms/master"
client_id = "delivery-app"
client_secret = "<SECRET>"
auth_url = "http://localhost:8080/realms/master/protocol/openid-connect/auth"
token_url = "http://localhost:8080/realms/master/protocol/openid-connect/token"
redirect_url = "http://127.0.0.1:3000/callback"

View File

@ -23,7 +23,7 @@ services:
dockerfile: Dockerfile dockerfile: Dockerfile
container_name: rust-microservice container_name: rust-microservice
ports: ports:
- "8080:8080" - "3000:8080"
environment: environment:
- REDIS_URL=redis://redis:6379 - REDIS_URL=redis://redis:6379
- RUST_LOG=info - RUST_LOG=info
@ -34,6 +34,40 @@ services:
- app-network - app-network
restart: unless-stopped restart: unless-stopped
keycloak_web:
image: quay.io/keycloak/keycloak:23.0.7
container_name: keycloak_web
environment:
KC_DB: postgres
KC_DB_URL: jdbc:postgresql://keycloakdb:5432/keycloak
KC_DB_USERNAME: ${POSTGRES_USER}
KC_DB_PASSWORD: ${POSTGRES_PASSWORD}
KC_HOSTNAME: localhost
KC_HOSTNAME_PORT: 8080
KC_HOSTNAME_STRICT: false
KC_HOSTNAME_STRICT_HTTPS: false
KC_LOG_LEVEL: info
KC_METRICS_ENABLED: true
KC_HEALTH_ENABLED: true
KEYCLOAK_ADMIN: admin
KEYCLOAK_ADMIN_PASSWORD: ${KEYCLOAK_ADMIN_PASSWORD}
command: start-dev
depends_on:
- keycloakdb
ports:
- 8080:8080
keycloakdb:
image: postgres:15
volumes:
- postgres_data:/var/lib/postgresql/data
environment:
POSTGRES_DB: keycloak
POSTGRES_USER: ${POSTGRES_USER}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD}
networks: networks:
app-network: app-network:
driver: bridge driver: bridge
@ -41,3 +75,5 @@ networks:
volumes: volumes:
redis-data: redis-data:
driver: local driver: local
postgres_data:
driver: local

View File

@ -1,26 +1,97 @@
use crate::gsd::dto::GSDResponseDTO;
use crate::middleware::AppState; use crate::middleware::AppState;
use crate::util::set_and_log_session;
use axum::Extension; use axum::Extension;
use axum::extract::Request;
use axum::response::IntoResponse;
use http::StatusCode;
use log::error;
use std::sync::Arc;
use axum::body::Body; use axum::body::Body;
use axum::extract::Request;
use axum::http::{HeaderValue, StatusCode};
use axum::response::IntoResponse;
use log::{error, info};
use std::sync::Arc;
pub async fn handle_post( pub async fn handle_post(
Extension(state): Extension<Arc<AppState>>, Extension(state): Extension<Arc<AppState>>,
request: Request<Body>, request: Request<Body>,
) -> impl IntoResponse { ) -> impl IntoResponse {
match state.clone().gsd_service.forward_post_request(request).await { let cloned_state = state.clone();
Ok(e) => e.text().await.unwrap().into_response(), let (mut parts, body) = request.into_parts();
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
Err(e) => { let mut forwarded_request = cloned_state
error!("Failed to forward post: {:?}", e); .gsd_service
StatusCode::INTERNAL_SERVER_ERROR.into_response() .forward_post_request(Request::from_parts(
parts.clone(),
Body::from(body_bytes.clone()),
))
.await;
if forwarded_request.is_err() {
error!(
"Failed to forward post: {:?}",
forwarded_request.err().unwrap()
);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
} }
let content_text = forwarded_request.unwrap().text().await;
if content_text.is_err() {
error!("Failed to read content text: {:?}", content_text.err());
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
let content = serde_json::from_str::<GSDResponseDTO>(content_text.as_ref().unwrap().as_str());
if content.is_err() {
error!("Failed to read content json: {:?}", content.err());
error!("Content: {:?}", content_text);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
let content_unwrapped = content.unwrap();
// Invalid session
if content_unwrapped.status.is_some()
&& content_unwrapped.status.unwrap().internal_status == "201"
{
info!("Session invalid. Re-negotiate new session");
match cloned_state.gsd_service.get_session().await {
Ok(session) => {
set_and_log_session(&cloned_state, session.clone()).await;
parts.headers.remove("sessionId");
parts.headers.insert(
"sessionId",
HeaderValue::from_str(session.clone().as_str()).unwrap(),
);
forwarded_request = cloned_state
.gsd_service
.forward_post_request(Request::from_parts(
parts.clone(),
Body::from(body_bytes.clone()),
))
.await;
if let Err(e) = &forwarded_request {
error!("Redis: failed to forward post: {:?}", e);
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
forwarded_request
.unwrap()
.text()
.await
.unwrap()
.into_response()
}
Err(error) => {
error!("Error getting session: {:?}", error);
StatusCode::UNAUTHORIZED.into_response()
}
}
} else {
content_text.unwrap().into_response()
} }
} }
pub async fn handle_login() -> impl IntoResponse { pub async fn handle_login() -> impl IntoResponse {}
}

431
src/auth.rs Normal file
View File

@ -0,0 +1,431 @@
use crate::config::Config;
use crate::middleware::AppState;
use crate::repository::RedisRepository;
use axum::http::{StatusCode, header};
use axum::response::Response;
use axum::{
Router,
extract::{Query, State},
response::{IntoResponse, Redirect},
routing::get,
};
use axum_extra::extract::CookieJar;
use oauth2::basic::{
BasicErrorResponse, BasicRevocationErrorResponse, BasicTokenIntrospectionResponse,
BasicTokenResponse,
};
use oauth2::{
AuthUrl, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken, EndpointNotSet,
EndpointSet, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, StandardRevocableToken,
TokenResponse, TokenUrl, basic::BasicClient,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use axum::routing::post;
pub type OAuthClient = Client<
BasicErrorResponse,
BasicTokenResponse,
BasicTokenIntrospectionResponse,
StandardRevocableToken,
BasicRevocationErrorResponse,
EndpointSet,
EndpointNotSet,
EndpointNotSet,
EndpointNotSet,
EndpointSet,
>;
pub fn router(state: Arc<AppState>) -> Router {
Router::new()
.route("/login", get(login))
.route("/callback", get(callback))
.route("/logout", post(logout))
.with_state(state)
}
async fn login(State(client): State<Arc<AppState>>) -> impl IntoResponse {
let cloned_client = client.clone();
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let csrf_token = CsrfToken::new_random();
// Store the PKCE verifier in Redis with CSRF token as key
let redis_key = format!("pkce_verifier:{}", csrf_token.secret());
let verifier_secret = pkce_verifier.secret().to_string();
match cloned_client
.repository
.set_with_expiry(&redis_key, &verifier_secret, 600) // 10 minutes expiry
.await
{
Ok(_) => {
let (auth_url, _) = cloned_client
.oauth_client
.authorize_url(|| csrf_token)
.add_scope(Scope::new("openid".to_string()))
.add_scope(Scope::new("profile".to_string()))
.add_scope(Scope::new("email".to_string()))
.set_pkce_challenge(pkce_challenge)
.url();
Redirect::to(auth_url.as_str()).into_response()
}
Err(e) => {
log::error!("Failed to store PKCE verifier in Redis: {:?}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to initiate login",
)
.into_response()
}
}
}
#[derive(Deserialize)]
pub struct Callback {
code: String,
state: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct UserSession {
pub(crate) access_token: String,
pub(crate) refresh_token: String,
pub(crate) expires_at: i64,
}
async fn callback(
State(client): State<Arc<AppState>>,
Query(query): Query<Callback>,
) -> impl IntoResponse {
let http_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("Client should build");
let cloned_state = client.clone();
// Retrieve the PKCE verifier from Redis using CSRF token
let redis_key = format!("pkce_verifier:{}", query.state);
let verifier_secret = match cloned_state.repository.get(&redis_key).await {
Ok(Some(secret)) => secret,
Ok(None) => {
log::error!("PKCE verifier not found for state: {}", query.state);
return (
StatusCode::BAD_REQUEST,
"Invalid or expired login session. Please try again.",
)
.into_response();
}
Err(e) => {
log::error!("Failed to retrieve PKCE verifier from Redis: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Login failed").into_response();
}
};
// Delete the verifier from Redis (one-time use)
let _ = cloned_state.repository.delete(&redis_key).await;
let pkce_verifier = PkceCodeVerifier::new(verifier_secret);
let token_result = cloned_state
.oauth_client
.exchange_code(AuthorizationCode::new(query.code))
.set_pkce_verifier(pkce_verifier)
.request_async(&http_client)
.await;
match token_result {
Ok(token) => {
let access_token = token.access_token().secret();
let refresh_token = token
.refresh_token()
.map(|rt| rt.secret().to_string())
.unwrap_or_else(|| "No refresh token".to_string());
let expires_at = chrono::Utc::now().timestamp()
+ token
.expires_in()
.map(|d| d.as_secs() as i64)
.unwrap_or(3600);
// ============================================
// 1. GENERATE A UNIQUE SESSION ID
// ============================================
let session_id = uuid::Uuid::now_v7().to_string();
// ============================================
// 2. CREATE THE USER SESSION STRUCT
// ============================================
let user_session = UserSession {
access_token: access_token.clone(),
refresh_token: refresh_token.clone(),
expires_at,
};
// ============================================
// 3. SERIALIZE THE SESSION TO JSON
// ============================================
let session_key = format!("user_session:{}", session_id);
let session_json = match serde_json::to_string(&user_session) {
Ok(json) => json,
Err(e) => {
log::error!("Failed to serialize user session: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Login failed").into_response();
}
};
// ============================================
// 4. STORE IN REDIS WITH 24 HOUR EXPIRATION
// This is where the tokens are actually stored!
// ============================================
if let Err(e) = cloned_state
.repository
.set_with_expiry(&session_key, &session_json, 86400) // 86400 = 24 hours
.await
{
log::error!("Failed to store user session in Redis: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Login failed").into_response();
}
log::info!("Successfully created session {} for user", session_id);
let cookie = format!(
"session_id={}; Path=/; HttpOnly; SameSite=Lax; Max-Age=86400",
session_id
);
// 4. Redirect to frontend
let redirect_url = format!("{}?login=success", cloned_state.frontend_url);
Response::builder()
.status(StatusCode::FOUND)
.header(header::SET_COOKIE, cookie)
.header(header::LOCATION, redirect_url.clone())
.body::<String>(format!("Redirecting to {}", redirect_url).into())
.unwrap()
.into_response()
}
Err(e) => {
log::error!("Token exchange failed: {:?}", e);
(StatusCode::UNAUTHORIZED, format!("Login failed: {:?}", e)).into_response()
}
}
}
pub fn create_oauth_client(config: &Config) -> OAuthClient {
BasicClient::new(ClientId::new(config.keycloak.client_id.clone()))
.set_client_secret(ClientSecret::new(config.keycloak.client_secret.clone()))
.set_redirect_uri(RedirectUrl::new(config.keycloak.redirect_url.clone()).unwrap())
.set_token_uri(TokenUrl::new(config.keycloak.token_url.clone()).unwrap())
.set_auth_uri(AuthUrl::new(config.keycloak.auth_url.clone()).unwrap())
}
/// Internal helper to refresh access token
pub async fn refresh_access_token_internal(
client: &OAuthClient,
repository: &RedisRepository,
session_id: &str,
user_session: &mut UserSession,
) -> Result<String, String> {
use oauth2::{RefreshToken, TokenResponse};
let http_client = reqwest::ClientBuilder::new()
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("Client should build");
let refresh_token = &user_session.refresh_token;
// Exchange refresh token for new access token
let token_result = client
.exchange_refresh_token(&RefreshToken::new(refresh_token.clone()))
.request_async(&http_client)
.await
.map_err(|e| format!("Token refresh request failed: {:?}", e))?;
// Update session with new tokens
let new_access_token = token_result.access_token().secret().to_string();
user_session.access_token = new_access_token.clone();
// Update refresh token if a new one was provided
if let Some(new_refresh_token) = token_result.refresh_token() {
user_session.refresh_token = new_refresh_token.secret().to_string();
}
// Update expiration time
user_session.expires_at = chrono::Utc::now().timestamp()
+ token_result
.expires_in()
.map(|d| d.as_secs() as i64)
.unwrap_or(3600);
// Save updated session back to Redis
let session_key = format!("user_session:{}", session_id);
let updated_json = serde_json::to_string(&user_session)
.map_err(|e| format!("Failed to serialize session: {:?}", e))?;
repository
.set_with_expiry(&session_key, &updated_json, 86400)
.await
.map_err(|e| format!("Failed to update session in Redis: {:?}", e))?;
Ok(new_access_token)
}
async fn logout(jar: CookieJar, State(oauth_state): State<Arc<AppState>>) -> impl IntoResponse {
// 1. Extract session ID from cookie
let session_id = match jar.get("session_id") {
Some(cookie) => cookie.value().to_string(),
None => {
log::warn!("Logout attempted without session cookie");
return (StatusCode::BAD_REQUEST, "No active session").into_response();
}
};
// 2. Get session from Redis to retrieve tokens
let session_key = format!("user_session:{}", session_id);
let session_json = match oauth_state.repository.get(&session_key).await {
Ok(Some(json)) => json,
Ok(None) => {
log::warn!("Session not found in Redis: {}", session_id);
// Session already gone, just clear cookie
return clear_session_cookie().into_response();
}
Err(e) => {
log::error!("Redis error while fetching session for logout: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Logout failed").into_response();
}
};
let user_session: UserSession = match serde_json::from_str(&session_json) {
Ok(session) => session,
Err(e) => {
log::error!("Failed to parse session JSON during logout: {:?}", e);
// Clean up anyway
let _ = oauth_state.repository.delete(&session_key).await;
return clear_session_cookie().into_response();
}
};
// 3. Revoke tokens at Keycloak
let revoke_result = revoke_tokens_at_keycloak(
&oauth_state,
&user_session.access_token,
&user_session.refresh_token,
)
.await;
if let Err(e) = revoke_result {
log::error!("Failed to revoke tokens at Keycloak: {}", e);
// Continue anyway - we'll still delete local session
} else {
log::info!(
"Successfully revoked tokens at Keycloak for session {}",
session_id
);
}
// 4. Delete session from Redis
match oauth_state.repository.delete(&session_key).await {
Ok(_) => {
log::info!("Successfully deleted session {} from Redis", session_id);
}
Err(e) => {
log::error!("Failed to delete session from Redis: {:?}", e);
}
}
// 5. Clear session cookie and respond
clear_session_cookie().into_response()
}
/// Helper function to revoke tokens at Keycloak's revocation endpoint
async fn revoke_tokens_at_keycloak(
oauth_state: &Arc<AppState>,
access_token: &str,
refresh_token: &str,
) -> Result<(), String> {
// Get client credentials from OAuth client
let client_id = oauth_state.oauth_client.client_id().as_str();
let client_secret = oauth_state.config.keycloak.client_secret.as_str();
// Build revocation endpoint URL
// Keycloak's revocation endpoint is typically at:
// {realm_url}/protocol/openid-connect/revoke
let token_url = oauth_state.config.keycloak.token_url.as_str();
// Replace /token with /revoke
let revoke_url = token_url.replace("/token", "/revoke");
log::info!("Revoking tokens at: {}", revoke_url);
let client = reqwest::Client::new();
// Revoke refresh token (this also invalidates the access token)
let revoke_refresh_result = client
.post(&revoke_url)
.form(&[
("token", refresh_token),
("token_type_hint", "refresh_token"),
("client_id", client_id),
("client_secret", client_secret),
])
.send()
.await
.map_err(|e| format!("Failed to send revoke request: {:?}", e))?;
if !revoke_refresh_result.status().is_success() {
let status = revoke_refresh_result.status();
let body = revoke_refresh_result
.text()
.await
.unwrap_or_else(|_| "Unable to read response".to_string());
log::warn!(
"Token revocation returned non-success status {}: {}",
status,
body
);
// Note: Keycloak returns 200 even if token is already invalid, so this is unusual
}
// Optionally also revoke access token explicitly
let revoke_access_result = client
.post(&revoke_url)
.form(&[
("token", access_token),
("token_type_hint", "access_token"),
("client_id", client_id),
("client_secret", client_secret),
])
.send()
.await
.map_err(|e| format!("Failed to send revoke request for access token: {:?}", e))?;
if !revoke_access_result.status().is_success() {
let status = revoke_access_result.status();
let body = revoke_access_result
.text()
.await
.unwrap_or_else(|_| "Unable to read response".to_string());
log::warn!(
"Access token revocation returned non-success status {}: {}",
status,
body
);
}
Ok(())
}
/// Helper function to create a response that clears the session cookie
fn clear_session_cookie() -> Response {
// Set cookie with Max-Age=0 to delete it
let clear_cookie = "session_id=; Path=/; HttpOnly; SameSite=Lax; Max-Age=0";
Response::builder()
.status(StatusCode::OK)
.header(header::SET_COOKIE, clear_cookie)
.body("Logged out successfully".into())
.unwrap()
}

View File

@ -11,12 +11,26 @@ pub struct Config {
pub host_port: u16, pub host_port: u16,
pub redis_url: String, pub redis_url: String,
pub frontend_url: String,
// GSD RestAPI configuration // GSD RestAPI configuration
pub gsd_app_key: String, pub gsd_app_key: String,
pub gsd_rest_url: String, pub gsd_rest_url: String,
pub gsd_user: String, pub gsd_user: String,
pub gsd_password: String, pub gsd_password: String,
pub gsd_app_names: Vec<String>, pub gsd_app_names: Vec<String>,
pub keycloak: Keycloak,
}
#[derive(serde::Serialize, serde::Deserialize, Clone)]
pub struct Keycloak {
pub realm_url: String,
pub client_id: String,
pub client_secret: String,
pub auth_url: String,
pub token_url: String,
pub redirect_url: String,
} }
impl Config { impl Config {
@ -58,8 +72,22 @@ pub fn create_standard_config() -> Config {
redis_url: String::from("redis://127.0.0.1:6379"), redis_url: String::from("redis://127.0.0.1:6379"),
gsd_rest_url: String::from("http://127.0.0.1:8334"), gsd_rest_url: String::from("http://127.0.0.1:8334"),
gsd_app_key: String::from("GSD-RestApi"), gsd_app_key: String::from("GSD-RestApi"),
frontend_url: String::from("http://127.0.0.1:3000"),
gsd_app_names: vec![String::from("GSD-RestApi")], gsd_app_names: vec![String::from("GSD-RestApi")],
gsd_user: String::from("<GSD-USER>"), gsd_user: String::from("<GSD-USER>"),
gsd_password: String::from("<GSD-Password>"), gsd_password: String::from("<GSD-Password>"),
keycloak: Keycloak {
realm_url: String::from("http://127.0.0.1:8080/auth/realms/master"),
client_id: String::from("delivery-backend"),
client_secret: String::from(""),
auth_url: String::from(
"http://127.0.0.1:8080/auth/realms/master/protocol/openid-connect/auth",
),
token_url: String::from(
"http://127.0.0.1:8080/auth/realms/master/protocol/openid-connect/token",
),
redirect_url: String::from("http://127.0.0.1:3000/callback"),
},
} }
} }

2
src/gsd.rs Normal file
View File

@ -0,0 +1,2 @@
pub(crate) mod dto;
pub(crate) mod service;

27
src/gsd/dto.rs Normal file
View File

@ -0,0 +1,27 @@
#[derive(serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GSDLoginRequestDTO {
pub user: String,
pub pass: String,
pub app_names: Vec<String>,
}
#[derive(serde::Deserialize, serde::Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct GSDResponseDTO {
pub status: Option<GSDResponseStatusDTO>,
pub data: Option<GSDLoginResponseDataDTO>,
}
#[derive(serde::Deserialize, serde::Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct GSDLoginResponseDataDTO {
pub session_id: String,
}
#[derive(serde::Deserialize, serde::Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct GSDResponseStatusDTO {
pub internal_status: String,
pub status_message: String,
}

View File

@ -1,37 +1,10 @@
use crate::config::Config;
use crate::gsd::dto::*;
use axum::body::Body; use axum::body::Body;
use axum::extract::Request; use axum::extract::Request;
use log::{error, info}; use log::{error, info};
use crate::config::Config;
use reqwest::Response; use reqwest::Response;
#[derive(serde::Deserialize, serde::Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GSDLoginRequestDTO {
user: String,
pass: String,
app_names: Vec<String>,
}
#[derive(serde::Deserialize, serde::Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct GSDLoginResponseDTO {
status: GSDLoginResponseStatusDTO,
data: Option<GSDLoginResponseDataDTO>,
}
#[derive(serde::Deserialize, serde::Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct GSDLoginResponseDataDTO {
session_id: String,
}
#[derive(serde::Deserialize, serde::Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct GSDLoginResponseStatusDTO {
internal_status: String,
status_message: String,
}
#[derive(Clone)] #[derive(Clone)]
pub struct GSDService { pub struct GSDService {
host_url: String, host_url: String,
@ -50,7 +23,10 @@ pub enum GSDServiceError {
impl GSDService { impl GSDService {
pub async fn get_session(&self) -> Result<String, GSDServiceError> { pub async fn get_session(&self) -> Result<String, GSDServiceError> {
info!("Session: No session found. Generate session from GSD server {}", self.host_url); info!(
"Session: No session found. Generate session from GSD server {}",
self.host_url
);
let dto = GSDLoginRequestDTO { let dto = GSDLoginRequestDTO {
user: self.username.clone(), user: self.username.clone(),
@ -70,31 +46,39 @@ impl GSDService {
GSDServiceError::LoginFailed GSDServiceError::LoginFailed
})?; })?;
let response_dto: GSDLoginResponseDTO = response let response_dto: GSDResponseDTO = response.json().await.map_err(|e| {
.json()
.await
.map_err(|e| {
error!("Session: error request to GSD: {}", e); error!("Session: error request to GSD: {}", e);
GSDServiceError::LoginResponseParsingFailed GSDServiceError::LoginResponseParsingFailed
})?; })?;
if response_dto.status.internal_status != "0" { let response_dto_unwrapped = response_dto.status.unwrap();
error!("Session: error message from GSD: {}", response_dto.status.status_message);
if response_dto_unwrapped.internal_status != "0" {
error!(
"Session: error message from GSD: {}",
response_dto_unwrapped.status_message
);
Err(GSDServiceError::LoginFailed) Err(GSDServiceError::LoginFailed)
} else { } else {
match response_dto.data { match response_dto.data {
Some(data) => { Some(data) => {
info!("Session: successfully obtained session with session id {}", &data.session_id); info!(
"Session: successfully obtained session with session id {}",
&data.session_id
);
Ok(data.session_id.clone()) Ok(data.session_id.clone())
}, }
None => { None => {
error!("Session: failed to obtain session id. No session id in request found."); error!("Session: failed to obtain session id. No session id in request found.");
Err(GSDServiceError::LoginResponseParsingFailed) Err(GSDServiceError::LoginResponseParsingFailed)
}, }
} }
} }
} }
pub async fn forward_post_request(&self, request: Request<Body>) -> Result<Response, GSDServiceError> { pub async fn forward_post_request(
&self,
request: Request<Body>,
) -> Result<Response, GSDServiceError> {
let (parts, body) = request.into_parts(); let (parts, body) = request.into_parts();
reqwest::Client::new() reqwest::Client::new()

View File

@ -1,18 +1,22 @@
use crate::api::{handle_login, handle_post}; use crate::api::handle_post;
use crate::config::load_config; use crate::config::load_config;
use crate::middleware::AppState; use crate::middleware::AppState;
use crate::repository::RedisRepository; use crate::repository::RedisRepository;
use crate::util::initialize_logging; use crate::util::initialize_logging;
use axum::routing::post; use axum::routing::post;
use axum::{Extension, Router}; use axum::{Extension, Router};
use axum_keycloak_auth::instance::{KeycloakAuthInstance, KeycloakConfig};
use axum_keycloak_auth::layer::KeycloakAuthLayer;
use axum_keycloak_auth::{PassthroughMode, Url};
use log::info; use log::info;
use std::sync::Arc; use std::sync::Arc;
mod api; mod api;
mod auth;
mod config; mod config;
mod gsd;
mod middleware; mod middleware;
mod repository; mod repository;
mod service_gsd;
mod util; mod util;
#[tokio::main] #[tokio::main]
@ -26,27 +30,53 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let redis_url = config.redis_url.clone(); let redis_url = config.redis_url.clone();
let host_url = config.get_host_url().clone(); let host_url = config.get_host_url().clone();
info!("Initializing redis server");
let state = Arc::new(AppState { let state = Arc::new(AppState {
config: config.clone(), config: config.clone(),
repository: RedisRepository::try_new(redis_url).await?, repository: RedisRepository::try_new(redis_url).await?,
gsd_service: (&config).into(), gsd_service: (&config).into(),
oauth_client: auth::create_oauth_client(&config),
frontend_url: config.frontend_url.clone(),
}); });
let app = Router::new() info!("Starting axum server");
.route("/login", post(handle_login))
let keycloak_instance: Arc<KeycloakAuthInstance> = Arc::new(KeycloakAuthInstance::new(
KeycloakConfig::builder()
.server(Url::parse("http://localhost:8080/").unwrap())
.realm(String::from("master"))
.build(),
));
let auth_router = auth::router(state.clone());
let proxy_router = Router::new()
.route("/{*wildcard}", post(handle_post)) .route("/{*wildcard}", post(handle_post))
.layer(Extension(state.clone())) .route_layer(Extension(state.clone()))
.route_layer(axum::middleware::from_fn_with_state( .route_layer(axum::middleware::from_fn_with_state(
state.clone(), state.clone(),
middleware::gsd_add_header, middleware::gsd_decorate_header,
)) ))
.route_layer(
KeycloakAuthLayer::<String>::builder()
.instance(keycloak_instance.clone())
.passthrough_mode(PassthroughMode::Block)
.persist_raw_claims(false)
.expected_audiences(vec![String::from("account")])
//.required_roles(vec![])
.build(),
)
.route_layer(axum::middleware::from_fn_with_state( .route_layer(axum::middleware::from_fn_with_state(
state.clone(), state.clone(),
middleware::auth_middleware, middleware::session_auth_middleware,
)) ))
.with_state(state); .with_state(state);
let listener = tokio::net::TcpListener::bind(host_url).await.unwrap(); let app = Router::new().merge(proxy_router).merge(auth_router);
info!("Listening on {}", host_url);
let listener = tokio::net::TcpListener::bind(host_url.clone())
.await
.unwrap();
axum::serve(listener, app).await.unwrap(); axum::serve(listener, app).await.unwrap();

View File

@ -1,11 +1,14 @@
use crate::auth::{OAuthClient, UserSession, refresh_access_token_internal};
use crate::config::Config; use crate::config::Config;
use crate::gsd::service::GSDService;
use crate::repository::RedisRepository; use crate::repository::RedisRepository;
use crate::service_gsd::GSDService; use crate::util::set_and_log_session;
use axum::extract::{Request, State}; use axum::extract::{Request, State};
use axum::http::{HeaderValue, StatusCode}; use axum::http::{HeaderValue, StatusCode};
use axum::middleware::Next; use axum::middleware::Next;
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};
use log::{error, info}; use axum_extra::extract::CookieJar;
use log::{error, info, warn};
use std::sync::Arc; use std::sync::Arc;
#[derive(Clone)] #[derive(Clone)]
@ -13,6 +16,8 @@ pub struct AppState {
pub config: Config, pub config: Config,
pub repository: RedisRepository, pub repository: RedisRepository,
pub gsd_service: GSDService, pub gsd_service: GSDService,
pub oauth_client: OAuthClient,
pub frontend_url: String,
} }
pub async fn auth_middleware( pub async fn auth_middleware(
@ -23,12 +28,110 @@ pub async fn auth_middleware(
next.run(request).await next.run(request).await
} }
pub async fn gsd_add_header( /// Middleware to validate session and refresh tokens if needed
pub async fn session_auth_middleware(
jar: CookieJar,
State(state): State<Arc<AppState>>,
mut request: Request,
next: Next,
) -> Response {
// 1. Extract session ID from cookie
let session_id = match jar.get("session_id") {
Some(cookie) => cookie.value().to_string(),
None => {
warn!("No session cookie found");
return (StatusCode::UNAUTHORIZED, "No session cookie").into_response();
}
};
// 2. Find session in Redis
let session_key = format!("user_session:{}", session_id);
let session_json = match state.repository.get(&session_key).await {
Ok(Some(json)) => json,
Ok(None) => {
warn!("Session not found in Redis: {}", session_id);
return (StatusCode::UNAUTHORIZED, "Session expired or invalid").into_response();
}
Err(e) => {
error!("Redis error while fetching session: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response();
}
};
// 3. Parse session data
let mut user_session: UserSession = match serde_json::from_str(&session_json) {
Ok(session) => session,
Err(e) => {
error!("Failed to parse session JSON: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Invalid session data").into_response();
}
};
// 4. Check if access token is expired
let now = chrono::Utc::now().timestamp();
if user_session.expires_at <= now {
info!(
"Access token expired for session {}, attempting refresh",
session_id
);
// 5. Refresh the access token using refresh token
match refresh_access_token_internal(
&state.oauth_client,
&state.repository,
&session_id,
&mut user_session,
)
.await
{
Ok(new_access_token) => {
info!(
"Successfully refreshed access token for session {}",
session_id
);
user_session.access_token = new_access_token;
}
Err(e) => {
error!("Failed to refresh access token: {}", e);
// Clean up invalid session
let _ = state.repository.delete(&session_key).await;
return (
StatusCode::UNAUTHORIZED,
"Session expired, please login again",
)
.into_response();
}
}
} else {
info!(
"Access token still valid for session {} (expires in {} seconds)",
session_id,
user_session.expires_at - now
);
}
// 6. Attach validated access token to request for downstream handlers
match HeaderValue::from_str(format!("Bearer {}", &user_session.access_token).as_str()) {
Ok(header_value) => {
request.headers_mut().insert("authorization", header_value);
}
Err(e) => {
error!("Failed to create authorization header: {:?}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response();
}
}
// 7. Pass the request to the next handler
next.run(request).await
}
pub async fn gsd_decorate_header(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
mut request: Request, mut request: Request,
next: Next, next: Next,
) -> Response { ) -> Response {
let state_cloned = state.clone(); let state_cloned = state.clone();
info!("Gsd decorate header");
let session = state_cloned.repository.get_session().await; let session = state_cloned.repository.get_session().await;
match session { match session {
@ -39,16 +142,7 @@ pub async fn gsd_add_header(
match state_cloned.gsd_service.get_session().await { match state_cloned.gsd_service.get_session().await {
Ok(session) => { Ok(session) => {
session_value = session.clone(); session_value = session.clone();
set_and_log_session(&state_cloned, session.clone()).await;
match state_cloned.repository.set_session(session.clone()).await {
Ok(_) => {
info!("Redis: saved session {}", &session);
}
Err(err) => {
error!("Redis: failed to save session: {}", err);
}
}
} }
Err(error) => { Err(error) => {
error!("Error getting session: {:?}", error); error!("Error getting session: {:?}", error);

View File

@ -1,10 +1,6 @@
use redis::aio::ConnectionManager; use redis::aio::ConnectionManager;
use redis::{AsyncTypedCommands, Connection, RedisError, RedisResult}; use redis::{AsyncTypedCommands, Connection, RedisError, RedisResult};
pub fn get_redis_connection(redis_url: String) -> RedisResult<Connection> {
redis::Client::open(redis_url)?.get_connection()
}
#[derive(Clone)] #[derive(Clone)]
pub struct RedisRepository { pub struct RedisRepository {
connection_manager: ConnectionManager, connection_manager: ConnectionManager,
@ -34,4 +30,22 @@ impl RedisRepository {
Ok(()) Ok(())
} }
pub async fn set_with_expiry(&self, key: &str, value: &str, expiry: u64) -> RedisResult<()> {
self.connection_manager
.clone()
.set_ex(key, value, expiry)
.await
}
pub async fn get(&self, key: &str) -> RedisResult<Option<String>> {
self.connection_manager
.clone()
.get::<String>(key.to_string())
.await
}
pub async fn delete(&self, key: &str) -> RedisResult<usize> {
self.connection_manager.clone().del(key.to_string()).await
}
} }

View File

@ -1,7 +1,11 @@
use crate::config::{Config, generate_log_file_name}; use crate::config::{Config, generate_log_file_name};
use log::LevelFilter; use crate::middleware::AppState;
use axum::body::Body;
use axum::extract::Request;
use log::{LevelFilter, error, info};
use simplelog::{ColorChoice, CombinedLogger, TermLogger, TerminalMode, WriteLogger}; use simplelog::{ColorChoice, CombinedLogger, TermLogger, TerminalMode, WriteLogger};
use std::fs::File; use std::fs::File;
use std::sync::Arc;
pub fn initialize_logging(config: &Config) { pub fn initialize_logging(config: &Config) {
CombinedLogger::init(vec![ CombinedLogger::init(vec![
@ -19,3 +23,15 @@ pub fn initialize_logging(config: &Config) {
]) ])
.unwrap(); .unwrap();
} }
pub async fn set_and_log_session(state: &Arc<AppState>, session: String) {
match state.repository.set_session(session.clone()).await {
Ok(_) => {
info!("Redis: saved session {}", &session);
}
Err(err) => {
error!("Redis: failed to save session: {}", err);
}
}
}