refactor: comprehensive audit — fix bugs, harden security, deduplicate providers, add CI/Docker
Phase 1: Fix compilation (config_path Option<PathBuf>, streaming test, stale test cleanup) Phase 2: Fix critical bugs (remove block_on deadlocks in 4 providers, fix broken SQL query builder) Phase 3: Security hardening (session manager, real auth, token masking, Gemini key to header, password policy) Phase 4: Implement stubs (real provider test, /proc health metrics, client/provider/backup endpoints, has_images) Phase 5: Code quality (shared provider helpers, explicit re-exports, all Clippy warnings fixed, unwrap removal, 6 unused deps removed, dashboard split into 7 sub-modules) Phase 6: Infrastructure (GitHub Actions CI, multi-stage Dockerfile, rustfmt.toml, clippy.toml, script fixes)
This commit is contained in:
346
src/dashboard/providers.rs
Normal file
346
src/dashboard/providers.rs
Normal file
@@ -0,0 +1,346 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::Json,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json;
|
||||
use sqlx::Row;
|
||||
use std::collections::HashMap;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct UpdateProviderRequest {
|
||||
pub(super) enabled: bool,
|
||||
pub(super) base_url: Option<String>,
|
||||
pub(super) api_key: Option<String>,
|
||||
pub(super) credit_balance: Option<f64>,
|
||||
pub(super) low_credit_threshold: Option<f64>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_providers(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let registry = &state.app_state.model_registry;
|
||||
let config = &state.app_state.config;
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Load all overrides from database
|
||||
let db_configs_result =
|
||||
sqlx::query("SELECT id, enabled, base_url, credit_balance, low_credit_threshold FROM provider_configs")
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
let mut db_configs = HashMap::new();
|
||||
if let Ok(rows) = db_configs_result {
|
||||
for row in rows {
|
||||
let id: String = row.get("id");
|
||||
let enabled: bool = row.get("enabled");
|
||||
let base_url: Option<String> = row.get("base_url");
|
||||
let balance: f64 = row.get("credit_balance");
|
||||
let threshold: f64 = row.get("low_credit_threshold");
|
||||
db_configs.insert(id, (enabled, base_url, balance, threshold));
|
||||
}
|
||||
}
|
||||
|
||||
let mut providers_json = Vec::new();
|
||||
|
||||
// Define the list of providers we support
|
||||
let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
|
||||
|
||||
for id in provider_ids {
|
||||
// Get base config
|
||||
let (mut enabled, mut base_url, display_name) = match id {
|
||||
"openai" => (
|
||||
config.providers.openai.enabled,
|
||||
config.providers.openai.base_url.clone(),
|
||||
"OpenAI",
|
||||
),
|
||||
"gemini" => (
|
||||
config.providers.gemini.enabled,
|
||||
config.providers.gemini.base_url.clone(),
|
||||
"Google Gemini",
|
||||
),
|
||||
"deepseek" => (
|
||||
config.providers.deepseek.enabled,
|
||||
config.providers.deepseek.base_url.clone(),
|
||||
"DeepSeek",
|
||||
),
|
||||
"grok" => (
|
||||
config.providers.grok.enabled,
|
||||
config.providers.grok.base_url.clone(),
|
||||
"xAI Grok",
|
||||
),
|
||||
"ollama" => (
|
||||
config.providers.ollama.enabled,
|
||||
config.providers.ollama.base_url.clone(),
|
||||
"Ollama",
|
||||
),
|
||||
_ => (false, "".to_string(), "Unknown"),
|
||||
};
|
||||
|
||||
let mut balance = 0.0;
|
||||
let mut threshold = 5.0;
|
||||
|
||||
// Apply database overrides
|
||||
if let Some((db_enabled, db_url, db_balance, db_threshold)) = db_configs.get(id) {
|
||||
enabled = *db_enabled;
|
||||
if let Some(url) = db_url {
|
||||
base_url = url.clone();
|
||||
}
|
||||
balance = *db_balance;
|
||||
threshold = *db_threshold;
|
||||
}
|
||||
|
||||
// Find models for this provider in registry
|
||||
let mut models = Vec::new();
|
||||
if let Some(p_info) = registry.providers.get(id) {
|
||||
models = p_info.models.keys().cloned().collect();
|
||||
} else if id == "ollama" {
|
||||
models = config.providers.ollama.models.clone();
|
||||
}
|
||||
|
||||
// Determine status
|
||||
let status = if !enabled {
|
||||
"disabled"
|
||||
} else {
|
||||
// Check if it's actually initialized in the provider manager
|
||||
if state.app_state.provider_manager.get_provider(id).await.is_some() {
|
||||
// Check circuit breaker
|
||||
if state
|
||||
.app_state
|
||||
.rate_limit_manager
|
||||
.check_provider_request(id)
|
||||
.await
|
||||
.unwrap_or(true)
|
||||
{
|
||||
"online"
|
||||
} else {
|
||||
"degraded"
|
||||
}
|
||||
} else {
|
||||
"error" // Enabled but failed to initialize (e.g. missing API key)
|
||||
}
|
||||
};
|
||||
|
||||
providers_json.push(serde_json::json!({
|
||||
"id": id,
|
||||
"name": display_name,
|
||||
"enabled": enabled,
|
||||
"status": status,
|
||||
"models": models,
|
||||
"base_url": base_url,
|
||||
"credit_balance": balance,
|
||||
"low_credit_threshold": threshold,
|
||||
"last_used": None::<String>,
|
||||
}));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(providers_json)))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_provider(
|
||||
State(state): State<DashboardState>,
|
||||
Path(name): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let registry = &state.app_state.model_registry;
|
||||
let config = &state.app_state.config;
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Validate provider name
|
||||
let (mut enabled, mut base_url, display_name) = match name.as_str() {
|
||||
"openai" => (
|
||||
config.providers.openai.enabled,
|
||||
config.providers.openai.base_url.clone(),
|
||||
"OpenAI",
|
||||
),
|
||||
"gemini" => (
|
||||
config.providers.gemini.enabled,
|
||||
config.providers.gemini.base_url.clone(),
|
||||
"Google Gemini",
|
||||
),
|
||||
"deepseek" => (
|
||||
config.providers.deepseek.enabled,
|
||||
config.providers.deepseek.base_url.clone(),
|
||||
"DeepSeek",
|
||||
),
|
||||
"grok" => (
|
||||
config.providers.grok.enabled,
|
||||
config.providers.grok.base_url.clone(),
|
||||
"xAI Grok",
|
||||
),
|
||||
"ollama" => (
|
||||
config.providers.ollama.enabled,
|
||||
config.providers.ollama.base_url.clone(),
|
||||
"Ollama",
|
||||
),
|
||||
_ => return Json(ApiResponse::error(format!("Unknown provider '{}'", name))),
|
||||
};
|
||||
|
||||
let mut balance = 0.0;
|
||||
let mut threshold = 5.0;
|
||||
|
||||
// Apply database overrides
|
||||
let db_config = sqlx::query(
|
||||
"SELECT enabled, base_url, credit_balance, low_credit_threshold FROM provider_configs WHERE id = ?",
|
||||
)
|
||||
.bind(&name)
|
||||
.fetch_optional(pool)
|
||||
.await;
|
||||
|
||||
if let Ok(Some(row)) = db_config {
|
||||
enabled = row.get::<bool, _>("enabled");
|
||||
if let Some(url) = row.get::<Option<String>, _>("base_url") {
|
||||
base_url = url;
|
||||
}
|
||||
balance = row.get::<f64, _>("credit_balance");
|
||||
threshold = row.get::<f64, _>("low_credit_threshold");
|
||||
}
|
||||
|
||||
// Find models for this provider
|
||||
let mut models = Vec::new();
|
||||
if let Some(p_info) = registry.providers.get(name.as_str()) {
|
||||
models = p_info.models.keys().cloned().collect();
|
||||
} else if name == "ollama" {
|
||||
models = config.providers.ollama.models.clone();
|
||||
}
|
||||
|
||||
// Determine status
|
||||
let status = if !enabled {
|
||||
"disabled"
|
||||
} else if state.app_state.provider_manager.get_provider(&name).await.is_some() {
|
||||
if state
|
||||
.app_state
|
||||
.rate_limit_manager
|
||||
.check_provider_request(&name)
|
||||
.await
|
||||
.unwrap_or(true)
|
||||
{
|
||||
"online"
|
||||
} else {
|
||||
"degraded"
|
||||
}
|
||||
} else {
|
||||
"error"
|
||||
};
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"id": name,
|
||||
"name": display_name,
|
||||
"enabled": enabled,
|
||||
"status": status,
|
||||
"models": models,
|
||||
"base_url": base_url,
|
||||
"credit_balance": balance,
|
||||
"low_credit_threshold": threshold,
|
||||
"last_used": None::<String>,
|
||||
})))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_update_provider(
|
||||
State(state): State<DashboardState>,
|
||||
Path(name): Path<String>,
|
||||
Json(payload): Json<UpdateProviderRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Update or insert into database
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
enabled = excluded.enabled,
|
||||
base_url = excluded.base_url,
|
||||
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
|
||||
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
|
||||
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"#
|
||||
)
|
||||
.bind(&name)
|
||||
.bind(name.to_uppercase())
|
||||
.bind(payload.enabled)
|
||||
.bind(&payload.base_url)
|
||||
.bind(&payload.api_key)
|
||||
.bind(payload.credit_balance)
|
||||
.bind(payload.low_credit_threshold)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
// Re-initialize provider in manager
|
||||
if let Err(e) = state
|
||||
.app_state
|
||||
.provider_manager
|
||||
.initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool)
|
||||
.await
|
||||
{
|
||||
warn!("Failed to re-initialize provider {}: {}", name, e);
|
||||
return Json(ApiResponse::error(format!(
|
||||
"Provider settings saved but initialization failed: {}",
|
||||
e
|
||||
)));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(
|
||||
serde_json::json!({ "message": "Provider updated and re-initialized" }),
|
||||
))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to update provider config: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to update provider: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_test_provider(
|
||||
State(state): State<DashboardState>,
|
||||
Path(name): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let provider = match state.app_state.provider_manager.get_provider(&name).await {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
return Json(ApiResponse::error(format!(
|
||||
"Provider '{}' not found or not enabled",
|
||||
name
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// Pick a real model for this provider from the registry
|
||||
let test_model = state
|
||||
.app_state
|
||||
.model_registry
|
||||
.providers
|
||||
.get(&name)
|
||||
.and_then(|p| p.models.keys().next().cloned())
|
||||
.unwrap_or_else(|| name.clone());
|
||||
|
||||
let test_request = crate::models::UnifiedRequest {
|
||||
client_id: "system-test".to_string(),
|
||||
model: test_model,
|
||||
messages: vec![crate::models::UnifiedMessage {
|
||||
role: "user".to_string(),
|
||||
content: vec![crate::models::ContentPart::Text { text: "Hi".to_string() }],
|
||||
}],
|
||||
temperature: None,
|
||||
max_tokens: Some(5),
|
||||
stream: false,
|
||||
has_images: false,
|
||||
};
|
||||
|
||||
match provider.chat_completion(test_request).await {
|
||||
Ok(_) => {
|
||||
let latency = start.elapsed().as_millis();
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"success": true,
|
||||
"latency": latency,
|
||||
"message": "Connection test successful"
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::error(format!("Provider test failed: {}", e))),
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user