Files
GopherGate/src/dashboard/models.rs
hobokenchicken 0d32d953d2
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled
fix(dashboard): accurately map used models to actual providers
This commit modifies the /api/models endpoint so that when fetching 'used models' for the Cost Management view, it accurately pairs each model with the exact provider it was routed through (by querying SELECT DISTINCT provider, model FROM llm_requests). Previously, it relied on the global registry's mapping, which could falsely attribute usage to unconfigured or alternate providers.
2026-03-07 01:12:41 +00:00

255 lines
9.6 KiB
Rust

use axum::{
extract::{Path, Query, State},
response::Json,
};
use serde::Deserialize;
use serde_json;
use sqlx::Row;
use std::collections::HashMap;
use super::{ApiResponse, DashboardState};
use crate::models::registry::{ModelFilter, ModelSortBy, SortOrder};
#[derive(Deserialize)]
pub(super) struct UpdateModelRequest {
pub(super) enabled: bool,
pub(super) prompt_cost: Option<f64>,
pub(super) completion_cost: Option<f64>,
pub(super) mapping: Option<String>,
}
/// Query parameters for `GET /api/models`.
#[derive(Debug, Deserialize, Default)]
pub(super) struct ModelListParams {
/// Filter by provider ID.
pub provider: Option<String>,
/// Text search on model ID or name.
pub search: Option<String>,
/// Filter by input modality (e.g. "image").
pub modality: Option<String>,
/// Only models that support tool calling.
pub tool_call: Option<bool>,
/// Only models that support reasoning.
pub reasoning: Option<bool>,
/// Only models that have pricing data.
pub has_cost: Option<bool>,
/// Only models that have been used in requests.
pub used_only: Option<bool>,
/// Sort field (name, id, provider, context_limit, input_cost, output_cost).
pub sort_by: Option<ModelSortBy>,
/// Sort direction (asc, desc).
pub sort_order: Option<SortOrder>,
}
pub(super) async fn handle_get_models(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Query(params): Query<ModelListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let registry = &state.app_state.model_registry;
let pool = &state.app_state.db_pool;
// Load overrides from database
let db_models_result =
sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
.fetch_all(pool)
.await;
let mut db_models = HashMap::new();
if let Ok(rows) = db_models_result {
for row in rows {
let id: String = row.get("id");
db_models.insert(id, row);
}
}
let mut models_json = Vec::new();
if params.used_only.unwrap_or(false) {
// EXACT USED MODELS LOGIC
let used_pairs_result = sqlx::query(
"SELECT DISTINCT provider, model FROM llm_requests",
)
.fetch_all(pool)
.await;
if let Ok(rows) = used_pairs_result {
for row in rows {
let provider: String = row.get("provider");
let m_key: String = row.get("model");
let provider_name = match provider.as_str() {
"openai" => "OpenAI",
"gemini" => "Google Gemini",
"deepseek" => "DeepSeek",
"grok" => "xAI Grok",
"ollama" => "Ollama",
_ => provider.as_str(),
}.to_string();
let m_meta = registry.find_model(&m_key);
let mut enabled = true;
let mut prompt_cost = m_meta.and_then(|m| m.cost.as_ref().map(|c| c.input)).unwrap_or(0.0);
let mut completion_cost = m_meta.and_then(|m| m.cost.as_ref().map(|c| c.output)).unwrap_or(0.0);
let cache_read_cost = m_meta.and_then(|m| m.cost.as_ref().and_then(|c| c.cache_read));
let cache_write_cost = m_meta.and_then(|m| m.cost.as_ref().and_then(|c| c.cache_write));
let mut mapping = None::<String>;
if let Some(db_row) = db_models.get(&m_key) {
enabled = db_row.get("enabled");
if let Some(p) = db_row.get::<Option<f64>, _>("prompt_cost_per_m") {
prompt_cost = p;
}
if let Some(c) = db_row.get::<Option<f64>, _>("completion_cost_per_m") {
completion_cost = c;
}
mapping = db_row.get("mapping");
}
models_json.push(serde_json::json!({
"id": m_key,
"provider": provider,
"provider_name": provider_name,
"name": m_meta.map(|m| m.name.clone()).unwrap_or_else(|| m_key.clone()),
"enabled": enabled,
"prompt_cost": prompt_cost,
"completion_cost": completion_cost,
"cache_read_cost": cache_read_cost,
"cache_write_cost": cache_write_cost,
"mapping": mapping,
"context_limit": m_meta.and_then(|m| m.limit.as_ref().map(|l| l.context)).unwrap_or(0),
"output_limit": m_meta.and_then(|m| m.limit.as_ref().map(|l| l.output)).unwrap_or(0),
"modalities": m_meta.and_then(|m| m.modalities.as_ref().map(|mo| serde_json::json!({
"input": mo.input,
"output": mo.output,
}))),
"tool_call": m_meta.and_then(|m| m.tool_call),
"reasoning": m_meta.and_then(|m| m.reasoning),
}));
}
}
} else {
// REGISTRY LISTING LOGIC
// Build filter from query params
let filter = ModelFilter {
provider: params.provider,
search: params.search,
modality: params.modality,
tool_call: params.tool_call,
reasoning: params.reasoning,
has_cost: params.has_cost,
};
let sort_by = params.sort_by.unwrap_or_default();
let sort_order = params.sort_order.unwrap_or_default();
// Get filtered and sorted model entries
let entries = registry.list_models(&filter, &sort_by, &sort_order);
for entry in &entries {
let m_key = entry.model_key;
let m_meta = entry.metadata;
let mut enabled = true;
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
let cache_read_cost = m_meta.cost.as_ref().and_then(|c| c.cache_read);
let cache_write_cost = m_meta.cost.as_ref().and_then(|c| c.cache_write);
let mut mapping = None::<String>;
if let Some(row) = db_models.get(m_key) {
enabled = row.get("enabled");
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") {
prompt_cost = p;
}
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") {
completion_cost = c;
}
mapping = row.get("mapping");
}
models_json.push(serde_json::json!({
"id": m_key,
"provider": entry.provider_id,
"provider_name": entry.provider_name,
"name": m_meta.name,
"enabled": enabled,
"prompt_cost": prompt_cost,
"completion_cost": completion_cost,
"cache_read_cost": cache_read_cost,
"cache_write_cost": cache_write_cost,
"mapping": mapping,
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
"output_limit": m_meta.limit.as_ref().map(|l| l.output).unwrap_or(0),
"modalities": m_meta.modalities.as_ref().map(|m| serde_json::json!({
"input": m.input,
"output": m.output,
})),
"tool_call": m_meta.tool_call,
"reasoning": m_meta.reasoning,
}));
}
}
Json(ApiResponse::success(serde_json::json!(models_json)))
}
pub(super) async fn handle_update_model(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(id): Path<String>,
Json(payload): Json<UpdateModelRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Find provider_id for this model in registry
let provider_id = state
.app_state
.model_registry
.providers
.iter()
.find(|(_, p)| p.models.contains_key(&id))
.map(|(id, _)| id.clone())
.unwrap_or_else(|| "unknown".to_string());
let result = sqlx::query(
r#"
INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
enabled = excluded.enabled,
prompt_cost_per_m = excluded.prompt_cost_per_m,
completion_cost_per_m = excluded.completion_cost_per_m,
mapping = excluded.mapping,
updated_at = CURRENT_TIMESTAMP
"#,
)
.bind(&id)
.bind(provider_id)
.bind(payload.enabled)
.bind(payload.prompt_cost)
.bind(payload.completion_cost)
.bind(payload.mapping)
.execute(pool)
.await;
match result {
Ok(_) => {
// Invalidate the in-memory cache so the proxy picks up the change immediately
state.app_state.model_config_cache.invalidate().await;
Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" })))
}
Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))),
}
}