feat(security): implement AES-256-GCM encryption for API keys and HMAC-signed session tokens

This commit introduces:
- AES-256-GCM encryption for LLM provider API keys in the database.
- HMAC-SHA256 signed session tokens with activity-based refresh logic.
- Standardized frontend XSS protection using a global escapeHtml utility.
- Hardened security headers and request body size limits.
- Improved database integrity with foreign key enforcement and atomic transactions.
- Integration tests for the full encrypted key storage and proxy usage lifecycle.
This commit is contained in:
2026-03-06 14:17:56 -05:00
parent 149a7c3a29
commit 9b8483e797
28 changed files with 1260 additions and 227 deletions

View File

@@ -26,3 +26,6 @@ LLM_PROXY__SERVER__PORT=8080
# Database path (optional)
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
# Session secret for HMAC-signed tokens (hex or base64 encoded, 32 bytes)
SESSION_SECRET=your_session_secret_here_32_bytes

201
Cargo.lock generated
View File

@@ -8,6 +8,41 @@ version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa"
[[package]]
name = "aead"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0"
dependencies = [
"crypto-common",
"generic-array",
]
[[package]]
name = "aes"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0"
dependencies = [
"cfg-if",
"cipher",
"cpufeatures",
]
[[package]]
name = "aes-gcm"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1"
dependencies = [
"aead",
"aes",
"cipher",
"ctr",
"ghash",
"subtle",
]
[[package]]
name = "ahash"
version = "0.7.8"
@@ -541,9 +576,33 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a"
dependencies = [
"generic-array",
"rand_core 0.6.4",
"typenum",
]
[[package]]
name = "ctr"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835"
dependencies = [
"cipher",
]
[[package]]
name = "dashmap"
version = "6.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf"
dependencies = [
"cfg-if",
"crossbeam-utils",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]]
name = "data-encoding"
version = "2.10.0"
@@ -895,6 +954,37 @@ dependencies = [
"wasip3",
]
[[package]]
name = "ghash"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1"
dependencies = [
"opaque-debug",
"polyval",
]
[[package]]
name = "governor"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0746aa765db78b521451ef74221663b57ba595bf83f75d0ce23cc09447c8139f"
dependencies = [
"cfg-if",
"dashmap",
"futures-sink",
"futures-timer",
"futures-util",
"no-std-compat",
"nonzero_ext",
"parking_lot",
"portable-atomic",
"quanta",
"rand 0.8.5",
"smallvec",
"spinning_top",
]
[[package]]
name = "h2"
version = "0.4.13"
@@ -923,6 +1013,12 @@ dependencies = [
"ahash",
]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]]
name = "hashbrown"
version = "0.15.5"
@@ -1431,6 +1527,7 @@ checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77"
name = "llm-proxy"
version = "0.1.0"
dependencies = [
"aes-gcm",
"anyhow",
"assert_cmd",
"async-stream",
@@ -1443,8 +1540,10 @@ dependencies = [
"config",
"dotenvy",
"futures",
"governor",
"headers",
"hex",
"hmac",
"image",
"insta",
"mime",
@@ -1454,6 +1553,7 @@ dependencies = [
"reqwest-eventsource",
"serde",
"serde_json",
"sha2",
"sqlx",
"tempfile",
"thiserror 1.0.69",
@@ -1598,6 +1698,12 @@ dependencies = [
"pxfm",
]
[[package]]
name = "no-std-compat"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
[[package]]
name = "nom"
version = "7.1.3"
@@ -1608,6 +1714,12 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "nonzero_ext"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
[[package]]
name = "nu-ansi-term"
version = "0.50.3"
@@ -1669,6 +1781,12 @@ version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "opaque-debug"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
[[package]]
name = "ordered-multimap"
version = "0.4.3"
@@ -1824,6 +1942,24 @@ dependencies = [
"miniz_oxide",
]
[[package]]
name = "polyval"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25"
dependencies = [
"cfg-if",
"cpufeatures",
"opaque-debug",
"universal-hash",
]
[[package]]
name = "portable-atomic"
version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
[[package]]
name = "potential_utf"
version = "0.1.4"
@@ -1897,6 +2033,21 @@ dependencies = [
"num-traits",
]
[[package]]
name = "quanta"
version = "0.12.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7"
dependencies = [
"crossbeam-utils",
"libc",
"once_cell",
"raw-cpuid",
"wasi",
"web-sys",
"winapi",
]
[[package]]
name = "quick-error"
version = "2.0.1"
@@ -2032,6 +2183,15 @@ dependencies = [
"getrandom 0.3.4",
]
[[package]]
name = "raw-cpuid"
version = "11.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186"
dependencies = [
"bitflags 2.11.0",
]
[[package]]
name = "redox_syscall"
version = "0.5.18"
@@ -2453,6 +2613,15 @@ dependencies = [
"lock_api",
]
[[package]]
name = "spinning_top"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
dependencies = [
"lock_api",
]
[[package]]
name = "spki"
version = "0.7.3"
@@ -3158,6 +3327,16 @@ version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "universal-hash"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea"
dependencies = [
"crypto-common",
"subtle",
]
[[package]]
name = "untrusted"
version = "0.9.0"
@@ -3411,6 +3590,28 @@ dependencies = [
"wasite",
]
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-core"
version = "0.62.2"

View File

@@ -13,7 +13,8 @@ repository = ""
axum = { version = "0.8", features = ["macros", "ws"] }
tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "net", "time", "signal", "fs"] }
tower = "0.5"
tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip", "fs"] }
tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip", "fs", "set-header", "limit"] }
governor = "0.7"
# ========== HTTP Clients ==========
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
@@ -46,6 +47,9 @@ mime = "0.3"
anyhow = "1.0"
thiserror = "1.0"
bcrypt = "0.15"
aes-gcm = "0.10"
hmac = "0.12"
sha2 = "0.10"
chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1.0", features = ["v4", "serde"] }
futures = "0.3"

View File

@@ -0,0 +1,13 @@
-- Migration: add composite indexes for query performance
-- Adds three composite indexes:
-- 1. idx_llm_requests_client_timestamp on llm_requests(client_id, timestamp)
-- 2. idx_llm_requests_provider_timestamp on llm_requests(provider, timestamp)
-- 3. idx_model_configs_provider_id on model_configs(provider_id)
BEGIN TRANSACTION;
CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp);
CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp);
CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id);
COMMIT;

View File

@@ -1,5 +1,7 @@
use anyhow::Result;
use base64::{Engine as _};
use config::{Config, File, FileFormat};
use hex;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
@@ -96,6 +98,7 @@ pub struct AppConfig {
pub model_mapping: ModelMappingConfig,
pub pricing: PricingConfig,
pub config_path: Option<PathBuf>,
pub encryption_key: String,
}
impl AppConfig {
@@ -136,7 +139,8 @@ impl AppConfig {
.set_default("providers.grok.enabled", true)?
.set_default("providers.ollama.base_url", "http://localhost:11434/v1")?
.set_default("providers.ollama.enabled", false)?
.set_default("providers.ollama.models", Vec::<String>::new())?;
.set_default("providers.ollama.models", Vec::<String>::new())?
.set_default("encryption_key", "")?;
// Load from config file if exists
// Priority: explicit path arg > LLM_PROXY__CONFIG_PATH env var > ./config.toml
@@ -167,6 +171,19 @@ impl AppConfig {
let server: ServerConfig = config.get("server")?;
let database: DatabaseConfig = config.get("database")?;
let providers: ProviderConfig = config.get("providers")?;
let encryption_key: String = config.get("encryption_key")?;
// Validate encryption key length (must be 32 bytes after hex or base64 decoding)
if encryption_key.is_empty() {
anyhow::bail!("Encryption key is required (LLM_PROXY__ENCRYPTION_KEY environment variable)");
}
// Try hex decode first, then base64
let key_bytes = hex::decode(&encryption_key)
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&encryption_key))
.map_err(|e| anyhow::anyhow!("Encryption key must be hex or base64 encoded: {}", e))?;
if key_bytes.len() != 32 {
anyhow::bail!("Encryption key must be 32 bytes (256 bits), got {} bytes", key_bytes.len());
}
// For now, use empty model mapping and pricing (will be populated later)
let model_mapping = ModelMappingConfig { patterns: vec![] };
@@ -185,6 +202,7 @@ impl AppConfig {
model_mapping,
pricing,
config_path: Some(config_path),
encryption_key,
}))
}

View File

