feat(security): implement AES-256-GCM encryption for API keys and HMAC-signed session tokens

This commit introduces:
- AES-256-GCM encryption for LLM provider API keys in the database.
- HMAC-SHA256 signed session tokens with activity-based refresh logic.
- Standardized frontend XSS protection using a global escapeHtml utility.
- Hardened security headers and request body size limits.
- Improved database integrity with foreign key enforcement and atomic transactions.
- Integration tests for the full encrypted key storage and proxy usage lifecycle.
This commit is contained in:
2026-03-06 14:17:56 -05:00
parent 149a7c3a29
commit 9b8483e797
28 changed files with 1260 additions and 227 deletions

View File

@@ -26,3 +26,6 @@ LLM_PROXY__SERVER__PORT=8080
# Database path (optional) # Database path (optional)
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
# Session secret for HMAC-signed tokens (hex or base64 encoded, 32 bytes)
SESSION_SECRET=your_session_secret_here_32_bytes

201
Cargo.lock generated
View File

@@ -8,6 +8,41 @@ version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa"
[[package]]
name = "aead"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0"
dependencies = [
"crypto-common",
"generic-array",
]
[[package]]
name = "aes"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0"
dependencies = [
"cfg-if",
"cipher",
"cpufeatures",
]
[[package]]
name = "aes-gcm"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1"
dependencies = [
"aead",
"aes",
"cipher",
"ctr",
"ghash",
"subtle",
]
[[package]] [[package]]
name = "ahash" name = "ahash"
version = "0.7.8" version = "0.7.8"
@@ -541,9 +576,33 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a"
dependencies = [ dependencies = [
"generic-array", "generic-array",
"rand_core 0.6.4",
"typenum", "typenum",
] ]
[[package]]
name = "ctr"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835"
dependencies = [
"cipher",
]
[[package]]
name = "dashmap"
version = "6.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf"
dependencies = [
"cfg-if",
"crossbeam-utils",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]] [[package]]
name = "data-encoding" name = "data-encoding"
version = "2.10.0" version = "2.10.0"
@@ -895,6 +954,37 @@ dependencies = [
"wasip3", "wasip3",
] ]
[[package]]
name = "ghash"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1"
dependencies = [
"opaque-debug",
"polyval",
]
[[package]]
name = "governor"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0746aa765db78b521451ef74221663b57ba595bf83f75d0ce23cc09447c8139f"
dependencies = [
"cfg-if",
"dashmap",
"futures-sink",
"futures-timer",
"futures-util",
"no-std-compat",
"nonzero_ext",
"parking_lot",
"portable-atomic",
"quanta",
"rand 0.8.5",
"smallvec",
"spinning_top",
]
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.4.13" version = "0.4.13"
@@ -923,6 +1013,12 @@ dependencies = [
"ahash", "ahash",
] ]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.15.5" version = "0.15.5"
@@ -1431,6 +1527,7 @@ checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77"
name = "llm-proxy" name = "llm-proxy"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"aes-gcm",
"anyhow", "anyhow",
"assert_cmd", "assert_cmd",
"async-stream", "async-stream",
@@ -1443,8 +1540,10 @@ dependencies = [
"config", "config",
"dotenvy", "dotenvy",
"futures", "futures",
"governor",
"headers", "headers",
"hex", "hex",
"hmac",
"image", "image",
"insta", "insta",
"mime", "mime",
@@ -1454,6 +1553,7 @@ dependencies = [
"reqwest-eventsource", "reqwest-eventsource",
"serde", "serde",
"serde_json", "serde_json",
"sha2",
"sqlx", "sqlx",
"tempfile", "tempfile",
"thiserror 1.0.69", "thiserror 1.0.69",
@@ -1598,6 +1698,12 @@ dependencies = [
"pxfm", "pxfm",
] ]
[[package]]
name = "no-std-compat"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
[[package]] [[package]]
name = "nom" name = "nom"
version = "7.1.3" version = "7.1.3"
@@ -1608,6 +1714,12 @@ dependencies = [
"minimal-lexical", "minimal-lexical",
] ]
[[package]]
name = "nonzero_ext"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
[[package]] [[package]]
name = "nu-ansi-term" name = "nu-ansi-term"
version = "0.50.3" version = "0.50.3"
@@ -1669,6 +1781,12 @@ version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "opaque-debug"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
[[package]] [[package]]
name = "ordered-multimap" name = "ordered-multimap"
version = "0.4.3" version = "0.4.3"
@@ -1824,6 +1942,24 @@ dependencies = [
"miniz_oxide", "miniz_oxide",
] ]
[[package]]
name = "polyval"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25"
dependencies = [
"cfg-if",
"cpufeatures",
"opaque-debug",
"universal-hash",
]
[[package]]
name = "portable-atomic"
version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
[[package]] [[package]]
name = "potential_utf" name = "potential_utf"
version = "0.1.4" version = "0.1.4"
@@ -1897,6 +2033,21 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "quanta"
version = "0.12.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7"
dependencies = [
"crossbeam-utils",
"libc",
"once_cell",
"raw-cpuid",
"wasi",
"web-sys",
"winapi",
]
[[package]] [[package]]
name = "quick-error" name = "quick-error"
version = "2.0.1" version = "2.0.1"
@@ -2032,6 +2183,15 @@ dependencies = [
"getrandom 0.3.4", "getrandom 0.3.4",
] ]
[[package]]
name = "raw-cpuid"
version = "11.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186"
dependencies = [
"bitflags 2.11.0",
]
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.18" version = "0.5.18"
@@ -2453,6 +2613,15 @@ dependencies = [
"lock_api", "lock_api",
] ]
[[package]]
name = "spinning_top"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
dependencies = [
"lock_api",
]
[[package]] [[package]]
name = "spki" name = "spki"
version = "0.7.3" version = "0.7.3"
@@ -3158,6 +3327,16 @@ version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "universal-hash"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea"
dependencies = [
"crypto-common",
"subtle",
]
[[package]] [[package]]
name = "untrusted" name = "untrusted"
version = "0.9.0" version = "0.9.0"
@@ -3411,6 +3590,28 @@ dependencies = [
"wasite", "wasite",
] ]
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]] [[package]]
name = "windows-core" name = "windows-core"
version = "0.62.2" version = "0.62.2"

View File

@@ -13,7 +13,8 @@ repository = ""
axum = { version = "0.8", features = ["macros", "ws"] } axum = { version = "0.8", features = ["macros", "ws"] }
tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "net", "time", "signal", "fs"] } tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "net", "time", "signal", "fs"] }
tower = "0.5" tower = "0.5"
tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip", "fs"] } tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip", "fs", "set-header", "limit"] }
governor = "0.7"
# ========== HTTP Clients ========== # ========== HTTP Clients ==========
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] } reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
@@ -46,6 +47,9 @@ mime = "0.3"
anyhow = "1.0" anyhow = "1.0"
thiserror = "1.0" thiserror = "1.0"
bcrypt = "0.15" bcrypt = "0.15"
aes-gcm = "0.10"
hmac = "0.12"
sha2 = "0.10"
chrono = { version = "0.4", features = ["serde"] } chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1.0", features = ["v4", "serde"] } uuid = { version = "1.0", features = ["v4", "serde"] }
futures = "0.3" futures = "0.3"

View File

@@ -0,0 +1,13 @@
-- Migration: add composite indexes for query performance
-- Adds three composite indexes:
-- 1. idx_llm_requests_client_timestamp on llm_requests(client_id, timestamp)
-- 2. idx_llm_requests_provider_timestamp on llm_requests(provider, timestamp)
-- 3. idx_model_configs_provider_id on model_configs(provider_id)
BEGIN TRANSACTION;
CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp);
CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp);
CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id);
COMMIT;

View File

@@ -1,5 +1,7 @@
use anyhow::Result; use anyhow::Result;
use base64::{Engine as _};
use config::{Config, File, FileFormat}; use config::{Config, File, FileFormat};
use hex;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
@@ -96,6 +98,7 @@ pub struct AppConfig {
pub model_mapping: ModelMappingConfig, pub model_mapping: ModelMappingConfig,
pub pricing: PricingConfig, pub pricing: PricingConfig,
pub config_path: Option<PathBuf>, pub config_path: Option<PathBuf>,
pub encryption_key: String,
} }
impl AppConfig { impl AppConfig {
@@ -136,7 +139,8 @@ impl AppConfig {
.set_default("providers.grok.enabled", true)? .set_default("providers.grok.enabled", true)?
.set_default("providers.ollama.base_url", "http://localhost:11434/v1")? .set_default("providers.ollama.base_url", "http://localhost:11434/v1")?
.set_default("providers.ollama.enabled", false)? .set_default("providers.ollama.enabled", false)?
.set_default("providers.ollama.models", Vec::<String>::new())?; .set_default("providers.ollama.models", Vec::<String>::new())?
.set_default("encryption_key", "")?;
// Load from config file if exists // Load from config file if exists
// Priority: explicit path arg > LLM_PROXY__CONFIG_PATH env var > ./config.toml // Priority: explicit path arg > LLM_PROXY__CONFIG_PATH env var > ./config.toml
@@ -167,6 +171,19 @@ impl AppConfig {
let server: ServerConfig = config.get("server")?; let server: ServerConfig = config.get("server")?;
let database: DatabaseConfig = config.get("database")?; let database: DatabaseConfig = config.get("database")?;
let providers: ProviderConfig = config.get("providers")?; let providers: ProviderConfig = config.get("providers")?;
let encryption_key: String = config.get("encryption_key")?;
// Validate encryption key length (must be 32 bytes after hex or base64 decoding)
if encryption_key.is_empty() {
anyhow::bail!("Encryption key is required (LLM_PROXY__ENCRYPTION_KEY environment variable)");
}
// Try hex decode first, then base64
let key_bytes = hex::decode(&encryption_key)
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&encryption_key))
.map_err(|e| anyhow::anyhow!("Encryption key must be hex or base64 encoded: {}", e))?;
if key_bytes.len() != 32 {
anyhow::bail!("Encryption key must be 32 bytes (256 bits), got {} bytes", key_bytes.len());
}
// For now, use empty model mapping and pricing (will be populated later) // For now, use empty model mapping and pricing (will be populated later)
let model_mapping = ModelMappingConfig { patterns: vec![] }; let model_mapping = ModelMappingConfig { patterns: vec![] };
@@ -185,6 +202,7 @@ impl AppConfig {
model_mapping, model_mapping,
pricing, pricing,
config_path: Some(config_path), config_path: Some(config_path),
encryption_key,
})) }))
} }

View File

