This commit introduces: - AES-256-GCM encryption for LLM provider API keys in the database. - HMAC-SHA256 signed session tokens with activity-based refresh logic. - Standardized frontend XSS protection using a global escapeHtml utility. - Hardened security headers and request body size limits. - Improved database integrity with foreign key enforcement and atomic transactions. - Integration tests for the full encrypted key storage and proxy usage lifecycle.
421 lines
14 KiB
Rust
421 lines
14 KiB
Rust
use axum::{
|
|
extract::{Path, State},
|
|
response::Json,
|
|
};
|
|
use serde::Deserialize;
|
|
use serde_json;
|
|
use sqlx::Row;
|
|
use std::collections::HashMap;
|
|
use tracing::warn;
|
|
|
|
use super::{ApiResponse, DashboardState};
|
|
use crate::utils::crypto;
|
|
|
|
#[derive(Deserialize)]
|
|
pub(super) struct UpdateProviderRequest {
|
|
pub(super) enabled: bool,
|
|
pub(super) base_url: Option<String>,
|
|
pub(super) api_key: Option<String>,
|
|
pub(super) credit_balance: Option<f64>,
|
|
pub(super) low_credit_threshold: Option<f64>,
|
|
pub(super) billing_mode: Option<String>,
|
|
}
|
|
|
|
pub(super) async fn handle_get_providers(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
|
let registry = &state.app_state.model_registry;
|
|
let config = &state.app_state.config;
|
|
let pool = &state.app_state.db_pool;
|
|
|
|
// Load all overrides from database (including billing_mode)
|
|
let db_configs_result = sqlx::query(
|
|
"SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs",
|
|
)
|
|
.fetch_all(pool)
|
|
.await;
|
|
|
|
let mut db_configs = HashMap::new();
|
|
if let Ok(rows) = db_configs_result {
|
|
for row in rows {
|
|
let id: String = row.get("id");
|
|
let enabled: bool = row.get("enabled");
|
|
let base_url: Option<String> = row.get("base_url");
|
|
let balance: f64 = row.get("credit_balance");
|
|
let threshold: f64 = row.get("low_credit_threshold");
|
|
let billing_mode: Option<String> = row.get("billing_mode");
|
|
db_configs.insert(id, (enabled, base_url, balance, threshold, billing_mode));
|
|
}
|
|
}
|
|
|
|
let mut providers_json = Vec::new();
|
|
|
|
// Define the list of providers we support
|
|
let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
|
|
|
|
for id in provider_ids {
|
|
// Get base config
|
|
let (mut enabled, mut base_url, display_name) = match id {
|
|
"openai" => (
|
|
config.providers.openai.enabled,
|
|
config.providers.openai.base_url.clone(),
|
|
"OpenAI",
|
|
),
|
|
"gemini" => (
|
|
config.providers.gemini.enabled,
|
|
config.providers.gemini.base_url.clone(),
|
|
"Google Gemini",
|
|
),
|
|
"deepseek" => (
|
|
config.providers.deepseek.enabled,
|
|
config.providers.deepseek.base_url.clone(),
|
|
"DeepSeek",
|
|
),
|
|
"grok" => (
|
|
config.providers.grok.enabled,
|
|
config.providers.grok.base_url.clone(),
|
|
"xAI Grok",
|
|
),
|
|
"ollama" => (
|
|
config.providers.ollama.enabled,
|
|
config.providers.ollama.base_url.clone(),
|
|
"Ollama",
|
|
),
|
|
_ => (false, "".to_string(), "Unknown"),
|
|
};
|
|
|
|
let mut balance = 0.0;
|
|
let mut threshold = 5.0;
|
|
let mut billing_mode: Option<String> = None;
|
|
|
|
// Apply database overrides
|
|
if let Some((db_enabled, db_url, db_balance, db_threshold, db_billing)) = db_configs.get(id) {
|
|
enabled = *db_enabled;
|
|
if let Some(url) = db_url {
|
|
base_url = url.clone();
|
|
}
|
|
balance = *db_balance;
|
|
threshold = *db_threshold;
|
|
billing_mode = db_billing.clone();
|
|
}
|
|
|
|
// Find models for this provider in registry
|
|
// NOTE: registry provider IDs differ from internal IDs for some providers.
|
|
let registry_key = match id {
|
|
"gemini" => "google",
|
|
"grok" => "xai",
|
|
_ => id,
|
|
};
|
|
|
|
let mut models = Vec::new();
|
|
if let Some(p_info) = registry.providers.get(registry_key) {
|
|
models = p_info.models.keys().cloned().collect();
|
|
} else if id == "ollama" {
|
|
models = config.providers.ollama.models.clone();
|
|
}
|
|
|
|
// Determine status
|
|
let status = if !enabled {
|
|
"disabled"
|
|
} else {
|
|
// Check if it's actually initialized in the provider manager
|
|
if state.app_state.provider_manager.get_provider(id).await.is_some() {
|
|
// Check circuit breaker
|
|
if state
|
|
.app_state
|
|
.rate_limit_manager
|
|
.check_provider_request(id)
|
|
.await
|
|
.unwrap_or(true)
|
|
{
|
|
"online"
|
|
} else {
|
|
"degraded"
|
|
}
|
|
} else {
|
|
"error" // Enabled but failed to initialize (e.g. missing API key)
|
|
}
|
|
};
|
|
|
|
providers_json.push(serde_json::json!({
|
|
"id": id,
|
|
"name": display_name,
|
|
"enabled": enabled,
|
|
"status": status,
|
|
"models": models,
|
|
"base_url": base_url,
|
|
"credit_balance": balance,
|
|
"low_credit_threshold": threshold,
|
|
"billing_mode": billing_mode,
|
|
"last_used": None::<String>,
|
|
}));
|
|
}
|
|
|
|
Json(ApiResponse::success(serde_json::json!(providers_json)))
|
|
}
|
|
|
|
pub(super) async fn handle_get_provider(
|
|
State(state): State<DashboardState>,
|
|
Path(name): Path<String>,
|
|
) -> Json<ApiResponse<serde_json::Value>> {
|
|
let registry = &state.app_state.model_registry;
|
|
let config = &state.app_state.config;
|
|
let pool = &state.app_state.db_pool;
|
|
|
|
// Validate provider name
|
|
let (mut enabled, mut base_url, display_name) = match name.as_str() {
|
|
"openai" => (
|
|
config.providers.openai.enabled,
|
|
config.providers.openai.base_url.clone(),
|
|
"OpenAI",
|
|
),
|
|
"gemini" => (
|
|
config.providers.gemini.enabled,
|
|
config.providers.gemini.base_url.clone(),
|
|
"Google Gemini",
|
|
),
|
|
"deepseek" => (
|
|
config.providers.deepseek.enabled,
|
|
config.providers.deepseek.base_url.clone(),
|
|
"DeepSeek",
|
|
),
|
|
"grok" => (
|
|
config.providers.grok.enabled,
|
|
config.providers.grok.base_url.clone(),
|
|
"xAI Grok",
|
|
),
|
|
"ollama" => (
|
|
config.providers.ollama.enabled,
|
|
config.providers.ollama.base_url.clone(),
|
|
"Ollama",
|
|
),
|
|
_ => return Json(ApiResponse::error(format!("Unknown provider '{}'", name))),
|
|
};
|
|
|
|
let mut balance = 0.0;
|
|
let mut threshold = 5.0;
|
|
let mut billing_mode: Option<String> = None;
|
|
|
|
// Apply database overrides
|
|
let db_config = sqlx::query(
|
|
"SELECT enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs WHERE id = ?",
|
|
)
|
|
.bind(&name)
|
|
.fetch_optional(pool)
|
|
.await;
|
|
|
|
if let Ok(Some(row)) = db_config {
|
|
enabled = row.get::<bool, _>("enabled");
|
|
if let Some(url) = row.get::<Option<String>, _>("base_url") {
|
|
base_url = url;
|
|
}
|
|
balance = row.get::<f64, _>("credit_balance");
|
|
threshold = row.get::<f64, _>("low_credit_threshold");
|
|
billing_mode = row.get::<Option<String>, _>("billing_mode");
|
|
}
|
|
|
|
// Find models for this provider
|
|
// NOTE: registry provider IDs differ from internal IDs for some providers.
|
|
let registry_key = match name.as_str() {
|
|
"gemini" => "google",
|
|
"grok" => "xai",
|
|
_ => name.as_str(),
|
|
};
|
|
|
|
let mut models = Vec::new();
|
|
if let Some(p_info) = registry.providers.get(registry_key) {
|
|
models = p_info.models.keys().cloned().collect();
|
|
} else if name == "ollama" {
|
|
models = config.providers.ollama.models.clone();
|
|
}
|
|
|
|
// Determine status
|
|
let status = if !enabled {
|
|
"disabled"
|
|
} else if state.app_state.provider_manager.get_provider(&name).await.is_some() {
|
|
if state
|
|
.app_state
|
|
.rate_limit_manager
|
|
.check_provider_request(&name)
|
|
.await
|
|
.unwrap_or(true)
|
|
{
|
|
"online"
|
|
} else {
|
|
"degraded"
|
|
}
|
|
} else {
|
|
"error"
|
|
};
|
|
|
|
Json(ApiResponse::success(serde_json::json!({
|
|
"id": name,
|
|
"name": display_name,
|
|
"enabled": enabled,
|
|
"status": status,
|
|
"models": models,
|
|
"base_url": base_url,
|
|
"credit_balance": balance,
|
|
"low_credit_threshold": threshold,
|
|
"billing_mode": billing_mode,
|
|
"last_used": None::<String>,
|
|
})))
|
|
}
|
|
|
|
pub(super) async fn handle_update_provider(
|
|
State(state): State<DashboardState>,
|
|
headers: axum::http::HeaderMap,
|
|
Path(name): Path<String>,
|
|
Json(payload): Json<UpdateProviderRequest>,
|
|
) -> Json<ApiResponse<serde_json::Value>> {
|
|
let (session, _) = match super::auth::require_admin(&state, &headers).await {
|
|
Ok((session, new_token)) => (session, new_token),
|
|
Err(e) => return e,
|
|
};
|
|
|
|
let pool = &state.app_state.db_pool;
|
|
|
|
// Prepare API key encryption if provided
|
|
let (api_key_to_store, api_key_encrypted_flag) = match &payload.api_key {
|
|
Some(key) if !key.is_empty() => {
|
|
match crypto::encrypt(key) {
|
|
Ok(encrypted) => (Some(encrypted), Some(true)),
|
|
Err(e) => {
|
|
warn!("Failed to encrypt API key for provider {}: {}", name, e);
|
|
return Json(ApiResponse::error(format!("Failed to encrypt API key: {}", e)));
|
|
}
|
|
}
|
|
}
|
|
Some(_) => {
|
|
// Empty string means clear the key
|
|
(None, Some(false))
|
|
}
|
|
None => {
|
|
// Keep existing key, we'll rely on COALESCE in SQL
|
|
(None, None)
|
|
}
|
|
};
|
|
|
|
// Update or insert into database (include billing_mode and api_key_encrypted)
|
|
let result = sqlx::query(
|
|
r#"
|
|
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold, billing_mode)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(id) DO UPDATE SET
|
|
enabled = excluded.enabled,
|
|
base_url = excluded.base_url,
|
|
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
|
|
api_key_encrypted = COALESCE(excluded.api_key_encrypted, provider_configs.api_key_encrypted),
|
|
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
|
|
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
|
|
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
|
|
updated_at = CURRENT_TIMESTAMP
|
|
"#,
|
|
)
|
|
.bind(&name)
|
|
.bind(name.to_uppercase())
|
|
.bind(payload.enabled)
|
|
.bind(&payload.base_url)
|
|
.bind(&api_key_to_store)
|
|
.bind(api_key_encrypted_flag)
|
|
.bind(payload.credit_balance)
|
|
.bind(payload.low_credit_threshold)
|
|
.bind(payload.billing_mode)
|
|
.execute(pool)
|
|
.await;
|
|
|
|
match result {
|
|
Ok(_) => {
|
|
// Re-initialize provider in manager
|
|
if let Err(e) = state
|
|
.app_state
|
|
.provider_manager
|
|
.initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool)
|
|
.await
|
|
{
|
|
warn!("Failed to re-initialize provider {}: {}", name, e);
|
|
return Json(ApiResponse::error(format!(
|
|
"Provider settings saved but initialization failed: {}",
|
|
e
|
|
)));
|
|
}
|
|
|
|
Json(ApiResponse::success(
|
|
serde_json::json!({ "message": "Provider updated and re-initialized" }),
|
|
))
|
|
}
|
|
Err(e) => {
|
|
warn!("Failed to update provider config: {}", e);
|
|
Json(ApiResponse::error(format!("Failed to update provider: {}", e)))
|
|
}
|
|
}
|
|
}
|
|
|
|
pub(super) async fn handle_test_provider(
|
|
State(state): State<DashboardState>,
|
|
Path(name): Path<String>,
|
|
) -> Json<ApiResponse<serde_json::Value>> {
|
|
let start = std::time::Instant::now();
|
|
|
|
let provider = match state.app_state.provider_manager.get_provider(&name).await {
|
|
Some(p) => p,
|
|
None => {
|
|
return Json(ApiResponse::error(format!(
|
|
"Provider '{}' not found or not enabled",
|
|
name
|
|
)));
|
|
}
|
|
};
|
|
|
|
// Pick a real model for this provider from the registry
|
|
// NOTE: registry provider IDs differ from internal IDs for some providers.
|
|
let registry_key = match name.as_str() {
|
|
"gemini" => "google",
|
|
"grok" => "xai",
|
|
_ => name.as_str(),
|
|
};
|
|
|
|
let test_model = state
|
|
.app_state
|
|
.model_registry
|
|
.providers
|
|
.get(registry_key)
|
|
.and_then(|p| p.models.keys().next().cloned())
|
|
.unwrap_or_else(|| name.clone());
|
|
|
|
let test_request = crate::models::UnifiedRequest {
|
|
client_id: "system-test".to_string(),
|
|
model: test_model,
|
|
messages: vec![crate::models::UnifiedMessage {
|
|
role: "user".to_string(),
|
|
content: vec![crate::models::ContentPart::Text { text: "Hi".to_string() }],
|
|
reasoning_content: None,
|
|
tool_calls: None,
|
|
name: None,
|
|
tool_call_id: None,
|
|
}],
|
|
temperature: None,
|
|
top_p: None,
|
|
top_k: None,
|
|
n: None,
|
|
stop: None,
|
|
max_tokens: Some(5),
|
|
presence_penalty: None,
|
|
frequency_penalty: None,
|
|
stream: false,
|
|
has_images: false,
|
|
tools: None,
|
|
tool_choice: None,
|
|
};
|
|
|
|
match provider.chat_completion(test_request).await {
|
|
Ok(_) => {
|
|
let latency = start.elapsed().as_millis();
|
|
Json(ApiResponse::success(serde_json::json!({
|
|
"success": true,
|
|
"latency": latency,
|
|
"message": "Connection test successful"
|
|
})))
|
|
}
|
|
Err(e) => Json(ApiResponse::error(format!("Provider test failed: {}", e))),
|
|
}
|
|
}
|