feat: implement web UI for provider and model configuration

- Added 'provider_configs' and 'model_configs' tables to database.
- Refactored ProviderManager to support thread-safe dynamic updates and database overrides.
- Implemented 'Models' tab in dashboard to manage model visibility, mapping, and pricing.
- Added provider configuration modal to 'Providers' tab.
- Integrated database overrides into chat completion logic (enabled state, mapping, and cost).
This commit is contained in:
2026-02-26 18:13:04 -05:00
parent c5fb2357ff
commit 3165aa1859
14 changed files with 707 additions and 103 deletions

View File

@@ -3,7 +3,7 @@
use axum::{
extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State},
response::{IntoResponse, Json},
routing::{get, post},
routing::{get, post, put},
Router,
};
use serde::{Deserialize, Serialize};
@@ -67,6 +67,8 @@ pub fn router(state: AppState) -> Router {
.route("/api/usage/time-series", get(handle_time_series))
.route("/api/usage/clients", get(handle_clients_usage))
.route("/api/usage/providers", get(handle_providers_usage))
.route("/api/models", get(handle_get_models))
.route("/api/models/{id}", put(handle_update_model))
.route("/api/clients", get(handle_get_clients).post(handle_create_client))
.route("/api/clients/{id}", get(handle_get_client).delete(handle_delete_client))
.route("/api/clients/{id}/usage", get(handle_client_usage))
@@ -531,19 +533,47 @@ async fn handle_client_usage(
async fn handle_get_providers(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let registry = &state.app_state.model_registry;
let config = &state.app_state.config;
let pool = &state.app_state.db_pool;
// Load all overrides from database
let db_configs_result = sqlx::query("SELECT id, enabled, base_url FROM provider_configs")
.fetch_all(pool)
.await;
let mut db_configs = HashMap::new();
if let Ok(rows) = db_configs_result {
for row in rows {
let id: String = row.get("id");
let enabled: bool = row.get("enabled");
let base_url: Option<String> = row.get("base_url");
db_configs.insert(id, (enabled, base_url));
}
}
let mut providers_json = Vec::new();
// Define the list of providers we support
let provider_configs = vec![
("openai", "OpenAI", config.providers.openai.enabled),
("gemini", "Google Gemini", config.providers.gemini.enabled),
("deepseek", "DeepSeek", config.providers.deepseek.enabled),
("grok", "xAI Grok", config.providers.grok.enabled),
("ollama", "Ollama", config.providers.ollama.enabled),
];
let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
for id in provider_ids {
// Get base config
let (mut enabled, mut base_url, display_name) = match id {
"openai" => (config.providers.openai.enabled, config.providers.openai.base_url.clone(), "OpenAI"),
"gemini" => (config.providers.gemini.enabled, config.providers.gemini.base_url.clone(), "Google Gemini"),
"deepseek" => (config.providers.deepseek.enabled, config.providers.deepseek.base_url.clone(), "DeepSeek"),
"grok" => (config.providers.grok.enabled, config.providers.grok.base_url.clone(), "xAI Grok"),
"ollama" => (config.providers.ollama.enabled, config.providers.ollama.base_url.clone(), "Ollama"),
_ => (false, "".to_string(), "Unknown"),
};
// Apply database overrides
if let Some((db_enabled, db_url)) = db_configs.get(id) {
enabled = *db_enabled;
if let Some(url) = db_url {
base_url = url.clone();
}
}
for (id, display_name, enabled) in provider_configs {
// Find models for this provider in registry
let mut models = Vec::new();
if let Some(p_info) = registry.providers.get(id) {
@@ -557,7 +587,7 @@ async fn handle_get_providers(State(state): State<DashboardState>) -> Json<ApiRe
"disabled"
} else {
// Check if it's actually initialized in the provider manager
if state.app_state.provider_manager.get_provider(id).is_some() {
if state.app_state.provider_manager.get_provider(id).await.is_some() {
// Check circuit breaker
if state.app_state.rate_limit_manager.check_provider_request(id).await.unwrap_or(true) {
"online"
@@ -575,6 +605,7 @@ async fn handle_get_providers(State(state): State<DashboardState>) -> Json<ApiRe
"enabled": enabled,
"status": status,
"models": models,
"base_url": base_url,
"last_used": None::<String>,
}));
}
@@ -589,14 +620,55 @@ async fn handle_get_provider(
Json(ApiResponse::error("Not implemented".to_string()))
}
#[derive(Deserialize)]
struct UpdateProviderRequest {
enabled: bool,
base_url: Option<String>,
api_key: Option<String>,
}
async fn handle_update_provider(
State(_state): State<DashboardState>,
axum::extract::Path(_name): axum::extract::Path<String>,
State(state): State<DashboardState>,
axum::extract::Path(name): axum::extract::Path<String>,
Json(payload): Json<UpdateProviderRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
Json(ApiResponse::success(serde_json::json!({
"success": true,
"message": "Provider updated"
})))
let pool = &state.app_state.db_pool;
// Update or insert into database
let result = sqlx::query(
r#"
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
enabled = excluded.enabled,
base_url = excluded.base_url,
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
updated_at = CURRENT_TIMESTAMP
"#
)
.bind(&name)
.bind(name.to_uppercase())
.bind(payload.enabled)
.bind(&payload.base_url)
.bind(&payload.api_key)
.execute(pool)
.await;
match result {
Ok(_) => {
// Re-initialize provider in manager
if let Err(e) = state.app_state.provider_manager.initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool).await {
warn!("Failed to re-initialize provider {}: {}", name, e);
return Json(ApiResponse::error(format!("Provider settings saved but initialization failed: {}", e)));
}
Json(ApiResponse::success(serde_json::json!({ "message": "Provider updated and re-initialized" })))
}
Err(e) => {
warn!("Failed to update provider config: {}", e);
Json(ApiResponse::error(format!("Failed to update provider: {}", e)))
}
}
}
async fn handle_test_provider(
@@ -610,6 +682,104 @@ async fn handle_test_provider(
})))
}
// Model handlers
async fn handle_get_models(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let registry = &state.app_state.model_registry;
let pool = &state.app_state.db_pool;
// Load overrides from database
let db_models_result = sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
.fetch_all(pool)
.await;
let mut db_models = HashMap::new();
if let Ok(rows) = db_models_result {
for row in rows {
let id: String = row.get("id");
db_models.insert(id, row);
}
}
let mut models_json = Vec::new();
for (p_id, p_info) in &registry.providers {
for (m_id, m_meta) in &p_info.models {
let mut enabled = true;
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
let mut mapping = None::<String>;
if let Some(row) = db_models.get(m_id) {
enabled = row.get("enabled");
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") { prompt_cost = p; }
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") { completion_cost = c; }
mapping = row.get("mapping");
}
models_json.push(serde_json::json!({
"id": m_id,
"provider": p_id,
"name": m_meta.name,
"enabled": enabled,
"prompt_cost": prompt_cost,
"completion_cost": completion_cost,
"mapping": mapping,
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
}));
}
}
Json(ApiResponse::success(serde_json::json!(models_json)))
}
#[derive(Deserialize)]
struct UpdateModelRequest {
enabled: bool,
prompt_cost: Option<f64>,
completion_cost: Option<f64>,
mapping: Option<String>,
}
async fn handle_update_model(
State(state): State<DashboardState>,
axum::extract::Path(id): axum::extract::Path<String>,
Json(payload): Json<UpdateModelRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
// Find provider_id for this model in registry
let provider_id = state.app_state.model_registry.providers.iter()
.find(|(_, p)| p.models.contains_key(&id))
.map(|(id, _)| id.clone())
.unwrap_or_else(|| "unknown".to_string());
let result = sqlx::query(
r#"
INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
enabled = excluded.enabled,
prompt_cost_per_m = excluded.prompt_cost_per_m,
completion_cost_per_m = excluded.completion_cost_per_m,
mapping = excluded.mapping,
updated_at = CURRENT_TIMESTAMP
"#
)
.bind(&id)
.bind(provider_id)
.bind(payload.enabled)
.bind(payload.prompt_cost)
.bind(payload.completion_cost)
.bind(payload.mapping)
.execute(pool)
.await;
match result {
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" }))),
Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))),
}
}
// System handlers
async fn handle_system_health(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let mut components = HashMap::new();
@@ -617,7 +787,7 @@ async fn handle_system_health(State(state): State<DashboardState>) -> Json<ApiRe
components.insert("database".to_string(), "online".to_string());
// Check provider health via circuit breakers
let provider_ids: Vec<String> = state.app_state.provider_manager.get_all_providers()
let provider_ids: Vec<String> = state.app_state.provider_manager.get_all_providers().await
.iter()
.map(|p| p.name().to_string())
.collect();

View File

@@ -78,6 +78,41 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
.execute(pool)
.await?;
// Create provider_configs table
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS provider_configs (
id TEXT PRIMARY KEY,
display_name TEXT NOT NULL,
enabled BOOLEAN DEFAULT TRUE,
base_url TEXT,
api_key TEXT,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
"#
)
.execute(pool)
.await?;
// Create model_configs table
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS model_configs (
id TEXT PRIMARY KEY,
provider_id TEXT NOT NULL,
display_name TEXT,
enabled BOOLEAN DEFAULT TRUE,
prompt_cost_per_m REAL,
completion_cost_per_m REAL,
mapping TEXT,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE
)
"#
)
.execute(pool)
.await?;
// Create indices
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)"