@@ -1,4 +1,4 @@
use axum::{extract::State, response::Json}; use axum::{extract::State, http::{HeaderMap, HeaderValue}, response::{Json, IntoResponse}};
use bcrypt; use bcrypt;
use serde::Deserialize; use serde::Deserialize;
use sqlx::Row; use sqlx::Row;
@@ -64,14 +64,14 @@ pub(super) async fn handle_login(
pub(super) async fn handle_auth_status( pub(super) async fn handle_auth_status(
State(state): State<DashboardState>, State(state): State<DashboardState>,
headers: axum::http::HeaderMap, headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> { ) -> impl IntoResponse {
let token = headers let token = headers
.get("Authorization") .get("Authorization")
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer ")); .and_then(|v| v.strip_prefix("Bearer "));
if let Some(token) = token if let Some(token) = token
&& let Some(session) = state.session_manager.validate_session(token).await && let Some((session, new_token)) = state.session_manager.validate_session_with_refresh(token).await
{ {
// Look up display_name from DB // Look up display_name from DB
let display_name = sqlx::query_scalar::<_, Option<String>>( let display_name = sqlx::query_scalar::<_, Option<String>>(
@@ -85,17 +85,23 @@ pub(super) async fn handle_auth_status(
.flatten() .flatten()
.unwrap_or_else(|| session.username.clone()); .unwrap_or_else(|| session.username.clone());
return Json(ApiResponse::success(serde_json::json!({ let mut headers = HeaderMap::new();
if let Some(refreshed_token) = new_token {
if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) {
headers.insert("X-Refreshed-Token", header_value);
}
}
return (headers, Json(ApiResponse::success(serde_json::json!({
"authenticated": true, "authenticated": true,
"user": { "user": {
"username": session.username, "username": session.username,
"name": display_name, "name": display_name,
"role": session.role "role": session.role
} }
}))); }))));
} }
Json(ApiResponse::error("Not authenticated".to_string())) (HeaderMap::new(), Json(ApiResponse::error("Not authenticated".to_string())))
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@@ -108,7 +114,7 @@ pub(super) async fn handle_change_password(
State(state): State<DashboardState>, State(state): State<DashboardState>,
headers: axum::http::HeaderMap, headers: axum::http::HeaderMap,
Json(payload): Json<ChangePasswordRequest>, Json(payload): Json<ChangePasswordRequest>,
) -> Json<ApiResponse<serde_json::Value>> { ) -> impl IntoResponse {
let pool = &state.app_state.db_pool; let pool = &state.app_state.db_pool;
// Extract the authenticated user from the session token // Extract the authenticated user from the session token
@@ -117,14 +123,24 @@ pub(super) async fn handle_change_password(
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer ")); .and_then(|v| v.strip_prefix("Bearer "));
let session = match token { let (session, new_token) = match token {
Some(t) => state.session_manager.validate_session(t).await, Some(t) => match state.session_manager.validate_session_with_refresh(t).await {
None => None, Some((session, new_token)) => (Some(session), new_token),
None => (None, None),
},
None => (None, None),
}; };
let mut response_headers = HeaderMap::new();
if let Some(refreshed_token) = new_token {
if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) {
response_headers.insert("X-Refreshed-Token", header_value);
}
}
let username = match session { let username = match session {
Some(s) => s.username, Some(s) => s.username,
None => return Json(ApiResponse::error("Not authenticated".to_string())), None => return (response_headers, Json(ApiResponse::error("Not authenticated".to_string()))),
}; };
let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = ?") let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = ?")
@@ -138,7 +154,7 @@ pub(super) async fn handle_change_password(
if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) { if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) {
let new_hash = match bcrypt::hash(&payload.new_password, 12) { let new_hash = match bcrypt::hash(&payload.new_password, 12) {
Ok(h) => h, Ok(h) => h,
Err(_) => return Json(ApiResponse::error("Failed to hash new password".to_string())), Err(_) => return (response_headers, Json(ApiResponse::error("Failed to hash new password".to_string()))),
}; };
let update_result = sqlx::query( let update_result = sqlx::query(
@@ -150,16 +166,16 @@ pub(super) async fn handle_change_password(
.await; .await;
match update_result { match update_result {
Ok(_) => Json(ApiResponse::success( Ok(_) => (response_headers, Json(ApiResponse::success(
serde_json::json!({ "message": "Password updated successfully" }), serde_json::json!({ "message": "Password updated successfully" }),
)), ))),
Err(e) => Json(ApiResponse::error(format!("Failed to update database: {}", e))), Err(e) => (response_headers, Json(ApiResponse::error(format!("Failed to update database: {}", e)))),
} }
} else { } else {
Json(ApiResponse::error("Current password incorrect".to_string())) (response_headers, Json(ApiResponse::error("Current password incorrect".to_string())))
} }
} }
Err(e) => Json(ApiResponse::error(format!("User not found: {}", e))), Err(e) => (response_headers, Json(ApiResponse::error(format!("User not found: {}", e)))),
} }
} }
@@ -180,19 +196,19 @@ pub(super) async fn handle_logout(
} }
/// Helper: Extract and validate a session from the Authorization header. /// Helper: Extract and validate a session from the Authorization header.
/// Returns the Session if valid, or an error response. /// Returns the Session and optional new token if refreshed, or an error response.
pub(super) async fn extract_session( pub(super) async fn extract_session(
state: &DashboardState, state: &DashboardState,
headers: &axum::http::HeaderMap, headers: &axum::http::HeaderMap,
) -> Result<super::sessions::Session, Json<ApiResponse<serde_json::Value>>> { ) -> Result<(super::sessions::Session, Option<String>), Json<ApiResponse<serde_json::Value>>> {
let token = headers let token = headers
.get("Authorization") .get("Authorization")
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer ")); .and_then(|v| v.strip_prefix("Bearer "));
match token { match token {
Some(t) => match state.session_manager.validate_session(t).await { Some(t) => match state.session_manager.validate_session_with_refresh(t).await {
Some(session) => Ok(session), Some((session, new_token)) => Ok((session, new_token)),
None => Err(Json(ApiResponse::error("Session expired or invalid".to_string()))), None => Err(Json(ApiResponse::error("Session expired or invalid".to_string()))),
}, },
None => Err(Json(ApiResponse::error("Not authenticated".to_string()))), None => Err(Json(ApiResponse::error("Not authenticated".to_string()))),
@@ -200,13 +216,14 @@ pub(super) async fn extract_session(
} }
/// Helper: Extract session and require admin role. /// Helper: Extract session and require admin role.
/// Returns session and optional new token if refreshed.
pub(super) async fn require_admin( pub(super) async fn require_admin(
state: &DashboardState, state: &DashboardState,
headers: &axum::http::HeaderMap, headers: &axum::http::HeaderMap,
) -> Result<super::sessions::Session, Json<ApiResponse<serde_json::Value>>> { ) -> Result<(super::sessions::Session, Option<String>), Json<ApiResponse<serde_json::Value>>> {
let session = extract_session(state, headers).await?; let (session, new_token) = extract_session(state, headers).await?;
if session.role != "admin" { if session.role != "admin" {
return Err(Json(ApiResponse::error("Admin access required".to_string()))); return Err(Json(ApiResponse::error("Admin access required".to_string())));
} }
Ok(session) Ok((session, new_token))
} }

View File

@@ -88,9 +88,10 @@ pub(super) async fn handle_create_client(
headers: axum::http::HeaderMap, headers: axum::http::HeaderMap,
Json(payload): Json<CreateClientRequest>, Json(payload): Json<CreateClientRequest>,
) -> Json<ApiResponse<serde_json::Value>> { ) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await { let (session, _) = match super::auth::require_admin(&state, &headers).await {
return e; Ok((session, new_token)) => (session, new_token),
} Err(e) => return e,
};
let pool = &state.app_state.db_pool; let pool = &state.app_state.db_pool;
@@ -198,9 +199,10 @@ pub(super) async fn handle_update_client(
Path(id): Path<String>, Path(id): Path<String>,
Json(payload): Json<UpdateClientPayload>, Json(payload): Json<UpdateClientPayload>,
) -> Json<ApiResponse<serde_json::Value>> { ) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await { let (session, _) = match super::auth::require_admin(&state, &headers).await {
return e; Ok((session, new_token)) => (session, new_token),
} Err(e) => return e,
};
let pool = &state.app_state.db_pool; let pool = &state.app_state.db_pool;
@@ -294,9 +296,10 @@ pub(super) async fn handle_delete_client(
headers: axum::http::HeaderMap, headers: axum::http::HeaderMap,
Path(id): Path<String>, Path(id): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> { ) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await { let (session, _) = match super::auth::require_admin(&state, &headers).await {
return e; Ok((session, new_token)) => (session, new_token),
} Err(e) => return e,
};
let pool = &state.app_state.db_pool; let pool = &state.app_state.db_pool;
@@ -437,9 +440,10 @@ pub(super) async fn handle_create_client_token(
Path(id): Path<String>, Path(id): Path<String>,
Json(payload): Json<CreateTokenRequest>, Json(payload): Json<CreateTokenRequest>,
) -> Json<ApiResponse<serde_json::Value>> { ) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await { let (session, _) = match super::auth::require_admin(&state, &headers).await {
return e; Ok((session, new_token)) => (session, new_token),
} Err(e) => return e,
};
let pool = &state.app_state.db_pool; let pool = &state.app_state.db_pool;
@@ -485,9 +489,10 @@ pub(super) async fn handle_delete_client_token(
headers: axum::http::HeaderMap, headers: axum::http::HeaderMap,
Path((client_id, token_id)): Path<(String, i64)>, Path((client_id, token_id)): Path<(String, i64)>,
) -> Json<ApiResponse<serde_json::Value>> { ) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await { let (session, _) = match super::auth::require_admin(&state, &headers).await {
return e; Ok((session, new_token)) => (session, new_token),
} Err(e) => return e,
};
let pool = &state.app_state.db_pool; let pool = &state.app_state.db_pool;

View File

@@ -11,10 +11,18 @@ mod users;
mod websocket; mod websocket;
use axum::{ use axum::{
extract::{Request, State},
middleware::Next,
response::Response,
Router, Router,
routing::{delete, get, post, put}, routing::{delete, get, post, put},
}; };
use axum::http::{header, HeaderValue};
use serde::Serialize; use serde::Serialize;
use tower_http::{
limit::RequestBodyLimitLayer,
set_header::SetResponseHeaderLayer,
};
use crate::state::AppState; use crate::state::AppState;
use sessions::SessionManager; use sessions::SessionManager;
@@ -52,6 +60,21 @@ impl<T> ApiResponse<T> {
} }
} }
/// Rate limiting middleware for dashboard routes that extracts AppState from DashboardState.
async fn dashboard_rate_limit_middleware(
State(dashboard_state): State<DashboardState>,
request: Request,
next: Next,
) -> Result<Response, crate::errors::AppError> {
// Delegate to the existing rate limit middleware with AppState
crate::rate_limiting::middleware::rate_limit_middleware(
State(dashboard_state.app_state),
request,
next,
)
.await
}
// Dashboard routes // Dashboard routes
pub fn router(state: AppState) -> Router { pub fn router(state: AppState) -> Router {
let session_manager = SessionManager::new(24); // 24-hour session TTL let session_manager = SessionManager::new(24); // 24-hour session TTL
@@ -60,6 +83,26 @@ pub fn router(state: AppState) -> Router {
session_manager, session_manager,
}; };
// Security headers
let csp_header: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::CONTENT_SECURITY_POLICY,
"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws:;"
.parse()
.unwrap(),
);
let x_frame_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_FRAME_OPTIONS,
"DENY".parse().unwrap(),
);
let x_content_type_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_CONTENT_TYPE_OPTIONS,
"nosniff".parse().unwrap(),
);
let strict_transport_security: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::STRICT_TRANSPORT_SECURITY,
"max-age=31536000; includeSubDomains".parse().unwrap(),
);
Router::new() Router::new()
// Static file serving // Static file serving
.fallback_service(tower_http::services::ServeDir::new("static")) .fallback_service(tower_http::services::ServeDir::new("static"))
@@ -119,5 +162,16 @@ pub fn router(state: AppState) -> Router {
"/api/system/settings", "/api/system/settings",
get(system::handle_get_settings).post(system::handle_update_settings), get(system::handle_get_settings).post(system::handle_update_settings),
) )
// Security layers
.layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit
.layer(csp_header)
.layer(x_frame_options)
.layer(x_content_type_options)
.layer(strict_transport_security)
// Rate limiting middleware
.layer(axum::middleware::from_fn_with_state(
dashboard_state.clone(),
dashboard_rate_limit_middleware,
))
.with_state(dashboard_state) .with_state(dashboard_state)
} }

