feat(models): add filtering and sorting to model registry and GET /api/models
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled

Add ModelFilter, ModelSortBy, SortOrder structs and list_models() method
to ModelRegistry. The /api/models endpoint now accepts query params:
provider, search, modality, tool_call, reasoning, has_cost, sort_by,
sort_order. Response also enriched with provider_name, output_limit,
modalities, tool_call, and reasoning fields.
This commit is contained in:
2026-03-02 08:51:33 -05:00
parent 2aad813ccd
commit 942aa23f88
2 changed files with 227 additions and 28 deletions

View File

@@ -1,5 +1,5 @@
use axum::{ use axum::{
extract::{Path, State}, extract::{Path, Query, State},
response::Json, response::Json,
}; };
use serde::Deserialize; use serde::Deserialize;
@@ -8,6 +8,7 @@ use sqlx::Row;
use std::collections::HashMap; use std::collections::HashMap;
use super::{ApiResponse, DashboardState}; use super::{ApiResponse, DashboardState};
use crate::models::registry::{ModelFilter, ModelSortBy, SortOrder};
#[derive(Deserialize)] #[derive(Deserialize)]
pub(super) struct UpdateModelRequest { pub(super) struct UpdateModelRequest {
@@ -17,10 +18,49 @@ pub(super) struct UpdateModelRequest {
pub(super) mapping: Option<String>, pub(super) mapping: Option<String>,
} }
pub(super) async fn handle_get_models(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> { /// Query parameters for `GET /api/models`.
#[derive(Debug, Deserialize, Default)]
pub(super) struct ModelListParams {
/// Filter by provider ID.
pub provider: Option<String>,
/// Text search on model ID or name.
pub search: Option<String>,
/// Filter by input modality (e.g. "image").
pub modality: Option<String>,
/// Only models that support tool calling.
pub tool_call: Option<bool>,
/// Only models that support reasoning.
pub reasoning: Option<bool>,
/// Only models that have pricing data.
pub has_cost: Option<bool>,
/// Sort field (name, id, provider, context_limit, input_cost, output_cost).
pub sort_by: Option<ModelSortBy>,
/// Sort direction (asc, desc).
pub sort_order: Option<SortOrder>,
}
pub(super) async fn handle_get_models(
State(state): State<DashboardState>,
Query(params): Query<ModelListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let registry = &state.app_state.model_registry; let registry = &state.app_state.model_registry;
let pool = &state.app_state.db_pool; let pool = &state.app_state.db_pool;
// Build filter from query params
let filter = ModelFilter {
provider: params.provider,
search: params.search,
modality: params.modality,
tool_call: params.tool_call,
reasoning: params.reasoning,
has_cost: params.has_cost,
};
let sort_by = params.sort_by.unwrap_or_default();
let sort_order = params.sort_order.unwrap_or_default();
// Get filtered and sorted model entries
let entries = registry.list_models(&filter, &sort_by, &sort_order);
// Load overrides from database // Load overrides from database
let db_models_result = let db_models_result =
sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs") sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
@@ -37,35 +77,44 @@ pub(super) async fn handle_get_models(State(state): State<DashboardState>) -> Js
let mut models_json = Vec::new(); let mut models_json = Vec::new();
for (p_id, p_info) in &registry.providers { for entry in &entries {
for (m_id, m_meta) in &p_info.models { let m_key = entry.model_key;
let mut enabled = true; let m_meta = entry.metadata;
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
let mut mapping = None::<String>;
if let Some(row) = db_models.get(m_id) { let mut enabled = true;
enabled = row.get("enabled"); let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") { let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
prompt_cost = p; let mut mapping = None::<String>;
}
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") { if let Some(row) = db_models.get(m_key) {
completion_cost = c; enabled = row.get("enabled");
} if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") {
mapping = row.get("mapping"); prompt_cost = p;
} }
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") {
models_json.push(serde_json::json!({ completion_cost = c;
"id": m_id, }
"provider": p_id, mapping = row.get("mapping");
"name": m_meta.name,
"enabled": enabled,
"prompt_cost": prompt_cost,
"completion_cost": completion_cost,
"mapping": mapping,
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
}));
} }
models_json.push(serde_json::json!({
"id": m_key,
"provider": entry.provider_id,
"provider_name": entry.provider_name,
"name": m_meta.name,
"enabled": enabled,
"prompt_cost": prompt_cost,
"completion_cost": completion_cost,
"mapping": mapping,
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
"output_limit": m_meta.limit.as_ref().map(|l| l.output).unwrap_or(0),
"modalities": m_meta.modalities.as_ref().map(|m| serde_json::json!({
"input": m.input,
"output": m.output,
})),
"tool_call": m_meta.tool_call,
"reasoning": m_meta.reasoning,
}));
} }
Json(ApiResponse::success(serde_json::json!(models_json))) Json(ApiResponse::success(serde_json::json!(models_json)))

