feat(models): add filtering and sorting to model registry and GET /api/models
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled

Add ModelFilter, ModelSortBy, SortOrder structs and list_models() method
to ModelRegistry. The /api/models endpoint now accepts query params:
provider, search, modality, tool_call, reasoning, has_cost, sort_by,
sort_order. Response also enriched with provider_name, output_limit,
modalities, tool_call, and reasoning fields.
This commit is contained in:
2026-03-02 08:51:33 -05:00
parent 2aad813ccd
commit 942aa23f88
2 changed files with 227 additions and 28 deletions

View File

@@ -1,5 +1,5 @@
use axum::{
extract::{Path, State},
extract::{Path, Query, State},
response::Json,
};
use serde::Deserialize;
@@ -8,6 +8,7 @@ use sqlx::Row;
use std::collections::HashMap;
use super::{ApiResponse, DashboardState};
use crate::models::registry::{ModelFilter, ModelSortBy, SortOrder};
#[derive(Deserialize)]
pub(super) struct UpdateModelRequest {
@@ -17,10 +18,49 @@ pub(super) struct UpdateModelRequest {
pub(super) mapping: Option<String>,
}
pub(super) async fn handle_get_models(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
/// Query parameters for `GET /api/models`.
#[derive(Debug, Deserialize, Default)]
pub(super) struct ModelListParams {
/// Filter by provider ID.
pub provider: Option<String>,
/// Text search on model ID or name.
pub search: Option<String>,
/// Filter by input modality (e.g. "image").
pub modality: Option<String>,
/// Only models that support tool calling.
pub tool_call: Option<bool>,
/// Only models that support reasoning.
pub reasoning: Option<bool>,
/// Only models that have pricing data.
pub has_cost: Option<bool>,
/// Sort field (name, id, provider, context_limit, input_cost, output_cost).
pub sort_by: Option<ModelSortBy>,
/// Sort direction (asc, desc).
pub sort_order: Option<SortOrder>,
}
pub(super) async fn handle_get_models(
State(state): State<DashboardState>,
Query(params): Query<ModelListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let registry = &state.app_state.model_registry;
let pool = &state.app_state.db_pool;
// Build filter from query params
let filter = ModelFilter {
provider: params.provider,
search: params.search,
modality: params.modality,
tool_call: params.tool_call,
reasoning: params.reasoning,
has_cost: params.has_cost,
};
let sort_by = params.sort_by.unwrap_or_default();
let sort_order = params.sort_order.unwrap_or_default();
// Get filtered and sorted model entries
let entries = registry.list_models(&filter, &sort_by, &sort_order);
// Load overrides from database
let db_models_result =
sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
@@ -37,35 +77,44 @@ pub(super) async fn handle_get_models(State(state): State<DashboardState>) -> Js
let mut models_json = Vec::new();
for (p_id, p_info) in &registry.providers {
for (m_id, m_meta) in &p_info.models {
let mut enabled = true;
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
let mut mapping = None::<String>;
for entry in &entries {
let m_key = entry.model_key;
let m_meta = entry.metadata;
if let Some(row) = db_models.get(m_id) {
enabled = row.get("enabled");
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") {
prompt_cost = p;
}
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") {
completion_cost = c;
}
mapping = row.get("mapping");
let mut enabled = true;
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
let mut mapping = None::<String>;
if let Some(row) = db_models.get(m_key) {
enabled = row.get("enabled");
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") {
prompt_cost = p;
}
models_json.push(serde_json::json!({
"id": m_id,
"provider": p_id,
"name": m_meta.name,
"enabled": enabled,
"prompt_cost": prompt_cost,
"completion_cost": completion_cost,
"mapping": mapping,
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
}));
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") {
completion_cost = c;
}
mapping = row.get("mapping");
}
models_json.push(serde_json::json!({
"id": m_key,
"provider": entry.provider_id,
"provider_name": entry.provider_name,
"name": m_meta.name,
"enabled": enabled,
"prompt_cost": prompt_cost,
"completion_cost": completion_cost,
"mapping": mapping,
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
"output_limit": m_meta.limit.as_ref().map(|l| l.output).unwrap_or(0),
"modalities": m_meta.modalities.as_ref().map(|m| serde_json::json!({
"input": m.input,
"output": m.output,
})),
"tool_call": m_meta.tool_call,
"reasoning": m_meta.reasoning,
}));
}
Json(ApiResponse::success(serde_json::json!(models_json)))