Files
GopherGate/src/dashboard/providers.rs
hobokenchicken 9b8483e797 feat(security): implement AES-256-GCM encryption for API keys and HMAC-signed session tokens
This commit introduces:
- AES-256-GCM encryption for LLM provider API keys in the database.
- HMAC-SHA256 signed session tokens with activity-based refresh logic.
- Standardized frontend XSS protection using a global escapeHtml utility.
- Hardened security headers and request body size limits.
- Improved database integrity with foreign key enforcement and atomic transactions.
- Integration tests for the full encrypted key storage and proxy usage lifecycle.
2026-03-06 14:17:56 -05:00

421 lines
14 KiB
Rust

use axum::{
extract::{Path, State},
response::Json,
};
use serde::Deserialize;
use serde_json;
use sqlx::Row;
use std::collections::HashMap;
use tracing::warn;
use super::{ApiResponse, DashboardState};
use crate::utils::crypto;
#[derive(Deserialize)]
pub(super) struct UpdateProviderRequest {
pub(super) enabled: bool,
pub(super) base_url: Option<String>,
pub(super) api_key: Option<String>,
pub(super) credit_balance: Option<f64>,
pub(super) low_credit_threshold: Option<f64>,
pub(super) billing_mode: Option<String>,
}
pub(super) async fn handle_get_providers(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let registry = &state.app_state.model_registry;
let config = &state.app_state.config;
let pool = &state.app_state.db_pool;
// Load all overrides from database (including billing_mode)
let db_configs_result = sqlx::query(
"SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs",
)
.fetch_all(pool)
.await;
let mut db_configs = HashMap::new();
if let Ok(rows) = db_configs_result {
for row in rows {
let id: String = row.get("id");
let enabled: bool = row.get("enabled");
let base_url: Option<String> = row.get("base_url");
let balance: f64 = row.get("credit_balance");
let threshold: f64 = row.get("low_credit_threshold");
let billing_mode: Option<String> = row.get("billing_mode");
db_configs.insert(id, (enabled, base_url, balance, threshold, billing_mode));
}
}
let mut providers_json = Vec::new();
// Define the list of providers we support
let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
for id in provider_ids {
// Get base config
let (mut enabled, mut base_url, display_name) = match id {
"openai" => (
config.providers.openai.enabled,
config.providers.openai.base_url.clone(),
"OpenAI",
),
"gemini" => (
config.providers.gemini.enabled,
config.providers.gemini.base_url.clone(),
"Google Gemini",
),
"deepseek" => (
config.providers.deepseek.enabled,
config.providers.deepseek.base_url.clone(),
"DeepSeek",
),
"grok" => (
config.providers.grok.enabled,
config.providers.grok.base_url.clone(),
"xAI Grok",
),
"ollama" => (
config.providers.ollama.enabled,
config.providers.ollama.base_url.clone(),
"Ollama",
),
_ => (false, "".to_string(), "Unknown"),
};
let mut balance = 0.0;
let mut threshold = 5.0;
let mut billing_mode: Option<String> = None;
// Apply database overrides
if let Some((db_enabled, db_url, db_balance, db_threshold, db_billing)) = db_configs.get(id) {
enabled = *db_enabled;
if let Some(url) = db_url {
base_url = url.clone();
}
balance = *db_balance;
threshold = *db_threshold;
billing_mode = db_billing.clone();
}
// Find models for this provider in registry
// NOTE: registry provider IDs differ from internal IDs for some providers.
let registry_key = match id {
"gemini" => "google",
"grok" => "xai",
_ => id,
};
let mut models = Vec::new();
if let Some(p_info) = registry.providers.get(registry_key) {
models = p_info.models.keys().cloned().collect();
} else if id == "ollama" {
models = config.providers.ollama.models.clone();
}
// Determine status
let status = if !enabled {
"disabled"
} else {
// Check if it's actually initialized in the provider manager
if state.app_state.provider_manager.get_provider(id).await.is_some() {
// Check circuit breaker
if state
.app_state
.rate_limit_manager
.check_provider_request(id)
.await
.unwrap_or(true)
{
"online"
} else {
"degraded"
}
} else {
"error" // Enabled but failed to initialize (e.g. missing API key)
}
};
providers_json.push(serde_json::json!({
"id": id,
"name": display_name,
"enabled": enabled,
"status": status,
"models": models,
"base_url": base_url,
"credit_balance": balance,
"low_credit_threshold": threshold,
"billing_mode": billing_mode,
"last_used": None::<String>,
}));
}
Json(ApiResponse::success(serde_json::json!(providers_json)))
}
pub(super) async fn handle_get_provider(
State(state): State<DashboardState>,
Path(name): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let registry = &state.app_state.model_registry;
let config = &state.app_state.config;
let pool = &state.app_state.db_pool;
// Validate provider name
let (mut enabled, mut base_url, display_name) = match name.as_str() {
"openai" => (
config.providers.openai.enabled,
config.providers.openai.base_url.clone(),
"OpenAI",
),
"gemini" => (
config.providers.gemini.enabled,
config.providers.gemini.base_url.clone(),
"Google Gemini",
),
"deepseek" => (
config.providers.deepseek.enabled,
config.providers.deepseek.base_url.clone(),
"DeepSeek",
),
"grok" => (
config.providers.grok.enabled,
config.providers.grok.base_url.clone(),
"xAI Grok",
),
"ollama" => (
config.providers.ollama.enabled,
config.providers.ollama.base_url.clone(),
"Ollama",
),
_ => return Json(ApiResponse::error(format!("Unknown provider '{}'", name))),
};
let mut balance = 0.0;
let mut threshold = 5.0;
let mut billing_mode: Option<String> = None;
// Apply database overrides
let db_config = sqlx::query(
"SELECT enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs WHERE id = ?",
)
.bind(&name)
.fetch_optional(pool)
.await;
if let Ok(Some(row)) = db_config {
enabled = row.get::<bool, _>("enabled");
if let Some(url) = row.get::<Option<String>, _>("base_url") {
base_url = url;
}
balance = row.get::<f64, _>("credit_balance");
threshold = row.get::<f64, _>("low_credit_threshold");
billing_mode = row.get::<Option<String>, _>("billing_mode");
}
// Find models for this provider
// NOTE: registry provider IDs differ from internal IDs for some providers.
let registry_key = match name.as_str() {
"gemini" => "google",
"grok" => "xai",
_ => name.as_str(),
};
let mut models = Vec::new();
if let Some(p_info) = registry.providers.get(registry_key) {
models = p_info.models.keys().cloned().collect();
} else if name == "ollama" {
models = config.providers.ollama.models.clone();
}
// Determine status
let status = if !enabled {
"disabled"
} else if state.app_state.provider_manager.get_provider(&name).await.is_some() {
if state
.app_state
.rate_limit_manager
.check_provider_request(&name)
.await
.unwrap_or(true)
{
"online"
} else {
"degraded"
}
} else {
"error"
};
Json(ApiResponse::success(serde_json::json!({
"id": name,
"name": display_name,
"enabled": enabled,
"status": status,
"models": models,
"base_url": base_url,
"credit_balance": balance,
"low_credit_threshold": threshold,
"billing_mode": billing_mode,
"last_used": None::<String>,
})))
}
pub(super) async fn handle_update_provider(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(name): Path<String>,
Json(payload): Json<UpdateProviderRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Prepare API key encryption if provided
let (api_key_to_store, api_key_encrypted_flag) = match &payload.api_key {
Some(key) if !key.is_empty() => {
match crypto::encrypt(key) {
Ok(encrypted) => (Some(encrypted), Some(true)),
Err(e) => {
warn!("Failed to encrypt API key for provider {}: {}", name, e);
return Json(ApiResponse::error(format!("Failed to encrypt API key: {}", e)));
}
}
}
Some(_) => {
// Empty string means clear the key
(None, Some(false))
}
None => {
// Keep existing key, we'll rely on COALESCE in SQL
(None, None)
}
};
// Update or insert into database (include billing_mode and api_key_encrypted)
let result = sqlx::query(
r#"
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold, billing_mode)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
enabled = excluded.enabled,
base_url = excluded.base_url,
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
api_key_encrypted = COALESCE(excluded.api_key_encrypted, provider_configs.api_key_encrypted),
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
updated_at = CURRENT_TIMESTAMP
"#,
)
.bind(&name)
.bind(name.to_uppercase())
.bind(payload.enabled)
.bind(&payload.base_url)
.bind(&api_key_to_store)
.bind(api_key_encrypted_flag)
.bind(payload.credit_balance)
.bind(payload.low_credit_threshold)
.bind(payload.billing_mode)
.execute(pool)
.await;
match result {
Ok(_) => {
// Re-initialize provider in manager
if let Err(e) = state
.app_state
.provider_manager
.initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool)
.await
{
warn!("Failed to re-initialize provider {}: {}", name, e);
return Json(ApiResponse::error(format!(
"Provider settings saved but initialization failed: {}",
e
)));
}
Json(ApiResponse::success(
serde_json::json!({ "message": "Provider updated and re-initialized" }),
))
}
Err(e) => {
warn!("Failed to update provider config: {}", e);
Json(ApiResponse::error(format!("Failed to update provider: {}", e)))
}
}
}
pub(super) async fn handle_test_provider(
State(state): State<DashboardState>,
Path(name): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let start = std::time::Instant::now();
let provider = match state.app_state.provider_manager.get_provider(&name).await {
Some(p) => p,
None => {
return Json(ApiResponse::error(format!(
"Provider '{}' not found or not enabled",
name
)));
}
};
// Pick a real model for this provider from the registry
// NOTE: registry provider IDs differ from internal IDs for some providers.
let registry_key = match name.as_str() {
"gemini" => "google",
"grok" => "xai",
_ => name.as_str(),
};
let test_model = state
.app_state
.model_registry
.providers
.get(registry_key)
.and_then(|p| p.models.keys().next().cloned())
.unwrap_or_else(|| name.clone());
let test_request = crate::models::UnifiedRequest {
client_id: "system-test".to_string(),
model: test_model,
messages: vec![crate::models::UnifiedMessage {
role: "user".to_string(),
content: vec![crate::models::ContentPart::Text { text: "Hi".to_string() }],
reasoning_content: None,
tool_calls: None,
name: None,
tool_call_id: None,
}],
temperature: None,
top_p: None,
top_k: None,
n: None,
stop: None,
max_tokens: Some(5),
presence_penalty: None,
frequency_penalty: None,
stream: false,
has_images: false,
tools: None,
tool_choice: None,
};
match provider.chat_completion(test_request).await {
Ok(_) => {
let latency = start.elapsed().as_millis();
Json(ApiResponse::success(serde_json::json!({
"success": true,
"latency": latency,
"message": "Connection test successful"
})))
}
Err(e) => Json(ApiResponse::error(format!("Provider test failed: {}", e))),
}
}