use chrono::{DateTime, Duration, Utc}; use hmac::{Hmac, Mac}; use serde::{Deserialize, Serialize}; use sha2::{Sha256, digest::generic_array::GenericArray}; use std::collections::HashMap; use std::env; use std::sync::Arc; use tokio::sync::RwLock; use uuid::Uuid; use base64::{engine::general_purpose::URL_SAFE, Engine as _}; const TOKEN_VERSION: &str = "v2"; const REFRESH_WINDOW_MINUTES: i64 = 15; // refresh if token expires within 15 minutes #[derive(Clone, Debug)] pub struct Session { pub username: String, pub role: String, pub created_at: DateTime, pub expires_at: DateTime, pub session_id: String, // unique identifier for the session (UUID) } #[derive(Clone)] pub struct SessionManager { sessions: Arc>>, // key = session_id ttl_hours: i64, secret: Vec, } #[derive(Debug, Serialize, Deserialize)] struct SessionPayload { session_id: String, username: String, role: String, iat: i64, // issued at (Unix timestamp) exp: i64, // expiry (Unix timestamp) version: String, } impl SessionManager { pub fn new(ttl_hours: i64) -> Self { let secret = load_session_secret(); Self { sessions: Arc::new(RwLock::new(HashMap::new())), ttl_hours, secret, } } /// Create a new session and return a signed session token. pub async fn create_session(&self, username: String, role: String) -> String { let session_id = Uuid::new_v4().to_string(); let now = Utc::now(); let expires_at = now + Duration::hours(self.ttl_hours); let session = Session { username: username.clone(), role: role.clone(), created_at: now, expires_at, session_id: session_id.clone(), }; // Store session by session_id self.sessions.write().await.insert(session_id.clone(), session); // Create signed token self.create_signed_token(&session_id, &username, &role, now.timestamp(), expires_at.timestamp()) } /// Validate a session token and return the session if valid and not expired. /// If the token is within the refresh window, returns a new token as well. pub async fn validate_session(&self, token: &str) -> Option { self.validate_session_with_refresh(token).await.map(|(session, _)| session) } /// Validate a session token and return (session, optional new token if refreshed). pub async fn validate_session_with_refresh(&self, token: &str) -> Option<(Session, Option)> { // Legacy token format (UUID) if token.starts_with("session-") { let sessions = self.sessions.read().await; return sessions.get(token).and_then(|s| { if s.expires_at > Utc::now() { Some((s.clone(), None)) } else { None } }); } // Signed token format let payload = match verify_signed_token(token, &self.secret) { Ok(p) => p, Err(_) => return None, }; // Check expiry let now = Utc::now().timestamp(); if payload.exp <= now { return None; } // Look up session by session_id let sessions = self.sessions.read().await; let session = match sessions.get(&payload.session_id) { Some(s) => s.clone(), None => return None, // session revoked or not found }; // Ensure session username/role matches (should always match) if session.username != payload.username || session.role != payload.role { return None; } // Check if token is within refresh window (last REFRESH_WINDOW_MINUTES of validity) let refresh_threshold = payload.exp - REFRESH_WINDOW_MINUTES * 60; let new_token = if now >= refresh_threshold { // Generate a new token with same session data but updated iat/exp? // According to activity-based refresh, we should extend the session expiry. // We'll extend from now by ttl_hours (or keep original expiry?). // Let's extend from now by ttl_hours (sliding window). let new_exp = Utc::now() + Duration::hours(self.ttl_hours); // Update session expiry in store drop(sessions); // release read lock before acquiring write lock self.update_session_expiry(&payload.session_id, new_exp).await; // Create new token with updated iat/exp let new_token = self.create_signed_token( &payload.session_id, &payload.username, &payload.role, now, new_exp.timestamp(), ); Some(new_token) } else { None }; Some((session, new_token)) } /// Revoke (delete) a session by token. /// Supports both legacy tokens (token is key) and signed tokens (extract session_id). pub async fn revoke_session(&self, token: &str) { if token.starts_with("session-") { self.sessions.write().await.remove(token); return; } // For signed token, try to extract session_id if let Ok(payload) = verify_signed_token(token, &self.secret) { self.sessions.write().await.remove(&payload.session_id); } } /// Remove all expired sessions from the store. pub async fn cleanup_expired(&self) { let now = Utc::now(); self.sessions.write().await.retain(|_, s| s.expires_at > now); } // --- Private helpers --- fn create_signed_token(&self, session_id: &str, username: &str, role: &str, iat: i64, exp: i64) -> String { let payload = SessionPayload { session_id: session_id.to_string(), username: username.to_string(), role: role.to_string(), iat, exp, version: TOKEN_VERSION.to_string(), }; sign_token(&payload, &self.secret) } async fn update_session_expiry(&self, session_id: &str, new_expires_at: DateTime) { let mut sessions = self.sessions.write().await; if let Some(session) = sessions.get_mut(session_id) { session.expires_at = new_expires_at; } } } /// Load session secret from environment variable SESSION_SECRET (hex or base64 encoded). /// If not set, generates a random 32-byte secret and logs a warning. fn load_session_secret() -> Vec { let secret_str = env::var("SESSION_SECRET").unwrap_or_else(|_| { // Also check LLM_PROXY__SESSION_SECRET for consistency with config prefix env::var("LLM_PROXY__SESSION_SECRET").unwrap_or_else(|_| { // Generate a random secret (32 bytes) and encode as hex use rand::RngCore; let mut bytes = [0u8; 32]; rand::rng().fill_bytes(&mut bytes); let hex_secret = hex::encode(bytes); tracing::warn!( "SESSION_SECRET environment variable not set. Using a randomly generated secret. \ This will invalidate all sessions on restart. Set SESSION_SECRET to a fixed hex or base64 encoded 32-byte value." ); hex_secret }) }); // Decode hex or base64 hex::decode(&secret_str) .or_else(|_| URL_SAFE.decode(&secret_str)) .or_else(|_| base64::engine::general_purpose::STANDARD.decode(&secret_str)) .unwrap_or_else(|_| { panic!("SESSION_SECRET must be hex or base64 encoded (32 bytes)"); }) } /// Sign a session payload and return a token string in format base64_url(payload).base64_url(signature). fn sign_token(payload: &SessionPayload, secret: &[u8]) -> String { let json = serde_json::to_vec(payload).expect("Failed to serialize payload"); let payload_b64 = URL_SAFE.encode(&json); let mut mac = Hmac::::new_from_slice(secret).expect("HMAC can take key of any size"); mac.update(&json); let signature = mac.finalize().into_bytes(); let signature_b64 = URL_SAFE.encode(signature); format!("{}.{}", payload_b64, signature_b64) } /// Verify a signed token and return the decoded payload if valid. fn verify_signed_token(token: &str, secret: &[u8]) -> Result { let parts: Vec<&str> = token.split('.').collect(); if parts.len() != 2 { return Err(TokenError::InvalidFormat); } let payload_b64 = parts[0]; let signature_b64 = parts[1]; let json = URL_SAFE.decode(payload_b64).map_err(|_| TokenError::InvalidFormat)?; let signature = URL_SAFE.decode(signature_b64).map_err(|_| TokenError::InvalidFormat)?; // Verify HMAC let mut mac = Hmac::::new_from_slice(secret).expect("HMAC can take key of any size"); mac.update(&json); // Convert signature slice to GenericArray let tag = GenericArray::from_slice(&signature); mac.verify(tag).map_err(|_| TokenError::InvalidSignature)?; // Deserialize payload let payload: SessionPayload = serde_json::from_slice(&json).map_err(|_| TokenError::InvalidPayload)?; Ok(payload) } #[derive(Debug)] enum TokenError { InvalidFormat, InvalidSignature, InvalidPayload, } #[cfg(test)] mod tests { use super::*; use std::env; #[test] fn test_sign_and_verify_token() { let secret = b"test-secret-must-be-32-bytes-long!"; let payload = SessionPayload { session_id: "test-session".to_string(), username: "testuser".to_string(), role: "user".to_string(), iat: 1000, exp: 2000, version: TOKEN_VERSION.to_string(), }; let token = sign_token(&payload, secret); let verified = verify_signed_token(&token, secret).unwrap(); assert_eq!(verified.session_id, payload.session_id); assert_eq!(verified.username, payload.username); assert_eq!(verified.role, payload.role); assert_eq!(verified.iat, payload.iat); assert_eq!(verified.exp, payload.exp); assert_eq!(verified.version, payload.version); } #[test] fn test_tampered_token() { let secret = b"test-secret-must-be-32-bytes-long!"; let payload = SessionPayload { session_id: "test-session".to_string(), username: "testuser".to_string(), role: "user".to_string(), iat: 1000, exp: 2000, version: TOKEN_VERSION.to_string(), }; let mut token = sign_token(&payload, secret); // Tamper with payload part let mut parts: Vec<&str> = token.split('.').collect(); let mut payload_bytes = URL_SAFE.decode(parts[0]).unwrap(); payload_bytes[0] ^= 0xFF; // flip some bits let tampered_payload = URL_SAFE.encode(payload_bytes); parts[0] = &tampered_payload; token = parts.join("."); assert!(verify_signed_token(&token, secret).is_err()); } #[test] fn test_load_session_secret_from_env() { unsafe { env::set_var("SESSION_SECRET", hex::encode([0xAA; 32])); } let secret = load_session_secret(); assert_eq!(secret, vec![0xAA; 32]); unsafe { env::remove_var("SESSION_SECRET"); } } }