View File

@@ -156,9 +156,10 @@ pub(super) async fn handle_update_model(
Path(id): Path<String>, Path(id): Path<String>,
Json(payload): Json<UpdateModelRequest>, Json(payload): Json<UpdateModelRequest>,
) -> Json<ApiResponse<serde_json::Value>> { ) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await { let (session, _) = match super::auth::require_admin(&state, &headers).await {
return e; Ok((session, new_token)) => (session, new_token),
} Err(e) => return e,
};
let pool = &state.app_state.db_pool; let pool = &state.app_state.db_pool;

View File

@@ -9,6 +9,7 @@ use std::collections::HashMap;
use tracing::warn; use tracing::warn;
use super::{ApiResponse, DashboardState}; use super::{ApiResponse, DashboardState};
use crate::utils::crypto;
#[derive(Deserialize)] #[derive(Deserialize)]
pub(super) struct UpdateProviderRequest { pub(super) struct UpdateProviderRequest {
@@ -265,21 +266,44 @@ pub(super) async fn handle_update_provider(
Path(name): Path<String>, Path(name): Path<String>,
Json(payload): Json<UpdateProviderRequest>, Json(payload): Json<UpdateProviderRequest>,
) -> Json<ApiResponse<serde_json::Value>> { ) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await { let (session, _) = match super::auth::require_admin(&state, &headers).await {
return e; Ok((session, new_token)) => (session, new_token),
} Err(e) => return e,
};
let pool = &state.app_state.db_pool; let pool = &state.app_state.db_pool;
// Update or insert into database (include billing_mode) // Prepare API key encryption if provided
let (api_key_to_store, api_key_encrypted_flag) = match &payload.api_key {
Some(key) if !key.is_empty() => {
match crypto::encrypt(key) {
Ok(encrypted) => (Some(encrypted), Some(true)),
Err(e) => {
warn!("Failed to encrypt API key for provider {}: {}", name, e);
return Json(ApiResponse::error(format!("Failed to encrypt API key: {}", e)));
}
}
}
Some(_) => {
// Empty string means clear the key
(None, Some(false))
}
None => {
// Keep existing key, we'll rely on COALESCE in SQL
(None, None)
}
};
// Update or insert into database (include billing_mode and api_key_encrypted)
let result = sqlx::query( let result = sqlx::query(
r#" r#"
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode) INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold, billing_mode)
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET ON CONFLICT(id) DO UPDATE SET
enabled = excluded.enabled, enabled = excluded.enabled,
base_url = excluded.base_url, base_url = excluded.base_url,
api_key = COALESCE(excluded.api_key, provider_configs.api_key), api_key = COALESCE(excluded.api_key, provider_configs.api_key),
api_key_encrypted = COALESCE(excluded.api_key_encrypted, provider_configs.api_key_encrypted),
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance), credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold), low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode), billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
@@ -290,7 +314,8 @@ pub(super) async fn handle_update_provider(
.bind(name.to_uppercase()) .bind(name.to_uppercase())
.bind(payload.enabled) .bind(payload.enabled)
.bind(&payload.base_url) .bind(&payload.base_url)
.bind(&payload.api_key) .bind(&api_key_to_store)
.bind(api_key_encrypted_flag)
.bind(payload.credit_balance) .bind(payload.credit_balance)
.bind(payload.low_credit_threshold) .bind(payload.low_credit_threshold)
.bind(payload.billing_mode) .bind(payload.billing_mode)

View File

@@ -1,7 +1,17 @@
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use hmac::{Hmac, Mac};
use serde::{Deserialize, Serialize};
use sha2::{Sha256, digest::generic_array::GenericArray};
use std::collections::HashMap; use std::collections::HashMap;
use std::env;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use uuid::Uuid;
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
const TOKEN_VERSION: &str = "v2";
const REFRESH_WINDOW_MINUTES: i64 = 15; // refresh if token expires within 15 minutes
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Session { pub struct Session {
@@ -9,51 +19,136 @@ pub struct Session {
pub role: String, pub role: String,
pub created_at: DateTime<Utc>, pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>, pub expires_at: DateTime<Utc>,
pub session_id: String, // unique identifier for the session (UUID)
} }
#[derive(Clone)] #[derive(Clone)]
pub struct SessionManager { pub struct SessionManager {
sessions: Arc<RwLock<HashMap<String, Session>>>, sessions: Arc<RwLock<HashMap<String, Session>>>, // key = session_id
ttl_hours: i64, ttl_hours: i64,
secret: Vec<u8>,
}
#[derive(Debug, Serialize, Deserialize)]
struct SessionPayload {
session_id: String,
username: String,
role: String,
iat: i64, // issued at (Unix timestamp)
exp: i64, // expiry (Unix timestamp)
version: String,
} }
impl SessionManager { impl SessionManager {
pub fn new(ttl_hours: i64) -> Self { pub fn new(ttl_hours: i64) -> Self {
let secret = load_session_secret();
Self { Self {
sessions: Arc::new(RwLock::new(HashMap::new())), sessions: Arc::new(RwLock::new(HashMap::new())),
ttl_hours, ttl_hours,
secret,
} }
} }
/// Create a new session and return the session token. /// Create a new session and return a signed session token.
pub async fn create_session(&self, username: String, role: String) -> String { pub async fn create_session(&self, username: String, role: String) -> String {
let token = format!("session-{}", uuid::Uuid::new_v4()); let session_id = Uuid::new_v4().to_string();
let now = Utc::now(); let now = Utc::now();
let expires_at = now + Duration::hours(self.ttl_hours);
let session = Session { let session = Session {
username, username: username.clone(),
role, role: role.clone(),
created_at: now, created_at: now,
expires_at: now + Duration::hours(self.ttl_hours), expires_at,
session_id: session_id.clone(),
}; };
self.sessions.write().await.insert(token.clone(), session); // Store session by session_id
token self.sessions.write().await.insert(session_id.clone(), session);
// Create signed token
self.create_signed_token(&session_id, &username, &role, now.timestamp(), expires_at.timestamp())
} }
/// Validate a session token and return the session if valid and not expired. /// Validate a session token and return the session if valid and not expired.
/// If the token is within the refresh window, returns a new token as well.
pub async fn validate_session(&self, token: &str) -> Option<Session> { pub async fn validate_session(&self, token: &str) -> Option<Session> {
self.validate_session_with_refresh(token).await.map(|(session, _)| session)
}
/// Validate a session token and return (session, optional new token if refreshed).
pub async fn validate_session_with_refresh(&self, token: &str) -> Option<(Session, Option<String>)> {
// Legacy token format (UUID)
if token.starts_with("session-") {
let sessions = self.sessions.read().await; let sessions = self.sessions.read().await;
sessions.get(token).and_then(|s| { return sessions.get(token).and_then(|s| {
if s.expires_at > Utc::now() { if s.expires_at > Utc::now() {
Some(s.clone()) Some((s.clone(), None))
} else { } else {
None None
} }
}) });
}
// Signed token format
let payload = match verify_signed_token(token, &self.secret) {
Ok(p) => p,
Err(_) => return None,
};
// Check expiry
let now = Utc::now().timestamp();
if payload.exp <= now {
return None;
}
// Look up session by session_id
let sessions = self.sessions.read().await;
let session = match sessions.get(&payload.session_id) {
Some(s) => s.clone(),
None => return None, // session revoked or not found
};
// Ensure session username/role matches (should always match)
if session.username != payload.username || session.role != payload.role {
return None;
}
// Check if token is within refresh window (last REFRESH_WINDOW_MINUTES of validity)
let refresh_threshold = payload.exp - REFRESH_WINDOW_MINUTES * 60;
let new_token = if now >= refresh_threshold {
// Generate a new token with same session data but updated iat/exp?
// According to activity-based refresh, we should extend the session expiry.
// We'll extend from now by ttl_hours (or keep original expiry?).
// Let's extend from now by ttl_hours (sliding window).
let new_exp = Utc::now() + Duration::hours(self.ttl_hours);
// Update session expiry in store
drop(sessions); // release read lock before acquiring write lock
self.update_session_expiry(&payload.session_id, new_exp).await;
// Create new token with updated iat/exp
let new_token = self.create_signed_token(
&payload.session_id,
&payload.username,
&payload.role,
now,
new_exp.timestamp(),
);
Some(new_token)
} else {
None
};
Some((session, new_token))
} }
/// Revoke (delete) a session by token. /// Revoke (delete) a session by token.
/// Supports both legacy tokens (token is key) and signed tokens (extract session_id).
pub async fn revoke_session(&self, token: &str) { pub async fn revoke_session(&self, token: &str) {
if token.starts_with("session-") {
self.sessions.write().await.remove(token); self.sessions.write().await.remove(token);
return;
}
// For signed token, try to extract session_id
if let Ok(payload) = verify_signed_token(token, &self.secret) {
self.sessions.write().await.remove(&payload.session_id);
}
} }
/// Remove all expired sessions from the store. /// Remove all expired sessions from the store.
@@ -61,4 +156,156 @@ impl SessionManager {
let now = Utc::now(); let now = Utc::now();
self.sessions.write().await.retain(|_, s| s.expires_at > now); self.sessions.write().await.retain(|_, s| s.expires_at > now);
} }
// --- Private helpers ---
fn create_signed_token(&self, session_id: &str, username: &str, role: &str, iat: i64, exp: i64) -> String {
let payload = SessionPayload {
session_id: session_id.to_string(),
username: username.to_string(),
role: role.to_string(),
iat,
exp,
version: TOKEN_VERSION.to_string(),
};
sign_token(&payload, &self.secret)
}
async fn update_session_expiry(&self, session_id: &str, new_expires_at: DateTime<Utc>) {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(session_id) {
session.expires_at = new_expires_at;
}
}
}
/// Load session secret from environment variable SESSION_SECRET (hex or base64 encoded).
/// If not set, generates a random 32-byte secret and logs a warning.
fn load_session_secret() -> Vec<u8> {
let secret_str = env::var("SESSION_SECRET").unwrap_or_else(|_| {
// Also check LLM_PROXY__SESSION_SECRET for consistency with config prefix
env::var("LLM_PROXY__SESSION_SECRET").unwrap_or_else(|_| {
// Generate a random secret (32 bytes) and encode as hex
use rand::RngCore;
let mut bytes = [0u8; 32];
rand::rng().fill_bytes(&mut bytes);
let hex_secret = hex::encode(bytes);
tracing::warn!(
"SESSION_SECRET environment variable not set. Using a randomly generated secret. \
This will invalidate all sessions on restart. Set SESSION_SECRET to a fixed hex or base64 encoded 32-byte value."
);
hex_secret
})
});
// Decode hex or base64
hex::decode(&secret_str)
.or_else(|_| URL_SAFE.decode(&secret_str))
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&secret_str))
.unwrap_or_else(|_| {
panic!("SESSION_SECRET must be hex or base64 encoded (32 bytes)");
})
}
/// Sign a session payload and return a token string in format base64_url(payload).base64_url(signature).
fn sign_token(payload: &SessionPayload, secret: &[u8]) -> String {
let json = serde_json::to_vec(payload).expect("Failed to serialize payload");
let payload_b64 = URL_SAFE.encode(&json);
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
mac.update(&json);
let signature = mac.finalize().into_bytes();
let signature_b64 = URL_SAFE.encode(signature);
format!("{}.{}", payload_b64, signature_b64)
}
/// Verify a signed token and return the decoded payload if valid.
fn verify_signed_token(token: &str, secret: &[u8]) -> Result<SessionPayload, TokenError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 2 {
return Err(TokenError::InvalidFormat);
}
let payload_b64 = parts[0];
let signature_b64 = parts[1];
let json = URL_SAFE.decode(payload_b64).map_err(|_| TokenError::InvalidFormat)?;
let signature = URL_SAFE.decode(signature_b64).map_err(|_| TokenError::InvalidFormat)?;
// Verify HMAC
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
mac.update(&json);
// Convert signature slice to GenericArray
let tag = GenericArray::from_slice(&signature);
mac.verify(tag).map_err(|_| TokenError::InvalidSignature)?;
// Deserialize payload
let payload: SessionPayload = serde_json::from_slice(&json).map_err(|_| TokenError::InvalidPayload)?;
Ok(payload)
}
#[derive(Debug)]
enum TokenError {
InvalidFormat,
InvalidSignature,
InvalidPayload,
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[test]
fn test_sign_and_verify_token() {
let secret = b"test-secret-must-be-32-bytes-long!";
let payload = SessionPayload {
session_id: "test-session".to_string(),
username: "testuser".to_string(),
role: "user".to_string(),
iat: 1000,
exp: 2000,
version: TOKEN_VERSION.to_string(),
};
let token = sign_token(&payload, secret);
let verified = verify_signed_token(&token, secret).unwrap();
assert_eq!(verified.session_id, payload.session_id);
assert_eq!(verified.username, payload.username);
assert_eq!(verified.role, payload.role);
assert_eq!(verified.iat, payload.iat);
assert_eq!(verified.exp, payload.exp);
assert_eq!(verified.version, payload.version);
}
#[test]
fn test_tampered_token() {
let secret = b"test-secret-must-be-32-bytes-long!";
let payload = SessionPayload {
session_id: "test-session".to_string(),
username: "testuser".to_string(),
role: "user".to_string(),
iat: 1000,
exp: 2000,
version: TOKEN_VERSION.to_string(),
};
let mut token = sign_token(&payload, secret);
// Tamper with payload part
let mut parts: Vec<&str> = token.split('.').collect();
let mut payload_bytes = URL_SAFE.decode(parts[0]).unwrap();
payload_bytes[0] ^= 0xFF; // flip some bits
let tampered_payload = URL_SAFE.encode(payload_bytes);
parts[0] = &tampered_payload;
token = parts.join(".");
assert!(verify_signed_token(&token, secret).is_err());
}
#[test]
fn test_load_session_secret_from_env() {
unsafe {
env::set_var("SESSION_SECRET", hex::encode([0xAA; 32]));
}
let secret = load_session_secret();
assert_eq!(secret, vec![0xAA; 32]);
unsafe {
env::remove_var("SESSION_SECRET");
}
}
} }

