diff --git a/src/dashboard/models.rs b/src/dashboard/models.rs index f06cf46f..ddcb5eaf 100644 --- a/src/dashboard/models.rs +++ b/src/dashboard/models.rs @@ -54,37 +54,6 @@ pub(super) async fn handle_get_models( let registry = &state.app_state.model_registry; let pool = &state.app_state.db_pool; - // If used_only, fetch the set of models that appear in llm_requests - let used_models: Option> = - if params.used_only.unwrap_or(false) { - match sqlx::query_scalar::<_, String>( - "SELECT DISTINCT model FROM llm_requests", - ) - .fetch_all(pool) - .await - { - Ok(models) => Some(models.into_iter().collect()), - Err(_) => Some(std::collections::HashSet::new()), - } - } else { - None - }; - - // Build filter from query params - let filter = ModelFilter { - provider: params.provider, - search: params.search, - modality: params.modality, - tool_call: params.tool_call, - reasoning: params.reasoning, - has_cost: params.has_cost, - }; - let sort_by = params.sort_by.unwrap_or_default(); - let sort_order = params.sort_order.unwrap_or_default(); - - // Get filtered and sorted model entries - let entries = registry.list_models(&filter, &sort_by, &sort_order); - // Load overrides from database let db_models_result = sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs") @@ -101,56 +70,130 @@ pub(super) async fn handle_get_models( let mut models_json = Vec::new(); - for entry in &entries { - let m_key = entry.model_key; + if params.used_only.unwrap_or(false) { + // EXACT USED MODELS LOGIC + let used_pairs_result = sqlx::query( + "SELECT DISTINCT provider, model FROM llm_requests", + ) + .fetch_all(pool) + .await; - // Skip models not in the used set (when used_only is active) - if let Some(ref used) = used_models { - if !used.contains(m_key) { - continue; + if let Ok(rows) = used_pairs_result { + for row in rows { + let provider: String = row.get("provider"); + let m_key: String = row.get("model"); + + let provider_name = match provider.as_str() { + "openai" => "OpenAI", + "gemini" => "Google Gemini", + "deepseek" => "DeepSeek", + "grok" => "xAI Grok", + "ollama" => "Ollama", + _ => provider.as_str(), + }.to_string(); + + let m_meta = registry.find_model(&m_key); + + let mut enabled = true; + let mut prompt_cost = m_meta.and_then(|m| m.cost.as_ref().map(|c| c.input)).unwrap_or(0.0); + let mut completion_cost = m_meta.and_then(|m| m.cost.as_ref().map(|c| c.output)).unwrap_or(0.0); + let cache_read_cost = m_meta.and_then(|m| m.cost.as_ref().and_then(|c| c.cache_read)); + let cache_write_cost = m_meta.and_then(|m| m.cost.as_ref().and_then(|c| c.cache_write)); + let mut mapping = None::; + + if let Some(db_row) = db_models.get(&m_key) { + enabled = db_row.get("enabled"); + if let Some(p) = db_row.get::, _>("prompt_cost_per_m") { + prompt_cost = p; + } + if let Some(c) = db_row.get::, _>("completion_cost_per_m") { + completion_cost = c; + } + mapping = db_row.get("mapping"); + } + + models_json.push(serde_json::json!({ + "id": m_key, + "provider": provider, + "provider_name": provider_name, + "name": m_meta.map(|m| m.name.clone()).unwrap_or_else(|| m_key.clone()), + "enabled": enabled, + "prompt_cost": prompt_cost, + "completion_cost": completion_cost, + "cache_read_cost": cache_read_cost, + "cache_write_cost": cache_write_cost, + "mapping": mapping, + "context_limit": m_meta.and_then(|m| m.limit.as_ref().map(|l| l.context)).unwrap_or(0), + "output_limit": m_meta.and_then(|m| m.limit.as_ref().map(|l| l.output)).unwrap_or(0), + "modalities": m_meta.and_then(|m| m.modalities.as_ref().map(|mo| serde_json::json!({ + "input": mo.input, + "output": mo.output, + }))), + "tool_call": m_meta.and_then(|m| m.tool_call), + "reasoning": m_meta.and_then(|m| m.reasoning), + })); } } + } else { + // REGISTRY LISTING LOGIC + // Build filter from query params + let filter = ModelFilter { + provider: params.provider, + search: params.search, + modality: params.modality, + tool_call: params.tool_call, + reasoning: params.reasoning, + has_cost: params.has_cost, + }; + let sort_by = params.sort_by.unwrap_or_default(); + let sort_order = params.sort_order.unwrap_or_default(); - let m_meta = entry.metadata; + // Get filtered and sorted model entries + let entries = registry.list_models(&filter, &sort_by, &sort_order); - let mut enabled = true; - let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0); - let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0); - let cache_read_cost = m_meta.cost.as_ref().and_then(|c| c.cache_read); - let cache_write_cost = m_meta.cost.as_ref().and_then(|c| c.cache_write); - let mut mapping = None::; + for entry in &entries { + let m_key = entry.model_key; + let m_meta = entry.metadata; - if let Some(row) = db_models.get(m_key) { - enabled = row.get("enabled"); - if let Some(p) = row.get::, _>("prompt_cost_per_m") { - prompt_cost = p; + let mut enabled = true; + let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0); + let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0); + let cache_read_cost = m_meta.cost.as_ref().and_then(|c| c.cache_read); + let cache_write_cost = m_meta.cost.as_ref().and_then(|c| c.cache_write); + let mut mapping = None::; + + if let Some(row) = db_models.get(m_key) { + enabled = row.get("enabled"); + if let Some(p) = row.get::, _>("prompt_cost_per_m") { + prompt_cost = p; + } + if let Some(c) = row.get::, _>("completion_cost_per_m") { + completion_cost = c; + } + mapping = row.get("mapping"); } - if let Some(c) = row.get::, _>("completion_cost_per_m") { - completion_cost = c; - } - mapping = row.get("mapping"); + + models_json.push(serde_json::json!({ + "id": m_key, + "provider": entry.provider_id, + "provider_name": entry.provider_name, + "name": m_meta.name, + "enabled": enabled, + "prompt_cost": prompt_cost, + "completion_cost": completion_cost, + "cache_read_cost": cache_read_cost, + "cache_write_cost": cache_write_cost, + "mapping": mapping, + "context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0), + "output_limit": m_meta.limit.as_ref().map(|l| l.output).unwrap_or(0), + "modalities": m_meta.modalities.as_ref().map(|m| serde_json::json!({ + "input": m.input, + "output": m.output, + })), + "tool_call": m_meta.tool_call, + "reasoning": m_meta.reasoning, + })); } - - models_json.push(serde_json::json!({ - "id": m_key, - "provider": entry.provider_id, - "provider_name": entry.provider_name, - "name": m_meta.name, - "enabled": enabled, - "prompt_cost": prompt_cost, - "completion_cost": completion_cost, - "cache_read_cost": cache_read_cost, - "cache_write_cost": cache_write_cost, - "mapping": mapping, - "context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0), - "output_limit": m_meta.limit.as_ref().map(|l| l.output).unwrap_or(0), - "modalities": m_meta.modalities.as_ref().map(|m| serde_json::json!({ - "input": m.input, - "output": m.output, - })), - "tool_call": m_meta.tool_call, - "reasoning": m_meta.reasoning, - })); } Json(ApiResponse::success(serde_json::json!(models_json)))