From 3165aa185902d841f33b6ae881092f220ef7cfa7 Mon Sep 17 00:00:00 2001 From: hobokenchicken Date: Thu, 26 Feb 2026 18:13:04 -0500 Subject: [PATCH] feat: implement web UI for provider and model configuration - Added 'provider_configs' and 'model_configs' tables to database. - Refactored ProviderManager to support thread-safe dynamic updates and database overrides. - Implemented 'Models' tab in dashboard to manage model visibility, mapping, and pricing. - Added provider configuration modal to 'Providers' tab. - Integrated database overrides into chat completion logic (enabled state, mapping, and cost). --- src/dashboard/mod.rs | 204 ++++++++++++++++++++++++++++++++--- src/database/mod.rs | 35 ++++++ src/main.rs | 69 ++---------- src/providers/deepseek.rs | 5 +- src/providers/gemini.rs | 5 +- src/providers/grok.rs | 5 +- src/providers/mod.rs | 134 +++++++++++++++++++++-- src/providers/openai.rs | 5 +- src/server/mod.rs | 67 ++++++++++-- src/utils/streaming.rs | 27 ++++- static/index.html | 5 + static/js/dashboard.js | 25 +++++ static/js/pages/models.js | 165 ++++++++++++++++++++++++++++ static/js/pages/providers.js | 59 +++++++++- 14 files changed, 707 insertions(+), 103 deletions(-) create mode 100644 static/js/pages/models.js diff --git a/src/dashboard/mod.rs b/src/dashboard/mod.rs index da63dbf0..8f8cd178 100644 --- a/src/dashboard/mod.rs +++ b/src/dashboard/mod.rs @@ -3,7 +3,7 @@ use axum::{ extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State}, response::{IntoResponse, Json}, - routing::{get, post}, + routing::{get, post, put}, Router, }; use serde::{Deserialize, Serialize}; @@ -67,6 +67,8 @@ pub fn router(state: AppState) -> Router { .route("/api/usage/time-series", get(handle_time_series)) .route("/api/usage/clients", get(handle_clients_usage)) .route("/api/usage/providers", get(handle_providers_usage)) + .route("/api/models", get(handle_get_models)) + .route("/api/models/{id}", put(handle_update_model)) .route("/api/clients", get(handle_get_clients).post(handle_create_client)) .route("/api/clients/{id}", get(handle_get_client).delete(handle_delete_client)) .route("/api/clients/{id}/usage", get(handle_client_usage)) @@ -531,19 +533,47 @@ async fn handle_client_usage( async fn handle_get_providers(State(state): State) -> Json> { let registry = &state.app_state.model_registry; let config = &state.app_state.config; + let pool = &state.app_state.db_pool; + // Load all overrides from database + let db_configs_result = sqlx::query("SELECT id, enabled, base_url FROM provider_configs") + .fetch_all(pool) + .await; + + let mut db_configs = HashMap::new(); + if let Ok(rows) = db_configs_result { + for row in rows { + let id: String = row.get("id"); + let enabled: bool = row.get("enabled"); + let base_url: Option = row.get("base_url"); + db_configs.insert(id, (enabled, base_url)); + } + } + let mut providers_json = Vec::new(); // Define the list of providers we support - let provider_configs = vec![ - ("openai", "OpenAI", config.providers.openai.enabled), - ("gemini", "Google Gemini", config.providers.gemini.enabled), - ("deepseek", "DeepSeek", config.providers.deepseek.enabled), - ("grok", "xAI Grok", config.providers.grok.enabled), - ("ollama", "Ollama", config.providers.ollama.enabled), - ]; + let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"]; + + for id in provider_ids { + // Get base config + let (mut enabled, mut base_url, display_name) = match id { + "openai" => (config.providers.openai.enabled, config.providers.openai.base_url.clone(), "OpenAI"), + "gemini" => (config.providers.gemini.enabled, config.providers.gemini.base_url.clone(), "Google Gemini"), + "deepseek" => (config.providers.deepseek.enabled, config.providers.deepseek.base_url.clone(), "DeepSeek"), + "grok" => (config.providers.grok.enabled, config.providers.grok.base_url.clone(), "xAI Grok"), + "ollama" => (config.providers.ollama.enabled, config.providers.ollama.base_url.clone(), "Ollama"), + _ => (false, "".to_string(), "Unknown"), + }; + + // Apply database overrides + if let Some((db_enabled, db_url)) = db_configs.get(id) { + enabled = *db_enabled; + if let Some(url) = db_url { + base_url = url.clone(); + } + } - for (id, display_name, enabled) in provider_configs { // Find models for this provider in registry let mut models = Vec::new(); if let Some(p_info) = registry.providers.get(id) { @@ -557,7 +587,7 @@ async fn handle_get_providers(State(state): State) -> Json) -> Json, })); } @@ -589,14 +620,55 @@ async fn handle_get_provider( Json(ApiResponse::error("Not implemented".to_string())) } +#[derive(Deserialize)] +struct UpdateProviderRequest { + enabled: bool, + base_url: Option, + api_key: Option, +} + async fn handle_update_provider( - State(_state): State, - axum::extract::Path(_name): axum::extract::Path, + State(state): State, + axum::extract::Path(name): axum::extract::Path, + Json(payload): Json, ) -> Json> { - Json(ApiResponse::success(serde_json::json!({ - "success": true, - "message": "Provider updated" - }))) + let pool = &state.app_state.db_pool; + + // Update or insert into database + let result = sqlx::query( + r#" + INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + enabled = excluded.enabled, + base_url = excluded.base_url, + api_key = COALESCE(excluded.api_key, provider_configs.api_key), + updated_at = CURRENT_TIMESTAMP + "# + ) + .bind(&name) + .bind(name.to_uppercase()) + .bind(payload.enabled) + .bind(&payload.base_url) + .bind(&payload.api_key) + .execute(pool) + .await; + + match result { + Ok(_) => { + // Re-initialize provider in manager + if let Err(e) = state.app_state.provider_manager.initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool).await { + warn!("Failed to re-initialize provider {}: {}", name, e); + return Json(ApiResponse::error(format!("Provider settings saved but initialization failed: {}", e))); + } + + Json(ApiResponse::success(serde_json::json!({ "message": "Provider updated and re-initialized" }))) + } + Err(e) => { + warn!("Failed to update provider config: {}", e); + Json(ApiResponse::error(format!("Failed to update provider: {}", e))) + } + } } async fn handle_test_provider( @@ -610,6 +682,104 @@ async fn handle_test_provider( }))) } +// Model handlers +async fn handle_get_models(State(state): State) -> Json> { + let registry = &state.app_state.model_registry; + let pool = &state.app_state.db_pool; + + // Load overrides from database + let db_models_result = sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs") + .fetch_all(pool) + .await; + + let mut db_models = HashMap::new(); + if let Ok(rows) = db_models_result { + for row in rows { + let id: String = row.get("id"); + db_models.insert(id, row); + } + } + + let mut models_json = Vec::new(); + + for (p_id, p_info) in ®istry.providers { + for (m_id, m_meta) in &p_info.models { + let mut enabled = true; + let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0); + let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0); + let mut mapping = None::; + + if let Some(row) = db_models.get(m_id) { + enabled = row.get("enabled"); + if let Some(p) = row.get::, _>("prompt_cost_per_m") { prompt_cost = p; } + if let Some(c) = row.get::, _>("completion_cost_per_m") { completion_cost = c; } + mapping = row.get("mapping"); + } + + models_json.push(serde_json::json!({ + "id": m_id, + "provider": p_id, + "name": m_meta.name, + "enabled": enabled, + "prompt_cost": prompt_cost, + "completion_cost": completion_cost, + "mapping": mapping, + "context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0), + })); + } + } + + Json(ApiResponse::success(serde_json::json!(models_json))) +} + +#[derive(Deserialize)] +struct UpdateModelRequest { + enabled: bool, + prompt_cost: Option, + completion_cost: Option, + mapping: Option, +} + +async fn handle_update_model( + State(state): State, + axum::extract::Path(id): axum::extract::Path, + Json(payload): Json, +) -> Json> { + let pool = &state.app_state.db_pool; + + // Find provider_id for this model in registry + let provider_id = state.app_state.model_registry.providers.iter() + .find(|(_, p)| p.models.contains_key(&id)) + .map(|(id, _)| id.clone()) + .unwrap_or_else(|| "unknown".to_string()); + + let result = sqlx::query( + r#" + INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + enabled = excluded.enabled, + prompt_cost_per_m = excluded.prompt_cost_per_m, + completion_cost_per_m = excluded.completion_cost_per_m, + mapping = excluded.mapping, + updated_at = CURRENT_TIMESTAMP + "# + ) + .bind(&id) + .bind(provider_id) + .bind(payload.enabled) + .bind(payload.prompt_cost) + .bind(payload.completion_cost) + .bind(payload.mapping) + .execute(pool) + .await; + + match result { + Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" }))), + Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))), + } +} + // System handlers async fn handle_system_health(State(state): State) -> Json> { let mut components = HashMap::new(); @@ -617,7 +787,7 @@ async fn handle_system_health(State(state): State) -> Json = state.app_state.provider_manager.get_all_providers() + let provider_ids: Vec = state.app_state.provider_manager.get_all_providers().await .iter() .map(|p| p.name().to_string()) .collect(); diff --git a/src/database/mod.rs b/src/database/mod.rs index 07954a50..33940244 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -78,6 +78,41 @@ async fn run_migrations(pool: &DbPool) -> Result<()> { .execute(pool) .await?; + // Create provider_configs table + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS provider_configs ( + id TEXT PRIMARY KEY, + display_name TEXT NOT NULL, + enabled BOOLEAN DEFAULT TRUE, + base_url TEXT, + api_key TEXT, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + "# + ) + .execute(pool) + .await?; + + // Create model_configs table + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS model_configs ( + id TEXT PRIMARY KEY, + provider_id TEXT NOT NULL, + display_name TEXT, + enabled BOOLEAN DEFAULT TRUE, + prompt_cost_per_m REAL, + completion_cost_per_m REAL, + mapping TEXT, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE + ) + "# + ) + .execute(pool) + .await?; + // Create indices sqlx::query( "CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)" diff --git a/src/main.rs b/src/main.rs index 1ed38d7a..68d1d388 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,20 +1,12 @@ use anyhow::Result; use axum::{Router, routing::get}; use std::net::SocketAddr; -use std::sync::Arc; use tracing::{info, error}; use llm_proxy::{ config::AppConfig, state::AppState, - providers::{ - ProviderManager, - openai::OpenAIProvider, - gemini::GeminiProvider, - deepseek::DeepSeekProvider, - grok::GrokProvider, - ollama::OllamaProvider, - }, + providers::ProviderManager, database, server, dashboard, @@ -40,60 +32,13 @@ async fn main() -> Result<()> { info!("Database initialized at {:?}", config.database.path); // Initialize provider manager with configured providers - let mut provider_manager = ProviderManager::new(); + let provider_manager = ProviderManager::new(); - // Initialize OpenAI - if config.providers.openai.enabled { - match OpenAIProvider::new(&config.providers.openai, &config) { - Ok(p) => { - provider_manager.add_provider(Arc::new(p)); - info!("OpenAI provider initialized"); - } - Err(e) => error!("Failed to initialize OpenAI provider: {}", e), - } - } - - // Initialize Gemini - if config.providers.gemini.enabled { - match GeminiProvider::new(&config.providers.gemini, &config) { - Ok(p) => { - provider_manager.add_provider(Arc::new(p)); - info!("Gemini provider initialized"); - } - Err(e) => error!("Failed to initialize Gemini provider: {}", e), - } - } - - // Initialize DeepSeek - if config.providers.deepseek.enabled { - match DeepSeekProvider::new(&config.providers.deepseek, &config) { - Ok(p) => { - provider_manager.add_provider(Arc::new(p)); - info!("DeepSeek provider initialized"); - } - Err(e) => error!("Failed to initialize DeepSeek provider: {}", e), - } - } - - // Initialize Grok - if config.providers.grok.enabled { - match GrokProvider::new(&config.providers.grok, &config) { - Ok(p) => { - provider_manager.add_provider(Arc::new(p)); - info!("Grok provider initialized"); - } - Err(e) => error!("Failed to initialize Grok provider: {}", e), - } - } - - // Initialize Ollama - if config.providers.ollama.enabled { - match OllamaProvider::new(&config.providers.ollama, &config) { - Ok(p) => { - provider_manager.add_provider(Arc::new(p)); - info!("Ollama provider initialized at {}", config.providers.ollama.base_url); - } - Err(e) => error!("Failed to initialize Ollama provider: {}", e), + // Initialize all supported providers (they handle their own enabled check) + let supported_providers = vec!["openai", "gemini", "deepseek", "grok", "ollama"]; + for name in supported_providers { + if let Err(e) = provider_manager.initialize_provider(name, &config, &db_pool).await { + error!("Failed to initialize provider {}: {}", name, e); } } diff --git a/src/providers/deepseek.rs b/src/providers/deepseek.rs index 63d6b91f..cc1e69e3 100644 --- a/src/providers/deepseek.rs +++ b/src/providers/deepseek.rs @@ -20,7 +20,10 @@ pub struct DeepSeekProvider { impl DeepSeekProvider { pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result { let api_key = app_config.get_api_key("deepseek")?; - + Self::new_with_key(config, app_config, api_key) + } + + pub fn new_with_key(config: &crate::config::DeepSeekConfig, app_config: &AppConfig, api_key: String) -> Result { Ok(Self { client: reqwest::Client::new(), config: config.clone(), diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index ddf77fa9..3cd51529 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -73,7 +73,10 @@ pub struct GeminiProvider { impl GeminiProvider { pub fn new(config: &crate::config::GeminiConfig, app_config: &AppConfig) -> Result { let api_key = app_config.get_api_key("gemini")?; - + Self::new_with_key(config, app_config, api_key) + } + + pub fn new_with_key(config: &crate::config::GeminiConfig, app_config: &AppConfig, api_key: String) -> Result { let client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(30)) .build()?; diff --git a/src/providers/grok.rs b/src/providers/grok.rs index ab827024..023b7872 100644 --- a/src/providers/grok.rs +++ b/src/providers/grok.rs @@ -20,7 +20,10 @@ pub struct GrokProvider { impl GrokProvider { pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result { let api_key = app_config.get_api_key("grok")?; - + Self::new_with_key(config, app_config, api_key) + } + + pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result { Ok(Self { client: reqwest::Client::new(), _config: config.clone(), diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 6a7c3434..39a10bbc 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -2,6 +2,7 @@ use async_trait::async_trait; use anyhow::Result; use std::sync::Arc; use futures::stream::BoxStream; +use sqlx::Row; use crate::models::UnifiedRequest; use crate::errors::AppError; @@ -59,36 +60,149 @@ pub struct ProviderStreamChunk { pub model: String, } +use tokio::sync::RwLock; + +use crate::config::AppConfig; +use crate::providers::{ + openai::OpenAIProvider, + gemini::GeminiProvider, + deepseek::DeepSeekProvider, + grok::GrokProvider, + ollama::OllamaProvider, +}; + #[derive(Clone)] pub struct ProviderManager { - providers: Vec>, + providers: Arc>>>, } impl ProviderManager { pub fn new() -> Self { Self { - providers: Vec::new(), + providers: Arc::new(RwLock::new(Vec::new())), } } - pub fn add_provider(&mut self, provider: Arc) { - self.providers.push(provider); + /// Initialize a provider by name using config and database overrides + pub async fn initialize_provider(&self, name: &str, app_config: &AppConfig, db_pool: &crate::database::DbPool) -> Result<()> { + // Load override from database + let db_config = sqlx::query("SELECT enabled, base_url, api_key FROM provider_configs WHERE id = ?") + .bind(name) + .fetch_optional(db_pool) + .await?; + + let (enabled, base_url, api_key) = if let Some(row) = db_config { + ( + row.get::("enabled"), + row.get::, _>("base_url"), + row.get::, _>("api_key"), + ) + } else { + // No database override, use defaults from AppConfig + match name { + "openai" => (app_config.providers.openai.enabled, Some(app_config.providers.openai.base_url.clone()), None), + "gemini" => (app_config.providers.gemini.enabled, Some(app_config.providers.gemini.base_url.clone()), None), + "deepseek" => (app_config.providers.deepseek.enabled, Some(app_config.providers.deepseek.base_url.clone()), None), + "grok" => (app_config.providers.grok.enabled, Some(app_config.providers.grok.base_url.clone()), None), + "ollama" => (app_config.providers.ollama.enabled, Some(app_config.providers.ollama.base_url.clone()), None), + _ => (false, None, None), + } + }; + + if !enabled { + self.remove_provider(name).await; + return Ok(()); + } + + // Create provider instance with merged config + let provider: Arc = match name { + "openai" => { + let mut cfg = app_config.providers.openai.clone(); + if let Some(url) = base_url { cfg.base_url = url; } + // Handle API key override if present + let p = if let Some(key) = api_key { + // We need a way to create a provider with an explicit key + // Let's modify the providers to allow this + OpenAIProvider::new_with_key(&cfg, app_config, key)? + } else { + OpenAIProvider::new(&cfg, app_config)? + }; + Arc::new(p) + }, + "ollama" => { + let mut cfg = app_config.providers.ollama.clone(); + if let Some(url) = base_url { cfg.base_url = url; } + Arc::new(OllamaProvider::new(&cfg, app_config)?) + }, + "gemini" => { + let mut cfg = app_config.providers.gemini.clone(); + if let Some(url) = base_url { cfg.base_url = url; } + let p = if let Some(key) = api_key { + GeminiProvider::new_with_key(&cfg, app_config, key)? + } else { + GeminiProvider::new(&cfg, app_config)? + }; + Arc::new(p) + }, + "deepseek" => { + let mut cfg = app_config.providers.deepseek.clone(); + if let Some(url) = base_url { cfg.base_url = url; } + let p = if let Some(key) = api_key { + DeepSeekProvider::new_with_key(&cfg, app_config, key)? + } else { + DeepSeekProvider::new(&cfg, app_config)? + }; + Arc::new(p) + }, + "grok" => { + let mut cfg = app_config.providers.grok.clone(); + if let Some(url) = base_url { cfg.base_url = url; } + let p = if let Some(key) = api_key { + GrokProvider::new_with_key(&cfg, app_config, key)? + } else { + GrokProvider::new(&cfg, app_config)? + }; + Arc::new(p) + }, + _ => return Err(anyhow::anyhow!("Unknown provider: {}", name)), + }; + + self.add_provider(provider).await; + Ok(()) } - pub fn get_provider_for_model(&self, model: &str) -> Option> { - self.providers.iter() + pub async fn add_provider(&self, provider: Arc) { + let mut providers = self.providers.write().await; + // If provider with same name exists, replace it + if let Some(index) = providers.iter().position(|p| p.name() == provider.name()) { + providers[index] = provider; + } else { + providers.push(provider); + } + } + + pub async fn remove_provider(&self, name: &str) { + let mut providers = self.providers.write().await; + providers.retain(|p| p.name() != name); + } + + pub async fn get_provider_for_model(&self, model: &str) -> Option> { + let providers = self.providers.read().await; + providers.iter() .find(|p| p.supports_model(model)) .map(|p| Arc::clone(p)) } - pub fn get_provider(&self, name: &str) -> Option> { - self.providers.iter() + pub async fn get_provider(&self, name: &str) -> Option> { + let providers = self.providers.read().await; + providers.iter() .find(|p| p.name() == name) .map(|p| Arc::clone(p)) } - pub fn get_all_providers(&self) -> Vec> { - self.providers.clone() + pub async fn get_all_providers(&self) -> Vec> { + let providers = self.providers.read().await; + providers.clone() } } diff --git a/src/providers/openai.rs b/src/providers/openai.rs index e1f53a59..8da5c022 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -20,7 +20,10 @@ pub struct OpenAIProvider { impl OpenAIProvider { pub fn new(config: &crate::config::OpenAIConfig, app_config: &AppConfig) -> Result { let api_key = app_config.get_api_key("openai")?; - + Self::new_with_key(config, app_config, api_key) + } + + pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result { Ok(Self { client: reqwest::Client::new(), _config: config.clone(), diff --git a/src/server/mod.rs b/src/server/mod.rs index 354e4336..94a0fa33 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; +use sqlx::Row; use uuid::Uuid; use axum::{ extract::State, @@ -27,10 +29,37 @@ pub fn router(state: AppState) -> Router { .with_state(state) } +async fn get_model_cost( + model: &str, + prompt_tokens: u32, + completion_tokens: u32, + provider: &Arc, + state: &AppState, +) -> f64 { + // Check database for cost overrides + let db_cost = sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?") + .bind(model) + .fetch_optional(&state.db_pool) + .await + .unwrap_or(None); + + if let Some(row) = db_cost { + let prompt_rate = row.get::, _>("prompt_cost_per_m"); + let completion_rate = row.get::, _>("completion_cost_per_m"); + + if let (Some(p), Some(c)) = (prompt_rate, completion_rate) { + return (prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0); + } + } + + // Fallback to provider's registry-based calculation + provider.calculate_cost(model, prompt_tokens, completion_tokens, &state.model_registry) +} + async fn chat_completions( State(state): State, auth: AuthenticatedClient, - Json(request): Json, + Json(mut request): Json, ) -> Result { // Validate token against configured auth tokens if !state.auth_tokens.is_empty() && !state.auth_tokens.contains(&auth.token) { @@ -43,8 +72,30 @@ async fn chat_completions( info!("Chat completion request from client {} for model {}", client_id, model); + // Check if model is enabled in database and get potential mapping + let model_config = sqlx::query("SELECT enabled, mapping FROM model_configs WHERE id = ?") + .bind(&model) + .fetch_optional(&state.db_pool) + .await + .unwrap_or(None); + + let (model_enabled, model_mapping) = match model_config { + Some(row) => (row.get::("enabled"), row.get::, _>("mapping")), + None => (true, None), + }; + + if !model_enabled { + return Err(AppError::ValidationError(format!("Model {} is currently disabled", model))); + } + + // Apply mapping if present + if let Some(target_model) = model_mapping { + info!("Mapping model {} to {}", model, target_model); + request.model = target_model; + } + // Find appropriate provider for the model - let provider = state.provider_manager.get_provider_for_model(&request.model) + let provider = state.provider_manager.get_provider_for_model(&request.model).await .ok_or_else(|| AppError::ProviderError(format!("No provider found for model: {}", request.model)))?; let provider_name = provider.name().to_string(); @@ -90,6 +141,7 @@ async fn chat_completions( state.request_logger.clone(), state.client_manager.clone(), state.model_registry.clone(), + state.db_pool.clone(), ); // Create SSE stream from aggregating stream @@ -141,13 +193,12 @@ async fn chat_completions( match result { Ok(response) => { - // Record provider success - state.rate_limit_manager.record_provider_success(&provider_name).await; + // Record provider success + state.rate_limit_manager.record_provider_success(&provider_name).await; - let duration = start_time.elapsed(); - let cost = provider.calculate_cost(&response.model, response.prompt_tokens, response.completion_tokens, &state.model_registry); - - // Log request to database + let duration = start_time.elapsed(); + let cost = get_model_cost(&response.model, response.prompt_tokens, response.completion_tokens, &provider, &state).await; + // Log request to database state.request_logger.log_request(crate::logging::RequestLog { timestamp: chrono::Utc::now(), client_id: client_id.clone(), diff --git a/src/utils/streaming.rs b/src/utils/streaming.rs index ad5d4bcb..b97e58ee 100644 --- a/src/utils/streaming.rs +++ b/src/utils/streaming.rs @@ -2,6 +2,7 @@ use futures::stream::Stream; use std::pin::Pin; use std::task::{Context, Poll}; use std::sync::Arc; +use sqlx::Row; use crate::logging::{RequestLogger, RequestLog}; use crate::client::ClientManager; use crate::providers::{Provider, ProviderStreamChunk}; @@ -20,6 +21,7 @@ pub struct AggregatingStream { logger: Arc, client_manager: Arc, model_registry: Arc, + db_pool: crate::database::DbPool, start_time: std::time::Instant, has_logged: bool, } @@ -38,6 +40,7 @@ where logger: Arc, client_manager: Arc, model_registry: Arc, + db_pool: crate::database::DbPool, ) -> Self { Self { inner, @@ -51,6 +54,7 @@ where logger, client_manager, model_registry, + db_pool, start_time: std::time::Instant::now(), has_logged: false, } @@ -72,6 +76,7 @@ where let prompt_tokens = self.prompt_tokens; let has_images = self.has_images; let registry = self.model_registry.clone(); + let pool = self.db_pool.clone(); // Estimate completion tokens (including reasoning if present) let content_tokens = estimate_completion_tokens(&self.accumulated_content, &model); @@ -83,10 +88,29 @@ where let completion_tokens = content_tokens + reasoning_tokens; let total_tokens = prompt_tokens + completion_tokens; - let cost = provider.calculate_cost(&model, prompt_tokens, completion_tokens, ®istry); // Spawn a background task to log the completion tokio::spawn(async move { + // Check database for cost overrides + let db_cost = sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?") + .bind(&model) + .fetch_optional(&pool) + .await + .unwrap_or(None); + + let cost = if let Some(row) = db_cost { + let prompt_rate = row.get::, _>("prompt_cost_per_m"); + let completion_rate = row.get::, _>("completion_cost_per_m"); + + if let (Some(p), Some(c)) = (prompt_rate, completion_rate) { + (prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0) + } else { + provider.calculate_cost(&model, prompt_tokens, completion_tokens, ®istry) + } + } else { + provider.calculate_cost(&model, prompt_tokens, completion_tokens, ®istry) + }; + // Log to database logger.log_request(RequestLog { timestamp: chrono::Utc::now(), @@ -188,6 +212,7 @@ mod tests { logger, client_manager, registry, + pool.clone(), ); while let Some(item) = agg_stream.next().await { diff --git a/static/index.html b/static/index.html index 25232407..82c7ead9 100644 --- a/static/index.html +++ b/static/index.html @@ -89,6 +89,10 @@ Providers + + + Models + Real-time Monitoring @@ -168,6 +172,7 @@ + diff --git a/static/js/dashboard.js b/static/js/dashboard.js index 3797c367..15cb01e9 100644 --- a/static/js/dashboard.js +++ b/static/js/dashboard.js @@ -137,6 +137,7 @@ class Dashboard { case 'overview': return this.getOverviewTemplate(); case 'clients': return this.getClientsTemplate(); case 'providers': return this.getProvidersTemplate(); + case 'models': return this.getModelsTemplate(); case 'logs': return this.getLogsTemplate(); case 'monitoring': return this.getMonitoringTemplate(); case 'settings': return '
Loading settings...
'; @@ -253,6 +254,30 @@ class Dashboard { `; } + getModelsTemplate() { + return ` +
+
+
+

Model Registry

+

Manage model availability and custom pricing

+
+
+ +
+
+
+ + + + + +
IDDisplay NameProviderPricing (In/Out)ContextStatusActions
+
+
+ `; + } + getLogsTemplate() { return `
diff --git a/static/js/pages/models.js b/static/js/pages/models.js new file mode 100644 index 00000000..1ac466d4 --- /dev/null +++ b/static/js/pages/models.js @@ -0,0 +1,165 @@ +// Models Page Module + +class ModelsPage { + constructor() { + this.models = []; + this.init(); + } + + async init() { + await this.loadModels(); + this.setupEventListeners(); + } + + async loadModels() { + try { + const data = await window.api.get('/models'); + this.models = data; + this.renderModelsTable(); + } catch (error) { + console.error('Error loading models:', error); + window.authManager.showToast('Failed to load models', 'error'); + } + } + + renderModelsTable() { + const tableBody = document.querySelector('#models-table tbody'); + if (!tableBody) return; + + if (this.models.length === 0) { + tableBody.innerHTML = 'No models found in registry'; + return; + } + + // Sort by provider then name + this.models.sort((a, b) => { + if (a.provider !== b.provider) return a.provider.localeCompare(b.provider); + return a.name.localeCompare(b.name); + }); + + tableBody.innerHTML = this.models.map(model => { + const statusClass = model.enabled ? 'success' : 'secondary'; + const statusIcon = model.enabled ? 'check-circle' : 'ban'; + + return ` + + ${model.id} + ${model.name} + ${model.provider.toUpperCase()} + ${window.api.formatCurrency(model.prompt_cost)} / ${window.api.formatCurrency(model.completion_cost)} + ${model.context_limit ? (model.context_limit / 1000) + 'k' : 'Unknown'} + + + + ${model.enabled ? 'Active' : 'Disabled'} + + + +
+ +
+ + + `; + }).join(''); + } + + configureModel(id) { + const model = this.models.find(m => m.id === id); + if (!model) return; + + const modal = document.createElement('div'); + modal.className = 'modal active'; + modal.innerHTML = ` + + `; + + document.body.appendChild(modal); + + modal.querySelector('#save-model-config').onclick = async () => { + const enabled = modal.querySelector('#model-enabled').checked; + const promptCost = parseFloat(modal.querySelector('#model-prompt-cost').value); + const completionCost = parseFloat(modal.querySelector('#model-completion-cost').value); + const mapping = modal.querySelector('#model-mapping').value; + + try { + await window.api.put(`/models/${id}`, { + enabled, + prompt_cost: promptCost, + completion_cost: completionCost, + mapping: mapping || null + }); + + window.authManager.showToast(`Model ${model.id} updated`, 'success'); + modal.remove(); + this.loadModels(); + } catch (error) { + window.authManager.showToast(error.message, 'error'); + } + }; + } + + setupEventListeners() { + const searchInput = document.getElementById('model-search'); + if (searchInput) { + searchInput.oninput = (e) => this.filterModels(e.target.value); + } + } + + filterModels(query) { + if (!query) { + this.renderModelsTable(); + return; + } + + const q = query.toLowerCase(); + const originalModels = this.models; + this.models = this.models.filter(m => + m.id.toLowerCase().includes(q) || + m.name.toLowerCase().includes(q) || + m.provider.toLowerCase().includes(q) + ); + this.renderModelsTable(); + this.models = originalModels; + } +} + +window.initModels = async () => { + window.modelsPage = new ModelsPage(); +}; diff --git a/static/js/pages/providers.js b/static/js/pages/providers.js index 4d9e834f..b4d311fc 100644 --- a/static/js/pages/providers.js +++ b/static/js/pages/providers.js @@ -124,7 +124,64 @@ class ProvidersPage { } configureProvider(id) { - window.authManager.showToast('Provider configuration via UI not yet implemented', 'info'); + const provider = this.providers.find(p => p.id === id); + if (!provider) return; + + const modal = document.createElement('div'); + modal.className = 'modal active'; + modal.innerHTML = ` + + `; + + document.body.appendChild(modal); + + modal.querySelector('#save-provider-config').onclick = async () => { + const enabled = modal.querySelector('#provider-enabled').checked; + const baseUrl = modal.querySelector('#provider-base-url').value; + const apiKey = modal.querySelector('#provider-api-key').value; + + try { + await window.api.put(`/providers/${id}`, { + enabled, + base_url: baseUrl || null, + api_key: apiKey || null + }); + + window.authManager.showToast(`${provider.name} configuration saved`, 'success'); + modal.remove(); + this.loadProviders(); + } catch (error) { + window.authManager.showToast(error.message, 'error'); + } + }; } setupEventListeners() {