View File

@@ -1,20 +1,12 @@
use anyhow::Result;
use axum::{Router, routing::get};
use std::net::SocketAddr;
use std::sync::Arc;
use tracing::{info, error};
use llm_proxy::{
config::AppConfig,
state::AppState,
providers::{
ProviderManager,
openai::OpenAIProvider,
gemini::GeminiProvider,
deepseek::DeepSeekProvider,
grok::GrokProvider,
ollama::OllamaProvider,
},
providers::ProviderManager,
database,
server,
dashboard,
@@ -40,60 +32,13 @@ async fn main() -> Result<()> {
info!("Database initialized at {:?}", config.database.path);
// Initialize provider manager with configured providers
let mut provider_manager = ProviderManager::new();
let provider_manager = ProviderManager::new();
// Initialize OpenAI
if config.providers.openai.enabled {
match OpenAIProvider::new(&config.providers.openai, &config) {
Ok(p) => {
provider_manager.add_provider(Arc::new(p));
info!("OpenAI provider initialized");
}
Err(e) => error!("Failed to initialize OpenAI provider: {}", e),
}
}
// Initialize Gemini
if config.providers.gemini.enabled {
match GeminiProvider::new(&config.providers.gemini, &config) {
Ok(p) => {
provider_manager.add_provider(Arc::new(p));
info!("Gemini provider initialized");
}
Err(e) => error!("Failed to initialize Gemini provider: {}", e),
}
}
// Initialize DeepSeek
if config.providers.deepseek.enabled {
match DeepSeekProvider::new(&config.providers.deepseek, &config) {
Ok(p) => {
provider_manager.add_provider(Arc::new(p));
info!("DeepSeek provider initialized");
}
Err(e) => error!("Failed to initialize DeepSeek provider: {}", e),
}
}
// Initialize Grok
if config.providers.grok.enabled {
match GrokProvider::new(&config.providers.grok, &config) {
Ok(p) => {
provider_manager.add_provider(Arc::new(p));
info!("Grok provider initialized");
}
Err(e) => error!("Failed to initialize Grok provider: {}", e),
}
}
// Initialize Ollama
if config.providers.ollama.enabled {
match OllamaProvider::new(&config.providers.ollama, &config) {
Ok(p) => {
provider_manager.add_provider(Arc::new(p));
info!("Ollama provider initialized at {}", config.providers.ollama.base_url);
}
Err(e) => error!("Failed to initialize Ollama provider: {}", e),
// Initialize all supported providers (they handle their own enabled check)
let supported_providers = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
for name in supported_providers {
if let Err(e) = provider_manager.initialize_provider(name, &config, &db_pool).await {
error!("Failed to initialize provider {}: {}", name, e);
}
}

View File

@@ -20,7 +20,10 @@ pub struct DeepSeekProvider {
impl DeepSeekProvider {
pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("deepseek")?;
Self::new_with_key(config, app_config, api_key)
}
pub fn new_with_key(config: &crate::config::DeepSeekConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
Ok(Self {
client: reqwest::Client::new(),
config: config.clone(),

View File

@@ -73,7 +73,10 @@ pub struct GeminiProvider {
impl GeminiProvider {
pub fn new(config: &crate::config::GeminiConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("gemini")?;
Self::new_with_key(config, app_config, api_key)
}
pub fn new_with_key(config: &crate::config::GeminiConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()?;

View File

@@ -20,7 +20,10 @@ pub struct GrokProvider {
impl GrokProvider {
pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("grok")?;
Self::new_with_key(config, app_config, api_key)
}
pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
Ok(Self {
client: reqwest::Client::new(),
_config: config.clone(),

View File

@@ -2,6 +2,7 @@ use async_trait::async_trait;
use anyhow::Result;
use std::sync::Arc;
use futures::stream::BoxStream;
use sqlx::Row;
use crate::models::UnifiedRequest;
use crate::errors::AppError;
@@ -59,36 +60,149 @@ pub struct ProviderStreamChunk {
pub model: String,
}
use tokio::sync::RwLock;
use crate::config::AppConfig;
use crate::providers::{
openai::OpenAIProvider,
gemini::GeminiProvider,
deepseek::DeepSeekProvider,
grok::GrokProvider,
ollama::OllamaProvider,
};
#[derive(Clone)]
pub struct ProviderManager {
providers: Vec<Arc<dyn Provider>>,
providers: Arc<RwLock<Vec<Arc<dyn Provider>>>>,
}
impl ProviderManager {
pub fn new() -> Self {
Self {
providers: Vec::new(),
providers: Arc::new(RwLock::new(Vec::new())),
}
}
pub fn add_provider(&mut self, provider: Arc<dyn Provider>) {
self.providers.push(provider);
/// Initialize a provider by name using config and database overrides
pub async fn initialize_provider(&self, name: &str, app_config: &AppConfig, db_pool: &crate::database::DbPool) -> Result<()> {
// Load override from database
let db_config = sqlx::query("SELECT enabled, base_url, api_key FROM provider_configs WHERE id = ?")
.bind(name)
.fetch_optional(db_pool)
.await?;
let (enabled, base_url, api_key) = if let Some(row) = db_config {
(
row.get::<bool, _>("enabled"),
row.get::<Option<String>, _>("base_url"),
row.get::<Option<String>, _>("api_key"),
)
} else {
// No database override, use defaults from AppConfig
match name {
"openai" => (app_config.providers.openai.enabled, Some(app_config.providers.openai.base_url.clone()), None),
"gemini" => (app_config.providers.gemini.enabled, Some(app_config.providers.gemini.base_url.clone()), None),
"deepseek" => (app_config.providers.deepseek.enabled, Some(app_config.providers.deepseek.base_url.clone()), None),
"grok" => (app_config.providers.grok.enabled, Some(app_config.providers.grok.base_url.clone()), None),
"ollama" => (app_config.providers.ollama.enabled, Some(app_config.providers.ollama.base_url.clone()), None),
_ => (false, None, None),
}
};
if !enabled {
self.remove_provider(name).await;
return Ok(());
}
pub fn get_provider_for_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
self.providers.iter()
// Create provider instance with merged config
let provider: Arc<dyn Provider> = match name {
"openai" => {
let mut cfg = app_config.providers.openai.clone();
if let Some(url) = base_url { cfg.base_url = url; }
// Handle API key override if present
let p = if let Some(key) = api_key {
// We need a way to create a provider with an explicit key
// Let's modify the providers to allow this
OpenAIProvider::new_with_key(&cfg, app_config, key)?
} else {
OpenAIProvider::new(&cfg, app_config)?
};
Arc::new(p)
},
"ollama" => {
let mut cfg = app_config.providers.ollama.clone();
if let Some(url) = base_url { cfg.base_url = url; }
Arc::new(OllamaProvider::new(&cfg, app_config)?)
},
"gemini" => {
let mut cfg = app_config.providers.gemini.clone();
if let Some(url) = base_url { cfg.base_url = url; }
let p = if let Some(key) = api_key {
GeminiProvider::new_with_key(&cfg, app_config, key)?
} else {
GeminiProvider::new(&cfg, app_config)?
};
Arc::new(p)
},
"deepseek" => {
let mut cfg = app_config.providers.deepseek.clone();
if let Some(url) = base_url { cfg.base_url = url; }
let p = if let Some(key) = api_key {
DeepSeekProvider::new_with_key(&cfg, app_config, key)?
} else {
DeepSeekProvider::new(&cfg, app_config)?
};
Arc::new(p)
},
"grok" => {
let mut cfg = app_config.providers.grok.clone();
if let Some(url) = base_url { cfg.base_url = url; }
let p = if let Some(key) = api_key {
GrokProvider::new_with_key(&cfg, app_config, key)?
} else {
GrokProvider::new(&cfg, app_config)?
};
Arc::new(p)
},
_ => return Err(anyhow::anyhow!("Unknown provider: {}", name)),
};
self.add_provider(provider).await;
Ok(())
}
pub async fn add_provider(&self, provider: Arc<dyn Provider>) {
let mut providers = self.providers.write().await;
// If provider with same name exists, replace it
if let Some(index) = providers.iter().position(|p| p.name() == provider.name()) {
providers[index] = provider;
} else {
providers.push(provider);
}
}
pub async fn remove_provider(&self, name: &str) {
let mut providers = self.providers.write().await;
providers.retain(|p| p.name() != name);
}
pub async fn get_provider_for_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
let providers = self.providers.read().await;
providers.iter()
.find(|p| p.supports_model(model))
.map(|p| Arc::clone(p))
}
pub fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
self.providers.iter()
pub async fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
let providers = self.providers.read().await;
providers.iter()
.find(|p| p.name() == name)
.map(|p| Arc::clone(p))
}
pub fn get_all_providers(&self) -> Vec<Arc<dyn Provider>> {
self.providers.clone()
pub async fn get_all_providers(&self) -> Vec<Arc<dyn Provider>> {
let providers = self.providers.read().await;
providers.clone()
}
}

View File

@@ -20,7 +20,10 @@ pub struct OpenAIProvider {
impl OpenAIProvider {
pub fn new(config: &crate::config::OpenAIConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("openai")?;
Self::new_with_key(config, app_config, api_key)
}
pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
Ok(Self {
client: reqwest::Client::new(),
_config: config.clone(),

View File

@@ -1,3 +1,5 @@
use std::sync::Arc;
use sqlx::Row;
use uuid::Uuid;
use axum::{
extract::State,
@@ -27,10 +29,37 @@ pub fn router(state: AppState) -> Router {
.with_state(state)
}
async fn get_model_cost(
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
provider: &Arc<dyn crate::providers::Provider>,
state: &AppState,
) -> f64 {
// Check database for cost overrides
let db_cost = sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?")
.bind(model)
.fetch_optional(&state.db_pool)
.await
.unwrap_or(None);
if let Some(row) = db_cost {
let prompt_rate = row.get::<Option<f64>, _>("prompt_cost_per_m");
let completion_rate = row.get::<Option<f64>, _>("completion_cost_per_m");
if let (Some(p), Some(c)) = (prompt_rate, completion_rate) {
return (prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0);
}
}
// Fallback to provider's registry-based calculation
provider.calculate_cost(model, prompt_tokens, completion_tokens, &state.model_registry)
}
async fn chat_completions(
State(state): State<AppState>,
auth: AuthenticatedClient,
Json(request): Json<ChatCompletionRequest>,
Json(mut request): Json<ChatCompletionRequest>,
) -> Result<axum::response::Response, AppError> {
// Validate token against configured auth tokens
if !state.auth_tokens.is_empty() && !state.auth_tokens.contains(&auth.token) {
@@ -43,8 +72,30 @@ async fn chat_completions(
info!("Chat completion request from client {} for model {}", client_id, model);
// Check if model is enabled in database and get potential mapping
let model_config = sqlx::query("SELECT enabled, mapping FROM model_configs WHERE id = ?")
.bind(&model)
.fetch_optional(&state.db_pool)
.await
.unwrap_or(None);
let (model_enabled, model_mapping) = match model_config {
Some(row) => (row.get::<bool, _>("enabled"), row.get::<Option<String>, _>("mapping")),
None => (true, None),
};
if !model_enabled {
return Err(AppError::ValidationError(format!("Model {} is currently disabled", model)));
}
// Apply mapping if present
if let Some(target_model) = model_mapping {
info!("Mapping model {} to {}", model, target_model);
request.model = target_model;
}
// Find appropriate provider for the model
let provider = state.provider_manager.get_provider_for_model(&request.model)
let provider = state.provider_manager.get_provider_for_model(&request.model).await
.ok_or_else(|| AppError::ProviderError(format!("No provider found for model: {}", request.model)))?;
let provider_name = provider.name().to_string();
@@ -90,6 +141,7 @@ async fn chat_completions(
state.request_logger.clone(),
state.client_manager.clone(),
state.model_registry.clone(),
state.db_pool.clone(),
);
// Create SSE stream from aggregating stream
@@ -145,8 +197,7 @@ async fn chat_completions(
state.rate_limit_manager.record_provider_success(&provider_name).await;
let duration = start_time.elapsed();
let cost = provider.calculate_cost(&response.model, response.prompt_tokens, response.completion_tokens, &state.model_registry);
let cost = get_model_cost(&response.model, response.prompt_tokens, response.completion_tokens, &provider, &state).await;
// Log request to database
state.request_logger.log_request(crate::logging::RequestLog {
timestamp: chrono::Utc::now(),

View File

@@ -2,6 +2,7 @@ use futures::stream::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::sync::Arc;
use sqlx::Row;
use crate::logging::{RequestLogger, RequestLog};
use crate::client::ClientManager;
use crate::providers::{Provider, ProviderStreamChunk};
@@ -20,6 +21,7 @@ pub struct AggregatingStream<S> {
logger: Arc<RequestLogger>,
client_manager: Arc<ClientManager>,
model_registry: Arc<crate::models::registry::ModelRegistry>,
db_pool: crate::database::DbPool,
start_time: std::time::Instant,
has_logged: bool,
}
@@ -38,6 +40,7 @@ where
logger: Arc<RequestLogger>,
client_manager: Arc<ClientManager>,
model_registry: Arc<crate::models::registry::ModelRegistry>,
db_pool: crate::database::DbPool,
) -> Self {
Self {
inner,
@@ -51,6 +54,7 @@ where
logger,
client_manager,
model_registry,
db_pool,
start_time: std::time::Instant::now(),
has_logged: false,
}
@@ -72,6 +76,7 @@ where
let prompt_tokens = self.prompt_tokens;
let has_images = self.has_images;
let registry = self.model_registry.clone();
let pool = self.db_pool.clone();
// Estimate completion tokens (including reasoning if present)
let content_tokens = estimate_completion_tokens(&self.accumulated_content, &model);
@@ -83,10 +88,29 @@ where
let completion_tokens = content_tokens + reasoning_tokens;
let total_tokens = prompt_tokens + completion_tokens;
let cost = provider.calculate_cost(&model, prompt_tokens, completion_tokens, &registry);
// Spawn a background task to log the completion
tokio::spawn(async move {
// Check database for cost overrides
let db_cost = sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?")
.bind(&model)
.fetch_optional(&pool)
.await
.unwrap_or(None);
let cost = if let Some(row) = db_cost {
let prompt_rate = row.get::<Option<f64>, _>("prompt_cost_per_m");
let completion_rate = row.get::<Option<f64>, _>("completion_cost_per_m");
if let (Some(p), Some(c)) = (prompt_rate, completion_rate) {
(prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0)
} else {
provider.calculate_cost(&model, prompt_tokens, completion_tokens, &registry)
}
} else {
provider.calculate_cost(&model, prompt_tokens, completion_tokens, &registry)
};
// Log to database
logger.log_request(RequestLog {
timestamp: chrono::Utc::now(),
@@ -188,6 +212,7 @@ mod tests {
logger,
client_manager,
registry,
pool.clone(),
);
while let Some(item) = agg_stream.next().await {

View File

@@ -89,6 +89,10 @@
<i class="fas fa-server"></i>
<span>Providers</span>
</a>
<a href="#models" class="menu-item" data-page="models" data-tooltip="Manage Models">
<i class="fas fa-cube"></i>
<span>Models</span>
</a>
<a href="#monitoring" class="menu-item" data-page="monitoring" data-tooltip="Live Monitoring">
<i class="fas fa-heartbeat"></i>
<span>Real-time Monitoring</span>
@@ -168,6 +172,7 @@
<script src="/js/pages/costs.js"></script>
<script src="/js/pages/clients.js"></script>
<script src="/js/pages/providers.js"></script>
<script src="/js/pages/models.js"></script>
<script src="/js/pages/monitoring.js"></script>
<script src="/js/pages/settings.js"></script>
<script src="/js/pages/logs.js"></script>

View File

@@ -137,6 +137,7 @@ class Dashboard {
case 'overview': return this.getOverviewTemplate();
case 'clients': return this.getClientsTemplate();
case 'providers': return this.getProvidersTemplate();
case 'models': return this.getModelsTemplate();
case 'logs': return this.getLogsTemplate();
case 'monitoring': return this.getMonitoringTemplate();
case 'settings': return '<div class="loading-placeholder">Loading settings...</div>';
@@ -253,6 +254,30 @@ class Dashboard {
`;
}
getModelsTemplate() {
return `
<div class="card">
<div class="card-header">
<div>
<h3 class="card-title">Model Registry</h3>
<p class="card-subtitle">Manage model availability and custom pricing</p>
</div>
<div class="card-actions">
<input type="text" id="model-search" placeholder="Search models..." class="form-control" style="margin-bottom: 0; padding: 4px 8px; width: 250px;">
</div>
</div>
<div class="table-container">
<table class="table" id="models-table">
<thead>
<tr><th>ID</th><th>Display Name</th><th>Provider</th><th>Pricing (In/Out)</th><th>Context</th><th>Status</th><th>Actions</th></tr>
</thead>
<tbody></tbody>
</table>
</div>
</div>
`;
}
getLogsTemplate() {
return `
<div class="card">

165
static/js/pages/models.js Normal file
View File

@@ -0,0 +1,165 @@
// Models Page Module
class ModelsPage {
constructor() {
this.models = [];
this.init();
}
async init() {
await this.loadModels();
this.setupEventListeners();
}
async loadModels() {
try {
const data = await window.api.get('/models');
this.models = data;
this.renderModelsTable();
} catch (error) {
console.error('Error loading models:', error);
window.authManager.showToast('Failed to load models', 'error');
}
}
renderModelsTable() {
const tableBody = document.querySelector('#models-table tbody');
if (!tableBody) return;
if (this.models.length === 0) {
tableBody.innerHTML = '<tr><td colspan="7" class="text-center">No models found in registry</td></tr>';
return;
}
// Sort by provider then name
this.models.sort((a, b) => {
if (a.provider !== b.provider) return a.provider.localeCompare(b.provider);
return a.name.localeCompare(b.name);
});
tableBody.innerHTML = this.models.map(model => {
const statusClass = model.enabled ? 'success' : 'secondary';
const statusIcon = model.enabled ? 'check-circle' : 'ban';
return `
<tr>
<td><code class="code-sm">${model.id}</code></td>
<td><strong>${model.name}</strong></td>
<td><span class="badge-client">${model.provider.toUpperCase()}</span></td>
<td>${window.api.formatCurrency(model.prompt_cost)} / ${window.api.formatCurrency(model.completion_cost)}</td>
<td>${model.context_limit ? (model.context_limit / 1000) + 'k' : 'Unknown'}</td>
<td>
<span class="status-badge ${statusClass}">
<i class="fas fa-${statusIcon}"></i>
${model.enabled ? 'Active' : 'Disabled'}
</span>
</td>
<td>
<div class="action-buttons">
<button class="btn-action" title="Edit Access/Pricing" onclick="window.modelsPage.configureModel('${model.id}')">
<i class="fas fa-cog"></i>
</button>
</div>
</td>
</tr>
`;
}).join('');
}
configureModel(id) {
const model = this.models.find(m => m.id === id);
if (!model) return;
const modal = document.createElement('div');
modal.className = 'modal active';
modal.innerHTML = `
<div class="modal-content">
<div class="modal-header">
<h3 class="modal-title">Manage Model: ${model.name}</h3>
<button class="modal-close" onclick="this.closest('.modal').remove()">
<i class="fas fa-times"></i>
</button>
</div>
<div class="modal-body">
<div class="form-control">
<label class="checkbox-label" style="display: flex; align-items: center; gap: 0.5rem; cursor: pointer;">
<input type="checkbox" id="model-enabled" ${model.enabled ? 'checked' : ''} style="width: auto;">
<span>Enable this model for proxying</span>
</label>
</div>
<div class="grid-2">
<div class="form-control">
<label for="model-prompt-cost">Input Cost (per 1M tokens)</label>
<input type="number" id="model-prompt-cost" value="${model.prompt_cost}" step="0.01">
</div>
<div class="form-control">
<label for="model-completion-cost">Output Cost (per 1M tokens)</label>
<input type="number" id="model-completion-cost" value="${model.completion_cost}" step="0.01">
</div>
</div>
<div class="form-control">
<label for="model-mapping">Internal Mapping (Optional)</label>
<input type="text" id="model-mapping" value="${model.mapping || ''}" placeholder="e.g. gpt-4o-2024-05-13">
<small>Route this model ID to a different specific provider ID</small>
</div>
</div>
<div class="modal-footer">
<button class="btn btn-secondary" onclick="this.closest('.modal').remove()">Cancel</button>
<button class="btn btn-primary" id="save-model-config">Save Changes</button>
</div>
</div>
`;
document.body.appendChild(modal);
modal.querySelector('#save-model-config').onclick = async () => {
const enabled = modal.querySelector('#model-enabled').checked;
const promptCost = parseFloat(modal.querySelector('#model-prompt-cost').value);
const completionCost = parseFloat(modal.querySelector('#model-completion-cost').value);
const mapping = modal.querySelector('#model-mapping').value;
try {
await window.api.put(`/models/${id}`, {
enabled,
prompt_cost: promptCost,
completion_cost: completionCost,
mapping: mapping || null
});
window.authManager.showToast(`Model ${model.id} updated`, 'success');
modal.remove();
this.loadModels();
} catch (error) {
window.authManager.showToast(error.message, 'error');
}
};
}
setupEventListeners() {
const searchInput = document.getElementById('model-search');
if (searchInput) {
searchInput.oninput = (e) => this.filterModels(e.target.value);
}
}
filterModels(query) {
if (!query) {
this.renderModelsTable();
return;
}
const q = query.toLowerCase();
const originalModels = this.models;
this.models = this.models.filter(m =>
m.id.toLowerCase().includes(q) ||
m.name.toLowerCase().includes(q) ||
m.provider.toLowerCase().includes(q)
);
this.renderModelsTable();
this.models = originalModels;
}
}
window.initModels = async () => {
window.modelsPage = new ModelsPage();
};

View File

@@ -124,7 +124,64 @@ class ProvidersPage {
}
configureProvider(id) {
window.authManager.showToast('Provider configuration via UI not yet implemented', 'info');
const provider = this.providers.find(p => p.id === id);
if (!provider) return;
const modal = document.createElement('div');
modal.className = 'modal active';
modal.innerHTML = `
<div class="modal-content">
<div class="modal-header">
<h3 class="modal-title">Configure ${provider.name}</h3>
<button class="modal-close" onclick="this.closest('.modal').remove()">
<i class="fas fa-times"></i>
</button>
</div>
<div class="modal-body">
<div class="form-control">
<label class="checkbox-label" style="display: flex; align-items: center; gap: 0.5rem; cursor: pointer;">
<input type="checkbox" id="provider-enabled" ${provider.enabled ? 'checked' : ''} style="width: auto;">
<span>Enable Provider</span>
</label>
</div>
<div class="form-control">
<label for="provider-base-url">Base URL</label>
<input type="text" id="provider-base-url" value="${provider.base_url || ''}" placeholder="Default: ${provider.id === 'ollama' ? 'http://localhost:11434/v1' : 'Standard API URL'}">
</div>
<div class="form-control">
<label for="provider-api-key">API Key (Optional / Overwrite)</label>
<input type="password" id="provider-api-key" placeholder="••••••••••••••••">
<small>Leave blank to keep existing key from .env or config.toml</small>
</div>
</div>
<div class="modal-footer">
<button class="btn btn-secondary" onclick="this.closest('.modal').remove()">Cancel</button>
<button class="btn btn-primary" id="save-provider-config">Save Configuration</button>
</div>
</div>
`;
document.body.appendChild(modal);
modal.querySelector('#save-provider-config').onclick = async () => {
const enabled = modal.querySelector('#provider-enabled').checked;
const baseUrl = modal.querySelector('#provider-base-url').value;
const apiKey = modal.querySelector('#provider-api-key').value;
try {
await window.api.put(`/providers/${id}`, {
enabled,
base_url: baseUrl || null,
api_key: apiKey || null
});
window.authManager.showToast(`${provider.name} configuration saved`, 'success');
modal.remove();
this.loadProviders();
} catch (error) {
window.authManager.showToast(error.message, 'error');
}
};
}
setupEventListeners() {