use axum::{ extract::{Path, Query, State}, response::Json, }; use serde::Deserialize; use serde_json; use sqlx::Row; use std::collections::HashMap; use super::{ApiResponse, DashboardState}; use crate::models::registry::{ModelFilter, ModelSortBy, SortOrder}; #[derive(Deserialize)] pub(super) struct UpdateModelRequest { pub(super) enabled: bool, pub(super) prompt_cost: Option, pub(super) completion_cost: Option, pub(super) mapping: Option, } /// Query parameters for `GET /api/models`. #[derive(Debug, Deserialize, Default)] pub(super) struct ModelListParams { /// Filter by provider ID. pub provider: Option, /// Text search on model ID or name. pub search: Option, /// Filter by input modality (e.g. "image"). pub modality: Option, /// Only models that support tool calling. pub tool_call: Option, /// Only models that support reasoning. pub reasoning: Option, /// Only models that have pricing data. pub has_cost: Option, /// Sort field (name, id, provider, context_limit, input_cost, output_cost). pub sort_by: Option, /// Sort direction (asc, desc). pub sort_order: Option, } pub(super) async fn handle_get_models( State(state): State, Query(params): Query, ) -> Json> { let registry = &state.app_state.model_registry; let pool = &state.app_state.db_pool; // Build filter from query params let filter = ModelFilter { provider: params.provider, search: params.search, modality: params.modality, tool_call: params.tool_call, reasoning: params.reasoning, has_cost: params.has_cost, }; let sort_by = params.sort_by.unwrap_or_default(); let sort_order = params.sort_order.unwrap_or_default(); // Get filtered and sorted model entries let entries = registry.list_models(&filter, &sort_by, &sort_order); // Load overrides from database let db_models_result = sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs") .fetch_all(pool) .await; let mut db_models = HashMap::new(); if let Ok(rows) = db_models_result { for row in rows { let id: String = row.get("id"); db_models.insert(id, row); } } let mut models_json = Vec::new(); for entry in &entries { let m_key = entry.model_key; let m_meta = entry.metadata; let mut enabled = true; let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0); let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0); let cache_read_cost = m_meta.cost.as_ref().and_then(|c| c.cache_read); let cache_write_cost = m_meta.cost.as_ref().and_then(|c| c.cache_write); let mut mapping = None::; if let Some(row) = db_models.get(m_key) { enabled = row.get("enabled"); if let Some(p) = row.get::, _>("prompt_cost_per_m") { prompt_cost = p; } if let Some(c) = row.get::, _>("completion_cost_per_m") { completion_cost = c; } mapping = row.get("mapping"); } models_json.push(serde_json::json!({ "id": m_key, "provider": entry.provider_id, "provider_name": entry.provider_name, "name": m_meta.name, "enabled": enabled, "prompt_cost": prompt_cost, "completion_cost": completion_cost, "cache_read_cost": cache_read_cost, "cache_write_cost": cache_write_cost, "mapping": mapping, "context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0), "output_limit": m_meta.limit.as_ref().map(|l| l.output).unwrap_or(0), "modalities": m_meta.modalities.as_ref().map(|m| serde_json::json!({ "input": m.input, "output": m.output, })), "tool_call": m_meta.tool_call, "reasoning": m_meta.reasoning, })); } Json(ApiResponse::success(serde_json::json!(models_json))) } pub(super) async fn handle_update_model( State(state): State, Path(id): Path, Json(payload): Json, ) -> Json> { let pool = &state.app_state.db_pool; // Find provider_id for this model in registry let provider_id = state .app_state .model_registry .providers .iter() .find(|(_, p)| p.models.contains_key(&id)) .map(|(id, _)| id.clone()) .unwrap_or_else(|| "unknown".to_string()); let result = sqlx::query( r#" INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping) VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET enabled = excluded.enabled, prompt_cost_per_m = excluded.prompt_cost_per_m, completion_cost_per_m = excluded.completion_cost_per_m, mapping = excluded.mapping, updated_at = CURRENT_TIMESTAMP "#, ) .bind(&id) .bind(provider_id) .bind(payload.enabled) .bind(payload.prompt_cost) .bind(payload.completion_cost) .bind(payload.mapping) .execute(pool) .await; match result { Ok(_) => { // Invalidate the in-memory cache so the proxy picks up the change immediately state.app_state.model_config_cache.invalidate().await; Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" }))) } Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))), } }