feat(models): add filtering and sorting to model registry and GET /api/models
Add ModelFilter, ModelSortBy, SortOrder structs and list_models() method to ModelRegistry. The /api/models endpoint now accepts query params: provider, search, modality, tool_call, reasoning, has_cost, sort_by, sort_order. Response also enriched with provider_name, output_limit, modalities, tool_call, and reasoning fields.
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
use axum::{
|
use axum::{
|
||||||
extract::{Path, State},
|
extract::{Path, Query, State},
|
||||||
response::Json,
|
response::Json,
|
||||||
};
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
@@ -8,6 +8,7 @@ use sqlx::Row;
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use super::{ApiResponse, DashboardState};
|
use super::{ApiResponse, DashboardState};
|
||||||
|
use crate::models::registry::{ModelFilter, ModelSortBy, SortOrder};
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
pub(super) struct UpdateModelRequest {
|
pub(super) struct UpdateModelRequest {
|
||||||
@@ -17,10 +18,49 @@ pub(super) struct UpdateModelRequest {
|
|||||||
pub(super) mapping: Option<String>,
|
pub(super) mapping: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) async fn handle_get_models(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
/// Query parameters for `GET /api/models`.
|
||||||
|
#[derive(Debug, Deserialize, Default)]
|
||||||
|
pub(super) struct ModelListParams {
|
||||||
|
/// Filter by provider ID.
|
||||||
|
pub provider: Option<String>,
|
||||||
|
/// Text search on model ID or name.
|
||||||
|
pub search: Option<String>,
|
||||||
|
/// Filter by input modality (e.g. "image").
|
||||||
|
pub modality: Option<String>,
|
||||||
|
/// Only models that support tool calling.
|
||||||
|
pub tool_call: Option<bool>,
|
||||||
|
/// Only models that support reasoning.
|
||||||
|
pub reasoning: Option<bool>,
|
||||||
|
/// Only models that have pricing data.
|
||||||
|
pub has_cost: Option<bool>,
|
||||||
|
/// Sort field (name, id, provider, context_limit, input_cost, output_cost).
|
||||||
|
pub sort_by: Option<ModelSortBy>,
|
||||||
|
/// Sort direction (asc, desc).
|
||||||
|
pub sort_order: Option<SortOrder>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) async fn handle_get_models(
|
||||||
|
State(state): State<DashboardState>,
|
||||||
|
Query(params): Query<ModelListParams>,
|
||||||
|
) -> Json<ApiResponse<serde_json::Value>> {
|
||||||
let registry = &state.app_state.model_registry;
|
let registry = &state.app_state.model_registry;
|
||||||
let pool = &state.app_state.db_pool;
|
let pool = &state.app_state.db_pool;
|
||||||
|
|
||||||
|
// Build filter from query params
|
||||||
|
let filter = ModelFilter {
|
||||||
|
provider: params.provider,
|
||||||
|
search: params.search,
|
||||||
|
modality: params.modality,
|
||||||
|
tool_call: params.tool_call,
|
||||||
|
reasoning: params.reasoning,
|
||||||
|
has_cost: params.has_cost,
|
||||||
|
};
|
||||||
|
let sort_by = params.sort_by.unwrap_or_default();
|
||||||
|
let sort_order = params.sort_order.unwrap_or_default();
|
||||||
|
|
||||||
|
// Get filtered and sorted model entries
|
||||||
|
let entries = registry.list_models(&filter, &sort_by, &sort_order);
|
||||||
|
|
||||||
// Load overrides from database
|
// Load overrides from database
|
||||||
let db_models_result =
|
let db_models_result =
|
||||||
sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
|
sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
|
||||||
@@ -37,35 +77,44 @@ pub(super) async fn handle_get_models(State(state): State<DashboardState>) -> Js
|
|||||||
|
|
||||||
let mut models_json = Vec::new();
|
let mut models_json = Vec::new();
|
||||||
|
|
||||||
for (p_id, p_info) in ®istry.providers {
|
for entry in &entries {
|
||||||
for (m_id, m_meta) in &p_info.models {
|
let m_key = entry.model_key;
|
||||||
let mut enabled = true;
|
let m_meta = entry.metadata;
|
||||||
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
|
||||||
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
|
||||||
let mut mapping = None::<String>;
|
|
||||||
|
|
||||||
if let Some(row) = db_models.get(m_id) {
|
let mut enabled = true;
|
||||||
enabled = row.get("enabled");
|
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
||||||
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") {
|
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
||||||
prompt_cost = p;
|
let mut mapping = None::<String>;
|
||||||
}
|
|
||||||
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") {
|
if let Some(row) = db_models.get(m_key) {
|
||||||
completion_cost = c;
|
enabled = row.get("enabled");
|
||||||
}
|
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") {
|
||||||
mapping = row.get("mapping");
|
prompt_cost = p;
|
||||||
}
|
}
|
||||||
|
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") {
|
||||||
models_json.push(serde_json::json!({
|
completion_cost = c;
|
||||||
"id": m_id,
|
}
|
||||||
"provider": p_id,
|
mapping = row.get("mapping");
|
||||||
"name": m_meta.name,
|
|
||||||
"enabled": enabled,
|
|
||||||
"prompt_cost": prompt_cost,
|
|
||||||
"completion_cost": completion_cost,
|
|
||||||
"mapping": mapping,
|
|
||||||
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
|
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
models_json.push(serde_json::json!({
|
||||||
|
"id": m_key,
|
||||||
|
"provider": entry.provider_id,
|
||||||
|
"provider_name": entry.provider_name,
|
||||||
|
"name": m_meta.name,
|
||||||
|
"enabled": enabled,
|
||||||
|
"prompt_cost": prompt_cost,
|
||||||
|
"completion_cost": completion_cost,
|
||||||
|
"mapping": mapping,
|
||||||
|
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
|
||||||
|
"output_limit": m_meta.limit.as_ref().map(|l| l.output).unwrap_or(0),
|
||||||
|
"modalities": m_meta.modalities.as_ref().map(|m| serde_json::json!({
|
||||||
|
"input": m.input,
|
||||||
|
"output": m.output,
|
||||||
|
})),
|
||||||
|
"tool_call": m_meta.tool_call,
|
||||||
|
"reasoning": m_meta.reasoning,
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!(models_json)))
|
Json(ApiResponse::success(serde_json::json!(models_json)))
|
||||||
|
|||||||
@@ -45,6 +45,54 @@ pub struct ModelModalities {
|
|||||||
pub output: Vec<String>,
|
pub output: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A model entry paired with its provider ID, returned by listing/filtering methods.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ModelEntry<'a> {
|
||||||
|
pub model_key: &'a str,
|
||||||
|
pub provider_id: &'a str,
|
||||||
|
pub provider_name: &'a str,
|
||||||
|
pub metadata: &'a ModelMetadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Filter criteria for listing models. All fields are optional; `None` means no filter.
|
||||||
|
#[derive(Debug, Default, Clone, Deserialize)]
|
||||||
|
pub struct ModelFilter {
|
||||||
|
/// Filter by provider ID (exact match).
|
||||||
|
pub provider: Option<String>,
|
||||||
|
/// Text search on model ID or name (case-insensitive substring).
|
||||||
|
pub search: Option<String>,
|
||||||
|
/// Filter by input modality (e.g. "image", "text").
|
||||||
|
pub modality: Option<String>,
|
||||||
|
/// Only models that support tool calling.
|
||||||
|
pub tool_call: Option<bool>,
|
||||||
|
/// Only models that support reasoning.
|
||||||
|
pub reasoning: Option<bool>,
|
||||||
|
/// Only models that have pricing data.
|
||||||
|
pub has_cost: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sort field for model listings.
|
||||||
|
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum ModelSortBy {
|
||||||
|
#[default]
|
||||||
|
Name,
|
||||||
|
Id,
|
||||||
|
Provider,
|
||||||
|
ContextLimit,
|
||||||
|
InputCost,
|
||||||
|
OutputCost,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sort direction.
|
||||||
|
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum SortOrder {
|
||||||
|
#[default]
|
||||||
|
Asc,
|
||||||
|
Desc,
|
||||||
|
}
|
||||||
|
|
||||||
impl ModelRegistry {
|
impl ModelRegistry {
|
||||||
/// Find a model by its ID (searching across all providers)
|
/// Find a model by its ID (searching across all providers)
|
||||||
pub fn find_model(&self, model_id: &str) -> Option<&ModelMetadata> {
|
pub fn find_model(&self, model_id: &str) -> Option<&ModelMetadata> {
|
||||||
@@ -66,4 +114,106 @@ impl ModelRegistry {
|
|||||||
|
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// List all models with optional filtering and sorting.
|
||||||
|
pub fn list_models(
|
||||||
|
&self,
|
||||||
|
filter: &ModelFilter,
|
||||||
|
sort_by: &ModelSortBy,
|
||||||
|
sort_order: &SortOrder,
|
||||||
|
) -> Vec<ModelEntry<'_>> {
|
||||||
|
let mut entries: Vec<ModelEntry<'_>> = Vec::new();
|
||||||
|
|
||||||
|
for (p_id, p_info) in &self.providers {
|
||||||
|
// Provider filter
|
||||||
|
if let Some(ref prov) = filter.provider {
|
||||||
|
if p_id != prov {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (m_key, m_meta) in &p_info.models {
|
||||||
|
// Text search filter
|
||||||
|
if let Some(ref search) = filter.search {
|
||||||
|
let search_lower = search.to_lowercase();
|
||||||
|
if !m_meta.id.to_lowercase().contains(&search_lower)
|
||||||
|
&& !m_meta.name.to_lowercase().contains(&search_lower)
|
||||||
|
&& !m_key.to_lowercase().contains(&search_lower)
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modality filter
|
||||||
|
if let Some(ref modality) = filter.modality {
|
||||||
|
let has_modality = m_meta
|
||||||
|
.modalities
|
||||||
|
.as_ref()
|
||||||
|
.is_some_and(|m| m.input.iter().any(|i| i.eq_ignore_ascii_case(modality)));
|
||||||
|
if !has_modality {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool call filter
|
||||||
|
if let Some(tc) = filter.tool_call {
|
||||||
|
if m_meta.tool_call.unwrap_or(false) != tc {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reasoning filter
|
||||||
|
if let Some(r) = filter.reasoning {
|
||||||
|
if m_meta.reasoning.unwrap_or(false) != r {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Has cost filter
|
||||||
|
if let Some(hc) = filter.has_cost {
|
||||||
|
if hc != m_meta.cost.is_some() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
entries.push(ModelEntry {
|
||||||
|
model_key: m_key,
|
||||||
|
provider_id: p_id,
|
||||||
|
provider_name: &p_info.name,
|
||||||
|
metadata: m_meta,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort
|
||||||
|
entries.sort_by(|a, b| {
|
||||||
|
let cmp = match sort_by {
|
||||||
|
ModelSortBy::Name => a.metadata.name.to_lowercase().cmp(&b.metadata.name.to_lowercase()),
|
||||||
|
ModelSortBy::Id => a.model_key.cmp(b.model_key),
|
||||||
|
ModelSortBy::Provider => a.provider_id.cmp(b.provider_id),
|
||||||
|
ModelSortBy::ContextLimit => {
|
||||||
|
let a_ctx = a.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0);
|
||||||
|
let b_ctx = b.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0);
|
||||||
|
a_ctx.cmp(&b_ctx)
|
||||||
|
}
|
||||||
|
ModelSortBy::InputCost => {
|
||||||
|
let a_cost = a.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
||||||
|
let b_cost = b.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
||||||
|
a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
}
|
||||||
|
ModelSortBy::OutputCost => {
|
||||||
|
let a_cost = a.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
||||||
|
let b_cost = b.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
||||||
|
a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match sort_order {
|
||||||
|
SortOrder::Asc => cmp,
|
||||||
|
SortOrder::Desc => cmp.reverse(),
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
entries
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user