From 942aa23f88808a0d46e578a6a219f61a3d03a5fc Mon Sep 17 00:00:00 2001 From: hobokenchicken Date: Mon, 2 Mar 2026 08:51:33 -0500 Subject: [PATCH] feat(models): add filtering and sorting to model registry and GET /api/models Add ModelFilter, ModelSortBy, SortOrder structs and list_models() method to ModelRegistry. The /api/models endpoint now accepts query params: provider, search, modality, tool_call, reasoning, has_cost, sort_by, sort_order. Response also enriched with provider_name, output_limit, modalities, tool_call, and reasoning fields. --- src/dashboard/models.rs | 105 ++++++++++++++++++++-------- src/models/registry.rs | 150 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 227 insertions(+), 28 deletions(-) diff --git a/src/dashboard/models.rs b/src/dashboard/models.rs index 1c326392..93f8e499 100644 --- a/src/dashboard/models.rs +++ b/src/dashboard/models.rs @@ -1,5 +1,5 @@ use axum::{ - extract::{Path, State}, + extract::{Path, Query, State}, response::Json, }; use serde::Deserialize; @@ -8,6 +8,7 @@ use sqlx::Row; use std::collections::HashMap; use super::{ApiResponse, DashboardState}; +use crate::models::registry::{ModelFilter, ModelSortBy, SortOrder}; #[derive(Deserialize)] pub(super) struct UpdateModelRequest { @@ -17,10 +18,49 @@ pub(super) struct UpdateModelRequest { pub(super) mapping: Option, } -pub(super) async fn handle_get_models(State(state): State) -> Json> { +/// Query parameters for `GET /api/models`. +#[derive(Debug, Deserialize, Default)] +pub(super) struct ModelListParams { + /// Filter by provider ID. + pub provider: Option, + /// Text search on model ID or name. + pub search: Option, + /// Filter by input modality (e.g. "image"). + pub modality: Option, + /// Only models that support tool calling. + pub tool_call: Option, + /// Only models that support reasoning. + pub reasoning: Option, + /// Only models that have pricing data. + pub has_cost: Option, + /// Sort field (name, id, provider, context_limit, input_cost, output_cost). + pub sort_by: Option, + /// Sort direction (asc, desc). + pub sort_order: Option, +} + +pub(super) async fn handle_get_models( + State(state): State, + Query(params): Query, +) -> Json> { let registry = &state.app_state.model_registry; let pool = &state.app_state.db_pool; + // Build filter from query params + let filter = ModelFilter { + provider: params.provider, + search: params.search, + modality: params.modality, + tool_call: params.tool_call, + reasoning: params.reasoning, + has_cost: params.has_cost, + }; + let sort_by = params.sort_by.unwrap_or_default(); + let sort_order = params.sort_order.unwrap_or_default(); + + // Get filtered and sorted model entries + let entries = registry.list_models(&filter, &sort_by, &sort_order); + // Load overrides from database let db_models_result = sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs") @@ -37,35 +77,44 @@ pub(super) async fn handle_get_models(State(state): State) -> Js let mut models_json = Vec::new(); - for (p_id, p_info) in ®istry.providers { - for (m_id, m_meta) in &p_info.models { - let mut enabled = true; - let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0); - let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0); - let mut mapping = None::; + for entry in &entries { + let m_key = entry.model_key; + let m_meta = entry.metadata; - if let Some(row) = db_models.get(m_id) { - enabled = row.get("enabled"); - if let Some(p) = row.get::, _>("prompt_cost_per_m") { - prompt_cost = p; - } - if let Some(c) = row.get::, _>("completion_cost_per_m") { - completion_cost = c; - } - mapping = row.get("mapping"); + let mut enabled = true; + let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0); + let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0); + let mut mapping = None::; + + if let Some(row) = db_models.get(m_key) { + enabled = row.get("enabled"); + if let Some(p) = row.get::, _>("prompt_cost_per_m") { + prompt_cost = p; } - - models_json.push(serde_json::json!({ - "id": m_id, - "provider": p_id, - "name": m_meta.name, - "enabled": enabled, - "prompt_cost": prompt_cost, - "completion_cost": completion_cost, - "mapping": mapping, - "context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0), - })); + if let Some(c) = row.get::, _>("completion_cost_per_m") { + completion_cost = c; + } + mapping = row.get("mapping"); } + + models_json.push(serde_json::json!({ + "id": m_key, + "provider": entry.provider_id, + "provider_name": entry.provider_name, + "name": m_meta.name, + "enabled": enabled, + "prompt_cost": prompt_cost, + "completion_cost": completion_cost, + "mapping": mapping, + "context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0), + "output_limit": m_meta.limit.as_ref().map(|l| l.output).unwrap_or(0), + "modalities": m_meta.modalities.as_ref().map(|m| serde_json::json!({ + "input": m.input, + "output": m.output, + })), + "tool_call": m_meta.tool_call, + "reasoning": m_meta.reasoning, + })); } Json(ApiResponse::success(serde_json::json!(models_json))) diff --git a/src/models/registry.rs b/src/models/registry.rs index f793e8e6..40f1430d 100644 --- a/src/models/registry.rs +++ b/src/models/registry.rs @@ -45,6 +45,54 @@ pub struct ModelModalities { pub output: Vec, } +/// A model entry paired with its provider ID, returned by listing/filtering methods. +#[derive(Debug, Clone)] +pub struct ModelEntry<'a> { + pub model_key: &'a str, + pub provider_id: &'a str, + pub provider_name: &'a str, + pub metadata: &'a ModelMetadata, +} + +/// Filter criteria for listing models. All fields are optional; `None` means no filter. +#[derive(Debug, Default, Clone, Deserialize)] +pub struct ModelFilter { + /// Filter by provider ID (exact match). + pub provider: Option, + /// Text search on model ID or name (case-insensitive substring). + pub search: Option, + /// Filter by input modality (e.g. "image", "text"). + pub modality: Option, + /// Only models that support tool calling. + pub tool_call: Option, + /// Only models that support reasoning. + pub reasoning: Option, + /// Only models that have pricing data. + pub has_cost: Option, +} + +/// Sort field for model listings. +#[derive(Debug, Clone, Deserialize, Default, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum ModelSortBy { + #[default] + Name, + Id, + Provider, + ContextLimit, + InputCost, + OutputCost, +} + +/// Sort direction. +#[derive(Debug, Clone, Deserialize, Default, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum SortOrder { + #[default] + Asc, + Desc, +} + impl ModelRegistry { /// Find a model by its ID (searching across all providers) pub fn find_model(&self, model_id: &str) -> Option<&ModelMetadata> { @@ -66,4 +114,106 @@ impl ModelRegistry { None } + + /// List all models with optional filtering and sorting. + pub fn list_models( + &self, + filter: &ModelFilter, + sort_by: &ModelSortBy, + sort_order: &SortOrder, + ) -> Vec> { + let mut entries: Vec> = Vec::new(); + + for (p_id, p_info) in &self.providers { + // Provider filter + if let Some(ref prov) = filter.provider { + if p_id != prov { + continue; + } + } + + for (m_key, m_meta) in &p_info.models { + // Text search filter + if let Some(ref search) = filter.search { + let search_lower = search.to_lowercase(); + if !m_meta.id.to_lowercase().contains(&search_lower) + && !m_meta.name.to_lowercase().contains(&search_lower) + && !m_key.to_lowercase().contains(&search_lower) + { + continue; + } + } + + // Modality filter + if let Some(ref modality) = filter.modality { + let has_modality = m_meta + .modalities + .as_ref() + .is_some_and(|m| m.input.iter().any(|i| i.eq_ignore_ascii_case(modality))); + if !has_modality { + continue; + } + } + + // Tool call filter + if let Some(tc) = filter.tool_call { + if m_meta.tool_call.unwrap_or(false) != tc { + continue; + } + } + + // Reasoning filter + if let Some(r) = filter.reasoning { + if m_meta.reasoning.unwrap_or(false) != r { + continue; + } + } + + // Has cost filter + if let Some(hc) = filter.has_cost { + if hc != m_meta.cost.is_some() { + continue; + } + } + + entries.push(ModelEntry { + model_key: m_key, + provider_id: p_id, + provider_name: &p_info.name, + metadata: m_meta, + }); + } + } + + // Sort + entries.sort_by(|a, b| { + let cmp = match sort_by { + ModelSortBy::Name => a.metadata.name.to_lowercase().cmp(&b.metadata.name.to_lowercase()), + ModelSortBy::Id => a.model_key.cmp(b.model_key), + ModelSortBy::Provider => a.provider_id.cmp(b.provider_id), + ModelSortBy::ContextLimit => { + let a_ctx = a.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0); + let b_ctx = b.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0); + a_ctx.cmp(&b_ctx) + } + ModelSortBy::InputCost => { + let a_cost = a.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0); + let b_cost = b.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0); + a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal) + } + ModelSortBy::OutputCost => { + let a_cost = a.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0); + let b_cost = b.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0); + a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal) + } + }; + + match sort_order { + SortOrder::Asc => cmp, + SortOrder::Desc => cmp.reverse(), + } + }); + + entries + } }