View File

@@ -45,6 +45,54 @@ pub struct ModelModalities {
pub output: Vec<String>, pub output: Vec<String>,
} }
/// A model entry paired with its provider ID, returned by listing/filtering methods.
#[derive(Debug, Clone)]
pub struct ModelEntry<'a> {
pub model_key: &'a str,
pub provider_id: &'a str,
pub provider_name: &'a str,
pub metadata: &'a ModelMetadata,
}
/// Filter criteria for listing models. All fields are optional; `None` means no filter.
#[derive(Debug, Default, Clone, Deserialize)]
pub struct ModelFilter {
/// Filter by provider ID (exact match).
pub provider: Option<String>,
/// Text search on model ID or name (case-insensitive substring).
pub search: Option<String>,
/// Filter by input modality (e.g. "image", "text").
pub modality: Option<String>,
/// Only models that support tool calling.
pub tool_call: Option<bool>,
/// Only models that support reasoning.
pub reasoning: Option<bool>,
/// Only models that have pricing data.
pub has_cost: Option<bool>,
}
/// Sort field for model listings.
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ModelSortBy {
#[default]
Name,
Id,
Provider,
ContextLimit,
InputCost,
OutputCost,
}
/// Sort direction.
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum SortOrder {
#[default]
Asc,
Desc,
}
impl ModelRegistry { impl ModelRegistry {
/// Find a model by its ID (searching across all providers) /// Find a model by its ID (searching across all providers)
pub fn find_model(&self, model_id: &str) -> Option<&ModelMetadata> { pub fn find_model(&self, model_id: &str) -> Option<&ModelMetadata> {
@@ -66,4 +114,106 @@ impl ModelRegistry {
None None
} }
/// List all models with optional filtering and sorting.
pub fn list_models(
&self,
filter: &ModelFilter,
sort_by: &ModelSortBy,
sort_order: &SortOrder,
) -> Vec<ModelEntry<'_>> {
let mut entries: Vec<ModelEntry<'_>> = Vec::new();
for (p_id, p_info) in &self.providers {
// Provider filter
if let Some(ref prov) = filter.provider {
if p_id != prov {
continue;
}
}
for (m_key, m_meta) in &p_info.models {
// Text search filter
if let Some(ref search) = filter.search {
let search_lower = search.to_lowercase();
if !m_meta.id.to_lowercase().contains(&search_lower)
&& !m_meta.name.to_lowercase().contains(&search_lower)
&& !m_key.to_lowercase().contains(&search_lower)
{
continue;
}
}
// Modality filter
if let Some(ref modality) = filter.modality {
let has_modality = m_meta
.modalities
.as_ref()
.is_some_and(|m| m.input.iter().any(|i| i.eq_ignore_ascii_case(modality)));
if !has_modality {
continue;
}
}
// Tool call filter
if let Some(tc) = filter.tool_call {
if m_meta.tool_call.unwrap_or(false) != tc {
continue;
}
}
// Reasoning filter
if let Some(r) = filter.reasoning {
if m_meta.reasoning.unwrap_or(false) != r {
continue;
}
}
// Has cost filter
if let Some(hc) = filter.has_cost {
if hc != m_meta.cost.is_some() {
continue;
}
}
entries.push(ModelEntry {
model_key: m_key,
provider_id: p_id,
provider_name: &p_info.name,
metadata: m_meta,
});
}
}
// Sort
entries.sort_by(|a, b| {
let cmp = match sort_by {
ModelSortBy::Name => a.metadata.name.to_lowercase().cmp(&b.metadata.name.to_lowercase()),
ModelSortBy::Id => a.model_key.cmp(b.model_key),
ModelSortBy::Provider => a.provider_id.cmp(b.provider_id),
ModelSortBy::ContextLimit => {
let a_ctx = a.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0);
let b_ctx = b.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0);
a_ctx.cmp(&b_ctx)
}
ModelSortBy::InputCost => {
let a_cost = a.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
let b_cost = b.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
}
ModelSortBy::OutputCost => {
let a_cost = a.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
let b_cost = b.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
}
};
match sort_order {
SortOrder::Asc => cmp,
SortOrder::Desc => cmp.reverse(),
}
});
entries
}
} }