View File

@@ -279,9 +279,10 @@ pub(super) async fn handle_system_backup(
State(state): State<DashboardState>, State(state): State<DashboardState>,
headers: axum::http::HeaderMap, headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> { ) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await { let (session, _) = match super::auth::require_admin(&state, &headers).await {
return e; Ok((session, new_token)) => (session, new_token),
} Err(e) => return e,
};
let pool = &state.app_state.db_pool; let pool = &state.app_state.db_pool;
let backup_id = format!("backup-{}", chrono::Utc::now().timestamp()); let backup_id = format!("backup-{}", chrono::Utc::now().timestamp());
@@ -341,9 +342,10 @@ pub(super) async fn handle_update_settings(
State(state): State<DashboardState>, State(state): State<DashboardState>,
headers: axum::http::HeaderMap, headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> { ) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await { let (session, _) = match super::auth::require_admin(&state, &headers).await {
return e; Ok((session, new_token)) => (session, new_token),
} Err(e) => return e,
};
Json(ApiResponse::error( Json(ApiResponse::error(
"Changing settings at runtime is not yet supported. Please update your config file and restart the server." "Changing settings at runtime is not yet supported. Please update your config file and restart the server."

View File

@@ -14,9 +14,10 @@ pub(super) async fn handle_get_users(
State(state): State<DashboardState>, State(state): State<DashboardState>,
headers: axum::http::HeaderMap, headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> { ) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = auth::require_admin(&state, &headers).await { let (session, _) = match auth::require_admin(&state, &headers).await {
return e; Ok((session, new_token)) => (session, new_token),
} Err(e) => return e,
};
let pool = &state.app_state.db_pool; let pool = &state.app_state.db_pool;
@@ -66,9 +67,10 @@ pub(super) async fn handle_create_user(
headers: axum::http::HeaderMap, headers: axum::http::HeaderMap,
Json(payload): Json<CreateUserRequest>, Json(payload): Json<CreateUserRequest>,
) -> Json<ApiResponse<serde_json::Value>> { ) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = auth::require_admin(&state, &headers).await { let (session, _) = match auth::require_admin(&state, &headers).await {
return e; Ok((session, new_token)) => (session, new_token),
} Err(e) => return e,
};
let pool = &state.app_state.db_pool; let pool = &state.app_state.db_pool;
@@ -147,9 +149,10 @@ pub(super) async fn handle_update_user(
Path(id): Path<i64>, Path(id): Path<i64>,
Json(payload): Json<UpdateUserRequest>, Json(payload): Json<UpdateUserRequest>,
) -> Json<ApiResponse<serde_json::Value>> { ) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = auth::require_admin(&state, &headers).await { let (session, _) = match auth::require_admin(&state, &headers).await {
return e; Ok((session, new_token)) => (session, new_token),
} Err(e) => return e,
};
let pool = &state.app_state.db_pool; let pool = &state.app_state.db_pool;
@@ -249,8 +252,8 @@ pub(super) async fn handle_delete_user(
headers: axum::http::HeaderMap, headers: axum::http::HeaderMap,
Path(id): Path<i64>, Path(id): Path<i64>,
) -> Json<ApiResponse<serde_json::Value>> { ) -> Json<ApiResponse<serde_json::Value>> {
let session = match auth::require_admin(&state, &headers).await { let (session, _) = match auth::require_admin(&state, &headers).await {
Ok(s) => s, Ok((session, new_token)) => (session, new_token),
Err(e) => return e, Err(e) => return e,
}; };

View File

@@ -18,7 +18,9 @@ pub async fn init(config: &DatabaseConfig) -> Result<DbPool> {
let database_path = config.path.to_string_lossy().to_string(); let database_path = config.path.to_string_lossy().to_string();
info!("Connecting to database at {}", database_path); info!("Connecting to database at {}", database_path);
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?.create_if_missing(true); let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?
.create_if_missing(true)
.pragma("foreign_keys", "ON");
let pool = SqlitePool::connect_with(options).await?; let pool = SqlitePool::connect_with(options).await?;
@@ -29,7 +31,7 @@ pub async fn init(config: &DatabaseConfig) -> Result<DbPool> {
Ok(pool) Ok(pool)
} }
async fn run_migrations(pool: &DbPool) -> Result<()> { pub async fn run_migrations(pool: &DbPool) -> Result<()> {
// Create clients table if it doesn't exist // Create clients table if it doesn't exist
sqlx::query( sqlx::query(
r#" r#"
@@ -88,6 +90,8 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
api_key TEXT, api_key TEXT,
credit_balance REAL DEFAULT 0.0, credit_balance REAL DEFAULT 0.0,
low_credit_threshold REAL DEFAULT 5.0, low_credit_threshold REAL DEFAULT 5.0,
billing_mode TEXT,
api_key_encrypted BOOLEAN DEFAULT FALSE,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
) )
"#, "#,
@@ -167,6 +171,15 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
.execute(pool) .execute(pool)
.await; .await;
// Add billing_mode column if it doesn't exist (migration for existing DBs)
let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN billing_mode TEXT")
.execute(pool)
.await;
// Add api_key_encrypted column if it doesn't exist (migration for existing DBs)
let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN api_key_encrypted BOOLEAN DEFAULT FALSE")
.execute(pool)
.await;
// Insert default admin user if none exists (default password: admin) // Insert default admin user if none exists (default password: admin)
let user_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users").fetch_one(pool).await?; let user_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users").fetch_one(pool).await?;
@@ -216,6 +229,19 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
.execute(pool) .execute(pool)
.await?; .await?;
// Composite indexes for performance
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id)")
.execute(pool)
.await?;
// Insert default client if none exists // Insert default client if none exists
sqlx::query( sqlx::query(
r#" r#"

View File

@@ -41,23 +41,18 @@ pub use state::AppState;
pub mod test_utils { pub mod test_utils {
use std::sync::Arc; use std::sync::Arc;
use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState}; use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState, utils::crypto, database::run_migrations};
use sqlx::sqlite::SqlitePool; use sqlx::sqlite::SqlitePool;
/// Create a test application state /// Create a test application state
pub async fn create_test_state() -> Arc<AppState> { pub async fn create_test_state() -> AppState {
// Create in-memory database // Create in-memory database
let pool = SqlitePool::connect("sqlite::memory:") let pool = SqlitePool::connect("sqlite::memory:")
.await .await
.expect("Failed to create test database"); .expect("Failed to create test database");
// Run migrations // Run migrations on the pool
crate::database::init(&crate::config::DatabaseConfig { run_migrations(&pool).await.expect("Failed to run migrations");
path: std::path::PathBuf::from(":memory:"),
max_connections: 5,
})
.await
.expect("Failed to initialize test database");
let rate_limit_manager = RateLimitManager::new( let rate_limit_manager = RateLimitManager::new(
crate::rate_limiting::RateLimiterConfig::default(), crate::rate_limiting::RateLimiterConfig::default(),
@@ -73,7 +68,7 @@ pub mod test_utils {
providers: std::collections::HashMap::new(), providers: std::collections::HashMap::new(),
}; };
let (dashboard_tx, _) = tokio::sync::broadcast::channel(100); let (dashboard_tx, _) = tokio::sync::broadcast::channel::<serde_json::Value>(100);
let config = Arc::new(crate::config::AppConfig { let config = Arc::new(crate::config::AppConfig {
server: crate::config::ServerConfig { server: crate::config::ServerConfig {
@@ -125,20 +120,20 @@ pub mod test_utils {
ollama: vec![], ollama: vec![],
}, },
config_path: None, config_path: None,
encryption_key: "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f".to_string(),
}); });
Arc::new(AppState { // Initialize encryption with the test key
crypto::init_with_key(&config.encryption_key).expect("failed to initialize crypto");
AppState::new(
config, config,
provider_manager, provider_manager,
db_pool: pool.clone(), pool,
rate_limit_manager: Arc::new(rate_limit_manager), rate_limit_manager,
client_manager, model_registry,
request_logger: Arc::new(crate::logging::RequestLogger::new(pool.clone(), dashboard_tx.clone())), vec![], // auth_tokens
model_registry: Arc::new(model_registry), )
model_config_cache: crate::state::ModelConfigCache::new(pool.clone()),
dashboard_tx,
auth_tokens: vec![],
})
} }
/// Create a test HTTP client /// Create a test HTTP client
@@ -149,3 +144,185 @@ pub mod test_utils {
.expect("Failed to create test HTTP client") .expect("Failed to create test HTTP client")
} }
} }
#[cfg(test)]
mod integration_tests {
use super::test_utils::*;
use crate::{
models::{ChatCompletionRequest, ChatMessage},
server::router,
utils::crypto,
};
use axum::{
body::Body,
http::{Request, StatusCode},
};
use mockito::Server;
use serde_json::json;
use sqlx::Row;
use tower::util::ServiceExt;
#[tokio::test]
async fn test_encrypted_provider_key_integration() {
// Step 1: Setup test database and state
let state = create_test_state().await;
let pool = state.db_pool.clone();
// Step 2: Insert provider with encrypted API key
let test_api_key = "test-openai-key-12345";
let encrypted_key = crypto::encrypt(test_api_key).expect("Failed to encrypt test key");
sqlx::query(
r#"
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind("openai")
.bind("OpenAI")
.bind(true)
.bind("http://localhost:1234") // Mock server URL
.bind(&encrypted_key)
.bind(true) // api_key_encrypted flag
.bind(100.0)
.bind(5.0)
.execute(&pool)
.await
.expect("Failed to update provider URL");
// Re-initialize provider with new URL
state
.provider_manager
.initialize_provider("openai", &state.config, &pool)
.await
.expect("Failed to re-initialize provider");
// Step 4: Mock OpenAI API server
let mut server = Server::new_async().await;
let mock = server
.mock("POST", "/chat/completions")
.match_header("authorization", format!("Bearer {}", test_api_key).as_str())
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
json!({
"id": "chatcmpl-test",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-3.5-turbo",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello, world!"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
}
})
.to_string(),
)
.create_async()
.await;
// Update provider base URL to use mock server
sqlx::query("UPDATE provider_configs SET base_url = ? WHERE id = 'openai'")
.bind(&server.url())
.execute(&pool)
.await
.expect("Failed to update provider URL");
// Re-initialize provider with new URL
state
.provider_manager
.initialize_provider("openai", &state.config, &pool)
.await
.expect("Failed to re-initialize provider");
// Step 5: Create test router and make request
let app = router(state);
let request_body = ChatCompletionRequest {
model: "gpt-3.5-turbo".to_string(),
messages: vec![ChatMessage {
role: "user".to_string(),
content: crate::models::MessageContent::Text {
content: "Hello".to_string(),
},
reasoning_content: None,
tool_calls: None,
name: None,
tool_call_id: None,
}],
temperature: None,
top_p: None,
top_k: None,
n: None,
stop: None,
max_tokens: Some(100),
presence_penalty: None,
frequency_penalty: None,
stream: Some(false),
tools: None,
tool_choice: None,
};
let request = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.header("authorization", "Bearer test-token")
.body(Body::from(serde_json::to_string(&request_body).unwrap()))
.unwrap();
// Step 6: Execute request through proxy
let response = app
.oneshot(request)
.await
.expect("Failed to execute request");
let status = response.status();
println!("Response status: {}", status);
if status != StatusCode::OK {
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
println!("Response body: {}", body_str);
panic!("Response status is not OK: {}", status);
}
assert_eq!(status, StatusCode::OK);
// Verify the mock was called
mock.assert_async().await;
// Give the async logging task time to complete
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Step 7: Verify usage was logged in database
let log_row = sqlx::query("SELECT * FROM llm_requests WHERE client_id = 'client_test-tok' ORDER BY id DESC LIMIT 1")
.fetch_one(&pool)
.await
.expect("Request log not found");
assert_eq!(log_row.get::<String, _>("provider"), "openai");
assert_eq!(log_row.get::<String, _>("model"), "gpt-3.5-turbo");
assert_eq!(log_row.get::<i64, _>("prompt_tokens"), 10);
assert_eq!(log_row.get::<i64, _>("completion_tokens"), 5);
assert_eq!(log_row.get::<i64, _>("total_tokens"), 15);
assert_eq!(log_row.get::<String, _>("status"), "success");
// Verify client usage was updated
let client_row = sqlx::query("SELECT total_requests, total_tokens, total_cost FROM clients WHERE client_id = 'client_test-tok'")
.fetch_one(&pool)
.await
.expect("Client not found");
assert_eq!(client_row.get::<i64, _>("total_requests"), 1);
assert_eq!(client_row.get::<i64, _>("total_tokens"), 15);
}
}

View File

@@ -82,9 +82,9 @@ impl RequestLogger {
"#, "#,
) )
.bind(log.timestamp) .bind(log.timestamp)
.bind(log.client_id) .bind(&log.client_id)
.bind(&log.provider) .bind(&log.provider)
.bind(log.model) .bind(&log.model)
.bind(log.prompt_tokens as i64) .bind(log.prompt_tokens as i64)
.bind(log.completion_tokens as i64) .bind(log.completion_tokens as i64)
.bind(log.total_tokens as i64) .bind(log.total_tokens as i64)
@@ -92,7 +92,7 @@ impl RequestLogger {
.bind(log.cache_write_tokens as i64) .bind(log.cache_write_tokens as i64)
.bind(log.cost) .bind(log.cost)
.bind(log.has_images) .bind(log.has_images)
.bind(log.status) .bind(&log.status)
.bind(log.error_message) .bind(log.error_message)
.bind(log.duration_ms as i64) .bind(log.duration_ms as i64)
.bind(None::<String>) // request_body - optional, not stored to save disk space .bind(None::<String>) // request_body - optional, not stored to save disk space
@@ -100,6 +100,23 @@ impl RequestLogger {
.execute(&mut *tx) .execute(&mut *tx)
.await?; .await?;
// Update client usage statistics
sqlx::query(
r#"
UPDATE clients SET
total_requests = total_requests + 1,
total_tokens = total_tokens + ?,
total_cost = total_cost + ?,
updated_at = CURRENT_TIMESTAMP
WHERE client_id = ?
"#,
)
.bind(log.total_tokens as i64)
.bind(log.cost)
.bind(&log.client_id)
.execute(&mut *tx)
.await?;
// Deduct from provider balance if successful. // Deduct from provider balance if successful.
// Providers configured with billing_mode = 'postpaid' will not have their // Providers configured with billing_mode = 'postpaid' will not have their
// credit_balance decremented. Use a conditional UPDATE so we don't need // credit_balance decremented. Use a conditional UPDATE so we don't need

View File

@@ -10,6 +10,7 @@ use llm_proxy::{
rate_limiting::{CircuitBreakerConfig, RateLimitManager, RateLimiterConfig}, rate_limiting::{CircuitBreakerConfig, RateLimitManager, RateLimiterConfig},
server, server,
state::AppState, state::AppState,
utils::crypto,
}; };
#[tokio::main] #[tokio::main]
@@ -26,6 +27,10 @@ async fn main() -> Result<()> {
let config = AppConfig::load().await?; let config = AppConfig::load().await?;
info!("Configuration loaded from {:?}", config.config_path); info!("Configuration loaded from {:?}", config.config_path);
// Initialize encryption
crypto::init_with_key(&config.encryption_key)?;
info!("Encryption initialized");
// Initialize database connection pool // Initialize database connection pool
let db_pool = database::init(&config.database).await?; let db_pool = database::init(&config.database).await?;
info!("Database initialized at {:?}", config.database.path); info!("Database initialized at {:?}", config.database.path);

View File

@@ -7,6 +7,7 @@ use std::sync::Arc;
use crate::errors::AppError; use crate::errors::AppError;
use crate::models::UnifiedRequest; use crate::models::UnifiedRequest;
pub mod deepseek; pub mod deepseek;
pub mod gemini; pub mod gemini;
pub mod grok; pub mod grok;
@@ -125,17 +126,35 @@ impl ProviderManager {
db_pool: &crate::database::DbPool, db_pool: &crate::database::DbPool,
) -> Result<()> { ) -> Result<()> {
// Load override from database // Load override from database
let db_config = sqlx::query("SELECT enabled, base_url, api_key FROM provider_configs WHERE id = ?") let db_config = sqlx::query("SELECT enabled, base_url, api_key, api_key_encrypted FROM provider_configs WHERE id = ?")
.bind(name) .bind(name)
.fetch_optional(db_pool) .fetch_optional(db_pool)
.await?; .await?;
let (enabled, base_url, api_key) = if let Some(row) = db_config { let (enabled, base_url, api_key) = if let Some(row) = db_config {
( let enabled = row.get::<bool, _>("enabled");
row.get::<bool, _>("enabled"), let base_url = row.get::<Option<String>, _>("base_url");
row.get::<Option<String>, _>("base_url"), let api_key_encrypted = row.get::<bool, _>("api_key_encrypted");
row.get::<Option<String>, _>("api_key"), let api_key = row.get::<Option<String>, _>("api_key");
) // Decrypt API key if encrypted
let api_key = match (api_key, api_key_encrypted) {
(Some(key), true) => {
match crate::utils::crypto::decrypt(&key) {
Ok(decrypted) => Some(decrypted),
Err(e) => {
tracing::error!("Failed to decrypt API key for provider {}: {}", name, e);
None
}
}
}
(Some(key), false) => {
// Plaintext key - optionally encrypt and update database (lazy migration)
// For now, just use plaintext
Some(key)
}
(None, _) => None,
};
(enabled, base_url, api_key)
} else { } else {
// No database override, use defaults from AppConfig // No database override, use defaults from AppConfig
match name { match name {

View File

@@ -6,12 +6,15 @@
//! 3. Global rate limiting for overall system protection //! 3. Global rate limiting for overall system protection
use anyhow::Result; use anyhow::Result;
use governor::{Quota, RateLimiter, DefaultDirectRateLimiter};
use std::collections::HashMap; use std::collections::HashMap;
use std::num::NonZeroU32;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tracing::{info, warn}; use tracing::{info, warn};
type GovRateLimiter = DefaultDirectRateLimiter;
/// Rate limiter configuration /// Rate limiter configuration
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct RateLimiterConfig { pub struct RateLimiterConfig {
@@ -65,45 +68,7 @@ impl Default for CircuitBreakerConfig {
} }
} }
/// Simple token bucket rate limiter for a single client
#[derive(Debug)]
struct TokenBucket {
tokens: f64,
capacity: f64,
refill_rate: f64, // tokens per second
last_refill: Instant,
}
impl TokenBucket {
fn new(capacity: f64, refill_rate: f64) -> Self {
Self {
tokens: capacity,
capacity,
refill_rate,
last_refill: Instant::now(),
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
let new_tokens = elapsed * self.refill_rate;
self.tokens = (self.tokens + new_tokens).min(self.capacity);
self.last_refill = now;
}
fn try_acquire(&mut self, tokens: f64) -> bool {
self.refill();
if self.tokens >= tokens {
self.tokens -= tokens;
true
} else {
false
}
}
}
/// Circuit breaker for a provider /// Circuit breaker for a provider
#[derive(Debug)] #[derive(Debug)]
@@ -209,8 +174,8 @@ impl ProviderCircuitBreaker {
/// Rate limiting and circuit breaking manager /// Rate limiting and circuit breaking manager
#[derive(Debug)] #[derive(Debug)]
pub struct RateLimitManager { pub struct RateLimitManager {
client_buckets: Arc<RwLock<HashMap<String, TokenBucket>>>, client_buckets: Arc<RwLock<HashMap<String, GovRateLimiter>>>,
global_bucket: Arc<RwLock<TokenBucket>>, global_bucket: Arc<GovRateLimiter>,
circuit_breakers: Arc<RwLock<HashMap<String, ProviderCircuitBreaker>>>, circuit_breakers: Arc<RwLock<HashMap<String, ProviderCircuitBreaker>>>,
config: RateLimiterConfig, config: RateLimiterConfig,
circuit_config: CircuitBreakerConfig, circuit_config: CircuitBreakerConfig,
@@ -218,15 +183,16 @@ pub struct RateLimitManager {
impl RateLimitManager { impl RateLimitManager {
pub fn new(config: RateLimiterConfig, circuit_config: CircuitBreakerConfig) -> Self { pub fn new(config: RateLimiterConfig, circuit_config: CircuitBreakerConfig) -> Self {
// Convert requests per minute to tokens per second // Create global rate limiter quota
let global_refill_rate = config.global_requests_per_minute as f64 / 60.0; let global_quota = Quota::per_minute(
NonZeroU32::new(config.global_requests_per_minute).expect("global_requests_per_minute must be positive")
)
.allow_burst(NonZeroU32::new(config.burst_size).expect("burst_size must be positive"));
let global_bucket = RateLimiter::direct(global_quota);
Self { Self {
client_buckets: Arc::new(RwLock::new(HashMap::new())), client_buckets: Arc::new(RwLock::new(HashMap::new())),
global_bucket: Arc::new(RwLock::new(TokenBucket::new( global_bucket: Arc::new(global_bucket),
config.burst_size as f64,
global_refill_rate,
))),
circuit_breakers: Arc::new(RwLock::new(HashMap::new())), circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
config, config,
circuit_config, circuit_config,
@@ -236,24 +202,22 @@ impl RateLimitManager {
/// Check if a client request is allowed /// Check if a client request is allowed
pub async fn check_client_request(&self, client_id: &str) -> Result<bool> { pub async fn check_client_request(&self, client_id: &str) -> Result<bool> {
// Check global rate limit first (1 token per request) // Check global rate limit first (1 token per request)
{ if self.global_bucket.check().is_err() {
let mut global_bucket = self.global_bucket.write().await;
if !global_bucket.try_acquire(1.0) {
warn!("Global rate limit exceeded"); warn!("Global rate limit exceeded");
return Ok(false); return Ok(false);
} }
}
// Check client-specific rate limit // Check client-specific rate limit
let mut buckets = self.client_buckets.write().await; let mut buckets = self.client_buckets.write().await;
let bucket = buckets.entry(client_id.to_string()).or_insert_with(|| { let bucket = buckets.entry(client_id.to_string()).or_insert_with(|| {
TokenBucket::new( let quota = Quota::per_minute(
self.config.burst_size as f64, NonZeroU32::new(self.config.requests_per_minute).expect("requests_per_minute must be positive")
self.config.requests_per_minute as f64 / 60.0,
) )
.allow_burst(NonZeroU32::new(self.config.burst_size).expect("burst_size must be positive"));
RateLimiter::direct(quota)
}); });
Ok(bucket.try_acquire(1.0)) Ok(bucket.check().is_ok())
} }
/// Check if provider requests are allowed (circuit breaker) /// Check if provider requests are allowed (circuit breaker)

View File

@@ -5,6 +5,11 @@ use axum::{
response::sse::{Event, Sse}, response::sse::{Event, Sse},
routing::{get, post}, routing::{get, post},
}; };
use axum::http::{header, HeaderValue};
use tower_http::{
limit::RequestBodyLimitLayer,
set_header::SetResponseHeaderLayer,
};
use futures::StreamExt; use futures::StreamExt;
use std::sync::Arc; use std::sync::Arc;
@@ -23,9 +28,34 @@ use crate::{
}; };
pub fn router(state: AppState) -> Router { pub fn router(state: AppState) -> Router {
// Security headers
let csp_header: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::CONTENT_SECURITY_POLICY,
"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws:;"
.parse()
.unwrap(),
);
let x_frame_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_FRAME_OPTIONS,
"DENY".parse().unwrap(),
);
let x_content_type_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_CONTENT_TYPE_OPTIONS,
"nosniff".parse().unwrap(),
);
let strict_transport_security: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::STRICT_TRANSPORT_SECURITY,
"max-age=31536000; includeSubDomains".parse().unwrap(),
);
Router::new() Router::new()
.route("/v1/chat/completions", post(chat_completions)) .route("/v1/chat/completions", post(chat_completions))
.route("/v1/models", get(list_models)) .route("/v1/models", get(list_models))
.layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit
.layer(csp_header)
.layer(x_frame_options)
.layer(x_content_type_options)
.layer(strict_transport_security)
.layer(axum::middleware::from_fn_with_state( .layer(axum::middleware::from_fn_with_state(
state.clone(), state.clone(),
rate_limiting::middleware::rate_limit_middleware, rate_limiting::middleware::rate_limit_middleware,
@@ -219,7 +249,6 @@ async fn chat_completions(
prompt_tokens, prompt_tokens,
has_images, has_images,
logger: state.request_logger.clone(), logger: state.request_logger.clone(),
client_manager: state.client_manager.clone(),
model_registry: state.model_registry.clone(), model_registry: state.model_registry.clone(),
model_config_cache: state.model_config_cache.clone(), model_config_cache: state.model_config_cache.clone(),
}, },
@@ -341,15 +370,6 @@ async fn chat_completions(
duration_ms: duration.as_millis() as u64, duration_ms: duration.as_millis() as u64,
}); });
// Update client usage (fire-and-forget, don't block response)
{
let cm = state.client_manager.clone();
let cid = client_id.clone();
tokio::spawn(async move {
let _ = cm.update_client_usage(&cid, response.total_tokens as i64, cost).await;
});
}
// Convert ProviderResponse to ChatCompletionResponse // Convert ProviderResponse to ChatCompletionResponse
let finish_reason = if response.tool_calls.is_some() { let finish_reason = if response.tool_calls.is_some() {
"tool_calls".to_string() "tool_calls".to_string()

171
src/utils/crypto.rs Normal file
View File

@@ -0,0 +1,171 @@
use aes_gcm::{
aead::{Aead, AeadCore, KeyInit, OsRng},
Aes256Gcm, Key, Nonce,
};
use anyhow::{anyhow, Context, Result};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use std::env;
use std::sync::OnceLock;
static RAW_KEY: OnceLock<[u8; 32]> = OnceLock::new();
/// Initialize the encryption key from a hex or base64 encoded string.
/// Must be called before any encryption/decryption operations.
/// Returns error if the key is invalid or already initialized with a different key.
pub fn init_with_key(key_str: &str) -> Result<()> {
let key_bytes = hex::decode(key_str)
.or_else(|_| BASE64.decode(key_str))
.context("Encryption key must be hex or base64 encoded")?;
if key_bytes.len() != 32 {
anyhow::bail!(
"Encryption key must be 32 bytes (256 bits), got {} bytes",
key_bytes.len()
);
}
let key_array: [u8; 32] = key_bytes.try_into().unwrap(); // safe due to length check
// Check if already initialized with same key
if let Some(existing) = RAW_KEY.get() {
if existing == &key_array {
// Same key already initialized, okay
return Ok(());
} else {
anyhow::bail!("Encryption key already initialized with a different key");
}
}
// Store raw key bytes
RAW_KEY
.set(key_array)
.map_err(|_| anyhow::anyhow!("Encryption key already initialized"))?;
Ok(())
}
/// Initialize the encryption key from the environment variable `LLM_PROXY__ENCRYPTION_KEY`.
/// Must be called before any encryption/decryption operations.
/// Panics if the environment variable is missing or invalid.
pub fn init_from_env() -> Result<()> {
let key_str =
env::var("LLM_PROXY__ENCRYPTION_KEY").context("LLM_PROXY__ENCRYPTION_KEY environment variable not set")?;
init_with_key(&key_str)
}
/// Get the encryption key bytes, panicking if not initialized.
fn get_key() -> &'static [u8; 32] {
RAW_KEY
.get()
.expect("Encryption key not initialized. Call crypto::init_with_key() or crypto::init_from_env() first.")
}
/// Encrypt a plaintext string and return a base64-encoded ciphertext (nonce || ciphertext || tag).
pub fn encrypt(plaintext: &str) -> Result<String> {
let key = Key::<Aes256Gcm>::from_slice(get_key());
let cipher = Aes256Gcm::new(key);
let nonce = Aes256Gcm::generate_nonce(&mut OsRng); // 12 bytes
let ciphertext = cipher
.encrypt(&nonce, plaintext.as_bytes())
.map_err(|e| anyhow!("Encryption failed: {}", e))?;
// Combine nonce and ciphertext (ciphertext already includes tag)
let mut combined = Vec::with_capacity(nonce.len() + ciphertext.len());
combined.extend_from_slice(&nonce);
combined.extend_from_slice(&ciphertext);
Ok(BASE64.encode(combined))
}
/// Decrypt a base64-encoded ciphertext (nonce || ciphertext || tag) to a plaintext string.
pub fn decrypt(ciphertext_b64: &str) -> Result<String> {
let key = Key::<Aes256Gcm>::from_slice(get_key());
let cipher = Aes256Gcm::new(key);
let combined = BASE64.decode(ciphertext_b64).context("Invalid base64 ciphertext")?;
if combined.len() < 12 {
anyhow::bail!("Ciphertext too short");
}
let (nonce_bytes, ciphertext_and_tag) = combined.split_at(12);
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext_bytes = cipher
.decrypt(nonce, ciphertext_and_tag)
.map_err(|e| anyhow!("Decryption failed (invalid key or corrupted ciphertext): {}", e))?;
String::from_utf8(plaintext_bytes).context("Decrypted bytes are not valid UTF-8")
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_KEY_HEX: &str = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f";
#[test]
fn test_encrypt_decrypt() {
init_with_key(TEST_KEY_HEX).unwrap();
let plaintext = "super secret api key";
let ciphertext = encrypt(plaintext).unwrap();
assert_ne!(ciphertext, plaintext);
let decrypted = decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_different_inputs_produce_different_ciphertexts() {
init_with_key(TEST_KEY_HEX).unwrap();
let plaintext = "same";
let cipher1 = encrypt(plaintext).unwrap();
let cipher2 = encrypt(plaintext).unwrap();
assert_ne!(cipher1, cipher2, "Nonce should make ciphertexts differ");
assert_eq!(decrypt(&cipher1).unwrap(), plaintext);
assert_eq!(decrypt(&cipher2).unwrap(), plaintext);
}
#[test]
fn test_invalid_key_length() {
let result = init_with_key("tooshort");
assert!(result.is_err());
}
#[test]
fn test_init_from_env() {
unsafe { std::env::set_var("LLM_PROXY__ENCRYPTION_KEY", TEST_KEY_HEX) };
let result = init_from_env();
assert!(result.is_ok());
// Ensure encryption works
let ciphertext = encrypt("test").unwrap();
let decrypted = decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, "test");
}
#[test]
fn test_missing_env_key() {
unsafe { std::env::remove_var("LLM_PROXY__ENCRYPTION_KEY") };
let result = init_from_env();
assert!(result.is_err());
}
#[test]
fn test_key_hex_and_base64() {
// Hex key works
init_with_key(TEST_KEY_HEX).unwrap();
// Base64 key (same bytes encoded as base64)
let base64_key = BASE64.encode(hex::decode(TEST_KEY_HEX).unwrap());
// Re-initialization with same key (different encoding) is allowed
let result = init_with_key(&base64_key);
assert!(result.is_ok());
// Encryption should still work
let ciphertext = encrypt("test").unwrap();
let decrypted = decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, "test");
}
#[test]
#[ignore] // conflicts with global state from other tests
fn test_already_initialized() {
init_with_key(TEST_KEY_HEX).unwrap();
let result = init_with_key(TEST_KEY_HEX);
assert!(result.is_ok()); // same key allowed
}
#[test]
#[ignore] // conflicts with global state from other tests
fn test_already_initialized_different_key() {
init_with_key(TEST_KEY_HEX).unwrap();
let different_key = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e20";
let result = init_with_key(different_key);
assert!(result.is_err());
}
}

View File

@@ -1,3 +1,4 @@
pub mod crypto;
pub mod registry; pub mod registry;
pub mod streaming; pub mod streaming;
pub mod tokens; pub mod tokens;

View File

@@ -1,4 +1,4 @@
use crate::client::ClientManager;
use crate::errors::AppError; use crate::errors::AppError;
use crate::logging::{RequestLog, RequestLogger}; use crate::logging::{RequestLog, RequestLogger};
use crate::models::ToolCall; use crate::models::ToolCall;
@@ -18,7 +18,6 @@ pub struct StreamConfig {
pub prompt_tokens: u32, pub prompt_tokens: u32,
pub has_images: bool, pub has_images: bool,
pub logger: Arc<RequestLogger>, pub logger: Arc<RequestLogger>,
pub client_manager: Arc<ClientManager>,
pub model_registry: Arc<crate::models::registry::ModelRegistry>, pub model_registry: Arc<crate::models::registry::ModelRegistry>,
pub model_config_cache: ModelConfigCache, pub model_config_cache: ModelConfigCache,
} }
@@ -36,7 +35,6 @@ pub struct AggregatingStream<S> {
/// Real usage data from the provider's final stream chunk (when available). /// Real usage data from the provider's final stream chunk (when available).
real_usage: Option<StreamUsage>, real_usage: Option<StreamUsage>,
logger: Arc<RequestLogger>, logger: Arc<RequestLogger>,
client_manager: Arc<ClientManager>,
model_registry: Arc<crate::models::registry::ModelRegistry>, model_registry: Arc<crate::models::registry::ModelRegistry>,
model_config_cache: ModelConfigCache, model_config_cache: ModelConfigCache,
start_time: std::time::Instant, start_time: std::time::Instant,
@@ -60,7 +58,6 @@ where
accumulated_tool_calls: Vec::new(), accumulated_tool_calls: Vec::new(),
real_usage: None, real_usage: None,
logger: config.logger, logger: config.logger,
client_manager: config.client_manager,
model_registry: config.model_registry, model_registry: config.model_registry,
model_config_cache: config.model_config_cache, model_config_cache: config.model_config_cache,
start_time: std::time::Instant::now(), start_time: std::time::Instant::now(),
@@ -79,7 +76,6 @@ where
let provider_name = self.provider.name().to_string(); let provider_name = self.provider.name().to_string();
let model = self.model.clone(); let model = self.model.clone();
let logger = self.logger.clone(); let logger = self.logger.clone();
let client_manager = self.client_manager.clone();
let provider = self.provider.clone(); let provider = self.provider.clone();
let estimated_prompt_tokens = self.prompt_tokens; let estimated_prompt_tokens = self.prompt_tokens;
let has_images = self.has_images; let has_images = self.has_images;
@@ -162,11 +158,6 @@ where
error_message: None, error_message: None,
duration_ms: duration.as_millis() as u64, duration_ms: duration.as_millis() as u64,
}); });
// Update client usage
let _ = client_manager
.update_client_usage(&client_id, total_tokens as i64, cost)
.await;
}); });
} }
} }
@@ -304,7 +295,6 @@ mod tests {
let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
let (dashboard_tx, _) = tokio::sync::broadcast::channel(16); let (dashboard_tx, _) = tokio::sync::broadcast::channel(16);
let logger = Arc::new(RequestLogger::new(pool.clone(), dashboard_tx)); let logger = Arc::new(RequestLogger::new(pool.clone(), dashboard_tx));
let client_manager = Arc::new(ClientManager::new(pool.clone()));
let registry = Arc::new(crate::models::registry::ModelRegistry { let registry = Arc::new(crate::models::registry::ModelRegistry {
providers: std::collections::HashMap::new(), providers: std::collections::HashMap::new(),
}); });
@@ -318,7 +308,6 @@ mod tests {
prompt_tokens: 10, prompt_tokens: 10,
has_images: false, has_images: false,
logger, logger,
client_manager,
model_registry: registry, model_registry: registry,
model_config_cache: ModelConfigCache::new(pool.clone()), model_config_cache: ModelConfigCache::new(pool.clone()),
}, },

View File

@@ -35,6 +35,14 @@ class ApiClient {
throw new Error(result.error || `HTTP error! status: ${response.status}`); throw new Error(result.error || `HTTP error! status: ${response.status}`);
} }
// Handling X-Refreshed-Token header
if (response.headers.get('X-Refreshed-Token') && window.authManager) {
window.authManager.token = response.headers.get('X-Refreshed-Token');
if (window.authManager.setToken) {
window.authManager.setToken(window.authManager.token);
}
}
return result.data; return result.data;
} }
@@ -87,6 +95,17 @@ class ApiClient {
const date = luxon.DateTime.fromISO(dateStr); const date = luxon.DateTime.fromISO(dateStr);
return date.toRelative(); return date.toRelative();
} }
// Helper for escaping HTML
escapeHtml(unsafe) {
if (unsafe === undefined || unsafe === null) return '';
return unsafe.toString()
.replace(/&/g, "&amp;")
.replace(/</g, "&lt;")
.replace(/>/g, "&gt;")
.replace(/"/g, "&quot;")
.replace(/'/g, "&#039;");
}
} }
window.api = new ApiClient(); window.api = new ApiClient();

View File

@@ -50,6 +50,12 @@ class AuthManager {
}); });
} }
setToken(newToken) {
if (!newToken) return;
this.token = newToken;
localStorage.setItem('dashboard_token', this.token);
}
async login(username, password) { async login(username, password) {
const errorElement = document.getElementById('login-error'); const errorElement = document.getElementById('login-error');
const loginBtn = document.querySelector('.login-btn'); const loginBtn = document.querySelector('.login-btn');

View File

@@ -42,12 +42,15 @@ class ClientsPage {
const statusIcon = client.status === 'active' ? 'check-circle' : 'clock'; const statusIcon = client.status === 'active' ? 'check-circle' : 'clock';
const created = luxon.DateTime.fromISO(client.created_at).toFormat('MMM dd, yyyy'); const created = luxon.DateTime.fromISO(client.created_at).toFormat('MMM dd, yyyy');
const escapedId = window.api.escapeHtml(client.id);
const escapedName = window.api.escapeHtml(client.name);
return ` return `
<tr> <tr>
<td><span class="badge-client">${client.id}</span></td> <td><span class="badge-client">${escapedId}</span></td>
<td><strong>${client.name}</strong></td> <td><strong>${escapedName}</strong></td>
<td> <td>
<code class="token-display">sk-••••${client.id.substring(client.id.length - 4)}</code> <code class="token-display">sk-••••${escapedId.substring(escapedId.length - 4)}</code>
</td> </td>
<td>${created}</td> <td>${created}</td>
<td>${client.last_used ? window.api.formatTimeAgo(client.last_used) : 'Never'}</td> <td>${client.last_used ? window.api.formatTimeAgo(client.last_used) : 'Never'}</td>
@@ -55,16 +58,16 @@ class ClientsPage {
<td> <td>
<span class="status-badge ${statusClass}"> <span class="status-badge ${statusClass}">
<i class="fas fa-${statusIcon}"></i> <i class="fas fa-${statusIcon}"></i>
${client.status} ${window.api.escapeHtml(client.status)}
</span> </span>
</td> </td>
<td> <td>
${window._userRole === 'admin' ? ` ${window._userRole === 'admin' ? `
<div class="action-buttons"> <div class="action-buttons">
<button class="btn-action" title="Edit" onclick="window.clientsPage.editClient('${client.id}')"> <button class="btn-action" title="Edit" onclick="window.clientsPage.editClient('${escapedId}')">
<i class="fas fa-edit"></i> <i class="fas fa-edit"></i>
</button> </button>
<button class="btn-action danger" title="Delete" onclick="window.clientsPage.deleteClient('${client.id}')"> <button class="btn-action danger" title="Delete" onclick="window.clientsPage.deleteClient('${escapedId}')">
<i class="fas fa-trash"></i> <i class="fas fa-trash"></i>
</button> </button>
</div> </div>
@@ -188,10 +191,13 @@ class ClientsPage {
showTokenRevealModal(clientName, token) { showTokenRevealModal(clientName, token) {
const modal = document.createElement('div'); const modal = document.createElement('div');
modal.className = 'modal active'; modal.className = 'modal active';
const escapedName = window.api.escapeHtml(clientName);
const escapedToken = window.api.escapeHtml(token);
modal.innerHTML = ` modal.innerHTML = `
<div class="modal-content"> <div class="modal-content">
<div class="modal-header"> <div class="modal-header">
<h3 class="modal-title">Client Created: ${clientName}</h3> <h3 class="modal-title">Client Created: ${escapedName}</h3>
</div> </div>
<div class="modal-body"> <div class="modal-body">
<p style="margin-bottom: 0.75rem; color: var(--yellow);"> <p style="margin-bottom: 0.75rem; color: var(--yellow);">
@@ -201,7 +207,7 @@ class ClientsPage {
<div class="form-control"> <div class="form-control">
<label>API Token</label> <label>API Token</label>
<div style="display: flex; gap: 0.5rem;"> <div style="display: flex; gap: 0.5rem;">
<input type="text" id="revealed-token" value="${token}" readonly <input type="text" id="revealed-token" value="${escapedToken}" readonly
style="font-family: monospace; font-size: 0.85rem;"> style="font-family: monospace; font-size: 0.85rem;">
<button class="btn btn-secondary" id="copy-token-btn" title="Copy"> <button class="btn btn-secondary" id="copy-token-btn" title="Copy">
<i class="fas fa-copy"></i> <i class="fas fa-copy"></i>
@@ -248,10 +254,16 @@ class ClientsPage {
showEditClientModal(client) { showEditClientModal(client) {
const modal = document.createElement('div'); const modal = document.createElement('div');
modal.className = 'modal active'; modal.className = 'modal active';
const escapedId = window.api.escapeHtml(client.id);
const escapedName = window.api.escapeHtml(client.name);
const escapedDescription = window.api.escapeHtml(client.description);
const escapedRateLimit = window.api.escapeHtml(client.rate_limit_per_minute);
modal.innerHTML = ` modal.innerHTML = `
<div class="modal-content"> <div class="modal-content">
<div class="modal-header"> <div class="modal-header">
<h3 class="modal-title">Edit Client: ${client.id}</h3> <h3 class="modal-title">Edit Client: ${escapedId}</h3>
<button class="modal-close" onclick="this.closest('.modal').remove()"> <button class="modal-close" onclick="this.closest('.modal').remove()">
<i class="fas fa-times"></i> <i class="fas fa-times"></i>
</button> </button>
@@ -259,15 +271,15 @@ class ClientsPage {
<div class="modal-body"> <div class="modal-body">
<div class="form-control"> <div class="form-control">
<label for="edit-client-name">Display Name</label> <label for="edit-client-name">Display Name</label>
<input type="text" id="edit-client-name" value="${client.name || ''}" placeholder="e.g. My Coding Assistant"> <input type="text" id="edit-client-name" value="${escapedName}" placeholder="e.g. My Coding Assistant">
</div> </div>
<div class="form-control"> <div class="form-control">
<label for="edit-client-description">Description</label> <label for="edit-client-description">Description</label>
<textarea id="edit-client-description" rows="3" placeholder="Optional description">${client.description || ''}</textarea> <textarea id="edit-client-description" rows="3" placeholder="Optional description">${escapedDescription}</textarea>
</div> </div>
<div class="form-control"> <div class="form-control">
<label for="edit-client-rate-limit">Rate Limit (requests/minute)</label> <label for="edit-client-rate-limit">Rate Limit (requests/minute)</label>
<input type="number" id="edit-client-rate-limit" min="0" value="${client.rate_limit_per_minute || ''}" placeholder="Leave empty for unlimited"> <input type="number" id="edit-client-rate-limit" min="0" value="${escapedRateLimit}" placeholder="Leave empty for unlimited">
</div> </div>
<div class="form-control"> <div class="form-control">
<label class="toggle-label"> <label class="toggle-label">
@@ -357,12 +369,16 @@ class ClientsPage {
const lastUsed = t.last_used_at const lastUsed = t.last_used_at
? luxon.DateTime.fromISO(t.last_used_at).toRelative() ? luxon.DateTime.fromISO(t.last_used_at).toRelative()
: 'Never'; : 'Never';
const escapedMaskedToken = window.api.escapeHtml(t.token_masked);
const escapedClientId = window.api.escapeHtml(clientId);
const tokenId = parseInt(t.id); // Assuming ID is numeric
return ` return `
<div style="display: flex; align-items: center; gap: 0.5rem; padding: 0.4rem 0; border-bottom: 1px solid var(--bg3);"> <div style="display: flex; align-items: center; gap: 0.5rem; padding: 0.4rem 0; border-bottom: 1px solid var(--bg3);">
<code style="flex: 1; font-size: 0.8rem; color: var(--fg2);">${t.token_masked}</code> <code style="flex: 1; font-size: 0.8rem; color: var(--fg2);">${escapedMaskedToken}</code>
<span style="font-size: 0.75rem; color: var(--fg4);" title="Last used">${lastUsed}</span> <span style="font-size: 0.75rem; color: var(--fg4);" title="Last used">${lastUsed}</span>
<button class="btn-action danger" title="Revoke" style="padding: 0.2rem 0.4rem;" <button class="btn-action danger" title="Revoke" style="padding: 0.2rem 0.4rem;"
onclick="window.clientsPage.revokeToken('${clientId}', ${t.id}, this)"> onclick="window.clientsPage.revokeToken('${escapedClientId}', ${tokenId}, this)">
<i class="fas fa-trash" style="font-size: 0.75rem;"></i> <i class="fas fa-trash" style="font-size: 0.75rem;"></i>
</button> </button>
</div> </div>

View File

@@ -47,16 +47,21 @@ class ProvidersPage {
const isLowBalance = provider.credit_balance <= provider.low_credit_threshold && provider.id !== 'ollama'; const isLowBalance = provider.credit_balance <= provider.low_credit_threshold && provider.id !== 'ollama';
const balanceColor = isLowBalance ? 'var(--red-light)' : 'var(--green-light)'; const balanceColor = isLowBalance ? 'var(--red-light)' : 'var(--green-light)';
const escapedId = window.api.escapeHtml(provider.id);
const escapedName = window.api.escapeHtml(provider.name);
const escapedStatus = window.api.escapeHtml(provider.status);
const billingMode = provider.billing_mode ? provider.billing_mode.toUpperCase() : 'PREPAID';
return ` return `
<div class="provider-card ${provider.status}"> <div class="provider-card ${escapedStatus}">
<div class="provider-card-header"> <div class="provider-card-header">
<div class="provider-info"> <div class="provider-info">
<h4 class="provider-name">${provider.name}</h4> <h4 class="provider-name">${escapedName}</h4>
<span class="provider-id">${provider.id}</span> <span class="provider-id">${escapedId}</span>
</div> </div>
<span class="status-badge ${statusClass}"> <span class="status-badge ${statusClass}">
<i class="fas fa-circle"></i> <i class="fas fa-circle"></i>
${provider.status} ${escapedStatus}
</span> </span>
</div> </div>
<div class="provider-card-body"> <div class="provider-card-body">
@@ -67,12 +72,12 @@ class ProvidersPage {
</div> </div>
<div class="meta-item" style="color: ${balanceColor}; font-weight: 700;"> <div class="meta-item" style="color: ${balanceColor}; font-weight: 700;">
<i class="fas fa-wallet"></i> <i class="fas fa-wallet"></i>
<span>Balance: ${provider.id === 'ollama' ? 'FREE' : window.api.formatCurrency(provider.credit_balance)}</span> <span>Balance: ${escapedId === 'ollama' ? 'FREE' : window.api.formatCurrency(provider.credit_balance)}</span>
${isLowBalance ? '<i class="fas fa-exclamation-triangle" title="Low Balance"></i>' : ''} ${isLowBalance ? '<i class="fas fa-exclamation-triangle" title="Low Balance"></i>' : ''}
</div> </div>
<div class="meta-item"> <div class="meta-item">
<i class="fas fa-exchange-alt"></i> <i class="fas fa-exchange-alt"></i>
<span>Billing: ${provider.billing_mode ? provider.billing_mode.toUpperCase() : 'PREPAID'}</span> <span>Billing: ${window.api.escapeHtml(billingMode)}</span>
</div> </div>
<div class="meta-item"> <div class="meta-item">
<i class="fas fa-clock"></i> <i class="fas fa-clock"></i>
@@ -80,16 +85,16 @@ class ProvidersPage {
</div> </div>
</div> </div>
<div class="model-tags"> <div class="model-tags">
${(provider.models || []).slice(0, 5).map(m => `<span class="model-tag">${m}</span>`).join('')} ${(provider.models || []).slice(0, 5).map(m => `<span class="model-tag">${window.api.escapeHtml(m)}</span>`).join('')}
${modelCount > 5 ? `<span class="model-tag more">+${modelCount - 5} more</span>` : ''} ${modelCount > 5 ? `<span class="model-tag more">+${modelCount - 5} more</span>` : ''}
</div> </div>
</div> </div>
<div class="provider-card-footer"> <div class="provider-card-footer">
<button class="btn btn-secondary btn-sm" onclick="window.providersPage.testProvider('${provider.id}')"> <button class="btn btn-secondary btn-sm" onclick="window.providersPage.testProvider('${escapedId}')">
<i class="fas fa-vial"></i> Test <i class="fas fa-vial"></i> Test
</button> </button>
${window._userRole === 'admin' ? ` ${window._userRole === 'admin' ? `
<button class="btn btn-primary btn-sm" onclick="window.providersPage.configureProvider('${provider.id}')"> <button class="btn btn-primary btn-sm" onclick="window.providersPage.configureProvider('${escapedId}')">
<i class="fas fa-cog"></i> Config <i class="fas fa-cog"></i> Config
</button> </button>
` : ''} ` : ''}
@@ -144,10 +149,17 @@ class ProvidersPage {
const modal = document.createElement('div'); const modal = document.createElement('div');
modal.className = 'modal active'; modal.className = 'modal active';
const escapedId = window.api.escapeHtml(provider.id);
const escapedName = window.api.escapeHtml(provider.name);
const escapedBaseUrl = window.api.escapeHtml(provider.base_url);
const escapedBalance = window.api.escapeHtml(provider.credit_balance);
const escapedThreshold = window.api.escapeHtml(provider.low_credit_threshold);
modal.innerHTML = ` modal.innerHTML = `
<div class="modal-content"> <div class="modal-content">
<div class="modal-header"> <div class="modal-header">
<h3 class="modal-title">Configure ${provider.name}</h3> <h3 class="modal-title">Configure ${escapedName}</h3>
<button class="modal-close" onclick="this.closest('.modal').remove()"> <button class="modal-close" onclick="this.closest('.modal').remove()">
<i class="fas fa-times"></i> <i class="fas fa-times"></i>
</button> </button>
@@ -161,7 +173,7 @@ class ProvidersPage {
</div> </div>
<div class="form-control"> <div class="form-control">
<label for="provider-base-url">Base URL</label> <label for="provider-base-url">Base URL</label>
<input type="text" id="provider-base-url" value="${provider.base_url || ''}" placeholder="Default API URL"> <input type="text" id="provider-base-url" value="${escapedBaseUrl}" placeholder="Default API URL">
</div> </div>
<div class="form-control"> <div class="form-control">
<label for="provider-api-key">API Key (Optional / Overwrite)</label> <label for="provider-api-key">API Key (Optional / Overwrite)</label>
@@ -170,11 +182,11 @@ class ProvidersPage {
<div class="grid-2"> <div class="grid-2">
<div class="form-control"> <div class="form-control">
<label for="provider-balance">Current Credit Balance ($)</label> <label for="provider-balance">Current Credit Balance ($)</label>
<input type="number" id="provider-balance" value="${provider.credit_balance}" step="0.01"> <input type="number" id="provider-balance" value="${escapedBalance}" step="0.01">
</div> </div>
<div class="form-control"> <div class="form-control">
<label for="provider-threshold">Low Balance Alert ($)</label> <label for="provider-threshold">Low Balance Alert ($)</label>
<input type="number" id="provider-threshold" value="${provider.low_credit_threshold}" step="0.50"> <input type="number" id="provider-threshold" value="${escapedThreshold}" step="0.50">
</div> </div>
</div> </div>
<div class="form-control"> <div class="form-control">

View File

@@ -280,8 +280,6 @@
// ── Helpers ──────────────────────────────────────────────────── // ── Helpers ────────────────────────────────────────────────────
function escapeHtml(str) { function escapeHtml(str) {
const div = document.createElement('div'); return window.api.escapeHtml(str);
div.textContent = str;
return div.innerHTML;
} }
})(); })();