feat: implement web UI for provider and model configuration
- Added 'provider_configs' and 'model_configs' tables to database. - Refactored ProviderManager to support thread-safe dynamic updates and database overrides. - Implemented 'Models' tab in dashboard to manage model visibility, mapping, and pricing. - Added provider configuration modal to 'Providers' tab. - Integrated database overrides into chat completion logic (enabled state, mapping, and cost).
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
use axum::{
|
||||
extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State},
|
||||
response::{IntoResponse, Json},
|
||||
routing::{get, post},
|
||||
routing::{get, post, put},
|
||||
Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -67,6 +67,8 @@ pub fn router(state: AppState) -> Router {
|
||||
.route("/api/usage/time-series", get(handle_time_series))
|
||||
.route("/api/usage/clients", get(handle_clients_usage))
|
||||
.route("/api/usage/providers", get(handle_providers_usage))
|
||||
.route("/api/models", get(handle_get_models))
|
||||
.route("/api/models/{id}", put(handle_update_model))
|
||||
.route("/api/clients", get(handle_get_clients).post(handle_create_client))
|
||||
.route("/api/clients/{id}", get(handle_get_client).delete(handle_delete_client))
|
||||
.route("/api/clients/{id}/usage", get(handle_client_usage))
|
||||
@@ -531,19 +533,47 @@ async fn handle_client_usage(
|
||||
async fn handle_get_providers(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let registry = &state.app_state.model_registry;
|
||||
let config = &state.app_state.config;
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Load all overrides from database
|
||||
let db_configs_result = sqlx::query("SELECT id, enabled, base_url FROM provider_configs")
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
let mut db_configs = HashMap::new();
|
||||
if let Ok(rows) = db_configs_result {
|
||||
for row in rows {
|
||||
let id: String = row.get("id");
|
||||
let enabled: bool = row.get("enabled");
|
||||
let base_url: Option<String> = row.get("base_url");
|
||||
db_configs.insert(id, (enabled, base_url));
|
||||
}
|
||||
}
|
||||
|
||||
let mut providers_json = Vec::new();
|
||||
|
||||
// Define the list of providers we support
|
||||
let provider_configs = vec![
|
||||
("openai", "OpenAI", config.providers.openai.enabled),
|
||||
("gemini", "Google Gemini", config.providers.gemini.enabled),
|
||||
("deepseek", "DeepSeek", config.providers.deepseek.enabled),
|
||||
("grok", "xAI Grok", config.providers.grok.enabled),
|
||||
("ollama", "Ollama", config.providers.ollama.enabled),
|
||||
];
|
||||
let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
|
||||
|
||||
for id in provider_ids {
|
||||
// Get base config
|
||||
let (mut enabled, mut base_url, display_name) = match id {
|
||||
"openai" => (config.providers.openai.enabled, config.providers.openai.base_url.clone(), "OpenAI"),
|
||||
"gemini" => (config.providers.gemini.enabled, config.providers.gemini.base_url.clone(), "Google Gemini"),
|
||||
"deepseek" => (config.providers.deepseek.enabled, config.providers.deepseek.base_url.clone(), "DeepSeek"),
|
||||
"grok" => (config.providers.grok.enabled, config.providers.grok.base_url.clone(), "xAI Grok"),
|
||||
"ollama" => (config.providers.ollama.enabled, config.providers.ollama.base_url.clone(), "Ollama"),
|
||||
_ => (false, "".to_string(), "Unknown"),
|
||||
};
|
||||
|
||||
// Apply database overrides
|
||||
if let Some((db_enabled, db_url)) = db_configs.get(id) {
|
||||
enabled = *db_enabled;
|
||||
if let Some(url) = db_url {
|
||||
base_url = url.clone();
|
||||
}
|
||||
}
|
||||
|
||||
for (id, display_name, enabled) in provider_configs {
|
||||
// Find models for this provider in registry
|
||||
let mut models = Vec::new();
|
||||
if let Some(p_info) = registry.providers.get(id) {
|
||||
@@ -557,7 +587,7 @@ async fn handle_get_providers(State(state): State<DashboardState>) -> Json<ApiRe
|
||||
"disabled"
|
||||
} else {
|
||||
// Check if it's actually initialized in the provider manager
|
||||
if state.app_state.provider_manager.get_provider(id).is_some() {
|
||||
if state.app_state.provider_manager.get_provider(id).await.is_some() {
|
||||
// Check circuit breaker
|
||||
if state.app_state.rate_limit_manager.check_provider_request(id).await.unwrap_or(true) {
|
||||
"online"
|
||||
@@ -575,6 +605,7 @@ async fn handle_get_providers(State(state): State<DashboardState>) -> Json<ApiRe
|
||||
"enabled": enabled,
|
||||
"status": status,
|
||||
"models": models,
|
||||
"base_url": base_url,
|
||||
"last_used": None::<String>,
|
||||
}));
|
||||
}
|
||||
@@ -589,14 +620,55 @@ async fn handle_get_provider(
|
||||
Json(ApiResponse::error("Not implemented".to_string()))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct UpdateProviderRequest {
|
||||
enabled: bool,
|
||||
base_url: Option<String>,
|
||||
api_key: Option<String>,
|
||||
}
|
||||
|
||||
async fn handle_update_provider(
|
||||
State(_state): State<DashboardState>,
|
||||
axum::extract::Path(_name): axum::extract::Path<String>,
|
||||
State(state): State<DashboardState>,
|
||||
axum::extract::Path(name): axum::extract::Path<String>,
|
||||
Json(payload): Json<UpdateProviderRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"success": true,
|
||||
"message": "Provider updated"
|
||||
})))
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Update or insert into database
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
enabled = excluded.enabled,
|
||||
base_url = excluded.base_url,
|
||||
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"#
|
||||
)
|
||||
.bind(&name)
|
||||
.bind(name.to_uppercase())
|
||||
.bind(payload.enabled)
|
||||
.bind(&payload.base_url)
|
||||
.bind(&payload.api_key)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
// Re-initialize provider in manager
|
||||
if let Err(e) = state.app_state.provider_manager.initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool).await {
|
||||
warn!("Failed to re-initialize provider {}: {}", name, e);
|
||||
return Json(ApiResponse::error(format!("Provider settings saved but initialization failed: {}", e)));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({ "message": "Provider updated and re-initialized" })))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to update provider config: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to update provider: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_test_provider(
|
||||
@@ -610,6 +682,104 @@ async fn handle_test_provider(
|
||||
})))
|
||||
}
|
||||
|
||||
// Model handlers
|
||||
async fn handle_get_models(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let registry = &state.app_state.model_registry;
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Load overrides from database
|
||||
let db_models_result = sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
let mut db_models = HashMap::new();
|
||||
if let Ok(rows) = db_models_result {
|
||||
for row in rows {
|
||||
let id: String = row.get("id");
|
||||
db_models.insert(id, row);
|
||||
}
|
||||
}
|
||||
|
||||
let mut models_json = Vec::new();
|
||||
|
||||
for (p_id, p_info) in ®istry.providers {
|
||||
for (m_id, m_meta) in &p_info.models {
|
||||
let mut enabled = true;
|
||||
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
||||
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
||||
let mut mapping = None::<String>;
|
||||
|
||||
if let Some(row) = db_models.get(m_id) {
|
||||
enabled = row.get("enabled");
|
||||
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") { prompt_cost = p; }
|
||||
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") { completion_cost = c; }
|
||||
mapping = row.get("mapping");
|
||||
}
|
||||
|
||||
models_json.push(serde_json::json!({
|
||||
"id": m_id,
|
||||
"provider": p_id,
|
||||
"name": m_meta.name,
|
||||
"enabled": enabled,
|
||||
"prompt_cost": prompt_cost,
|
||||
"completion_cost": completion_cost,
|
||||
"mapping": mapping,
|
||||
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(models_json)))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct UpdateModelRequest {
|
||||
enabled: bool,
|
||||
prompt_cost: Option<f64>,
|
||||
completion_cost: Option<f64>,
|
||||
mapping: Option<String>,
|
||||
}
|
||||
|
||||
async fn handle_update_model(
|
||||
State(state): State<DashboardState>,
|
||||
axum::extract::Path(id): axum::extract::Path<String>,
|
||||
Json(payload): Json<UpdateModelRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Find provider_id for this model in registry
|
||||
let provider_id = state.app_state.model_registry.providers.iter()
|
||||
.find(|(_, p)| p.models.contains_key(&id))
|
||||
.map(|(id, _)| id.clone())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
enabled = excluded.enabled,
|
||||
prompt_cost_per_m = excluded.prompt_cost_per_m,
|
||||
completion_cost_per_m = excluded.completion_cost_per_m,
|
||||
mapping = excluded.mapping,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"#
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(provider_id)
|
||||
.bind(payload.enabled)
|
||||
.bind(payload.prompt_cost)
|
||||
.bind(payload.completion_cost)
|
||||
.bind(payload.mapping)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" }))),
|
||||
Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))),
|
||||
}
|
||||
}
|
||||
|
||||
// System handlers
|
||||
async fn handle_system_health(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let mut components = HashMap::new();
|
||||
@@ -617,7 +787,7 @@ async fn handle_system_health(State(state): State<DashboardState>) -> Json<ApiRe
|
||||
components.insert("database".to_string(), "online".to_string());
|
||||
|
||||
// Check provider health via circuit breakers
|
||||
let provider_ids: Vec<String> = state.app_state.provider_manager.get_all_providers()
|
||||
let provider_ids: Vec<String> = state.app_state.provider_manager.get_all_providers().await
|
||||
.iter()
|
||||
.map(|p| p.name().to_string())
|
||||
.collect();
|
||||
|
||||
Reference in New Issue
Block a user