diff --git a/.env.example b/.env.example index f2851a1f..689ca7b7 100644 --- a/.env.example +++ b/.env.example @@ -25,4 +25,7 @@ LLM_PROXY__SERVER__AUTH_TOKENS=your_bearer_token_here,another_token LLM_PROXY__SERVER__PORT=8080 # Database path (optional) -LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db \ No newline at end of file +LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db + +# Session secret for HMAC-signed tokens (hex or base64 encoded, 32 bytes) +SESSION_SECRET=your_session_secret_here_32_bytes \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 04509cd7..936d5c1d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,41 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "aes-gcm" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + [[package]] name = "ahash" version = "0.7.8" @@ -541,9 +576,33 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" dependencies = [ "generic-array", + "rand_core 0.6.4", "typenum", ] +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.10.0" @@ -895,6 +954,37 @@ dependencies = [ "wasip3", ] +[[package]] +name = "ghash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" +dependencies = [ + "opaque-debug", + "polyval", +] + +[[package]] +name = "governor" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0746aa765db78b521451ef74221663b57ba595bf83f75d0ce23cc09447c8139f" +dependencies = [ + "cfg-if", + "dashmap", + "futures-sink", + "futures-timer", + "futures-util", + "no-std-compat", + "nonzero_ext", + "parking_lot", + "portable-atomic", + "quanta", + "rand 0.8.5", + "smallvec", + "spinning_top", +] + [[package]] name = "h2" version = "0.4.13" @@ -923,6 +1013,12 @@ dependencies = [ "ahash", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.5" @@ -1431,6 +1527,7 @@ checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" name = "llm-proxy" version = "0.1.0" dependencies = [ + "aes-gcm", "anyhow", "assert_cmd", "async-stream", @@ -1443,8 +1540,10 @@ dependencies = [ "config", "dotenvy", "futures", + "governor", "headers", "hex", + "hmac", "image", "insta", "mime", @@ -1454,6 +1553,7 @@ dependencies = [ "reqwest-eventsource", "serde", "serde_json", + "sha2", "sqlx", "tempfile", "thiserror 1.0.69", @@ -1598,6 +1698,12 @@ dependencies = [ "pxfm", ] +[[package]] +name = "no-std-compat" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" + [[package]] name = "nom" version = "7.1.3" @@ -1608,6 +1714,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "nu-ansi-term" version = "0.50.3" @@ -1669,6 +1781,12 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "opaque-debug" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" + [[package]] name = "ordered-multimap" version = "0.4.3" @@ -1824,6 +1942,24 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "polyval" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" +dependencies = [ + "cfg-if", + "cpufeatures", + "opaque-debug", + "universal-hash", +] + +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + [[package]] name = "potential_utf" version = "0.1.4" @@ -1897,6 +2033,21 @@ dependencies = [ "num-traits", ] +[[package]] +name = "quanta" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + [[package]] name = "quick-error" version = "2.0.1" @@ -2032,6 +2183,15 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "raw-cpuid" +version = "11.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" +dependencies = [ + "bitflags 2.11.0", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -2453,6 +2613,15 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spinning_top" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.7.3" @@ -3158,6 +3327,16 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + [[package]] name = "untrusted" version = "0.9.0" @@ -3411,6 +3590,28 @@ dependencies = [ "wasite", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" version = "0.62.2" diff --git a/Cargo.toml b/Cargo.toml index 98b043a9..f2d8f48b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,8 @@ repository = "" axum = { version = "0.8", features = ["macros", "ws"] } tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "net", "time", "signal", "fs"] } tower = "0.5" -tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip", "fs"] } +tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip", "fs", "set-header", "limit"] } +governor = "0.7" # ========== HTTP Clients ========== reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] } @@ -46,6 +47,9 @@ mime = "0.3" anyhow = "1.0" thiserror = "1.0" bcrypt = "0.15" +aes-gcm = "0.10" +hmac = "0.12" +sha2 = "0.10" chrono = { version = "0.4", features = ["serde"] } uuid = { version = "1.0", features = ["v4", "serde"] } futures = "0.3" diff --git a/migrations/002-add-indexes.sql b/migrations/002-add-indexes.sql new file mode 100644 index 00000000..c56ea85e --- /dev/null +++ b/migrations/002-add-indexes.sql @@ -0,0 +1,13 @@ +-- Migration: add composite indexes for query performance +-- Adds three composite indexes: +-- 1. idx_llm_requests_client_timestamp on llm_requests(client_id, timestamp) +-- 2. idx_llm_requests_provider_timestamp on llm_requests(provider, timestamp) +-- 3. idx_model_configs_provider_id on model_configs(provider_id) + +BEGIN TRANSACTION; + +CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp); +CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp); +CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id); + +COMMIT; \ No newline at end of file diff --git a/src/config/mod.rs b/src/config/mod.rs index 3f3b4078..aa0a651c 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,5 +1,7 @@ use anyhow::Result; +use base64::{Engine as _}; use config::{Config, File, FileFormat}; +use hex; use serde::{Deserialize, Serialize}; use std::path::PathBuf; use std::sync::Arc; @@ -96,6 +98,7 @@ pub struct AppConfig { pub model_mapping: ModelMappingConfig, pub pricing: PricingConfig, pub config_path: Option, + pub encryption_key: String, } impl AppConfig { @@ -136,7 +139,8 @@ impl AppConfig { .set_default("providers.grok.enabled", true)? .set_default("providers.ollama.base_url", "http://localhost:11434/v1")? .set_default("providers.ollama.enabled", false)? - .set_default("providers.ollama.models", Vec::::new())?; + .set_default("providers.ollama.models", Vec::::new())? + .set_default("encryption_key", "")?; // Load from config file if exists // Priority: explicit path arg > LLM_PROXY__CONFIG_PATH env var > ./config.toml @@ -167,6 +171,19 @@ impl AppConfig { let server: ServerConfig = config.get("server")?; let database: DatabaseConfig = config.get("database")?; let providers: ProviderConfig = config.get("providers")?; + let encryption_key: String = config.get("encryption_key")?; + + // Validate encryption key length (must be 32 bytes after hex or base64 decoding) + if encryption_key.is_empty() { + anyhow::bail!("Encryption key is required (LLM_PROXY__ENCRYPTION_KEY environment variable)"); + } + // Try hex decode first, then base64 + let key_bytes = hex::decode(&encryption_key) + .or_else(|_| base64::engine::general_purpose::STANDARD.decode(&encryption_key)) + .map_err(|e| anyhow::anyhow!("Encryption key must be hex or base64 encoded: {}", e))?; + if key_bytes.len() != 32 { + anyhow::bail!("Encryption key must be 32 bytes (256 bits), got {} bytes", key_bytes.len()); + } // For now, use empty model mapping and pricing (will be populated later) let model_mapping = ModelMappingConfig { patterns: vec![] }; @@ -185,6 +202,7 @@ impl AppConfig { model_mapping, pricing, config_path: Some(config_path), + encryption_key, })) } diff --git a/src/dashboard/auth.rs b/src/dashboard/auth.rs index 3694b499..10a0df84 100644 --- a/src/dashboard/auth.rs +++ b/src/dashboard/auth.rs @@ -1,4 +1,4 @@ -use axum::{extract::State, response::Json}; +use axum::{extract::State, http::{HeaderMap, HeaderValue}, response::{Json, IntoResponse}}; use bcrypt; use serde::Deserialize; use sqlx::Row; @@ -64,14 +64,14 @@ pub(super) async fn handle_login( pub(super) async fn handle_auth_status( State(state): State, headers: axum::http::HeaderMap, -) -> Json> { +) -> impl IntoResponse { let token = headers .get("Authorization") .and_then(|v| v.to_str().ok()) .and_then(|v| v.strip_prefix("Bearer ")); if let Some(token) = token - && let Some(session) = state.session_manager.validate_session(token).await + && let Some((session, new_token)) = state.session_manager.validate_session_with_refresh(token).await { // Look up display_name from DB let display_name = sqlx::query_scalar::<_, Option>( @@ -85,17 +85,23 @@ pub(super) async fn handle_auth_status( .flatten() .unwrap_or_else(|| session.username.clone()); - return Json(ApiResponse::success(serde_json::json!({ + let mut headers = HeaderMap::new(); + if let Some(refreshed_token) = new_token { + if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) { + headers.insert("X-Refreshed-Token", header_value); + } + } + return (headers, Json(ApiResponse::success(serde_json::json!({ "authenticated": true, "user": { "username": session.username, "name": display_name, "role": session.role } - }))); + })))); } - Json(ApiResponse::error("Not authenticated".to_string())) + (HeaderMap::new(), Json(ApiResponse::error("Not authenticated".to_string()))) } #[derive(Deserialize)] @@ -108,7 +114,7 @@ pub(super) async fn handle_change_password( State(state): State, headers: axum::http::HeaderMap, Json(payload): Json, -) -> Json> { +) -> impl IntoResponse { let pool = &state.app_state.db_pool; // Extract the authenticated user from the session token @@ -117,14 +123,24 @@ pub(super) async fn handle_change_password( .and_then(|v| v.to_str().ok()) .and_then(|v| v.strip_prefix("Bearer ")); - let session = match token { - Some(t) => state.session_manager.validate_session(t).await, - None => None, + let (session, new_token) = match token { + Some(t) => match state.session_manager.validate_session_with_refresh(t).await { + Some((session, new_token)) => (Some(session), new_token), + None => (None, None), + }, + None => (None, None), }; + let mut response_headers = HeaderMap::new(); + if let Some(refreshed_token) = new_token { + if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) { + response_headers.insert("X-Refreshed-Token", header_value); + } + } + let username = match session { Some(s) => s.username, - None => return Json(ApiResponse::error("Not authenticated".to_string())), + None => return (response_headers, Json(ApiResponse::error("Not authenticated".to_string()))), }; let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = ?") @@ -138,7 +154,7 @@ pub(super) async fn handle_change_password( if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) { let new_hash = match bcrypt::hash(&payload.new_password, 12) { Ok(h) => h, - Err(_) => return Json(ApiResponse::error("Failed to hash new password".to_string())), + Err(_) => return (response_headers, Json(ApiResponse::error("Failed to hash new password".to_string()))), }; let update_result = sqlx::query( @@ -150,16 +166,16 @@ pub(super) async fn handle_change_password( .await; match update_result { - Ok(_) => Json(ApiResponse::success( + Ok(_) => (response_headers, Json(ApiResponse::success( serde_json::json!({ "message": "Password updated successfully" }), - )), - Err(e) => Json(ApiResponse::error(format!("Failed to update database: {}", e))), + ))), + Err(e) => (response_headers, Json(ApiResponse::error(format!("Failed to update database: {}", e)))), } } else { - Json(ApiResponse::error("Current password incorrect".to_string())) + (response_headers, Json(ApiResponse::error("Current password incorrect".to_string()))) } } - Err(e) => Json(ApiResponse::error(format!("User not found: {}", e))), + Err(e) => (response_headers, Json(ApiResponse::error(format!("User not found: {}", e)))), } } @@ -180,19 +196,19 @@ pub(super) async fn handle_logout( } /// Helper: Extract and validate a session from the Authorization header. -/// Returns the Session if valid, or an error response. +/// Returns the Session and optional new token if refreshed, or an error response. pub(super) async fn extract_session( state: &DashboardState, headers: &axum::http::HeaderMap, -) -> Result>> { +) -> Result<(super::sessions::Session, Option), Json>> { let token = headers .get("Authorization") .and_then(|v| v.to_str().ok()) .and_then(|v| v.strip_prefix("Bearer ")); match token { - Some(t) => match state.session_manager.validate_session(t).await { - Some(session) => Ok(session), + Some(t) => match state.session_manager.validate_session_with_refresh(t).await { + Some((session, new_token)) => Ok((session, new_token)), None => Err(Json(ApiResponse::error("Session expired or invalid".to_string()))), }, None => Err(Json(ApiResponse::error("Not authenticated".to_string()))), @@ -200,13 +216,14 @@ pub(super) async fn extract_session( } /// Helper: Extract session and require admin role. +/// Returns session and optional new token if refreshed. pub(super) async fn require_admin( state: &DashboardState, headers: &axum::http::HeaderMap, -) -> Result>> { - let session = extract_session(state, headers).await?; +) -> Result<(super::sessions::Session, Option), Json>> { + let (session, new_token) = extract_session(state, headers).await?; if session.role != "admin" { return Err(Json(ApiResponse::error("Admin access required".to_string()))); } - Ok(session) + Ok((session, new_token)) } diff --git a/src/dashboard/clients.rs b/src/dashboard/clients.rs index b3eebf95..3f5e2476 100644 --- a/src/dashboard/clients.rs +++ b/src/dashboard/clients.rs @@ -88,9 +88,10 @@ pub(super) async fn handle_create_client( headers: axum::http::HeaderMap, Json(payload): Json, ) -> Json> { - if let Err(e) = super::auth::require_admin(&state, &headers).await { - return e; - } + let (session, _) = match super::auth::require_admin(&state, &headers).await { + Ok((session, new_token)) => (session, new_token), + Err(e) => return e, + }; let pool = &state.app_state.db_pool; @@ -198,9 +199,10 @@ pub(super) async fn handle_update_client( Path(id): Path, Json(payload): Json, ) -> Json> { - if let Err(e) = super::auth::require_admin(&state, &headers).await { - return e; - } + let (session, _) = match super::auth::require_admin(&state, &headers).await { + Ok((session, new_token)) => (session, new_token), + Err(e) => return e, + }; let pool = &state.app_state.db_pool; @@ -294,9 +296,10 @@ pub(super) async fn handle_delete_client( headers: axum::http::HeaderMap, Path(id): Path, ) -> Json> { - if let Err(e) = super::auth::require_admin(&state, &headers).await { - return e; - } + let (session, _) = match super::auth::require_admin(&state, &headers).await { + Ok((session, new_token)) => (session, new_token), + Err(e) => return e, + }; let pool = &state.app_state.db_pool; @@ -437,9 +440,10 @@ pub(super) async fn handle_create_client_token( Path(id): Path, Json(payload): Json, ) -> Json> { - if let Err(e) = super::auth::require_admin(&state, &headers).await { - return e; - } + let (session, _) = match super::auth::require_admin(&state, &headers).await { + Ok((session, new_token)) => (session, new_token), + Err(e) => return e, + }; let pool = &state.app_state.db_pool; @@ -485,9 +489,10 @@ pub(super) async fn handle_delete_client_token( headers: axum::http::HeaderMap, Path((client_id, token_id)): Path<(String, i64)>, ) -> Json> { - if let Err(e) = super::auth::require_admin(&state, &headers).await { - return e; - } + let (session, _) = match super::auth::require_admin(&state, &headers).await { + Ok((session, new_token)) => (session, new_token), + Err(e) => return e, + }; let pool = &state.app_state.db_pool; diff --git a/src/dashboard/mod.rs b/src/dashboard/mod.rs index 188dafcc..ba3ff567 100644 --- a/src/dashboard/mod.rs +++ b/src/dashboard/mod.rs @@ -11,10 +11,18 @@ mod users; mod websocket; use axum::{ + extract::{Request, State}, + middleware::Next, + response::Response, Router, routing::{delete, get, post, put}, }; +use axum::http::{header, HeaderValue}; use serde::Serialize; +use tower_http::{ + limit::RequestBodyLimitLayer, + set_header::SetResponseHeaderLayer, +}; use crate::state::AppState; use sessions::SessionManager; @@ -52,6 +60,21 @@ impl ApiResponse { } } +/// Rate limiting middleware for dashboard routes that extracts AppState from DashboardState. +async fn dashboard_rate_limit_middleware( + State(dashboard_state): State, + request: Request, + next: Next, +) -> Result { + // Delegate to the existing rate limit middleware with AppState + crate::rate_limiting::middleware::rate_limit_middleware( + State(dashboard_state.app_state), + request, + next, + ) + .await +} + // Dashboard routes pub fn router(state: AppState) -> Router { let session_manager = SessionManager::new(24); // 24-hour session TTL @@ -60,6 +83,26 @@ pub fn router(state: AppState) -> Router { session_manager, }; + // Security headers + let csp_header: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( + header::CONTENT_SECURITY_POLICY, + "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws:;" + .parse() + .unwrap(), + ); + let x_frame_options: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( + header::X_FRAME_OPTIONS, + "DENY".parse().unwrap(), + ); + let x_content_type_options: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( + header::X_CONTENT_TYPE_OPTIONS, + "nosniff".parse().unwrap(), + ); + let strict_transport_security: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( + header::STRICT_TRANSPORT_SECURITY, + "max-age=31536000; includeSubDomains".parse().unwrap(), + ); + Router::new() // Static file serving .fallback_service(tower_http::services::ServeDir::new("static")) @@ -119,5 +162,16 @@ pub fn router(state: AppState) -> Router { "/api/system/settings", get(system::handle_get_settings).post(system::handle_update_settings), ) + // Security layers + .layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit + .layer(csp_header) + .layer(x_frame_options) + .layer(x_content_type_options) + .layer(strict_transport_security) + // Rate limiting middleware + .layer(axum::middleware::from_fn_with_state( + dashboard_state.clone(), + dashboard_rate_limit_middleware, + )) .with_state(dashboard_state) } diff --git a/src/dashboard/models.rs b/src/dashboard/models.rs index 4ece2bb2..618d1a41 100644 --- a/src/dashboard/models.rs +++ b/src/dashboard/models.rs @@ -156,9 +156,10 @@ pub(super) async fn handle_update_model( Path(id): Path, Json(payload): Json, ) -> Json> { - if let Err(e) = super::auth::require_admin(&state, &headers).await { - return e; - } + let (session, _) = match super::auth::require_admin(&state, &headers).await { + Ok((session, new_token)) => (session, new_token), + Err(e) => return e, + }; let pool = &state.app_state.db_pool; diff --git a/src/dashboard/providers.rs b/src/dashboard/providers.rs index 4770a23e..744e35d7 100644 --- a/src/dashboard/providers.rs +++ b/src/dashboard/providers.rs @@ -9,6 +9,7 @@ use std::collections::HashMap; use tracing::warn; use super::{ApiResponse, DashboardState}; +use crate::utils::crypto; #[derive(Deserialize)] pub(super) struct UpdateProviderRequest { @@ -265,21 +266,44 @@ pub(super) async fn handle_update_provider( Path(name): Path, Json(payload): Json, ) -> Json> { - if let Err(e) = super::auth::require_admin(&state, &headers).await { - return e; - } + let (session, _) = match super::auth::require_admin(&state, &headers).await { + Ok((session, new_token)) => (session, new_token), + Err(e) => return e, + }; let pool = &state.app_state.db_pool; - // Update or insert into database (include billing_mode) + // Prepare API key encryption if provided + let (api_key_to_store, api_key_encrypted_flag) = match &payload.api_key { + Some(key) if !key.is_empty() => { + match crypto::encrypt(key) { + Ok(encrypted) => (Some(encrypted), Some(true)), + Err(e) => { + warn!("Failed to encrypt API key for provider {}: {}", name, e); + return Json(ApiResponse::error(format!("Failed to encrypt API key: {}", e))); + } + } + } + Some(_) => { + // Empty string means clear the key + (None, Some(false)) + } + None => { + // Keep existing key, we'll rely on COALESCE in SQL + (None, None) + } + }; + + // Update or insert into database (include billing_mode and api_key_encrypted) let result = sqlx::query( r#" - INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold, billing_mode) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET enabled = excluded.enabled, base_url = excluded.base_url, api_key = COALESCE(excluded.api_key, provider_configs.api_key), + api_key_encrypted = COALESCE(excluded.api_key_encrypted, provider_configs.api_key_encrypted), credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance), low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold), billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode), @@ -290,7 +314,8 @@ pub(super) async fn handle_update_provider( .bind(name.to_uppercase()) .bind(payload.enabled) .bind(&payload.base_url) - .bind(&payload.api_key) + .bind(&api_key_to_store) + .bind(api_key_encrypted_flag) .bind(payload.credit_balance) .bind(payload.low_credit_threshold) .bind(payload.billing_mode) diff --git a/src/dashboard/sessions.rs b/src/dashboard/sessions.rs index 5fdc3cc2..0e011efc 100644 --- a/src/dashboard/sessions.rs +++ b/src/dashboard/sessions.rs @@ -1,7 +1,17 @@ use chrono::{DateTime, Duration, Utc}; +use hmac::{Hmac, Mac}; +use serde::{Deserialize, Serialize}; +use sha2::{Sha256, digest::generic_array::GenericArray}; use std::collections::HashMap; +use std::env; use std::sync::Arc; use tokio::sync::RwLock; +use uuid::Uuid; + +use base64::{engine::general_purpose::URL_SAFE, Engine as _}; + +const TOKEN_VERSION: &str = "v2"; +const REFRESH_WINDOW_MINUTES: i64 = 15; // refresh if token expires within 15 minutes #[derive(Clone, Debug)] pub struct Session { @@ -9,51 +19,136 @@ pub struct Session { pub role: String, pub created_at: DateTime, pub expires_at: DateTime, + pub session_id: String, // unique identifier for the session (UUID) } #[derive(Clone)] pub struct SessionManager { - sessions: Arc>>, + sessions: Arc>>, // key = session_id ttl_hours: i64, + secret: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +struct SessionPayload { + session_id: String, + username: String, + role: String, + iat: i64, // issued at (Unix timestamp) + exp: i64, // expiry (Unix timestamp) + version: String, } impl SessionManager { pub fn new(ttl_hours: i64) -> Self { + let secret = load_session_secret(); Self { sessions: Arc::new(RwLock::new(HashMap::new())), ttl_hours, + secret, } } - /// Create a new session and return the session token. + /// Create a new session and return a signed session token. pub async fn create_session(&self, username: String, role: String) -> String { - let token = format!("session-{}", uuid::Uuid::new_v4()); + let session_id = Uuid::new_v4().to_string(); let now = Utc::now(); + let expires_at = now + Duration::hours(self.ttl_hours); let session = Session { - username, - role, + username: username.clone(), + role: role.clone(), created_at: now, - expires_at: now + Duration::hours(self.ttl_hours), + expires_at, + session_id: session_id.clone(), }; - self.sessions.write().await.insert(token.clone(), session); - token + // Store session by session_id + self.sessions.write().await.insert(session_id.clone(), session); + // Create signed token + self.create_signed_token(&session_id, &username, &role, now.timestamp(), expires_at.timestamp()) } /// Validate a session token and return the session if valid and not expired. + /// If the token is within the refresh window, returns a new token as well. pub async fn validate_session(&self, token: &str) -> Option { + self.validate_session_with_refresh(token).await.map(|(session, _)| session) + } + + /// Validate a session token and return (session, optional new token if refreshed). + pub async fn validate_session_with_refresh(&self, token: &str) -> Option<(Session, Option)> { + // Legacy token format (UUID) + if token.starts_with("session-") { + let sessions = self.sessions.read().await; + return sessions.get(token).and_then(|s| { + if s.expires_at > Utc::now() { + Some((s.clone(), None)) + } else { + None + } + }); + } + + // Signed token format + let payload = match verify_signed_token(token, &self.secret) { + Ok(p) => p, + Err(_) => return None, + }; + + // Check expiry + let now = Utc::now().timestamp(); + if payload.exp <= now { + return None; + } + + // Look up session by session_id let sessions = self.sessions.read().await; - sessions.get(token).and_then(|s| { - if s.expires_at > Utc::now() { - Some(s.clone()) - } else { - None - } - }) + let session = match sessions.get(&payload.session_id) { + Some(s) => s.clone(), + None => return None, // session revoked or not found + }; + + // Ensure session username/role matches (should always match) + if session.username != payload.username || session.role != payload.role { + return None; + } + + // Check if token is within refresh window (last REFRESH_WINDOW_MINUTES of validity) + let refresh_threshold = payload.exp - REFRESH_WINDOW_MINUTES * 60; + let new_token = if now >= refresh_threshold { + // Generate a new token with same session data but updated iat/exp? + // According to activity-based refresh, we should extend the session expiry. + // We'll extend from now by ttl_hours (or keep original expiry?). + // Let's extend from now by ttl_hours (sliding window). + let new_exp = Utc::now() + Duration::hours(self.ttl_hours); + // Update session expiry in store + drop(sessions); // release read lock before acquiring write lock + self.update_session_expiry(&payload.session_id, new_exp).await; + // Create new token with updated iat/exp + let new_token = self.create_signed_token( + &payload.session_id, + &payload.username, + &payload.role, + now, + new_exp.timestamp(), + ); + Some(new_token) + } else { + None + }; + + Some((session, new_token)) } /// Revoke (delete) a session by token. + /// Supports both legacy tokens (token is key) and signed tokens (extract session_id). pub async fn revoke_session(&self, token: &str) { - self.sessions.write().await.remove(token); + if token.starts_with("session-") { + self.sessions.write().await.remove(token); + return; + } + // For signed token, try to extract session_id + if let Ok(payload) = verify_signed_token(token, &self.secret) { + self.sessions.write().await.remove(&payload.session_id); + } } /// Remove all expired sessions from the store. @@ -61,4 +156,156 @@ impl SessionManager { let now = Utc::now(); self.sessions.write().await.retain(|_, s| s.expires_at > now); } + + // --- Private helpers --- + + fn create_signed_token(&self, session_id: &str, username: &str, role: &str, iat: i64, exp: i64) -> String { + let payload = SessionPayload { + session_id: session_id.to_string(), + username: username.to_string(), + role: role.to_string(), + iat, + exp, + version: TOKEN_VERSION.to_string(), + }; + sign_token(&payload, &self.secret) + } + + async fn update_session_expiry(&self, session_id: &str, new_expires_at: DateTime) { + let mut sessions = self.sessions.write().await; + if let Some(session) = sessions.get_mut(session_id) { + session.expires_at = new_expires_at; + } + } } + +/// Load session secret from environment variable SESSION_SECRET (hex or base64 encoded). +/// If not set, generates a random 32-byte secret and logs a warning. +fn load_session_secret() -> Vec { + let secret_str = env::var("SESSION_SECRET").unwrap_or_else(|_| { + // Also check LLM_PROXY__SESSION_SECRET for consistency with config prefix + env::var("LLM_PROXY__SESSION_SECRET").unwrap_or_else(|_| { + // Generate a random secret (32 bytes) and encode as hex + use rand::RngCore; + let mut bytes = [0u8; 32]; + rand::rng().fill_bytes(&mut bytes); + let hex_secret = hex::encode(bytes); + tracing::warn!( + "SESSION_SECRET environment variable not set. Using a randomly generated secret. \ + This will invalidate all sessions on restart. Set SESSION_SECRET to a fixed hex or base64 encoded 32-byte value." + ); + hex_secret + }) + }); + + // Decode hex or base64 + hex::decode(&secret_str) + .or_else(|_| URL_SAFE.decode(&secret_str)) + .or_else(|_| base64::engine::general_purpose::STANDARD.decode(&secret_str)) + .unwrap_or_else(|_| { + panic!("SESSION_SECRET must be hex or base64 encoded (32 bytes)"); + }) +} + +/// Sign a session payload and return a token string in format base64_url(payload).base64_url(signature). +fn sign_token(payload: &SessionPayload, secret: &[u8]) -> String { + let json = serde_json::to_vec(payload).expect("Failed to serialize payload"); + let payload_b64 = URL_SAFE.encode(&json); + let mut mac = Hmac::::new_from_slice(secret).expect("HMAC can take key of any size"); + mac.update(&json); + let signature = mac.finalize().into_bytes(); + let signature_b64 = URL_SAFE.encode(signature); + format!("{}.{}", payload_b64, signature_b64) +} + +/// Verify a signed token and return the decoded payload if valid. +fn verify_signed_token(token: &str, secret: &[u8]) -> Result { + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 2 { + return Err(TokenError::InvalidFormat); + } + let payload_b64 = parts[0]; + let signature_b64 = parts[1]; + + let json = URL_SAFE.decode(payload_b64).map_err(|_| TokenError::InvalidFormat)?; + let signature = URL_SAFE.decode(signature_b64).map_err(|_| TokenError::InvalidFormat)?; + + // Verify HMAC + let mut mac = Hmac::::new_from_slice(secret).expect("HMAC can take key of any size"); + mac.update(&json); + // Convert signature slice to GenericArray + let tag = GenericArray::from_slice(&signature); + mac.verify(tag).map_err(|_| TokenError::InvalidSignature)?; + + // Deserialize payload + let payload: SessionPayload = serde_json::from_slice(&json).map_err(|_| TokenError::InvalidPayload)?; + Ok(payload) +} + +#[derive(Debug)] +enum TokenError { + InvalidFormat, + InvalidSignature, + InvalidPayload, +} + +#[cfg(test)] +mod tests { + use super::*; + use std::env; + + #[test] + fn test_sign_and_verify_token() { + let secret = b"test-secret-must-be-32-bytes-long!"; + let payload = SessionPayload { + session_id: "test-session".to_string(), + username: "testuser".to_string(), + role: "user".to_string(), + iat: 1000, + exp: 2000, + version: TOKEN_VERSION.to_string(), + }; + let token = sign_token(&payload, secret); + let verified = verify_signed_token(&token, secret).unwrap(); + assert_eq!(verified.session_id, payload.session_id); + assert_eq!(verified.username, payload.username); + assert_eq!(verified.role, payload.role); + assert_eq!(verified.iat, payload.iat); + assert_eq!(verified.exp, payload.exp); + assert_eq!(verified.version, payload.version); + } + + #[test] + fn test_tampered_token() { + let secret = b"test-secret-must-be-32-bytes-long!"; + let payload = SessionPayload { + session_id: "test-session".to_string(), + username: "testuser".to_string(), + role: "user".to_string(), + iat: 1000, + exp: 2000, + version: TOKEN_VERSION.to_string(), + }; + let mut token = sign_token(&payload, secret); + // Tamper with payload part + let mut parts: Vec<&str> = token.split('.').collect(); + let mut payload_bytes = URL_SAFE.decode(parts[0]).unwrap(); + payload_bytes[0] ^= 0xFF; // flip some bits + let tampered_payload = URL_SAFE.encode(payload_bytes); + parts[0] = &tampered_payload; + token = parts.join("."); + assert!(verify_signed_token(&token, secret).is_err()); + } + + #[test] + fn test_load_session_secret_from_env() { + unsafe { + env::set_var("SESSION_SECRET", hex::encode([0xAA; 32])); + } + let secret = load_session_secret(); + assert_eq!(secret, vec![0xAA; 32]); + unsafe { + env::remove_var("SESSION_SECRET"); + } + } +} \ No newline at end of file diff --git a/src/dashboard/system.rs b/src/dashboard/system.rs index 6c544506..9985001b 100644 --- a/src/dashboard/system.rs +++ b/src/dashboard/system.rs @@ -279,9 +279,10 @@ pub(super) async fn handle_system_backup( State(state): State, headers: axum::http::HeaderMap, ) -> Json> { - if let Err(e) = super::auth::require_admin(&state, &headers).await { - return e; - } + let (session, _) = match super::auth::require_admin(&state, &headers).await { + Ok((session, new_token)) => (session, new_token), + Err(e) => return e, + }; let pool = &state.app_state.db_pool; let backup_id = format!("backup-{}", chrono::Utc::now().timestamp()); @@ -341,9 +342,10 @@ pub(super) async fn handle_update_settings( State(state): State, headers: axum::http::HeaderMap, ) -> Json> { - if let Err(e) = super::auth::require_admin(&state, &headers).await { - return e; - } + let (session, _) = match super::auth::require_admin(&state, &headers).await { + Ok((session, new_token)) => (session, new_token), + Err(e) => return e, + }; Json(ApiResponse::error( "Changing settings at runtime is not yet supported. Please update your config file and restart the server." diff --git a/src/dashboard/users.rs b/src/dashboard/users.rs index 4aec56ad..a790199a 100644 --- a/src/dashboard/users.rs +++ b/src/dashboard/users.rs @@ -14,9 +14,10 @@ pub(super) async fn handle_get_users( State(state): State, headers: axum::http::HeaderMap, ) -> Json> { - if let Err(e) = auth::require_admin(&state, &headers).await { - return e; - } + let (session, _) = match auth::require_admin(&state, &headers).await { + Ok((session, new_token)) => (session, new_token), + Err(e) => return e, + }; let pool = &state.app_state.db_pool; @@ -66,9 +67,10 @@ pub(super) async fn handle_create_user( headers: axum::http::HeaderMap, Json(payload): Json, ) -> Json> { - if let Err(e) = auth::require_admin(&state, &headers).await { - return e; - } + let (session, _) = match auth::require_admin(&state, &headers).await { + Ok((session, new_token)) => (session, new_token), + Err(e) => return e, + }; let pool = &state.app_state.db_pool; @@ -147,9 +149,10 @@ pub(super) async fn handle_update_user( Path(id): Path, Json(payload): Json, ) -> Json> { - if let Err(e) = auth::require_admin(&state, &headers).await { - return e; - } + let (session, _) = match auth::require_admin(&state, &headers).await { + Ok((session, new_token)) => (session, new_token), + Err(e) => return e, + }; let pool = &state.app_state.db_pool; @@ -249,8 +252,8 @@ pub(super) async fn handle_delete_user( headers: axum::http::HeaderMap, Path(id): Path, ) -> Json> { - let session = match auth::require_admin(&state, &headers).await { - Ok(s) => s, + let (session, _) = match auth::require_admin(&state, &headers).await { + Ok((session, new_token)) => (session, new_token), Err(e) => return e, }; diff --git a/src/database/mod.rs b/src/database/mod.rs index bb8a652f..de31c555 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -18,7 +18,9 @@ pub async fn init(config: &DatabaseConfig) -> Result { let database_path = config.path.to_string_lossy().to_string(); info!("Connecting to database at {}", database_path); - let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?.create_if_missing(true); + let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))? + .create_if_missing(true) + .pragma("foreign_keys", "ON"); let pool = SqlitePool::connect_with(options).await?; @@ -29,7 +31,7 @@ pub async fn init(config: &DatabaseConfig) -> Result { Ok(pool) } -async fn run_migrations(pool: &DbPool) -> Result<()> { +pub async fn run_migrations(pool: &DbPool) -> Result<()> { // Create clients table if it doesn't exist sqlx::query( r#" @@ -88,6 +90,8 @@ async fn run_migrations(pool: &DbPool) -> Result<()> { api_key TEXT, credit_balance REAL DEFAULT 0.0, low_credit_threshold REAL DEFAULT 5.0, + billing_mode TEXT, + api_key_encrypted BOOLEAN DEFAULT FALSE, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ) "#, @@ -167,6 +171,15 @@ async fn run_migrations(pool: &DbPool) -> Result<()> { .execute(pool) .await; + // Add billing_mode column if it doesn't exist (migration for existing DBs) + let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN billing_mode TEXT") + .execute(pool) + .await; + // Add api_key_encrypted column if it doesn't exist (migration for existing DBs) + let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN api_key_encrypted BOOLEAN DEFAULT FALSE") + .execute(pool) + .await; + // Insert default admin user if none exists (default password: admin) let user_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users").fetch_one(pool).await?; @@ -216,6 +229,19 @@ async fn run_migrations(pool: &DbPool) -> Result<()> { .execute(pool) .await?; + // Composite indexes for performance + sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp)") + .execute(pool) + .await?; + + sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp)") + .execute(pool) + .await?; + + sqlx::query("CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id)") + .execute(pool) + .await?; + // Insert default client if none exists sqlx::query( r#" diff --git a/src/lib.rs b/src/lib.rs index a2c6aca9..ab114519 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,23 +41,18 @@ pub use state::AppState; pub mod test_utils { use std::sync::Arc; - use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState}; + use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState, utils::crypto, database::run_migrations}; use sqlx::sqlite::SqlitePool; /// Create a test application state - pub async fn create_test_state() -> Arc { + pub async fn create_test_state() -> AppState { // Create in-memory database let pool = SqlitePool::connect("sqlite::memory:") .await .expect("Failed to create test database"); - // Run migrations - crate::database::init(&crate::config::DatabaseConfig { - path: std::path::PathBuf::from(":memory:"), - max_connections: 5, - }) - .await - .expect("Failed to initialize test database"); + // Run migrations on the pool + run_migrations(&pool).await.expect("Failed to run migrations"); let rate_limit_manager = RateLimitManager::new( crate::rate_limiting::RateLimiterConfig::default(), @@ -73,7 +68,7 @@ pub mod test_utils { providers: std::collections::HashMap::new(), }; - let (dashboard_tx, _) = tokio::sync::broadcast::channel(100); + let (dashboard_tx, _) = tokio::sync::broadcast::channel::(100); let config = Arc::new(crate::config::AppConfig { server: crate::config::ServerConfig { @@ -125,20 +120,20 @@ pub mod test_utils { ollama: vec![], }, config_path: None, + encryption_key: "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f".to_string(), }); - Arc::new(AppState { + // Initialize encryption with the test key + crypto::init_with_key(&config.encryption_key).expect("failed to initialize crypto"); + + AppState::new( config, provider_manager, - db_pool: pool.clone(), - rate_limit_manager: Arc::new(rate_limit_manager), - client_manager, - request_logger: Arc::new(crate::logging::RequestLogger::new(pool.clone(), dashboard_tx.clone())), - model_registry: Arc::new(model_registry), - model_config_cache: crate::state::ModelConfigCache::new(pool.clone()), - dashboard_tx, - auth_tokens: vec![], - }) + pool, + rate_limit_manager, + model_registry, + vec![], // auth_tokens + ) } /// Create a test HTTP client @@ -149,3 +144,185 @@ pub mod test_utils { .expect("Failed to create test HTTP client") } } + +#[cfg(test)] +mod integration_tests { + use super::test_utils::*; + use crate::{ + models::{ChatCompletionRequest, ChatMessage}, + server::router, + utils::crypto, + }; + use axum::{ + body::Body, + http::{Request, StatusCode}, + }; + use mockito::Server; + use serde_json::json; + use sqlx::Row; + use tower::util::ServiceExt; + + #[tokio::test] + async fn test_encrypted_provider_key_integration() { + // Step 1: Setup test database and state + let state = create_test_state().await; + let pool = state.db_pool.clone(); + + // Step 2: Insert provider with encrypted API key + let test_api_key = "test-openai-key-12345"; + let encrypted_key = crypto::encrypt(test_api_key).expect("Failed to encrypt test key"); + + sqlx::query( + r#" + INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + "#, + ) + .bind("openai") + .bind("OpenAI") + .bind(true) + .bind("http://localhost:1234") // Mock server URL + .bind(&encrypted_key) + .bind(true) // api_key_encrypted flag + .bind(100.0) + .bind(5.0) + .execute(&pool) + .await + .expect("Failed to update provider URL"); + + // Re-initialize provider with new URL + state + .provider_manager + .initialize_provider("openai", &state.config, &pool) + .await + .expect("Failed to re-initialize provider"); + + // Step 4: Mock OpenAI API server + let mut server = Server::new_async().await; + let mock = server + .mock("POST", "/chat/completions") + .match_header("authorization", format!("Bearer {}", test_api_key).as_str()) + .with_status(200) + .with_header("content-type", "application/json") + .with_body( + json!({ + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1234567890, + "model": "gpt-3.5-turbo", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Hello, world!" + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15 + } + }) + .to_string(), + ) + .create_async() + .await; + + // Update provider base URL to use mock server + sqlx::query("UPDATE provider_configs SET base_url = ? WHERE id = 'openai'") + .bind(&server.url()) + .execute(&pool) + .await + .expect("Failed to update provider URL"); + + // Re-initialize provider with new URL + state + .provider_manager + .initialize_provider("openai", &state.config, &pool) + .await + .expect("Failed to re-initialize provider"); + + // Step 5: Create test router and make request + let app = router(state); + + let request_body = ChatCompletionRequest { + model: "gpt-3.5-turbo".to_string(), + messages: vec![ChatMessage { + role: "user".to_string(), + content: crate::models::MessageContent::Text { + content: "Hello".to_string(), + }, + reasoning_content: None, + tool_calls: None, + name: None, + tool_call_id: None, + }], + temperature: None, + top_p: None, + top_k: None, + n: None, + stop: None, + max_tokens: Some(100), + presence_penalty: None, + frequency_penalty: None, + stream: Some(false), + tools: None, + tool_choice: None, + }; + + let request = Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .header("authorization", "Bearer test-token") + .body(Body::from(serde_json::to_string(&request_body).unwrap())) + .unwrap(); + + // Step 6: Execute request through proxy + let response = app + .oneshot(request) + .await + .expect("Failed to execute request"); + + let status = response.status(); + println!("Response status: {}", status); + + if status != StatusCode::OK { + let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let body_str = String::from_utf8(body_bytes.to_vec()).unwrap(); + println!("Response body: {}", body_str); + panic!("Response status is not OK: {}", status); + } + + assert_eq!(status, StatusCode::OK); + + // Verify the mock was called + mock.assert_async().await; + + // Give the async logging task time to complete + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Step 7: Verify usage was logged in database + let log_row = sqlx::query("SELECT * FROM llm_requests WHERE client_id = 'client_test-tok' ORDER BY id DESC LIMIT 1") + .fetch_one(&pool) + .await + .expect("Request log not found"); + + assert_eq!(log_row.get::("provider"), "openai"); + assert_eq!(log_row.get::("model"), "gpt-3.5-turbo"); + assert_eq!(log_row.get::("prompt_tokens"), 10); + assert_eq!(log_row.get::("completion_tokens"), 5); + assert_eq!(log_row.get::("total_tokens"), 15); + assert_eq!(log_row.get::("status"), "success"); + + // Verify client usage was updated + let client_row = sqlx::query("SELECT total_requests, total_tokens, total_cost FROM clients WHERE client_id = 'client_test-tok'") + .fetch_one(&pool) + .await + .expect("Client not found"); + + assert_eq!(client_row.get::("total_requests"), 1); + assert_eq!(client_row.get::("total_tokens"), 15); + } +} diff --git a/src/logging/mod.rs b/src/logging/mod.rs index 90c22458..bf1c7359 100644 --- a/src/logging/mod.rs +++ b/src/logging/mod.rs @@ -82,9 +82,9 @@ impl RequestLogger { "#, ) .bind(log.timestamp) - .bind(log.client_id) + .bind(&log.client_id) .bind(&log.provider) - .bind(log.model) + .bind(&log.model) .bind(log.prompt_tokens as i64) .bind(log.completion_tokens as i64) .bind(log.total_tokens as i64) @@ -92,7 +92,7 @@ impl RequestLogger { .bind(log.cache_write_tokens as i64) .bind(log.cost) .bind(log.has_images) - .bind(log.status) + .bind(&log.status) .bind(log.error_message) .bind(log.duration_ms as i64) .bind(None::) // request_body - optional, not stored to save disk space @@ -100,6 +100,23 @@ impl RequestLogger { .execute(&mut *tx) .await?; + // Update client usage statistics + sqlx::query( + r#" + UPDATE clients SET + total_requests = total_requests + 1, + total_tokens = total_tokens + ?, + total_cost = total_cost + ?, + updated_at = CURRENT_TIMESTAMP + WHERE client_id = ? + "#, + ) + .bind(log.total_tokens as i64) + .bind(log.cost) + .bind(&log.client_id) + .execute(&mut *tx) + .await?; + // Deduct from provider balance if successful. // Providers configured with billing_mode = 'postpaid' will not have their // credit_balance decremented. Use a conditional UPDATE so we don't need diff --git a/src/main.rs b/src/main.rs index 992b3a5a..b12a3615 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,6 +10,7 @@ use llm_proxy::{ rate_limiting::{CircuitBreakerConfig, RateLimitManager, RateLimiterConfig}, server, state::AppState, + utils::crypto, }; #[tokio::main] @@ -26,6 +27,10 @@ async fn main() -> Result<()> { let config = AppConfig::load().await?; info!("Configuration loaded from {:?}", config.config_path); + // Initialize encryption + crypto::init_with_key(&config.encryption_key)?; + info!("Encryption initialized"); + // Initialize database connection pool let db_pool = database::init(&config.database).await?; info!("Database initialized at {:?}", config.database.path); diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 2a2e138d..03938c0d 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use crate::errors::AppError; use crate::models::UnifiedRequest; + pub mod deepseek; pub mod gemini; pub mod grok; @@ -125,17 +126,35 @@ impl ProviderManager { db_pool: &crate::database::DbPool, ) -> Result<()> { // Load override from database - let db_config = sqlx::query("SELECT enabled, base_url, api_key FROM provider_configs WHERE id = ?") + let db_config = sqlx::query("SELECT enabled, base_url, api_key, api_key_encrypted FROM provider_configs WHERE id = ?") .bind(name) .fetch_optional(db_pool) .await?; let (enabled, base_url, api_key) = if let Some(row) = db_config { - ( - row.get::("enabled"), - row.get::, _>("base_url"), - row.get::, _>("api_key"), - ) + let enabled = row.get::("enabled"); + let base_url = row.get::, _>("base_url"); + let api_key_encrypted = row.get::("api_key_encrypted"); + let api_key = row.get::, _>("api_key"); + // Decrypt API key if encrypted + let api_key = match (api_key, api_key_encrypted) { + (Some(key), true) => { + match crate::utils::crypto::decrypt(&key) { + Ok(decrypted) => Some(decrypted), + Err(e) => { + tracing::error!("Failed to decrypt API key for provider {}: {}", name, e); + None + } + } + } + (Some(key), false) => { + // Plaintext key - optionally encrypt and update database (lazy migration) + // For now, just use plaintext + Some(key) + } + (None, _) => None, + }; + (enabled, base_url, api_key) } else { // No database override, use defaults from AppConfig match name { diff --git a/src/rate_limiting/mod.rs b/src/rate_limiting/mod.rs index c5e22e14..96a0d814 100644 --- a/src/rate_limiting/mod.rs +++ b/src/rate_limiting/mod.rs @@ -6,12 +6,15 @@ //! 3. Global rate limiting for overall system protection use anyhow::Result; +use governor::{Quota, RateLimiter, DefaultDirectRateLimiter}; use std::collections::HashMap; +use std::num::NonZeroU32; use std::sync::Arc; -use std::time::Instant; use tokio::sync::RwLock; use tracing::{info, warn}; +type GovRateLimiter = DefaultDirectRateLimiter; + /// Rate limiter configuration #[derive(Debug, Clone)] pub struct RateLimiterConfig { @@ -65,45 +68,7 @@ impl Default for CircuitBreakerConfig { } } -/// Simple token bucket rate limiter for a single client -#[derive(Debug)] -struct TokenBucket { - tokens: f64, - capacity: f64, - refill_rate: f64, // tokens per second - last_refill: Instant, -} -impl TokenBucket { - fn new(capacity: f64, refill_rate: f64) -> Self { - Self { - tokens: capacity, - capacity, - refill_rate, - last_refill: Instant::now(), - } - } - - fn refill(&mut self) { - let now = Instant::now(); - let elapsed = now.duration_since(self.last_refill).as_secs_f64(); - let new_tokens = elapsed * self.refill_rate; - - self.tokens = (self.tokens + new_tokens).min(self.capacity); - self.last_refill = now; - } - - fn try_acquire(&mut self, tokens: f64) -> bool { - self.refill(); - - if self.tokens >= tokens { - self.tokens -= tokens; - true - } else { - false - } - } -} /// Circuit breaker for a provider #[derive(Debug)] @@ -209,8 +174,8 @@ impl ProviderCircuitBreaker { /// Rate limiting and circuit breaking manager #[derive(Debug)] pub struct RateLimitManager { - client_buckets: Arc>>, - global_bucket: Arc>, + client_buckets: Arc>>, + global_bucket: Arc, circuit_breakers: Arc>>, config: RateLimiterConfig, circuit_config: CircuitBreakerConfig, @@ -218,15 +183,16 @@ pub struct RateLimitManager { impl RateLimitManager { pub fn new(config: RateLimiterConfig, circuit_config: CircuitBreakerConfig) -> Self { - // Convert requests per minute to tokens per second - let global_refill_rate = config.global_requests_per_minute as f64 / 60.0; + // Create global rate limiter quota + let global_quota = Quota::per_minute( + NonZeroU32::new(config.global_requests_per_minute).expect("global_requests_per_minute must be positive") + ) + .allow_burst(NonZeroU32::new(config.burst_size).expect("burst_size must be positive")); + let global_bucket = RateLimiter::direct(global_quota); Self { client_buckets: Arc::new(RwLock::new(HashMap::new())), - global_bucket: Arc::new(RwLock::new(TokenBucket::new( - config.burst_size as f64, - global_refill_rate, - ))), + global_bucket: Arc::new(global_bucket), circuit_breakers: Arc::new(RwLock::new(HashMap::new())), config, circuit_config, @@ -236,24 +202,22 @@ impl RateLimitManager { /// Check if a client request is allowed pub async fn check_client_request(&self, client_id: &str) -> Result { // Check global rate limit first (1 token per request) - { - let mut global_bucket = self.global_bucket.write().await; - if !global_bucket.try_acquire(1.0) { - warn!("Global rate limit exceeded"); - return Ok(false); - } + if self.global_bucket.check().is_err() { + warn!("Global rate limit exceeded"); + return Ok(false); } // Check client-specific rate limit let mut buckets = self.client_buckets.write().await; let bucket = buckets.entry(client_id.to_string()).or_insert_with(|| { - TokenBucket::new( - self.config.burst_size as f64, - self.config.requests_per_minute as f64 / 60.0, + let quota = Quota::per_minute( + NonZeroU32::new(self.config.requests_per_minute).expect("requests_per_minute must be positive") ) + .allow_burst(NonZeroU32::new(self.config.burst_size).expect("burst_size must be positive")); + RateLimiter::direct(quota) }); - Ok(bucket.try_acquire(1.0)) + Ok(bucket.check().is_ok()) } /// Check if provider requests are allowed (circuit breaker) diff --git a/src/server/mod.rs b/src/server/mod.rs index d91520ff..c7faf9e5 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -5,6 +5,11 @@ use axum::{ response::sse::{Event, Sse}, routing::{get, post}, }; +use axum::http::{header, HeaderValue}; +use tower_http::{ + limit::RequestBodyLimitLayer, + set_header::SetResponseHeaderLayer, +}; use futures::StreamExt; use std::sync::Arc; @@ -23,9 +28,34 @@ use crate::{ }; pub fn router(state: AppState) -> Router { + // Security headers + let csp_header: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( + header::CONTENT_SECURITY_POLICY, + "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws:;" + .parse() + .unwrap(), + ); + let x_frame_options: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( + header::X_FRAME_OPTIONS, + "DENY".parse().unwrap(), + ); + let x_content_type_options: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( + header::X_CONTENT_TYPE_OPTIONS, + "nosniff".parse().unwrap(), + ); + let strict_transport_security: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( + header::STRICT_TRANSPORT_SECURITY, + "max-age=31536000; includeSubDomains".parse().unwrap(), + ); + Router::new() .route("/v1/chat/completions", post(chat_completions)) .route("/v1/models", get(list_models)) + .layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit + .layer(csp_header) + .layer(x_frame_options) + .layer(x_content_type_options) + .layer(strict_transport_security) .layer(axum::middleware::from_fn_with_state( state.clone(), rate_limiting::middleware::rate_limit_middleware, @@ -219,7 +249,6 @@ async fn chat_completions( prompt_tokens, has_images, logger: state.request_logger.clone(), - client_manager: state.client_manager.clone(), model_registry: state.model_registry.clone(), model_config_cache: state.model_config_cache.clone(), }, @@ -341,15 +370,6 @@ async fn chat_completions( duration_ms: duration.as_millis() as u64, }); - // Update client usage (fire-and-forget, don't block response) - { - let cm = state.client_manager.clone(); - let cid = client_id.clone(); - tokio::spawn(async move { - let _ = cm.update_client_usage(&cid, response.total_tokens as i64, cost).await; - }); - } - // Convert ProviderResponse to ChatCompletionResponse let finish_reason = if response.tool_calls.is_some() { "tool_calls".to_string() diff --git a/src/utils/crypto.rs b/src/utils/crypto.rs new file mode 100644 index 00000000..7f7740c2 --- /dev/null +++ b/src/utils/crypto.rs @@ -0,0 +1,171 @@ +use aes_gcm::{ + aead::{Aead, AeadCore, KeyInit, OsRng}, + Aes256Gcm, Key, Nonce, +}; +use anyhow::{anyhow, Context, Result}; +use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; +use std::env; +use std::sync::OnceLock; + +static RAW_KEY: OnceLock<[u8; 32]> = OnceLock::new(); + +/// Initialize the encryption key from a hex or base64 encoded string. +/// Must be called before any encryption/decryption operations. +/// Returns error if the key is invalid or already initialized with a different key. +pub fn init_with_key(key_str: &str) -> Result<()> { + let key_bytes = hex::decode(key_str) + .or_else(|_| BASE64.decode(key_str)) + .context("Encryption key must be hex or base64 encoded")?; + if key_bytes.len() != 32 { + anyhow::bail!( + "Encryption key must be 32 bytes (256 bits), got {} bytes", + key_bytes.len() + ); + } + let key_array: [u8; 32] = key_bytes.try_into().unwrap(); // safe due to length check + // Check if already initialized with same key + if let Some(existing) = RAW_KEY.get() { + if existing == &key_array { + // Same key already initialized, okay + return Ok(()); + } else { + anyhow::bail!("Encryption key already initialized with a different key"); + } + } + // Store raw key bytes + RAW_KEY + .set(key_array) + .map_err(|_| anyhow::anyhow!("Encryption key already initialized"))?; + Ok(()) +} + +/// Initialize the encryption key from the environment variable `LLM_PROXY__ENCRYPTION_KEY`. +/// Must be called before any encryption/decryption operations. +/// Panics if the environment variable is missing or invalid. +pub fn init_from_env() -> Result<()> { + let key_str = + env::var("LLM_PROXY__ENCRYPTION_KEY").context("LLM_PROXY__ENCRYPTION_KEY environment variable not set")?; + init_with_key(&key_str) +} + +/// Get the encryption key bytes, panicking if not initialized. +fn get_key() -> &'static [u8; 32] { + RAW_KEY + .get() + .expect("Encryption key not initialized. Call crypto::init_with_key() or crypto::init_from_env() first.") +} + +/// Encrypt a plaintext string and return a base64-encoded ciphertext (nonce || ciphertext || tag). +pub fn encrypt(plaintext: &str) -> Result { + let key = Key::::from_slice(get_key()); + let cipher = Aes256Gcm::new(key); + let nonce = Aes256Gcm::generate_nonce(&mut OsRng); // 12 bytes + let ciphertext = cipher + .encrypt(&nonce, plaintext.as_bytes()) + .map_err(|e| anyhow!("Encryption failed: {}", e))?; + // Combine nonce and ciphertext (ciphertext already includes tag) + let mut combined = Vec::with_capacity(nonce.len() + ciphertext.len()); + combined.extend_from_slice(&nonce); + combined.extend_from_slice(&ciphertext); + Ok(BASE64.encode(combined)) +} + +/// Decrypt a base64-encoded ciphertext (nonce || ciphertext || tag) to a plaintext string. +pub fn decrypt(ciphertext_b64: &str) -> Result { + let key = Key::::from_slice(get_key()); + let cipher = Aes256Gcm::new(key); + let combined = BASE64.decode(ciphertext_b64).context("Invalid base64 ciphertext")?; + if combined.len() < 12 { + anyhow::bail!("Ciphertext too short"); + } + let (nonce_bytes, ciphertext_and_tag) = combined.split_at(12); + let nonce = Nonce::from_slice(nonce_bytes); + let plaintext_bytes = cipher + .decrypt(nonce, ciphertext_and_tag) + .map_err(|e| anyhow!("Decryption failed (invalid key or corrupted ciphertext): {}", e))?; + String::from_utf8(plaintext_bytes).context("Decrypted bytes are not valid UTF-8") +} + +#[cfg(test)] +mod tests { + use super::*; + + const TEST_KEY_HEX: &str = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"; + + #[test] + fn test_encrypt_decrypt() { + init_with_key(TEST_KEY_HEX).unwrap(); + let plaintext = "super secret api key"; + let ciphertext = encrypt(plaintext).unwrap(); + assert_ne!(ciphertext, plaintext); + let decrypted = decrypt(&ciphertext).unwrap(); + assert_eq!(decrypted, plaintext); + } + + #[test] + fn test_different_inputs_produce_different_ciphertexts() { + init_with_key(TEST_KEY_HEX).unwrap(); + let plaintext = "same"; + let cipher1 = encrypt(plaintext).unwrap(); + let cipher2 = encrypt(plaintext).unwrap(); + assert_ne!(cipher1, cipher2, "Nonce should make ciphertexts differ"); + assert_eq!(decrypt(&cipher1).unwrap(), plaintext); + assert_eq!(decrypt(&cipher2).unwrap(), plaintext); + } + + #[test] + fn test_invalid_key_length() { + let result = init_with_key("tooshort"); + assert!(result.is_err()); + } + + #[test] + fn test_init_from_env() { + unsafe { std::env::set_var("LLM_PROXY__ENCRYPTION_KEY", TEST_KEY_HEX) }; + let result = init_from_env(); + assert!(result.is_ok()); + // Ensure encryption works + let ciphertext = encrypt("test").unwrap(); + let decrypted = decrypt(&ciphertext).unwrap(); + assert_eq!(decrypted, "test"); + } + + #[test] + fn test_missing_env_key() { + unsafe { std::env::remove_var("LLM_PROXY__ENCRYPTION_KEY") }; + let result = init_from_env(); + assert!(result.is_err()); + } + + #[test] + fn test_key_hex_and_base64() { + // Hex key works + init_with_key(TEST_KEY_HEX).unwrap(); + // Base64 key (same bytes encoded as base64) + let base64_key = BASE64.encode(hex::decode(TEST_KEY_HEX).unwrap()); + // Re-initialization with same key (different encoding) is allowed + let result = init_with_key(&base64_key); + assert!(result.is_ok()); + // Encryption should still work + let ciphertext = encrypt("test").unwrap(); + let decrypted = decrypt(&ciphertext).unwrap(); + assert_eq!(decrypted, "test"); + } + + #[test] + #[ignore] // conflicts with global state from other tests + fn test_already_initialized() { + init_with_key(TEST_KEY_HEX).unwrap(); + let result = init_with_key(TEST_KEY_HEX); + assert!(result.is_ok()); // same key allowed + } + + #[test] + #[ignore] // conflicts with global state from other tests + fn test_already_initialized_different_key() { + init_with_key(TEST_KEY_HEX).unwrap(); + let different_key = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e20"; + let result = init_with_key(different_key); + assert!(result.is_err()); + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 0f0d1351..6e09e6c9 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,3 +1,4 @@ +pub mod crypto; pub mod registry; pub mod streaming; pub mod tokens; diff --git a/src/utils/streaming.rs b/src/utils/streaming.rs index 123eba2c..03e2b3fc 100644 --- a/src/utils/streaming.rs +++ b/src/utils/streaming.rs @@ -1,4 +1,4 @@ -use crate::client::ClientManager; + use crate::errors::AppError; use crate::logging::{RequestLog, RequestLogger}; use crate::models::ToolCall; @@ -18,7 +18,6 @@ pub struct StreamConfig { pub prompt_tokens: u32, pub has_images: bool, pub logger: Arc, - pub client_manager: Arc, pub model_registry: Arc, pub model_config_cache: ModelConfigCache, } @@ -36,7 +35,6 @@ pub struct AggregatingStream { /// Real usage data from the provider's final stream chunk (when available). real_usage: Option, logger: Arc, - client_manager: Arc, model_registry: Arc, model_config_cache: ModelConfigCache, start_time: std::time::Instant, @@ -60,7 +58,6 @@ where accumulated_tool_calls: Vec::new(), real_usage: None, logger: config.logger, - client_manager: config.client_manager, model_registry: config.model_registry, model_config_cache: config.model_config_cache, start_time: std::time::Instant::now(), @@ -79,7 +76,6 @@ where let provider_name = self.provider.name().to_string(); let model = self.model.clone(); let logger = self.logger.clone(); - let client_manager = self.client_manager.clone(); let provider = self.provider.clone(); let estimated_prompt_tokens = self.prompt_tokens; let has_images = self.has_images; @@ -162,11 +158,6 @@ where error_message: None, duration_ms: duration.as_millis() as u64, }); - - // Update client usage - let _ = client_manager - .update_client_usage(&client_id, total_tokens as i64, cost) - .await; }); } } @@ -304,7 +295,6 @@ mod tests { let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); let (dashboard_tx, _) = tokio::sync::broadcast::channel(16); let logger = Arc::new(RequestLogger::new(pool.clone(), dashboard_tx)); - let client_manager = Arc::new(ClientManager::new(pool.clone())); let registry = Arc::new(crate::models::registry::ModelRegistry { providers: std::collections::HashMap::new(), }); @@ -318,7 +308,6 @@ mod tests { prompt_tokens: 10, has_images: false, logger, - client_manager, model_registry: registry, model_config_cache: ModelConfigCache::new(pool.clone()), }, diff --git a/static/js/api.js b/static/js/api.js index 140e9511..524f888d 100644 --- a/static/js/api.js +++ b/static/js/api.js @@ -35,6 +35,14 @@ class ApiClient { throw new Error(result.error || `HTTP error! status: ${response.status}`); } + // Handling X-Refreshed-Token header + if (response.headers.get('X-Refreshed-Token') && window.authManager) { + window.authManager.token = response.headers.get('X-Refreshed-Token'); + if (window.authManager.setToken) { + window.authManager.setToken(window.authManager.token); + } + } + return result.data; } @@ -87,6 +95,17 @@ class ApiClient { const date = luxon.DateTime.fromISO(dateStr); return date.toRelative(); } + + // Helper for escaping HTML + escapeHtml(unsafe) { + if (unsafe === undefined || unsafe === null) return ''; + return unsafe.toString() + .replace(/&/g, "&") + .replace(//g, ">") + .replace(/"/g, """) + .replace(/'/g, "'"); + } } window.api = new ApiClient(); diff --git a/static/js/auth.js b/static/js/auth.js index 4d564a8f..5035e179 100644 --- a/static/js/auth.js +++ b/static/js/auth.js @@ -50,6 +50,12 @@ class AuthManager { }); } + setToken(newToken) { + if (!newToken) return; + this.token = newToken; + localStorage.setItem('dashboard_token', this.token); + } + async login(username, password) { const errorElement = document.getElementById('login-error'); const loginBtn = document.querySelector('.login-btn'); diff --git a/static/js/pages/clients.js b/static/js/pages/clients.js index c6d2e321..b58a913d 100644 --- a/static/js/pages/clients.js +++ b/static/js/pages/clients.js @@ -42,12 +42,15 @@ class ClientsPage { const statusIcon = client.status === 'active' ? 'check-circle' : 'clock'; const created = luxon.DateTime.fromISO(client.created_at).toFormat('MMM dd, yyyy'); + const escapedId = window.api.escapeHtml(client.id); + const escapedName = window.api.escapeHtml(client.name); + return ` - ${client.id} - ${client.name} + ${escapedId} + ${escapedName} - sk-••••${client.id.substring(client.id.length - 4)} + sk-••••${escapedId.substring(escapedId.length - 4)} ${created} ${client.last_used ? window.api.formatTimeAgo(client.last_used) : 'Never'} @@ -55,16 +58,16 @@ class ClientsPage { - ${client.status} + ${window.api.escapeHtml(client.status)} ${window._userRole === 'admin' ? `
- -
@@ -188,10 +191,13 @@ class ClientsPage { showTokenRevealModal(clientName, token) { const modal = document.createElement('div'); modal.className = 'modal active'; + const escapedName = window.api.escapeHtml(clientName); + const escapedToken = window.api.escapeHtml(token); + modal.innerHTML = `