This commit introduces: - AES-256-GCM encryption for LLM provider API keys in the database. - HMAC-SHA256 signed session tokens with activity-based refresh logic. - Standardized frontend XSS protection using a global escapeHtml utility. - Hardened security headers and request body size limits. - Improved database integrity with foreign key enforcement and atomic transactions. - Integration tests for the full encrypted key storage and proxy usage lifecycle.
311 lines
11 KiB
Rust
311 lines
11 KiB
Rust
use chrono::{DateTime, Duration, Utc};
|
|
use hmac::{Hmac, Mac};
|
|
use serde::{Deserialize, Serialize};
|
|
use sha2::{Sha256, digest::generic_array::GenericArray};
|
|
use std::collections::HashMap;
|
|
use std::env;
|
|
use std::sync::Arc;
|
|
use tokio::sync::RwLock;
|
|
use uuid::Uuid;
|
|
|
|
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
|
|
|
|
const TOKEN_VERSION: &str = "v2";
|
|
const REFRESH_WINDOW_MINUTES: i64 = 15; // refresh if token expires within 15 minutes
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct Session {
|
|
pub username: String,
|
|
pub role: String,
|
|
pub created_at: DateTime<Utc>,
|
|
pub expires_at: DateTime<Utc>,
|
|
pub session_id: String, // unique identifier for the session (UUID)
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct SessionManager {
|
|
sessions: Arc<RwLock<HashMap<String, Session>>>, // key = session_id
|
|
ttl_hours: i64,
|
|
secret: Vec<u8>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct SessionPayload {
|
|
session_id: String,
|
|
username: String,
|
|
role: String,
|
|
iat: i64, // issued at (Unix timestamp)
|
|
exp: i64, // expiry (Unix timestamp)
|
|
version: String,
|
|
}
|
|
|
|
impl SessionManager {
|
|
pub fn new(ttl_hours: i64) -> Self {
|
|
let secret = load_session_secret();
|
|
Self {
|
|
sessions: Arc::new(RwLock::new(HashMap::new())),
|
|
ttl_hours,
|
|
secret,
|
|
}
|
|
}
|
|
|
|
/// Create a new session and return a signed session token.
|
|
pub async fn create_session(&self, username: String, role: String) -> String {
|
|
let session_id = Uuid::new_v4().to_string();
|
|
let now = Utc::now();
|
|
let expires_at = now + Duration::hours(self.ttl_hours);
|
|
let session = Session {
|
|
username: username.clone(),
|
|
role: role.clone(),
|
|
created_at: now,
|
|
expires_at,
|
|
session_id: session_id.clone(),
|
|
};
|
|
// Store session by session_id
|
|
self.sessions.write().await.insert(session_id.clone(), session);
|
|
// Create signed token
|
|
self.create_signed_token(&session_id, &username, &role, now.timestamp(), expires_at.timestamp())
|
|
}
|
|
|
|
/// Validate a session token and return the session if valid and not expired.
|
|
/// If the token is within the refresh window, returns a new token as well.
|
|
pub async fn validate_session(&self, token: &str) -> Option<Session> {
|
|
self.validate_session_with_refresh(token).await.map(|(session, _)| session)
|
|
}
|
|
|
|
/// Validate a session token and return (session, optional new token if refreshed).
|
|
pub async fn validate_session_with_refresh(&self, token: &str) -> Option<(Session, Option<String>)> {
|
|
// Legacy token format (UUID)
|
|
if token.starts_with("session-") {
|
|
let sessions = self.sessions.read().await;
|
|
return sessions.get(token).and_then(|s| {
|
|
if s.expires_at > Utc::now() {
|
|
Some((s.clone(), None))
|
|
} else {
|
|
None
|
|
}
|
|
});
|
|
}
|
|
|
|
// Signed token format
|
|
let payload = match verify_signed_token(token, &self.secret) {
|
|
Ok(p) => p,
|
|
Err(_) => return None,
|
|
};
|
|
|
|
// Check expiry
|
|
let now = Utc::now().timestamp();
|
|
if payload.exp <= now {
|
|
return None;
|
|
}
|
|
|
|
// Look up session by session_id
|
|
let sessions = self.sessions.read().await;
|
|
let session = match sessions.get(&payload.session_id) {
|
|
Some(s) => s.clone(),
|
|
None => return None, // session revoked or not found
|
|
};
|
|
|
|
// Ensure session username/role matches (should always match)
|
|
if session.username != payload.username || session.role != payload.role {
|
|
return None;
|
|
}
|
|
|
|
// Check if token is within refresh window (last REFRESH_WINDOW_MINUTES of validity)
|
|
let refresh_threshold = payload.exp - REFRESH_WINDOW_MINUTES * 60;
|
|
let new_token = if now >= refresh_threshold {
|
|
// Generate a new token with same session data but updated iat/exp?
|
|
// According to activity-based refresh, we should extend the session expiry.
|
|
// We'll extend from now by ttl_hours (or keep original expiry?).
|
|
// Let's extend from now by ttl_hours (sliding window).
|
|
let new_exp = Utc::now() + Duration::hours(self.ttl_hours);
|
|
// Update session expiry in store
|
|
drop(sessions); // release read lock before acquiring write lock
|
|
self.update_session_expiry(&payload.session_id, new_exp).await;
|
|
// Create new token with updated iat/exp
|
|
let new_token = self.create_signed_token(
|
|
&payload.session_id,
|
|
&payload.username,
|
|
&payload.role,
|
|
now,
|
|
new_exp.timestamp(),
|
|
);
|
|
Some(new_token)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
Some((session, new_token))
|
|
}
|
|
|
|
/// Revoke (delete) a session by token.
|
|
/// Supports both legacy tokens (token is key) and signed tokens (extract session_id).
|
|
pub async fn revoke_session(&self, token: &str) {
|
|
if token.starts_with("session-") {
|
|
self.sessions.write().await.remove(token);
|
|
return;
|
|
}
|
|
// For signed token, try to extract session_id
|
|
if let Ok(payload) = verify_signed_token(token, &self.secret) {
|
|
self.sessions.write().await.remove(&payload.session_id);
|
|
}
|
|
}
|
|
|
|
/// Remove all expired sessions from the store.
|
|
pub async fn cleanup_expired(&self) {
|
|
let now = Utc::now();
|
|
self.sessions.write().await.retain(|_, s| s.expires_at > now);
|
|
}
|
|
|
|
// --- Private helpers ---
|
|
|
|
fn create_signed_token(&self, session_id: &str, username: &str, role: &str, iat: i64, exp: i64) -> String {
|
|
let payload = SessionPayload {
|
|
session_id: session_id.to_string(),
|
|
username: username.to_string(),
|
|
role: role.to_string(),
|
|
iat,
|
|
exp,
|
|
version: TOKEN_VERSION.to_string(),
|
|
};
|
|
sign_token(&payload, &self.secret)
|
|
}
|
|
|
|
async fn update_session_expiry(&self, session_id: &str, new_expires_at: DateTime<Utc>) {
|
|
let mut sessions = self.sessions.write().await;
|
|
if let Some(session) = sessions.get_mut(session_id) {
|
|
session.expires_at = new_expires_at;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Load session secret from environment variable SESSION_SECRET (hex or base64 encoded).
|
|
/// If not set, generates a random 32-byte secret and logs a warning.
|
|
fn load_session_secret() -> Vec<u8> {
|
|
let secret_str = env::var("SESSION_SECRET").unwrap_or_else(|_| {
|
|
// Also check LLM_PROXY__SESSION_SECRET for consistency with config prefix
|
|
env::var("LLM_PROXY__SESSION_SECRET").unwrap_or_else(|_| {
|
|
// Generate a random secret (32 bytes) and encode as hex
|
|
use rand::RngCore;
|
|
let mut bytes = [0u8; 32];
|
|
rand::rng().fill_bytes(&mut bytes);
|
|
let hex_secret = hex::encode(bytes);
|
|
tracing::warn!(
|
|
"SESSION_SECRET environment variable not set. Using a randomly generated secret. \
|
|
This will invalidate all sessions on restart. Set SESSION_SECRET to a fixed hex or base64 encoded 32-byte value."
|
|
);
|
|
hex_secret
|
|
})
|
|
});
|
|
|
|
// Decode hex or base64
|
|
hex::decode(&secret_str)
|
|
.or_else(|_| URL_SAFE.decode(&secret_str))
|
|
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&secret_str))
|
|
.unwrap_or_else(|_| {
|
|
panic!("SESSION_SECRET must be hex or base64 encoded (32 bytes)");
|
|
})
|
|
}
|
|
|
|
/// Sign a session payload and return a token string in format base64_url(payload).base64_url(signature).
|
|
fn sign_token(payload: &SessionPayload, secret: &[u8]) -> String {
|
|
let json = serde_json::to_vec(payload).expect("Failed to serialize payload");
|
|
let payload_b64 = URL_SAFE.encode(&json);
|
|
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
|
|
mac.update(&json);
|
|
let signature = mac.finalize().into_bytes();
|
|
let signature_b64 = URL_SAFE.encode(signature);
|
|
format!("{}.{}", payload_b64, signature_b64)
|
|
}
|
|
|
|
/// Verify a signed token and return the decoded payload if valid.
|
|
fn verify_signed_token(token: &str, secret: &[u8]) -> Result<SessionPayload, TokenError> {
|
|
let parts: Vec<&str> = token.split('.').collect();
|
|
if parts.len() != 2 {
|
|
return Err(TokenError::InvalidFormat);
|
|
}
|
|
let payload_b64 = parts[0];
|
|
let signature_b64 = parts[1];
|
|
|
|
let json = URL_SAFE.decode(payload_b64).map_err(|_| TokenError::InvalidFormat)?;
|
|
let signature = URL_SAFE.decode(signature_b64).map_err(|_| TokenError::InvalidFormat)?;
|
|
|
|
// Verify HMAC
|
|
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
|
|
mac.update(&json);
|
|
// Convert signature slice to GenericArray
|
|
let tag = GenericArray::from_slice(&signature);
|
|
mac.verify(tag).map_err(|_| TokenError::InvalidSignature)?;
|
|
|
|
// Deserialize payload
|
|
let payload: SessionPayload = serde_json::from_slice(&json).map_err(|_| TokenError::InvalidPayload)?;
|
|
Ok(payload)
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
enum TokenError {
|
|
InvalidFormat,
|
|
InvalidSignature,
|
|
InvalidPayload,
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use std::env;
|
|
|
|
#[test]
|
|
fn test_sign_and_verify_token() {
|
|
let secret = b"test-secret-must-be-32-bytes-long!";
|
|
let payload = SessionPayload {
|
|
session_id: "test-session".to_string(),
|
|
username: "testuser".to_string(),
|
|
role: "user".to_string(),
|
|
iat: 1000,
|
|
exp: 2000,
|
|
version: TOKEN_VERSION.to_string(),
|
|
};
|
|
let token = sign_token(&payload, secret);
|
|
let verified = verify_signed_token(&token, secret).unwrap();
|
|
assert_eq!(verified.session_id, payload.session_id);
|
|
assert_eq!(verified.username, payload.username);
|
|
assert_eq!(verified.role, payload.role);
|
|
assert_eq!(verified.iat, payload.iat);
|
|
assert_eq!(verified.exp, payload.exp);
|
|
assert_eq!(verified.version, payload.version);
|
|
}
|
|
|
|
#[test]
|
|
fn test_tampered_token() {
|
|
let secret = b"test-secret-must-be-32-bytes-long!";
|
|
let payload = SessionPayload {
|
|
session_id: "test-session".to_string(),
|
|
username: "testuser".to_string(),
|
|
role: "user".to_string(),
|
|
iat: 1000,
|
|
exp: 2000,
|
|
version: TOKEN_VERSION.to_string(),
|
|
};
|
|
let mut token = sign_token(&payload, secret);
|
|
// Tamper with payload part
|
|
let mut parts: Vec<&str> = token.split('.').collect();
|
|
let mut payload_bytes = URL_SAFE.decode(parts[0]).unwrap();
|
|
payload_bytes[0] ^= 0xFF; // flip some bits
|
|
let tampered_payload = URL_SAFE.encode(payload_bytes);
|
|
parts[0] = &tampered_payload;
|
|
token = parts.join(".");
|
|
assert!(verify_signed_token(&token, secret).is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_load_session_secret_from_env() {
|
|
unsafe {
|
|
env::set_var("SESSION_SECRET", hex::encode([0xAA; 32]));
|
|
}
|
|
let secret = load_session_secret();
|
|
assert_eq!(secret, vec![0xAA; 32]);
|
|
unsafe {
|
|
env::remove_var("SESSION_SECRET");
|
|
}
|
|
}
|
|
} |