feat(models): add filtering and sorting to model registry and GET /api/models
Add ModelFilter, ModelSortBy, SortOrder structs and list_models() method to ModelRegistry. The /api/models endpoint now accepts query params: provider, search, modality, tool_call, reasoning, has_cost, sort_by, sort_order. Response also enriched with provider_name, output_limit, modalities, tool_call, and reasoning fields.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
extract::{Path, Query, State},
|
||||
response::Json,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
@@ -8,6 +8,7 @@ use sqlx::Row;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
use crate::models::registry::{ModelFilter, ModelSortBy, SortOrder};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct UpdateModelRequest {
|
||||
@@ -17,10 +18,49 @@ pub(super) struct UpdateModelRequest {
|
||||
pub(super) mapping: Option<String>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_models(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
/// Query parameters for `GET /api/models`.
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
pub(super) struct ModelListParams {
|
||||
/// Filter by provider ID.
|
||||
pub provider: Option<String>,
|
||||
/// Text search on model ID or name.
|
||||
pub search: Option<String>,
|
||||
/// Filter by input modality (e.g. "image").
|
||||
pub modality: Option<String>,
|
||||
/// Only models that support tool calling.
|
||||
pub tool_call: Option<bool>,
|
||||
/// Only models that support reasoning.
|
||||
pub reasoning: Option<bool>,
|
||||
/// Only models that have pricing data.
|
||||
pub has_cost: Option<bool>,
|
||||
/// Sort field (name, id, provider, context_limit, input_cost, output_cost).
|
||||
pub sort_by: Option<ModelSortBy>,
|
||||
/// Sort direction (asc, desc).
|
||||
pub sort_order: Option<SortOrder>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_models(
|
||||
State(state): State<DashboardState>,
|
||||
Query(params): Query<ModelListParams>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let registry = &state.app_state.model_registry;
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Build filter from query params
|
||||
let filter = ModelFilter {
|
||||
provider: params.provider,
|
||||
search: params.search,
|
||||
modality: params.modality,
|
||||
tool_call: params.tool_call,
|
||||
reasoning: params.reasoning,
|
||||
has_cost: params.has_cost,
|
||||
};
|
||||
let sort_by = params.sort_by.unwrap_or_default();
|
||||
let sort_order = params.sort_order.unwrap_or_default();
|
||||
|
||||
// Get filtered and sorted model entries
|
||||
let entries = registry.list_models(&filter, &sort_by, &sort_order);
|
||||
|
||||
// Load overrides from database
|
||||
let db_models_result =
|
||||
sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
|
||||
@@ -37,35 +77,44 @@ pub(super) async fn handle_get_models(State(state): State<DashboardState>) -> Js
|
||||
|
||||
let mut models_json = Vec::new();
|
||||
|
||||
for (p_id, p_info) in ®istry.providers {
|
||||
for (m_id, m_meta) in &p_info.models {
|
||||
let mut enabled = true;
|
||||
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
||||
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
||||
let mut mapping = None::<String>;
|
||||
for entry in &entries {
|
||||
let m_key = entry.model_key;
|
||||
let m_meta = entry.metadata;
|
||||
|
||||
if let Some(row) = db_models.get(m_id) {
|
||||
enabled = row.get("enabled");
|
||||
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") {
|
||||
prompt_cost = p;
|
||||
}
|
||||
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") {
|
||||
completion_cost = c;
|
||||
}
|
||||
mapping = row.get("mapping");
|
||||
let mut enabled = true;
|
||||
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
||||
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
||||
let mut mapping = None::<String>;
|
||||
|
||||
if let Some(row) = db_models.get(m_key) {
|
||||
enabled = row.get("enabled");
|
||||
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") {
|
||||
prompt_cost = p;
|
||||
}
|
||||
|
||||
models_json.push(serde_json::json!({
|
||||
"id": m_id,
|
||||
"provider": p_id,
|
||||
"name": m_meta.name,
|
||||
"enabled": enabled,
|
||||
"prompt_cost": prompt_cost,
|
||||
"completion_cost": completion_cost,
|
||||
"mapping": mapping,
|
||||
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
|
||||
}));
|
||||
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") {
|
||||
completion_cost = c;
|
||||
}
|
||||
mapping = row.get("mapping");
|
||||
}
|
||||
|
||||
models_json.push(serde_json::json!({
|
||||
"id": m_key,
|
||||
"provider": entry.provider_id,
|
||||
"provider_name": entry.provider_name,
|
||||
"name": m_meta.name,
|
||||
"enabled": enabled,
|
||||
"prompt_cost": prompt_cost,
|
||||
"completion_cost": completion_cost,
|
||||
"mapping": mapping,
|
||||
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
|
||||
"output_limit": m_meta.limit.as_ref().map(|l| l.output).unwrap_or(0),
|
||||
"modalities": m_meta.modalities.as_ref().map(|m| serde_json::json!({
|
||||
"input": m.input,
|
||||
"output": m.output,
|
||||
})),
|
||||
"tool_call": m_meta.tool_call,
|
||||
"reasoning": m_meta.reasoning,
|
||||
}));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(models_json)))
|
||||
|
||||
@@ -45,6 +45,54 @@ pub struct ModelModalities {
|
||||
pub output: Vec<String>,
|
||||
}
|
||||
|
||||
/// A model entry paired with its provider ID, returned by listing/filtering methods.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelEntry<'a> {
|
||||
pub model_key: &'a str,
|
||||
pub provider_id: &'a str,
|
||||
pub provider_name: &'a str,
|
||||
pub metadata: &'a ModelMetadata,
|
||||
}
|
||||
|
||||
/// Filter criteria for listing models. All fields are optional; `None` means no filter.
|
||||
#[derive(Debug, Default, Clone, Deserialize)]
|
||||
pub struct ModelFilter {
|
||||
/// Filter by provider ID (exact match).
|
||||
pub provider: Option<String>,
|
||||
/// Text search on model ID or name (case-insensitive substring).
|
||||
pub search: Option<String>,
|
||||
/// Filter by input modality (e.g. "image", "text").
|
||||
pub modality: Option<String>,
|
||||
/// Only models that support tool calling.
|
||||
pub tool_call: Option<bool>,
|
||||
/// Only models that support reasoning.
|
||||
pub reasoning: Option<bool>,
|
||||
/// Only models that have pricing data.
|
||||
pub has_cost: Option<bool>,
|
||||
}
|
||||
|
||||
/// Sort field for model listings.
|
||||
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ModelSortBy {
|
||||
#[default]
|
||||
Name,
|
||||
Id,
|
||||
Provider,
|
||||
ContextLimit,
|
||||
InputCost,
|
||||
OutputCost,
|
||||
}
|
||||
|
||||
/// Sort direction.
|
||||
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SortOrder {
|
||||
#[default]
|
||||
Asc,
|
||||
Desc,
|
||||
}
|
||||
|
||||
impl ModelRegistry {
|
||||
/// Find a model by its ID (searching across all providers)
|
||||
pub fn find_model(&self, model_id: &str) -> Option<&ModelMetadata> {
|
||||
@@ -66,4 +114,106 @@ impl ModelRegistry {
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// List all models with optional filtering and sorting.
|
||||
pub fn list_models(
|
||||
&self,
|
||||
filter: &ModelFilter,
|
||||
sort_by: &ModelSortBy,
|
||||
sort_order: &SortOrder,
|
||||
) -> Vec<ModelEntry<'_>> {
|
||||
let mut entries: Vec<ModelEntry<'_>> = Vec::new();
|
||||
|
||||
for (p_id, p_info) in &self.providers {
|
||||
// Provider filter
|
||||
if let Some(ref prov) = filter.provider {
|
||||
if p_id != prov {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for (m_key, m_meta) in &p_info.models {
|
||||
// Text search filter
|
||||
if let Some(ref search) = filter.search {
|
||||
let search_lower = search.to_lowercase();
|
||||
if !m_meta.id.to_lowercase().contains(&search_lower)
|
||||
&& !m_meta.name.to_lowercase().contains(&search_lower)
|
||||
&& !m_key.to_lowercase().contains(&search_lower)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Modality filter
|
||||
if let Some(ref modality) = filter.modality {
|
||||
let has_modality = m_meta
|
||||
.modalities
|
||||
.as_ref()
|
||||
.is_some_and(|m| m.input.iter().any(|i| i.eq_ignore_ascii_case(modality)));
|
||||
if !has_modality {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Tool call filter
|
||||
if let Some(tc) = filter.tool_call {
|
||||
if m_meta.tool_call.unwrap_or(false) != tc {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Reasoning filter
|
||||
if let Some(r) = filter.reasoning {
|
||||
if m_meta.reasoning.unwrap_or(false) != r {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Has cost filter
|
||||
if let Some(hc) = filter.has_cost {
|
||||
if hc != m_meta.cost.is_some() {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
entries.push(ModelEntry {
|
||||
model_key: m_key,
|
||||
provider_id: p_id,
|
||||
provider_name: &p_info.name,
|
||||
metadata: m_meta,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Sort
|
||||
entries.sort_by(|a, b| {
|
||||
let cmp = match sort_by {
|
||||
ModelSortBy::Name => a.metadata.name.to_lowercase().cmp(&b.metadata.name.to_lowercase()),
|
||||
ModelSortBy::Id => a.model_key.cmp(b.model_key),
|
||||
ModelSortBy::Provider => a.provider_id.cmp(b.provider_id),
|
||||
ModelSortBy::ContextLimit => {
|
||||
let a_ctx = a.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0);
|
||||
let b_ctx = b.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0);
|
||||
a_ctx.cmp(&b_ctx)
|
||||
}
|
||||
ModelSortBy::InputCost => {
|
||||
let a_cost = a.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
||||
let b_cost = b.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
||||
a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
|
||||
}
|
||||
ModelSortBy::OutputCost => {
|
||||
let a_cost = a.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
||||
let b_cost = b.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
||||
a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
|
||||
}
|
||||
};
|
||||
|
||||
match sort_order {
|
||||
SortOrder::Asc => cmp,
|
||||
SortOrder::Desc => cmp.reverse(),
|
||||
}
|
||||
});
|
||||
|
||||
entries
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user