use axum::{ extract::{Path, State}, response::Json, }; use serde::Deserialize; use serde_json; use sqlx::Row; use std::collections::HashMap; use tracing::warn; use super::{ApiResponse, DashboardState}; use crate::utils::crypto; #[derive(Deserialize)] pub(super) struct UpdateProviderRequest { pub(super) enabled: bool, pub(super) base_url: Option, pub(super) api_key: Option, pub(super) credit_balance: Option, pub(super) low_credit_threshold: Option, pub(super) billing_mode: Option, } pub(super) async fn handle_get_providers(State(state): State) -> Json> { let registry = &state.app_state.model_registry; let config = &state.app_state.config; let pool = &state.app_state.db_pool; // Load all overrides from database (including billing_mode) let db_configs_result = sqlx::query( "SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs", ) .fetch_all(pool) .await; let mut db_configs = HashMap::new(); if let Ok(rows) = db_configs_result { for row in rows { let id: String = row.get("id"); let enabled: bool = row.get("enabled"); let base_url: Option = row.get("base_url"); let balance: f64 = row.get("credit_balance"); let threshold: f64 = row.get("low_credit_threshold"); let billing_mode: Option = row.get("billing_mode"); db_configs.insert(id, (enabled, base_url, balance, threshold, billing_mode)); } } let mut providers_json = Vec::new(); // Define the list of providers we support let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"]; for id in provider_ids { // Get base config let (mut enabled, mut base_url, display_name) = match id { "openai" => ( config.providers.openai.enabled, config.providers.openai.base_url.clone(), "OpenAI", ), "gemini" => ( config.providers.gemini.enabled, config.providers.gemini.base_url.clone(), "Google Gemini", ), "deepseek" => ( config.providers.deepseek.enabled, config.providers.deepseek.base_url.clone(), "DeepSeek", ), "grok" => ( config.providers.grok.enabled, config.providers.grok.base_url.clone(), "xAI Grok", ), "ollama" => ( config.providers.ollama.enabled, config.providers.ollama.base_url.clone(), "Ollama", ), _ => (false, "".to_string(), "Unknown"), }; let mut balance = 0.0; let mut threshold = 5.0; let mut billing_mode: Option = None; // Apply database overrides if let Some((db_enabled, db_url, db_balance, db_threshold, db_billing)) = db_configs.get(id) { enabled = *db_enabled; if let Some(url) = db_url { base_url = url.clone(); } balance = *db_balance; threshold = *db_threshold; billing_mode = db_billing.clone(); } // Find models for this provider in registry // NOTE: registry provider IDs differ from internal IDs for some providers. let registry_key = match id { "gemini" => "google", "grok" => "xai", _ => id, }; let mut models = Vec::new(); if let Some(p_info) = registry.providers.get(registry_key) { models = p_info.models.keys().cloned().collect(); } else if id == "ollama" { models = config.providers.ollama.models.clone(); } // Determine status let status = if !enabled { "disabled" } else { // Check if it's actually initialized in the provider manager if state.app_state.provider_manager.get_provider(id).await.is_some() { // Check circuit breaker if state .app_state .rate_limit_manager .check_provider_request(id) .await .unwrap_or(true) { "online" } else { "degraded" } } else { "error" // Enabled but failed to initialize (e.g. missing API key) } }; providers_json.push(serde_json::json!({ "id": id, "name": display_name, "enabled": enabled, "status": status, "models": models, "base_url": base_url, "credit_balance": balance, "low_credit_threshold": threshold, "billing_mode": billing_mode, "last_used": None::, })); } Json(ApiResponse::success(serde_json::json!(providers_json))) } pub(super) async fn handle_get_provider( State(state): State, Path(name): Path, ) -> Json> { let registry = &state.app_state.model_registry; let config = &state.app_state.config; let pool = &state.app_state.db_pool; // Validate provider name let (mut enabled, mut base_url, display_name) = match name.as_str() { "openai" => ( config.providers.openai.enabled, config.providers.openai.base_url.clone(), "OpenAI", ), "gemini" => ( config.providers.gemini.enabled, config.providers.gemini.base_url.clone(), "Google Gemini", ), "deepseek" => ( config.providers.deepseek.enabled, config.providers.deepseek.base_url.clone(), "DeepSeek", ), "grok" => ( config.providers.grok.enabled, config.providers.grok.base_url.clone(), "xAI Grok", ), "ollama" => ( config.providers.ollama.enabled, config.providers.ollama.base_url.clone(), "Ollama", ), _ => return Json(ApiResponse::error(format!("Unknown provider '{}'", name))), }; let mut balance = 0.0; let mut threshold = 5.0; let mut billing_mode: Option = None; // Apply database overrides let db_config = sqlx::query( "SELECT enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs WHERE id = ?", ) .bind(&name) .fetch_optional(pool) .await; if let Ok(Some(row)) = db_config { enabled = row.get::("enabled"); if let Some(url) = row.get::, _>("base_url") { base_url = url; } balance = row.get::("credit_balance"); threshold = row.get::("low_credit_threshold"); billing_mode = row.get::, _>("billing_mode"); } // Find models for this provider // NOTE: registry provider IDs differ from internal IDs for some providers. let registry_key = match name.as_str() { "gemini" => "google", "grok" => "xai", _ => name.as_str(), }; let mut models = Vec::new(); if let Some(p_info) = registry.providers.get(registry_key) { models = p_info.models.keys().cloned().collect(); } else if name == "ollama" { models = config.providers.ollama.models.clone(); } // Determine status let status = if !enabled { "disabled" } else if state.app_state.provider_manager.get_provider(&name).await.is_some() { if state .app_state .rate_limit_manager .check_provider_request(&name) .await .unwrap_or(true) { "online" } else { "degraded" } } else { "error" }; Json(ApiResponse::success(serde_json::json!({ "id": name, "name": display_name, "enabled": enabled, "status": status, "models": models, "base_url": base_url, "credit_balance": balance, "low_credit_threshold": threshold, "billing_mode": billing_mode, "last_used": None::, }))) } pub(super) async fn handle_update_provider( State(state): State, headers: axum::http::HeaderMap, Path(name): Path, Json(payload): Json, ) -> Json> { let (session, _) = match super::auth::require_admin(&state, &headers).await { Ok((session, new_token)) => (session, new_token), Err(e) => return e, }; let pool = &state.app_state.db_pool; // Prepare API key encryption if provided let (api_key_to_store, api_key_encrypted_flag) = match &payload.api_key { Some(key) if !key.is_empty() => { match crypto::encrypt(key) { Ok(encrypted) => (Some(encrypted), Some(true)), Err(e) => { warn!("Failed to encrypt API key for provider {}: {}", name, e); return Json(ApiResponse::error(format!("Failed to encrypt API key: {}", e))); } } } Some(_) => { // Empty string means clear the key (None, Some(false)) } None => { // Keep existing key, we'll rely on COALESCE in SQL (None, None) } }; // Update or insert into database (include billing_mode and api_key_encrypted) let result = sqlx::query( r#" INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold, billing_mode) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET enabled = excluded.enabled, base_url = excluded.base_url, api_key = COALESCE(excluded.api_key, provider_configs.api_key), api_key_encrypted = COALESCE(excluded.api_key_encrypted, provider_configs.api_key_encrypted), credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance), low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold), billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode), updated_at = CURRENT_TIMESTAMP "#, ) .bind(&name) .bind(name.to_uppercase()) .bind(payload.enabled) .bind(&payload.base_url) .bind(&api_key_to_store) .bind(api_key_encrypted_flag) .bind(payload.credit_balance) .bind(payload.low_credit_threshold) .bind(payload.billing_mode) .execute(pool) .await; match result { Ok(_) => { // Re-initialize provider in manager if let Err(e) = state .app_state .provider_manager .initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool) .await { warn!("Failed to re-initialize provider {}: {}", name, e); return Json(ApiResponse::error(format!( "Provider settings saved but initialization failed: {}", e ))); } Json(ApiResponse::success( serde_json::json!({ "message": "Provider updated and re-initialized" }), )) } Err(e) => { warn!("Failed to update provider config: {}", e); Json(ApiResponse::error(format!("Failed to update provider: {}", e))) } } } pub(super) async fn handle_test_provider( State(state): State, Path(name): Path, ) -> Json> { let start = std::time::Instant::now(); let provider = match state.app_state.provider_manager.get_provider(&name).await { Some(p) => p, None => { return Json(ApiResponse::error(format!( "Provider '{}' not found or not enabled", name ))); } }; // Pick a real model for this provider from the registry // NOTE: registry provider IDs differ from internal IDs for some providers. let registry_key = match name.as_str() { "gemini" => "google", "grok" => "xai", _ => name.as_str(), }; let test_model = state .app_state .model_registry .providers .get(registry_key) .and_then(|p| p.models.keys().next().cloned()) .unwrap_or_else(|| name.clone()); let test_request = crate::models::UnifiedRequest { client_id: "system-test".to_string(), model: test_model, messages: vec![crate::models::UnifiedMessage { role: "user".to_string(), content: vec![crate::models::ContentPart::Text { text: "Hi".to_string() }], reasoning_content: None, tool_calls: None, name: None, tool_call_id: None, }], temperature: None, top_p: None, top_k: None, n: None, stop: None, max_tokens: Some(5), presence_penalty: None, frequency_penalty: None, stream: false, has_images: false, tools: None, tool_choice: None, }; match provider.chat_completion(test_request).await { Ok(_) => { let latency = start.elapsed().as_millis(); Json(ApiResponse::success(serde_json::json!({ "success": true, "latency": latency, "message": "Connection test successful" }))) } Err(e) => Json(ApiResponse::error(format!("Provider test failed: {}", e))), } }