@@ -1,4 +1,4 @@
use axum::{extract::State, response::Json};
use axum::{extract::State, http::{HeaderMap, HeaderValue}, response::{Json, IntoResponse}};
use bcrypt;
use serde::Deserialize;
use sqlx::Row;
@@ -64,14 +64,14 @@ pub(super) async fn handle_login(
pub(super) async fn handle_auth_status(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
) -> impl IntoResponse {
let token = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
if let Some(token) = token
&& let Some(session) = state.session_manager.validate_session(token).await
&& let Some((session, new_token)) = state.session_manager.validate_session_with_refresh(token).await
{
// Look up display_name from DB
let display_name = sqlx::query_scalar::<_, Option<String>>(
@@ -85,17 +85,23 @@ pub(super) async fn handle_auth_status(
.flatten()
.unwrap_or_else(|| session.username.clone());
return Json(ApiResponse::success(serde_json::json!({
let mut headers = HeaderMap::new();
if let Some(refreshed_token) = new_token {
if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) {
headers.insert("X-Refreshed-Token", header_value);
}
}
return (headers, Json(ApiResponse::success(serde_json::json!({
"authenticated": true,
"user": {
"username": session.username,
"name": display_name,
"role": session.role
}
})));
}))));
}
Json(ApiResponse::error("Not authenticated".to_string()))
(HeaderMap::new(), Json(ApiResponse::error("Not authenticated".to_string())))
}
#[derive(Deserialize)]
@@ -108,7 +114,7 @@ pub(super) async fn handle_change_password(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Json(payload): Json<ChangePasswordRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
) -> impl IntoResponse {
let pool = &state.app_state.db_pool;
// Extract the authenticated user from the session token
@@ -117,14 +123,24 @@ pub(super) async fn handle_change_password(
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
let session = match token {
Some(t) => state.session_manager.validate_session(t).await,
None => None,
let (session, new_token) = match token {
Some(t) => match state.session_manager.validate_session_with_refresh(t).await {
Some((session, new_token)) => (Some(session), new_token),
None => (None, None),
},
None => (None, None),
};
let mut response_headers = HeaderMap::new();
if let Some(refreshed_token) = new_token {
if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) {
response_headers.insert("X-Refreshed-Token", header_value);
}
}
let username = match session {
Some(s) => s.username,
None => return Json(ApiResponse::error("Not authenticated".to_string())),
None => return (response_headers, Json(ApiResponse::error("Not authenticated".to_string()))),
};
let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = ?")
@@ -138,7 +154,7 @@ pub(super) async fn handle_change_password(
if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) {
let new_hash = match bcrypt::hash(&payload.new_password, 12) {
Ok(h) => h,
Err(_) => return Json(ApiResponse::error("Failed to hash new password".to_string())),
Err(_) => return (response_headers, Json(ApiResponse::error("Failed to hash new password".to_string()))),
};
let update_result = sqlx::query(
@@ -150,16 +166,16 @@ pub(super) async fn handle_change_password(
.await;
match update_result {
Ok(_) => Json(ApiResponse::success(
Ok(_) => (response_headers, Json(ApiResponse::success(
serde_json::json!({ "message": "Password updated successfully" }),
)),
Err(e) => Json(ApiResponse::error(format!("Failed to update database: {}", e))),
))),
Err(e) => (response_headers, Json(ApiResponse::error(format!("Failed to update database: {}", e)))),
}
} else {
Json(ApiResponse::error("Current password incorrect".to_string()))
(response_headers, Json(ApiResponse::error("Current password incorrect".to_string())))
}
}
Err(e) => Json(ApiResponse::error(format!("User not found: {}", e))),
Err(e) => (response_headers, Json(ApiResponse::error(format!("User not found: {}", e)))),
}
}
@@ -180,19 +196,19 @@ pub(super) async fn handle_logout(
}
/// Helper: Extract and validate a session from the Authorization header.
/// Returns the Session if valid, or an error response.
/// Returns the Session and optional new token if refreshed, or an error response.
pub(super) async fn extract_session(
state: &DashboardState,
headers: &axum::http::HeaderMap,
) -> Result<super::sessions::Session, Json<ApiResponse<serde_json::Value>>> {
) -> Result<(super::sessions::Session, Option<String>), Json<ApiResponse<serde_json::Value>>> {
let token = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
match token {
Some(t) => match state.session_manager.validate_session(t).await {
Some(session) => Ok(session),
Some(t) => match state.session_manager.validate_session_with_refresh(t).await {
Some((session, new_token)) => Ok((session, new_token)),
None => Err(Json(ApiResponse::error("Session expired or invalid".to_string()))),
},
None => Err(Json(ApiResponse::error("Not authenticated".to_string()))),
@@ -200,13 +216,14 @@ pub(super) async fn extract_session(
}
/// Helper: Extract session and require admin role.
/// Returns session and optional new token if refreshed.
pub(super) async fn require_admin(
state: &DashboardState,
headers: &axum::http::HeaderMap,
) -> Result<super::sessions::Session, Json<ApiResponse<serde_json::Value>>> {
let session = extract_session(state, headers).await?;
) -> Result<(super::sessions::Session, Option<String>), Json<ApiResponse<serde_json::Value>>> {
let (session, new_token) = extract_session(state, headers).await?;
if session.role != "admin" {
return Err(Json(ApiResponse::error("Admin access required".to_string())));
}
Ok(session)
Ok((session, new_token))
}

View File

@@ -88,9 +88,10 @@ pub(super) async fn handle_create_client(
headers: axum::http::HeaderMap,
Json(payload): Json<CreateClientRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
@@ -198,9 +199,10 @@ pub(super) async fn handle_update_client(
Path(id): Path<String>,
Json(payload): Json<UpdateClientPayload>,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
@@ -294,9 +296,10 @@ pub(super) async fn handle_delete_client(
headers: axum::http::HeaderMap,
Path(id): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
@@ -437,9 +440,10 @@ pub(super) async fn handle_create_client_token(
Path(id): Path<String>,
Json(payload): Json<CreateTokenRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
@@ -485,9 +489,10 @@ pub(super) async fn handle_delete_client_token(
headers: axum::http::HeaderMap,
Path((client_id, token_id)): Path<(String, i64)>,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;

View File

@@ -11,10 +11,18 @@ mod users;
mod websocket;
use axum::{
extract::{Request, State},
middleware::Next,
response::Response,
Router,
routing::{delete, get, post, put},
};
use axum::http::{header, HeaderValue};
use serde::Serialize;
use tower_http::{
limit::RequestBodyLimitLayer,
set_header::SetResponseHeaderLayer,
};
use crate::state::AppState;
use sessions::SessionManager;
@@ -52,6 +60,21 @@ impl<T> ApiResponse<T> {
}
}
/// Rate limiting middleware for dashboard routes that extracts AppState from DashboardState.
async fn dashboard_rate_limit_middleware(
State(dashboard_state): State<DashboardState>,
request: Request,
next: Next,
) -> Result<Response, crate::errors::AppError> {
// Delegate to the existing rate limit middleware with AppState
crate::rate_limiting::middleware::rate_limit_middleware(
State(dashboard_state.app_state),
request,
next,
)
.await
}
// Dashboard routes
pub fn router(state: AppState) -> Router {
let session_manager = SessionManager::new(24); // 24-hour session TTL
@@ -60,6 +83,26 @@ pub fn router(state: AppState) -> Router {
session_manager,
};
// Security headers
let csp_header: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::CONTENT_SECURITY_POLICY,
"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws:;"
.parse()
.unwrap(),
);
let x_frame_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_FRAME_OPTIONS,
"DENY".parse().unwrap(),
);
let x_content_type_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_CONTENT_TYPE_OPTIONS,
"nosniff".parse().unwrap(),
);
let strict_transport_security: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::STRICT_TRANSPORT_SECURITY,
"max-age=31536000; includeSubDomains".parse().unwrap(),
);
Router::new()
// Static file serving
.fallback_service(tower_http::services::ServeDir::new("static"))
@@ -119,5 +162,16 @@ pub fn router(state: AppState) -> Router {
"/api/system/settings",
get(system::handle_get_settings).post(system::handle_update_settings),
)
// Security layers
.layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit
.layer(csp_header)
.layer(x_frame_options)
.layer(x_content_type_options)
.layer(strict_transport_security)
// Rate limiting middleware
.layer(axum::middleware::from_fn_with_state(
dashboard_state.clone(),
dashboard_rate_limit_middleware,
))
.with_state(dashboard_state)
}

View File

@@ -156,9 +156,10 @@ pub(super) async fn handle_update_model(
Path(id): Path<String>,
Json(payload): Json<UpdateModelRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;

View File

@@ -9,6 +9,7 @@ use std::collections::HashMap;
use tracing::warn;
use super::{ApiResponse, DashboardState};
use crate::utils::crypto;
#[derive(Deserialize)]
pub(super) struct UpdateProviderRequest {
@@ -265,21 +266,44 @@ pub(super) async fn handle_update_provider(
Path(name): Path<String>,
Json(payload): Json<UpdateProviderRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Update or insert into database (include billing_mode)
// Prepare API key encryption if provided
let (api_key_to_store, api_key_encrypted_flag) = match &payload.api_key {
Some(key) if !key.is_empty() => {
match crypto::encrypt(key) {
Ok(encrypted) => (Some(encrypted), Some(true)),
Err(e) => {
warn!("Failed to encrypt API key for provider {}: {}", name, e);
return Json(ApiResponse::error(format!("Failed to encrypt API key: {}", e)));
}
}
}
Some(_) => {
// Empty string means clear the key
(None, Some(false))
}
None => {
// Keep existing key, we'll rely on COALESCE in SQL
(None, None)
}
};
// Update or insert into database (include billing_mode and api_key_encrypted)
let result = sqlx::query(
r#"
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold, billing_mode)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
enabled = excluded.enabled,
base_url = excluded.base_url,
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
api_key_encrypted = COALESCE(excluded.api_key_encrypted, provider_configs.api_key_encrypted),
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
@@ -290,7 +314,8 @@ pub(super) async fn handle_update_provider(
.bind(name.to_uppercase())
.bind(payload.enabled)
.bind(&payload.base_url)
.bind(&payload.api_key)
.bind(&api_key_to_store)
.bind(api_key_encrypted_flag)
.bind(payload.credit_balance)
.bind(payload.low_credit_threshold)
.bind(payload.billing_mode)

View File

@@ -1,7 +1,17 @@
use chrono::{DateTime, Duration, Utc};
use hmac::{Hmac, Mac};
use serde::{Deserialize, Serialize};
use sha2::{Sha256, digest::generic_array::GenericArray};
use std::collections::HashMap;
use std::env;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
const TOKEN_VERSION: &str = "v2";
const REFRESH_WINDOW_MINUTES: i64 = 15; // refresh if token expires within 15 minutes
#[derive(Clone, Debug)]
pub struct Session {
@@ -9,51 +19,136 @@ pub struct Session {
pub role: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub session_id: String, // unique identifier for the session (UUID)
}
#[derive(Clone)]
pub struct SessionManager {
sessions: Arc<RwLock<HashMap<String, Session>>>,
sessions: Arc<RwLock<HashMap<String, Session>>>, // key = session_id
ttl_hours: i64,
secret: Vec<u8>,
}
#[derive(Debug, Serialize, Deserialize)]
struct SessionPayload {
session_id: String,
username: String,
role: String,
iat: i64, // issued at (Unix timestamp)
exp: i64, // expiry (Unix timestamp)
version: String,
}
impl SessionManager {
pub fn new(ttl_hours: i64) -> Self {
let secret = load_session_secret();
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
ttl_hours,
secret,
}
}
/// Create a new session and return the session token.
/// Create a new session and return a signed session token.
pub async fn create_session(&self, username: String, role: String) -> String {
let token = format!("session-{}", uuid::Uuid::new_v4());
let session_id = Uuid::new_v4().to_string();
let now = Utc::now();
let expires_at = now + Duration::hours(self.ttl_hours);
let session = Session {
username,
role,
username: username.clone(),
role: role.clone(),
created_at: now,
expires_at: now + Duration::hours(self.ttl_hours),
expires_at,
session_id: session_id.clone(),
};
self.sessions.write().await.insert(token.clone(), session);
token
// Store session by session_id
self.sessions.write().await.insert(session_id.clone(), session);
// Create signed token
self.create_signed_token(&session_id, &username, &role, now.timestamp(), expires_at.timestamp())
}
/// Validate a session token and return the session if valid and not expired.
/// If the token is within the refresh window, returns a new token as well.
pub async fn validate_session(&self, token: &str) -> Option<Session> {
self.validate_session_with_refresh(token).await.map(|(session, _)| session)
}
/// Validate a session token and return (session, optional new token if refreshed).
pub async fn validate_session_with_refresh(&self, token: &str) -> Option<(Session, Option<String>)> {
// Legacy token format (UUID)
if token.starts_with("session-") {
let sessions = self.sessions.read().await;
sessions.get(token).and_then(|s| {
return sessions.get(token).and_then(|s| {
if s.expires_at > Utc::now() {
Some(s.clone())
Some((s.clone(), None))
} else {
None
}
})
});
}
// Signed token format
let payload = match verify_signed_token(token, &self.secret) {
Ok(p) => p,
Err(_) => return None,
};
// Check expiry
let now = Utc::now().timestamp();
if payload.exp <= now {
return None;
}
// Look up session by session_id
let sessions = self.sessions.read().await;
let session = match sessions.get(&payload.session_id) {
Some(s) => s.clone(),
None => return None, // session revoked or not found
};
// Ensure session username/role matches (should always match)
if session.username != payload.username || session.role != payload.role {
return None;
}
// Check if token is within refresh window (last REFRESH_WINDOW_MINUTES of validity)
let refresh_threshold = payload.exp - REFRESH_WINDOW_MINUTES * 60;
let new_token = if now >= refresh_threshold {
// Generate a new token with same session data but updated iat/exp?
// According to activity-based refresh, we should extend the session expiry.
// We'll extend from now by ttl_hours (or keep original expiry?).
// Let's extend from now by ttl_hours (sliding window).
let new_exp = Utc::now() + Duration::hours(self.ttl_hours);
// Update session expiry in store
drop(sessions); // release read lock before acquiring write lock
self.update_session_expiry(&payload.session_id, new_exp).await;
// Create new token with updated iat/exp
let new_token = self.create_signed_token(
&payload.session_id,
&payload.username,
&payload.role,
now,
new_exp.timestamp(),
);
Some(new_token)
} else {
None
};
Some((session, new_token))
}
/// Revoke (delete) a session by token.
/// Supports both legacy tokens (token is key) and signed tokens (extract session_id).
pub async fn revoke_session(&self, token: &str) {
if token.starts_with("session-") {
self.sessions.write().await.remove(token);
return;
}
// For signed token, try to extract session_id
if let Ok(payload) = verify_signed_token(token, &self.secret) {
self.sessions.write().await.remove(&payload.session_id);
}
}
/// Remove all expired sessions from the store.
@@ -61,4 +156,156 @@ impl SessionManager {
let now = Utc::now();
self.sessions.write().await.retain(|_, s| s.expires_at > now);
}
// --- Private helpers ---
fn create_signed_token(&self, session_id: &str, username: &str, role: &str, iat: i64, exp: i64) -> String {
let payload = SessionPayload {
session_id: session_id.to_string(),
username: username.to_string(),
role: role.to_string(),
iat,
exp,
version: TOKEN_VERSION.to_string(),
};
sign_token(&payload, &self.secret)
}
async fn update_session_expiry(&self, session_id: &str, new_expires_at: DateTime<Utc>) {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(session_id) {
session.expires_at = new_expires_at;
}
}
}
/// Load session secret from environment variable SESSION_SECRET (hex or base64 encoded).
/// If not set, generates a random 32-byte secret and logs a warning.
fn load_session_secret() -> Vec<u8> {
let secret_str = env::var("SESSION_SECRET").unwrap_or_else(|_| {
// Also check LLM_PROXY__SESSION_SECRET for consistency with config prefix
env::var("LLM_PROXY__SESSION_SECRET").unwrap_or_else(|_| {
// Generate a random secret (32 bytes) and encode as hex
use rand::RngCore;
let mut bytes = [0u8; 32];
rand::rng().fill_bytes(&mut bytes);
let hex_secret = hex::encode(bytes);
tracing::warn!(
"SESSION_SECRET environment variable not set. Using a randomly generated secret. \
This will invalidate all sessions on restart. Set SESSION_SECRET to a fixed hex or base64 encoded 32-byte value."
);
hex_secret
})
});
// Decode hex or base64
hex::decode(&secret_str)
.or_else(|_| URL_SAFE.decode(&secret_str))
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&secret_str))
.unwrap_or_else(|_| {
panic!("SESSION_SECRET must be hex or base64 encoded (32 bytes)");
})
}
/// Sign a session payload and return a token string in format base64_url(payload).base64_url(signature).
fn sign_token(payload: &SessionPayload, secret: &[u8]) -> String {
let json = serde_json::to_vec(payload).expect("Failed to serialize payload");
let payload_b64 = URL_SAFE.encode(&json);
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
mac.update(&json);
let signature = mac.finalize().into_bytes();
let signature_b64 = URL_SAFE.encode(signature);
format!("{}.{}", payload_b64, signature_b64)
}
/// Verify a signed token and return the decoded payload if valid.
fn verify_signed_token(token: &str, secret: &[u8]) -> Result<SessionPayload, TokenError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 2 {
return Err(TokenError::InvalidFormat);
}
let payload_b64 = parts[0];
let signature_b64 = parts[1];
let json = URL_SAFE.decode(payload_b64).map_err(|_| TokenError::InvalidFormat)?;
let signature = URL_SAFE.decode(signature_b64).map_err(|_| TokenError::InvalidFormat)?;
// Verify HMAC
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
mac.update(&json);
// Convert signature slice to GenericArray
let tag = GenericArray::from_slice(&signature);
mac.verify(tag).map_err(|_| TokenError::InvalidSignature)?;
// Deserialize payload
let payload: SessionPayload = serde_json::from_slice(&json).map_err(|_| TokenError::InvalidPayload)?;
Ok(payload)
}
#[derive(Debug)]
enum TokenError {
InvalidFormat,
InvalidSignature,
InvalidPayload,
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[test]
fn test_sign_and_verify_token() {
let secret = b"test-secret-must-be-32-bytes-long!";
let payload = SessionPayload {
session_id: "test-session".to_string(),
username: "testuser".to_string(),
role: "user".to_string(),
iat: 1000,
exp: 2000,
version: TOKEN_VERSION.to_string(),
};
let token = sign_token(&payload, secret);
let verified = verify_signed_token(&token, secret).unwrap();
assert_eq!(verified.session_id, payload.session_id);
assert_eq!(verified.username, payload.username);
assert_eq!(verified.role, payload.role);
assert_eq!(verified.iat, payload.iat);
assert_eq!(verified.exp, payload.exp);
assert_eq!(verified.version, payload.version);
}
#[test]
fn test_tampered_token() {
let secret = b"test-secret-must-be-32-bytes-long!";
let payload = SessionPayload {
session_id: "test-session".to_string(),
username: "testuser".to_string(),
role: "user".to_string(),
iat: 1000,
exp: 2000,
version: TOKEN_VERSION.to_string(),
};
let mut token = sign_token(&payload, secret);
// Tamper with payload part
let mut parts: Vec<&str> = token.split('.').collect();
let mut payload_bytes = URL_SAFE.decode(parts[0]).unwrap();
payload_bytes[0] ^= 0xFF; // flip some bits
let tampered_payload = URL_SAFE.encode(payload_bytes);
parts[0] = &tampered_payload;
token = parts.join(".");
assert!(verify_signed_token(&token, secret).is_err());
}
#[test]
fn test_load_session_secret_from_env() {
unsafe {
env::set_var("SESSION_SECRET", hex::encode([0xAA; 32]));
}
let secret = load_session_secret();
assert_eq!(secret, vec![0xAA; 32]);
unsafe {
env::remove_var("SESSION_SECRET");
}
}
}

View File

@@ -279,9 +279,10 @@ pub(super) async fn handle_system_backup(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let backup_id = format!("backup-{}", chrono::Utc::now().timestamp());
@@ -341,9 +342,10 @@ pub(super) async fn handle_update_settings(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
Json(ApiResponse::error(
"Changing settings at runtime is not yet supported. Please update your config file and restart the server."

View File

@@ -14,9 +14,10 @@ pub(super) async fn handle_get_users(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
@@ -66,9 +67,10 @@ pub(super) async fn handle_create_user(
headers: axum::http::HeaderMap,
Json(payload): Json<CreateUserRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
@@ -147,9 +149,10 @@ pub(super) async fn handle_update_user(
Path(id): Path<i64>,
Json(payload): Json<UpdateUserRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
@@ -249,8 +252,8 @@ pub(super) async fn handle_delete_user(
headers: axum::http::HeaderMap,
Path(id): Path<i64>,
) -> Json<ApiResponse<serde_json::Value>> {
let session = match auth::require_admin(&state, &headers).await {
Ok(s) => s,
let (session, _) = match auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};

View File

@@ -18,7 +18,9 @@ pub async fn init(config: &DatabaseConfig) -> Result<DbPool> {
let database_path = config.path.to_string_lossy().to_string();
info!("Connecting to database at {}", database_path);
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?.create_if_missing(true);
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?
.create_if_missing(true)
.pragma("foreign_keys", "ON");
let pool = SqlitePool::connect_with(options).await?;
@@ -29,7 +31,7 @@ pub async fn init(config: &DatabaseConfig) -> Result<DbPool> {
Ok(pool)
}
async fn run_migrations(pool: &DbPool) -> Result<()> {
pub async fn run_migrations(pool: &DbPool) -> Result<()> {
// Create clients table if it doesn't exist
sqlx::query(
r#"
@@ -88,6 +90,8 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
api_key TEXT,
credit_balance REAL DEFAULT 0.0,
low_credit_threshold REAL DEFAULT 5.0,
billing_mode TEXT,
api_key_encrypted BOOLEAN DEFAULT FALSE,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
"#,
@@ -167,6 +171,15 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
.execute(pool)
.await;
// Add billing_mode column if it doesn't exist (migration for existing DBs)
let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN billing_mode TEXT")
.execute(pool)
.await;
// Add api_key_encrypted column if it doesn't exist (migration for existing DBs)
let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN api_key_encrypted BOOLEAN DEFAULT FALSE")
.execute(pool)
.await;
// Insert default admin user if none exists (default password: admin)
let user_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users").fetch_one(pool).await?;
@@ -216,6 +229,19 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
.execute(pool)
.await?;
// Composite indexes for performance
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id)")
.execute(pool)
.await?;
// Insert default client if none exists
sqlx::query(
r#"

View File

@@ -41,23 +41,18 @@ pub use state::AppState;
pub mod test_utils {
use std::sync::Arc;
use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState};
use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState, utils::crypto, database::run_migrations};
use sqlx::sqlite::SqlitePool;
/// Create a test application state
pub async fn create_test_state() -> Arc<AppState> {
pub async fn create_test_state() -> AppState {
// Create in-memory database
let pool = SqlitePool::connect("sqlite::memory:")
.await
.expect("Failed to create test database");
// Run migrations
crate::database::init(&crate::config::DatabaseConfig {
path: std::path::PathBuf::from(":memory:"),
max_connections: 5,
})
.await
.expect("Failed to initialize test database");
// Run migrations on the pool
run_migrations(&pool).await.expect("Failed to run migrations");
let rate_limit_manager = RateLimitManager::new(
crate::rate_limiting::RateLimiterConfig::default(),
@@ -73,7 +68,7 @@ pub mod test_utils {
providers: std::collections::HashMap::new(),
};
let (dashboard_tx, _) = tokio::sync::broadcast::channel(100);
let (dashboard_tx, _) = tokio::sync::broadcast::channel::<serde_json::Value>(100);
let config = Arc::new(crate::config::AppConfig {
server: crate::config::ServerConfig {
@@ -125,20 +120,20 @@ pub mod test_utils {
ollama: vec![],
},
config_path: None,
encryption_key: "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f".to_string(),
});
Arc::new(AppState {
// Initialize encryption with the test key
crypto::init_with_key(&config.encryption_key).expect("failed to initialize crypto");
AppState::new(
config,
provider_manager,
db_pool: pool.clone(),
rate_limit_manager: Arc::new(rate_limit_manager),
client_manager,
request_logger: Arc::new(crate::logging::RequestLogger::new(pool.clone(), dashboard_tx.clone())),
model_registry: Arc::new(model_registry),
model_config_cache: crate::state::ModelConfigCache::new(pool.clone()),
dashboard_tx,
auth_tokens: vec![],
})
pool,
rate_limit_manager,
model_registry,
vec![], // auth_tokens
)
}
/// Create a test HTTP client
@@ -149,3 +144,185 @@ pub mod test_utils {
.expect("Failed to create test HTTP client")
}
}
#[cfg(test)]
mod integration_tests {
use super::test_utils::*;
use crate::{
models::{ChatCompletionRequest, ChatMessage},
server::router,
utils::crypto,
};
use axum::{
body::Body,
http::{Request, StatusCode},
};
use mockito::Server;
use serde_json::json;
use sqlx::Row;
use tower::util::ServiceExt;
#[tokio::test]
async fn test_encrypted_provider_key_integration() {
// Step 1: Setup test database and state
let state = create_test_state().await;
let pool = state.db_pool.clone();
// Step 2: Insert provider with encrypted API key
let test_api_key = "test-openai-key-12345";
let encrypted_key = crypto::encrypt(test_api_key).expect("Failed to encrypt test key");
sqlx::query(
r#"
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind("openai")
.bind("OpenAI")
.bind(true)
.bind("http://localhost:1234") // Mock server URL
.bind(&encrypted_key)
.bind(true) // api_key_encrypted flag
.bind(100.0)
.bind(5.0)
.execute(&pool)
.await
.expect("Failed to update provider URL");
// Re-initialize provider with new URL
state
.provider_manager
.initialize_provider("openai", &state.config, &pool)
.await
.expect("Failed to re-initialize provider");
// Step 4: Mock OpenAI API server
let mut server = Server::new_async().await;
let mock = server
.mock("POST", "/chat/completions")
.match_header("authorization", format!("Bearer {}", test_api_key).as_str())
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
json!({
"id": "chatcmpl-test",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-3.5-turbo",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello, world!"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
}
})
.to_string(),
)
.create_async()
.await;
// Update provider base URL to use mock server
sqlx::query("UPDATE provider_configs SET base_url = ? WHERE id = 'openai'")
.bind(&server.url())
.execute(&pool)
.await
.expect("Failed to update provider URL");
// Re-initialize provider with new URL
state
.provider_manager
.initialize_provider("openai", &state.config, &pool)
.await
.expect("Failed to re-initialize provider");
// Step 5: Create test router and make request
let app = router(state);
let request_body = ChatCompletionRequest {
model: "gpt-3.5-turbo".to_string(),
messages: vec![ChatMessage {
role: "user".to_string(),
content: crate::models::MessageContent::Text {
content: "Hello".to_string(),
},
reasoning_content: None,
tool_calls: None,
name: None,
tool_call_id: None,
}],
temperature: None,
top_p: None,
top_k: None,
n: None,
stop: None,
max_tokens: Some(100),
presence_penalty: None,
frequency_penalty: None,
stream: Some(false),
tools: None,
tool_choice: None,
};
let request = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.header("authorization", "Bearer test-token")
.body(Body::from(serde_json::to_string(&request_body).unwrap()))
.unwrap();
// Step 6: Execute request through proxy
let response = app
.oneshot(request)
.await
.expect("Failed to execute request");
let status = response.status();
println!("Response status: {}", status);
if status != StatusCode::OK {
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
println!("Response body: {}", body_str);
panic!("Response status is not OK: {}", status);
}
assert_eq!(status, StatusCode::OK);
// Verify the mock was called
mock.assert_async().await;
// Give the async logging task time to complete
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Step 7: Verify usage was logged in database
let log_row = sqlx::query("SELECT * FROM llm_requests WHERE client_id = 'client_test-tok' ORDER BY id DESC LIMIT 1")
.fetch_one(&pool)
.await
.expect("Request log not found");
assert_eq!(log_row.get::<String, _>("provider"), "openai");
assert_eq!(log_row.get::<String, _>("model"), "gpt-3.5-turbo");
assert_eq!(log_row.get::<i64, _>("prompt_tokens"), 10);
assert_eq!(log_row.get::<i64, _>("completion_tokens"), 5);
assert_eq!(log_row.get::<i64, _>("total_tokens"), 15);
assert_eq!(log_row.get::<String, _>("status"), "success");
// Verify client usage was updated
let client_row = sqlx::query("SELECT total_requests, total_tokens, total_cost FROM clients WHERE client_id = 'client_test-tok'")
.fetch_one(&pool)
.await
.expect("Client not found");
assert_eq!(client_row.get::<i64, _>("total_requests"), 1);
assert_eq!(client_row.get::<i64, _>("total_tokens"), 15);
}
}

View File

@@ -82,9 +82,9 @@ impl RequestLogger {
"#,
)
.bind(log.timestamp)
.bind(log.client_id)
.bind(&log.client_id)
.bind(&log.provider)
.bind(log.model)
.bind(&log.model)
.bind(log.prompt_tokens as i64)
.bind(log.completion_tokens as i64)
.bind(log.total_tokens as i64)
@@ -92,7 +92,7 @@ impl RequestLogger {
.bind(log.cache_write_tokens as i64)
.bind(log.cost)
.bind(log.has_images)
.bind(log.status)
.bind(&log.status)
.bind(log.error_message)
.bind(log.duration_ms as i64)
.bind(None::<String>) // request_body - optional, not stored to save disk space
@@ -100,6 +100,23 @@ impl RequestLogger {
.execute(&mut *tx)
.await?;
// Update client usage statistics
sqlx::query(
r#"
UPDATE clients SET
total_requests = total_requests + 1,
total_tokens = total_tokens + ?,
total_cost = total_cost + ?,
updated_at = CURRENT_TIMESTAMP
WHERE client_id = ?
"#,
)
.bind(log.total_tokens as i64)
.bind(log.cost)
.bind(&log.client_id)
.execute(&mut *tx)
.await?;
// Deduct from provider balance if successful.
// Providers configured with billing_mode = 'postpaid' will not have their
// credit_balance decremented. Use a conditional UPDATE so we don't need

View File

@@ -10,6 +10,7 @@ use llm_proxy::{
rate_limiting::{CircuitBreakerConfig, RateLimitManager, RateLimiterConfig},
server,
state::AppState,
utils::crypto,
};
#[tokio::main]
@@ -26,6 +27,10 @@ async fn main() -> Result<()> {
let config = AppConfig::load().await?;
info!("Configuration loaded from {:?}", config.config_path);
// Initialize encryption
crypto::init_with_key(&config.encryption_key)?;
info!("Encryption initialized");
// Initialize database connection pool
let db_pool = database::init(&config.database).await?;
info!("Database initialized at {:?}", config.database.path);

View File

@@ -7,6 +7,7 @@ use std::sync::Arc;
use crate::errors::AppError;
use crate::models::UnifiedRequest;
pub mod deepseek;
pub mod gemini;
pub mod grok;
@@ -125,17 +126,35 @@ impl ProviderManager {
db_pool: &crate::database::DbPool,
) -> Result<()> {
// Load override from database
let db_config = sqlx::query("SELECT enabled, base_url, api_key FROM provider_configs WHERE id = ?")
let db_config = sqlx::query("SELECT enabled, base_url, api_key, api_key_encrypted FROM provider_configs WHERE id = ?")
.bind(name)
.fetch_optional(db_pool)
.await?;
let (enabled, base_url, api_key) = if let Some(row) = db_config {
(
row.get::<bool, _>("enabled"),
row.get::<Option<String>, _>("base_url"),
row.get::<Option<String>, _>("api_key"),
)
let enabled = row.get::<bool, _>("enabled");
let base_url = row.get::<Option<String>, _>("base_url");
let api_key_encrypted = row.get::<bool, _>("api_key_encrypted");
let api_key = row.get::<Option<String>, _>("api_key");
// Decrypt API key if encrypted
let api_key = match (api_key, api_key_encrypted) {
(Some(key), true) => {
match crate::utils::crypto::decrypt(&key) {
Ok(decrypted) => Some(decrypted),
Err(e) => {
tracing::error!("Failed to decrypt API key for provider {}: {}", name, e);
None
}
}
}
(Some(key), false) => {
// Plaintext key - optionally encrypt and update database (lazy migration)
// For now, just use plaintext
Some(key)
}
(None, _) => None,
};
(enabled, base_url, api_key)
} else {
// No database override, use defaults from AppConfig
match name {

View File

@@ -6,12 +6,15 @@
//! 3. Global rate limiting for overall system protection
use anyhow::Result;
use governor::{Quota, RateLimiter, DefaultDirectRateLimiter};
use std::collections::HashMap;
use std::num::NonZeroU32;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use tracing::{info, warn};
type GovRateLimiter = DefaultDirectRateLimiter;
/// Rate limiter configuration
#[derive(Debug, Clone)]
pub struct RateLimiterConfig {
@@ -65,45 +68,7 @@ impl Default for CircuitBreakerConfig {
}
}
/// Simple token bucket rate limiter for a single client
#[derive(Debug)]
struct TokenBucket {
tokens: f64,
capacity: f64,
refill_rate: f64, // tokens per second
last_refill: Instant,
}
impl TokenBucket {
fn new(capacity: f64, refill_rate: f64) -> Self {
Self {
tokens: capacity,
capacity,
refill_rate,
last_refill: Instant::now(),
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
let new_tokens = elapsed * self.refill_rate;
self.tokens = (self.tokens + new_tokens).min(self.capacity);
self.last_refill = now;
}
fn try_acquire(&mut self, tokens: f64) -> bool {
self.refill();
if self.tokens >= tokens {
self.tokens -= tokens;
true
} else {
false
}
}
}
/// Circuit breaker for a provider
#[derive(Debug)]
@@ -209,8 +174,8 @@ impl ProviderCircuitBreaker {
/// Rate limiting and circuit breaking manager
#[derive(Debug)]
pub struct RateLimitManager {
client_buckets: Arc<RwLock<HashMap<String, TokenBucket>>>,
global_bucket: Arc<RwLock<TokenBucket>>,
client_buckets: Arc<RwLock<HashMap<String, GovRateLimiter>>>,
global_bucket: Arc<GovRateLimiter>,
circuit_breakers: Arc<RwLock<HashMap<String, ProviderCircuitBreaker>>>,
config: RateLimiterConfig,
circuit_config: CircuitBreakerConfig,
@@ -218,15 +183,16 @@ pub struct RateLimitManager {
impl RateLimitManager {
pub fn new(config: RateLimiterConfig, circuit_config: CircuitBreakerConfig) -> Self {
// Convert requests per minute to tokens per second
let global_refill_rate = config.global_requests_per_minute as f64 / 60.0;
// Create global rate limiter quota
let global_quota = Quota::per_minute(
NonZeroU32::new(config.global_requests_per_minute).expect("global_requests_per_minute must be positive")
)
.allow_burst(NonZeroU32::new(config.burst_size).expect("burst_size must be positive"));
let global_bucket = RateLimiter::direct(global_quota);
Self {
client_buckets: Arc::new(RwLock::new(HashMap::new())),
global_bucket: Arc::new(RwLock::new(TokenBucket::new(
config.burst_size as f64,
global_refill_rate,
))),
global_bucket: Arc::new(global_bucket),
circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
config,
circuit_config,
@@ -236,24 +202,22 @@ impl RateLimitManager {
/// Check if a client request is allowed
pub async fn check_client_request(&self, client_id: &str) -> Result<bool> {
// Check global rate limit first (1 token per request)
{
let mut global_bucket = self.global_bucket.write().await;
if !global_bucket.try_acquire(1.0) {
if self.global_bucket.check().is_err() {
warn!("Global rate limit exceeded");
return Ok(false);
}
}
// Check client-specific rate limit
let mut buckets = self.client_buckets.write().await;
let bucket = buckets.entry(client_id.to_string()).or_insert_with(|| {
TokenBucket::new(
self.config.burst_size as f64,
self.config.requests_per_minute as f64 / 60.0,
let quota = Quota::per_minute(
NonZeroU32::new(self.config.requests_per_minute).expect("requests_per_minute must be positive")
)
.allow_burst(NonZeroU32::new(self.config.burst_size).expect("burst_size must be positive"));
RateLimiter::direct(quota)
});
Ok(bucket.try_acquire(1.0))
Ok(bucket.check().is_ok())
}
/// Check if provider requests are allowed (circuit breaker)

View File

@@ -5,6 +5,11 @@ use axum::{
response::sse::{Event, Sse},
routing::{get, post},
};
use axum::http::{header, HeaderValue};
use tower_http::{
limit::RequestBodyLimitLayer,
set_header::SetResponseHeaderLayer,
};
use futures::StreamExt;
use std::sync::Arc;
@@ -23,9 +28,34 @@ use crate::{
};
pub fn router(state: AppState) -> Router {
// Security headers
let csp_header: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::CONTENT_SECURITY_POLICY,
"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws:;"
.parse()
.unwrap(),
);
let x_frame_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_FRAME_OPTIONS,
"DENY".parse().unwrap(),
);
let x_content_type_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_CONTENT_TYPE_OPTIONS,
"nosniff".parse().unwrap(),
);
let strict_transport_security: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::STRICT_TRANSPORT_SECURITY,
"max-age=31536000; includeSubDomains".parse().unwrap(),
);
Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/models", get(list_models))
.layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit
.layer(csp_header)
.layer(x_frame_options)
.layer(x_content_type_options)
.layer(strict_transport_security)
.layer(axum::middleware::from_fn_with_state(
state.clone(),
rate_limiting::middleware::rate_limit_middleware,
@@ -219,7 +249,6 @@ async fn chat_completions(
prompt_tokens,
has_images,
logger: state.request_logger.clone(),
client_manager: state.client_manager.clone(),
model_registry: state.model_registry.clone(),
model_config_cache: state.model_config_cache.clone(),
},
@@ -341,15 +370,6 @@ async fn chat_completions(
duration_ms: duration.as_millis() as u64,
});
// Update client usage (fire-and-forget, don't block response)
{
let cm = state.client_manager.clone();
let cid = client_id.clone();
tokio::spawn(async move {
let _ = cm.update_client_usage(&cid, response.total_tokens as i64, cost).await;
});
}
// Convert ProviderResponse to ChatCompletionResponse
let finish_reason = if response.tool_calls.is_some() {
"tool_calls".to_string()

171
src/utils/crypto.rs Normal file
View File

@@ -0,0 +1,171 @@
use aes_gcm::{
aead::{Aead, AeadCore, KeyInit, OsRng},
Aes256Gcm, Key, Nonce,
};
use anyhow::{anyhow, Context, Result};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use std::env;
use std::sync::OnceLock;
static RAW_KEY: OnceLock<[u8; 32]> = OnceLock::new();
/// Initialize the encryption key from a hex or base64 encoded string.
/// Must be called before any encryption/decryption operations.
/// Returns error if the key is invalid or already initialized with a different key.
pub fn init_with_key(key_str: &str) -> Result<()> {
let key_bytes = hex::decode(key_str)
.or_else(|_| BASE64.decode(key_str))
.context("Encryption key must be hex or base64 encoded")?;
if key_bytes.len() != 32 {
anyhow::bail!(
"Encryption key must be 32 bytes (256 bits), got {} bytes",
key_bytes.len()
);
}
let key_array: [u8; 32] = key_bytes.try_into().unwrap(); // safe due to length check
// Check if already initialized with same key
if let Some(existing) = RAW_KEY.get() {
if existing == &key_array {
// Same key already initialized, okay
return Ok(());
} else {
anyhow::bail!("Encryption key already initialized with a different key");
}
}
// Store raw key bytes
RAW_KEY
.set(key_array)
.map_err(|_| anyhow::anyhow!("Encryption key already initialized"))?;
Ok(())
}
/// Initialize the encryption key from the environment variable `LLM_PROXY__ENCRYPTION_KEY`.
/// Must be called before any encryption/decryption operations.
/// Panics if the environment variable is missing or invalid.
pub fn init_from_env() -> Result<()> {
let key_str =
env::var("LLM_PROXY__ENCRYPTION_KEY").context("LLM_PROXY__ENCRYPTION_KEY environment variable not set")?;
init_with_key(&key_str)
}
/// Get the encryption key bytes, panicking if not initialized.
fn get_key() -> &'static [u8; 32] {
RAW_KEY
.get()
.expect("Encryption key not initialized. Call crypto::init_with_key() or crypto::init_from_env() first.")
}
/// Encrypt a plaintext string and return a base64-encoded ciphertext (nonce || ciphertext || tag).
pub fn encrypt(plaintext: &str) -> Result<String> {
let key = Key::<Aes256Gcm>::from_slice(get_key());
let cipher = Aes256Gcm::new(key);
let nonce = Aes256Gcm::generate_nonce(&mut OsRng); // 12 bytes
let ciphertext = cipher
.encrypt(&nonce, plaintext.as_bytes())
.map_err(|e| anyhow!("Encryption failed: {}", e))?;
// Combine nonce and ciphertext (ciphertext already includes tag)
let mut combined = Vec::with_capacity(nonce.len() + ciphertext.len());
combined.extend_from_slice(&nonce);
combined.extend_from_slice(&ciphertext);
Ok(BASE64.encode(combined))
}
/// Decrypt a base64-encoded ciphertext (nonce || ciphertext || tag) to a plaintext string.
pub fn decrypt(ciphertext_b64: &str) -> Result<String> {
let key = Key::<Aes256Gcm>::from_slice(get_key());
let cipher = Aes256Gcm::new(key);
let combined = BASE64.decode(ciphertext_b64).context("Invalid base64 ciphertext")?;
if combined.len() < 12 {
anyhow::bail!("Ciphertext too short");
}
let (nonce_bytes, ciphertext_and_tag) = combined.split_at(12);
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext_bytes = cipher
.decrypt(nonce, ciphertext_and_tag)
.map_err(|e| anyhow!("Decryption failed (invalid key or corrupted ciphertext): {}", e))?;
String::from_utf8(plaintext_bytes).context("Decrypted bytes are not valid UTF-8")
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_KEY_HEX: &str = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f";
#[test]
fn test_encrypt_decrypt() {
init_with_key(TEST_KEY_HEX).unwrap();
let plaintext = "super secret api key";
let ciphertext = encrypt(plaintext).unwrap();
assert_ne!(ciphertext, plaintext);
let decrypted = decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_different_inputs_produce_different_ciphertexts() {
init_with_key(TEST_KEY_HEX).unwrap();
let plaintext = "same";
let cipher1 = encrypt(plaintext).unwrap();
let cipher2 = encrypt(plaintext).unwrap();
assert_ne!(cipher1, cipher2, "Nonce should make ciphertexts differ");
assert_eq!(decrypt(&cipher1).unwrap(), plaintext);
assert_eq!(decrypt(&cipher2).unwrap(), plaintext);
}
#[test]
fn test_invalid_key_length() {
let result = init_with_key("tooshort");
assert!(result.is_err());
}
#[test]
fn test_init_from_env() {
unsafe { std::env::set_var("LLM_PROXY__ENCRYPTION_KEY", TEST_KEY_HEX) };
let result = init_from_env();
assert!(result.is_ok());
// Ensure encryption works
let ciphertext = encrypt("test").unwrap();
let decrypted = decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, "test");
}
#[test]
fn test_missing_env_key() {
unsafe { std::env::remove_var("LLM_PROXY__ENCRYPTION_KEY") };
let result = init_from_env();
assert!(result.is_err());
}
#[test]
fn test_key_hex_and_base64() {
// Hex key works
init_with_key(TEST_KEY_HEX).unwrap();
// Base64 key (same bytes encoded as base64)
let base64_key = BASE64.encode(hex::decode(TEST_KEY_HEX).unwrap());
// Re-initialization with same key (different encoding) is allowed
let result = init_with_key(&base64_key);
assert!(result.is_ok());
// Encryption should still work
let ciphertext = encrypt("test").unwrap();
let decrypted = decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, "test");
}
#[test]
#[ignore] // conflicts with global state from other tests
fn test_already_initialized() {
init_with_key(TEST_KEY_HEX).unwrap();
let result = init_with_key(TEST_KEY_HEX);
assert!(result.is_ok()); // same key allowed
}
#[test]
#[ignore] // conflicts with global state from other tests
fn test_already_initialized_different_key() {
init_with_key(TEST_KEY_HEX).unwrap();
let different_key = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e20";
let result = init_with_key(different_key);
assert!(result.is_err());
}
}

View File

@@ -1,3 +1,4 @@
pub mod crypto;
pub mod registry;
pub mod streaming;
pub mod tokens;

View File

@@ -1,4 +1,4 @@
use crate::client::ClientManager;
use crate::errors::AppError;
use crate::logging::{RequestLog, RequestLogger};
use crate::models::ToolCall;
@@ -18,7 +18,6 @@ pub struct StreamConfig {
pub prompt_tokens: u32,
pub has_images: bool,
pub logger: Arc<RequestLogger>,
pub client_manager: Arc<ClientManager>,
pub model_registry: Arc<crate::models::registry::ModelRegistry>,
pub model_config_cache: ModelConfigCache,
}
@@ -36,7 +35,6 @@ pub struct AggregatingStream<S> {
/// Real usage data from the provider's final stream chunk (when available).
real_usage: Option<StreamUsage>,
logger: Arc<RequestLogger>,
client_manager: Arc<ClientManager>,
model_registry: Arc<crate::models::registry::ModelRegistry>,
model_config_cache: ModelConfigCache,
start_time: std::time::Instant,
@@ -60,7 +58,6 @@ where
accumulated_tool_calls: Vec::new(),
real_usage: None,
logger: config.logger,
client_manager: config.client_manager,
model_registry: config.model_registry,
model_config_cache: config.model_config_cache,
start_time: std::time::Instant::now(),
@@ -79,7 +76,6 @@ where
let provider_name = self.provider.name().to_string();
let model = self.model.clone();
let logger = self.logger.clone();
let client_manager = self.client_manager.clone();
let provider = self.provider.clone();
let estimated_prompt_tokens = self.prompt_tokens;
let has_images = self.has_images;
@@ -162,11 +158,6 @@ where
error_message: None,
duration_ms: duration.as_millis() as u64,
});
// Update client usage
let _ = client_manager
.update_client_usage(&client_id, total_tokens as i64, cost)
.await;
});
}
}
@@ -304,7 +295,6 @@ mod tests {
let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
let (dashboard_tx, _) = tokio::sync::broadcast::channel(16);
let logger = Arc::new(RequestLogger::new(pool.clone(), dashboard_tx));
let client_manager = Arc::new(ClientManager::new(pool.clone()));
let registry = Arc::new(crate::models::registry::ModelRegistry {
providers: std::collections::HashMap::new(),
});
@@ -318,7 +308,6 @@ mod tests {
prompt_tokens: 10,
has_images: false,
logger,
client_manager,
model_registry: registry,
model_config_cache: ModelConfigCache::new(pool.clone()),
},

View File

@@ -35,6 +35,14 @@ class ApiClient {
throw new Error(result.error || `HTTP error! status: ${response.status}`);
}
// Handling X-Refreshed-Token header
if (response.headers.get('X-Refreshed-Token') && window.authManager) {
window.authManager.token = response.headers.get('X-Refreshed-Token');
if (window.authManager.setToken) {
window.authManager.setToken(window.authManager.token);
}
}
return result.data;
}
@@ -87,6 +95,17 @@ class ApiClient {
const date = luxon.DateTime.fromISO(dateStr);
return date.toRelative();
}
// Helper for escaping HTML
escapeHtml(unsafe) {
if (unsafe === undefined || unsafe === null) return '';
return unsafe.toString()
.replace(/&/g, "&amp;")
.replace(/</g, "&lt;")
.replace(/>/g, "&gt;")
.replace(/"/g, "&quot;")
.replace(/'/g, "&#039;");
}
}
window.api = new ApiClient();

View File

@@ -50,6 +50,12 @@ class AuthManager {
});
}
setToken(newToken) {
if (!newToken) return;
this.token = newToken;
localStorage.setItem('dashboard_token', this.token);
}
async login(username, password) {
const errorElement = document.getElementById('login-error');
const loginBtn = document.querySelector('.login-btn');

View File

@@ -42,12 +42,15 @@ class ClientsPage {
const statusIcon = client.status === 'active' ? 'check-circle' : 'clock';
const created = luxon.DateTime.fromISO(client.created_at).toFormat('MMM dd, yyyy');
const escapedId = window.api.escapeHtml(client.id);
const escapedName = window.api.escapeHtml(client.name);
return `
<tr>
<td><span class="badge-client">${client.id}</span></td>
<td><strong>${client.name}</strong></td>
<td><span class="badge-client">${escapedId}</span></td>
<td><strong>${escapedName}</strong></td>
<td>
<code class="token-display">sk-••••${client.id.substring(client.id.length - 4)}</code>
<code class="token-display">sk-••••${escapedId.substring(escapedId.length - 4)}</code>
</td>
<td>${created}</td>
<td>${client.last_used ? window.api.formatTimeAgo(client.last_used) : 'Never'}</td>
@@ -55,16 +58,16 @@ class ClientsPage {
<td>
<span class="status-badge ${statusClass}">
<i class="fas fa-${statusIcon}"></i>
${client.status}
${window.api.escapeHtml(client.status)}
</span>
</td>
<td>
${window._userRole === 'admin' ? `
<div class="action-buttons">
<button class="btn-action" title="Edit" onclick="window.clientsPage.editClient('${client.id}')">
<button class="btn-action" title="Edit" onclick="window.clientsPage.editClient('${escapedId}')">
<i class="fas fa-edit"></i>
</button>
<button class="btn-action danger" title="Delete" onclick="window.clientsPage.deleteClient('${client.id}')">
<button class="btn-action danger" title="Delete" onclick="window.clientsPage.deleteClient('${escapedId}')">
<i class="fas fa-trash"></i>
</button>
</div>
@@ -188,10 +191,13 @@ class ClientsPage {
showTokenRevealModal(clientName, token) {
const modal = document.createElement('div');
modal.className = 'modal active';
const escapedName = window.api.escapeHtml(clientName);
const escapedToken = window.api.escapeHtml(token);
modal.innerHTML = `
<div class="modal-content">
<div class="modal-header">
<h3 class="modal-title">Client Created: ${clientName}</h3>
<h3 class="modal-title">Client Created: ${escapedName}</h3>
</div>
<div class="modal-body">
<p style="margin-bottom: 0.75rem; color: var(--yellow);">
@@ -201,7 +207,7 @@ class ClientsPage {
<div class="form-control">
<label>API Token</label>
<div style="display: flex; gap: 0.5rem;">
<input type="text" id="revealed-token" value="${token}" readonly
<input type="text" id="revealed-token" value="${escapedToken}" readonly
style="font-family: monospace; font-size: 0.85rem;">
<button class="btn btn-secondary" id="copy-token-btn" title="Copy">
<i class="fas fa-copy"></i>
@@ -248,10 +254,16 @@ class ClientsPage {
showEditClientModal(client) {
const modal = document.createElement('div');
modal.className = 'modal active';
const escapedId = window.api.escapeHtml(client.id);
const escapedName = window.api.escapeHtml(client.name);
const escapedDescription = window.api.escapeHtml(client.description);
const escapedRateLimit = window.api.escapeHtml(client.rate_limit_per_minute);
modal.innerHTML = `
<div class="modal-content">
<div class="modal-header">
<h3 class="modal-title">Edit Client: ${client.id}</h3>
<h3 class="modal-title">Edit Client: ${escapedId}</h3>
<button class="modal-close" onclick="this.closest('.modal').remove()">
<i class="fas fa-times"></i>
</button>
@@ -259,15 +271,15 @@ class ClientsPage {
<div class="modal-body">
<div class="form-control">
<label for="edit-client-name">Display Name</label>
<input type="text" id="edit-client-name" value="${client.name || ''}" placeholder="e.g. My Coding Assistant">
<input type="text" id="edit-client-name" value="${escapedName}" placeholder="e.g. My Coding Assistant">
</div>
<div class="form-control">
<label for="edit-client-description">Description</label>
<textarea id="edit-client-description" rows="3" placeholder="Optional description">${client.description || ''}</textarea>
<textarea id="edit-client-description" rows="3" placeholder="Optional description">${escapedDescription}</textarea>
</div>
<div class="form-control">
<label for="edit-client-rate-limit">Rate Limit (requests/minute)</label>
<input type="number" id="edit-client-rate-limit" min="0" value="${client.rate_limit_per_minute || ''}" placeholder="Leave empty for unlimited">
<input type="number" id="edit-client-rate-limit" min="0" value="${escapedRateLimit}" placeholder="Leave empty for unlimited">
</div>
<div class="form-control">
<label class="toggle-label">
@@ -357,12 +369,16 @@ class ClientsPage {
const lastUsed = t.last_used_at
? luxon.DateTime.fromISO(t.last_used_at).toRelative()
: 'Never';
const escapedMaskedToken = window.api.escapeHtml(t.token_masked);
const escapedClientId = window.api.escapeHtml(clientId);
const tokenId = parseInt(t.id); // Assuming ID is numeric
return `
<div style="display: flex; align-items: center; gap: 0.5rem; padding: 0.4rem 0; border-bottom: 1px solid var(--bg3);">
<code style="flex: 1; font-size: 0.8rem; color: var(--fg2);">${t.token_masked}</code>
<code style="flex: 1; font-size: 0.8rem; color: var(--fg2);">${escapedMaskedToken}</code>
<span style="font-size: 0.75rem; color: var(--fg4);" title="Last used">${lastUsed}</span>
<button class="btn-action danger" title="Revoke" style="padding: 0.2rem 0.4rem;"
onclick="window.clientsPage.revokeToken('${clientId}', ${t.id}, this)">
onclick="window.clientsPage.revokeToken('${escapedClientId}', ${tokenId}, this)">
<i class="fas fa-trash" style="font-size: 0.75rem;"></i>
</button>
</div>

View File

@@ -47,16 +47,21 @@ class ProvidersPage {
const isLowBalance = provider.credit_balance <= provider.low_credit_threshold && provider.id !== 'ollama';
const balanceColor = isLowBalance ? 'var(--red-light)' : 'var(--green-light)';
const escapedId = window.api.escapeHtml(provider.id);
const escapedName = window.api.escapeHtml(provider.name);
const escapedStatus = window.api.escapeHtml(provider.status);
const billingMode = provider.billing_mode ? provider.billing_mode.toUpperCase() : 'PREPAID';
return `
<div class="provider-card ${provider.status}">
<div class="provider-card ${escapedStatus}">
<div class="provider-card-header">
<div class="provider-info">
<h4 class="provider-name">${provider.name}</h4>
<span class="provider-id">${provider.id}</span>
<h4 class="provider-name">${escapedName}</h4>
<span class="provider-id">${escapedId}</span>
</div>
<span class="status-badge ${statusClass}">
<i class="fas fa-circle"></i>
${provider.status}
${escapedStatus}
</span>
</div>
<div class="provider-card-body">
@@ -67,12 +72,12 @@ class ProvidersPage {
</div>
<div class="meta-item" style="color: ${balanceColor}; font-weight: 700;">
<i class="fas fa-wallet"></i>
<span>Balance: ${provider.id === 'ollama' ? 'FREE' : window.api.formatCurrency(provider.credit_balance)}</span>
<span>Balance: ${escapedId === 'ollama' ? 'FREE' : window.api.formatCurrency(provider.credit_balance)}</span>
${isLowBalance ? '<i class="fas fa-exclamation-triangle" title="Low Balance"></i>' : ''}
</div>
<div class="meta-item">
<i class="fas fa-exchange-alt"></i>
<span>Billing: ${provider.billing_mode ? provider.billing_mode.toUpperCase() : 'PREPAID'}</span>
<span>Billing: ${window.api.escapeHtml(billingMode)}</span>
</div>
<div class="meta-item">
<i class="fas fa-clock"></i>
@@ -80,16 +85,16 @@ class ProvidersPage {
</div>
</div>
<div class="model-tags">
${(provider.models || []).slice(0, 5).map(m => `<span class="model-tag">${m}</span>`).join('')}
${(provider.models || []).slice(0, 5).map(m => `<span class="model-tag">${window.api.escapeHtml(m)}</span>`).join('')}
${modelCount > 5 ? `<span class="model-tag more">+${modelCount - 5} more</span>` : ''}
</div>
</div>
<div class="provider-card-footer">
<button class="btn btn-secondary btn-sm" onclick="window.providersPage.testProvider('${provider.id}')">
<button class="btn btn-secondary btn-sm" onclick="window.providersPage.testProvider('${escapedId}')">
<i class="fas fa-vial"></i> Test
</button>
${window._userRole === 'admin' ? `
<button class="btn btn-primary btn-sm" onclick="window.providersPage.configureProvider('${provider.id}')">
<button class="btn btn-primary btn-sm" onclick="window.providersPage.configureProvider('${escapedId}')">
<i class="fas fa-cog"></i> Config
</button>
` : ''}
@@ -144,10 +149,17 @@ class ProvidersPage {
const modal = document.createElement('div');
modal.className = 'modal active';
const escapedId = window.api.escapeHtml(provider.id);
const escapedName = window.api.escapeHtml(provider.name);
const escapedBaseUrl = window.api.escapeHtml(provider.base_url);
const escapedBalance = window.api.escapeHtml(provider.credit_balance);
const escapedThreshold = window.api.escapeHtml(provider.low_credit_threshold);
modal.innerHTML = `
<div class="modal-content">
<div class="modal-header">
<h3 class="modal-title">Configure ${provider.name}</h3>
<h3 class="modal-title">Configure ${escapedName}</h3>
<button class="modal-close" onclick="this.closest('.modal').remove()">
<i class="fas fa-times"></i>
</button>
@@ -161,7 +173,7 @@ class ProvidersPage {
</div>
<div class="form-control">
<label for="provider-base-url">Base URL</label>
<input type="text" id="provider-base-url" value="${provider.base_url || ''}" placeholder="Default API URL">
<input type="text" id="provider-base-url" value="${escapedBaseUrl}" placeholder="Default API URL">
</div>
<div class="form-control">
<label for="provider-api-key">API Key (Optional / Overwrite)</label>
@@ -170,11 +182,11 @@ class ProvidersPage {
<div class="grid-2">
<div class="form-control">
<label for="provider-balance">Current Credit Balance ($)</label>
<input type="number" id="provider-balance" value="${provider.credit_balance}" step="0.01">
<input type="number" id="provider-balance" value="${escapedBalance}" step="0.01">
</div>
<div class="form-control">
<label for="provider-threshold">Low Balance Alert ($)</label>
<input type="number" id="provider-threshold" value="${provider.low_credit_threshold}" step="0.50">
<input type="number" id="provider-threshold" value="${escapedThreshold}" step="0.50">
</div>
</div>
<div class="form-control">

View File

@@ -279,9 +279,7 @@
// ── Helpers ────────────────────────────────────────────────────
function escapeHtml(str) {
const div = document.createElement('div');
div.textContent = str;
return div.innerHTML;
}
function escapeHtml(str) {
return window.api.escapeHtml(str);
}
})();