feat: implement web UI for provider and model configuration
- Added 'provider_configs' and 'model_configs' tables to database. - Refactored ProviderManager to support thread-safe dynamic updates and database overrides. - Implemented 'Models' tab in dashboard to manage model visibility, mapping, and pricing. - Added provider configuration modal to 'Providers' tab. - Integrated database overrides into chat completion logic (enabled state, mapping, and cost).
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
use sqlx::Row;
|
||||
use uuid::Uuid;
|
||||
use axum::{
|
||||
extract::State,
|
||||
@@ -27,10 +29,37 @@ pub fn router(state: AppState) -> Router {
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
async fn get_model_cost(
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
provider: &Arc<dyn crate::providers::Provider>,
|
||||
state: &AppState,
|
||||
) -> f64 {
|
||||
// Check database for cost overrides
|
||||
let db_cost = sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?")
|
||||
.bind(model)
|
||||
.fetch_optional(&state.db_pool)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
if let Some(row) = db_cost {
|
||||
let prompt_rate = row.get::<Option<f64>, _>("prompt_cost_per_m");
|
||||
let completion_rate = row.get::<Option<f64>, _>("completion_cost_per_m");
|
||||
|
||||
if let (Some(p), Some(c)) = (prompt_rate, completion_rate) {
|
||||
return (prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to provider's registry-based calculation
|
||||
provider.calculate_cost(model, prompt_tokens, completion_tokens, &state.model_registry)
|
||||
}
|
||||
|
||||
async fn chat_completions(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthenticatedClient,
|
||||
Json(request): Json<ChatCompletionRequest>,
|
||||
Json(mut request): Json<ChatCompletionRequest>,
|
||||
) -> Result<axum::response::Response, AppError> {
|
||||
// Validate token against configured auth tokens
|
||||
if !state.auth_tokens.is_empty() && !state.auth_tokens.contains(&auth.token) {
|
||||
@@ -43,8 +72,30 @@ async fn chat_completions(
|
||||
|
||||
info!("Chat completion request from client {} for model {}", client_id, model);
|
||||
|
||||
// Check if model is enabled in database and get potential mapping
|
||||
let model_config = sqlx::query("SELECT enabled, mapping FROM model_configs WHERE id = ?")
|
||||
.bind(&model)
|
||||
.fetch_optional(&state.db_pool)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let (model_enabled, model_mapping) = match model_config {
|
||||
Some(row) => (row.get::<bool, _>("enabled"), row.get::<Option<String>, _>("mapping")),
|
||||
None => (true, None),
|
||||
};
|
||||
|
||||
if !model_enabled {
|
||||
return Err(AppError::ValidationError(format!("Model {} is currently disabled", model)));
|
||||
}
|
||||
|
||||
// Apply mapping if present
|
||||
if let Some(target_model) = model_mapping {
|
||||
info!("Mapping model {} to {}", model, target_model);
|
||||
request.model = target_model;
|
||||
}
|
||||
|
||||
// Find appropriate provider for the model
|
||||
let provider = state.provider_manager.get_provider_for_model(&request.model)
|
||||
let provider = state.provider_manager.get_provider_for_model(&request.model).await
|
||||
.ok_or_else(|| AppError::ProviderError(format!("No provider found for model: {}", request.model)))?;
|
||||
|
||||
let provider_name = provider.name().to_string();
|
||||
@@ -90,6 +141,7 @@ async fn chat_completions(
|
||||
state.request_logger.clone(),
|
||||
state.client_manager.clone(),
|
||||
state.model_registry.clone(),
|
||||
state.db_pool.clone(),
|
||||
);
|
||||
|
||||
// Create SSE stream from aggregating stream
|
||||
@@ -141,13 +193,12 @@ async fn chat_completions(
|
||||
|
||||
match result {
|
||||
Ok(response) => {
|
||||
// Record provider success
|
||||
state.rate_limit_manager.record_provider_success(&provider_name).await;
|
||||
// Record provider success
|
||||
state.rate_limit_manager.record_provider_success(&provider_name).await;
|
||||
|
||||
let duration = start_time.elapsed();
|
||||
let cost = provider.calculate_cost(&response.model, response.prompt_tokens, response.completion_tokens, &state.model_registry);
|
||||
|
||||
// Log request to database
|
||||
let duration = start_time.elapsed();
|
||||
let cost = get_model_cost(&response.model, response.prompt_tokens, response.completion_tokens, &provider, &state).await;
|
||||
// Log request to database
|
||||
state.request_logger.log_request(crate::logging::RequestLog {
|
||||
timestamp: chrono::Utc::now(),
|
||||
client_id: client_id.clone(),
|
||||
|
||||
Reference in New Issue
Block a user