Add ModelFilter, ModelSortBy, SortOrder structs and list_models() method to ModelRegistry. The /api/models endpoint now accepts query params: provider, search, modality, tool_call, reasoning, has_cost, sort_by, sort_order. Response also enriched with provider_name, output_limit, modalities, tool_call, and reasoning fields.
220 lines
6.9 KiB
Rust
220 lines
6.9 KiB
Rust
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ModelRegistry {
|
|
#[serde(flatten)]
|
|
pub providers: HashMap<String, ProviderInfo>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ProviderInfo {
|
|
pub id: String,
|
|
pub name: String,
|
|
pub models: HashMap<String, ModelMetadata>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ModelMetadata {
|
|
pub id: String,
|
|
pub name: String,
|
|
pub cost: Option<ModelCost>,
|
|
pub limit: Option<ModelLimit>,
|
|
pub modalities: Option<ModelModalities>,
|
|
pub tool_call: Option<bool>,
|
|
pub reasoning: Option<bool>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ModelCost {
|
|
pub input: f64,
|
|
pub output: f64,
|
|
pub cache_read: Option<f64>,
|
|
pub cache_write: Option<f64>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ModelLimit {
|
|
pub context: u32,
|
|
pub output: u32,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ModelModalities {
|
|
pub input: Vec<String>,
|
|
pub output: Vec<String>,
|
|
}
|
|
|
|
/// A model entry paired with its provider ID, returned by listing/filtering methods.
|
|
#[derive(Debug, Clone)]
|
|
pub struct ModelEntry<'a> {
|
|
pub model_key: &'a str,
|
|
pub provider_id: &'a str,
|
|
pub provider_name: &'a str,
|
|
pub metadata: &'a ModelMetadata,
|
|
}
|
|
|
|
/// Filter criteria for listing models. All fields are optional; `None` means no filter.
|
|
#[derive(Debug, Default, Clone, Deserialize)]
|
|
pub struct ModelFilter {
|
|
/// Filter by provider ID (exact match).
|
|
pub provider: Option<String>,
|
|
/// Text search on model ID or name (case-insensitive substring).
|
|
pub search: Option<String>,
|
|
/// Filter by input modality (e.g. "image", "text").
|
|
pub modality: Option<String>,
|
|
/// Only models that support tool calling.
|
|
pub tool_call: Option<bool>,
|
|
/// Only models that support reasoning.
|
|
pub reasoning: Option<bool>,
|
|
/// Only models that have pricing data.
|
|
pub has_cost: Option<bool>,
|
|
}
|
|
|
|
/// Sort field for model listings.
|
|
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum ModelSortBy {
|
|
#[default]
|
|
Name,
|
|
Id,
|
|
Provider,
|
|
ContextLimit,
|
|
InputCost,
|
|
OutputCost,
|
|
}
|
|
|
|
/// Sort direction.
|
|
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum SortOrder {
|
|
#[default]
|
|
Asc,
|
|
Desc,
|
|
}
|
|
|
|
impl ModelRegistry {
|
|
/// Find a model by its ID (searching across all providers)
|
|
pub fn find_model(&self, model_id: &str) -> Option<&ModelMetadata> {
|
|
// First try exact match if the key in models map matches the ID
|
|
for provider in self.providers.values() {
|
|
if let Some(model) = provider.models.get(model_id) {
|
|
return Some(model);
|
|
}
|
|
}
|
|
|
|
// Try searching for the model ID inside the metadata if the key was different
|
|
for provider in self.providers.values() {
|
|
for model in provider.models.values() {
|
|
if model.id == model_id {
|
|
return Some(model);
|
|
}
|
|
}
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
/// List all models with optional filtering and sorting.
|
|
pub fn list_models(
|
|
&self,
|
|
filter: &ModelFilter,
|
|
sort_by: &ModelSortBy,
|
|
sort_order: &SortOrder,
|
|
) -> Vec<ModelEntry<'_>> {
|
|
let mut entries: Vec<ModelEntry<'_>> = Vec::new();
|
|
|
|
for (p_id, p_info) in &self.providers {
|
|
// Provider filter
|
|
if let Some(ref prov) = filter.provider {
|
|
if p_id != prov {
|
|
continue;
|
|
}
|
|
}
|
|
|
|
for (m_key, m_meta) in &p_info.models {
|
|
// Text search filter
|
|
if let Some(ref search) = filter.search {
|
|
let search_lower = search.to_lowercase();
|
|
if !m_meta.id.to_lowercase().contains(&search_lower)
|
|
&& !m_meta.name.to_lowercase().contains(&search_lower)
|
|
&& !m_key.to_lowercase().contains(&search_lower)
|
|
{
|
|
continue;
|
|
}
|
|
}
|
|
|
|
// Modality filter
|
|
if let Some(ref modality) = filter.modality {
|
|
let has_modality = m_meta
|
|
.modalities
|
|
.as_ref()
|
|
.is_some_and(|m| m.input.iter().any(|i| i.eq_ignore_ascii_case(modality)));
|
|
if !has_modality {
|
|
continue;
|
|
}
|
|
}
|
|
|
|
// Tool call filter
|
|
if let Some(tc) = filter.tool_call {
|
|
if m_meta.tool_call.unwrap_or(false) != tc {
|
|
continue;
|
|
}
|
|
}
|
|
|
|
// Reasoning filter
|
|
if let Some(r) = filter.reasoning {
|
|
if m_meta.reasoning.unwrap_or(false) != r {
|
|
continue;
|
|
}
|
|
}
|
|
|
|
// Has cost filter
|
|
if let Some(hc) = filter.has_cost {
|
|
if hc != m_meta.cost.is_some() {
|
|
continue;
|
|
}
|
|
}
|
|
|
|
entries.push(ModelEntry {
|
|
model_key: m_key,
|
|
provider_id: p_id,
|
|
provider_name: &p_info.name,
|
|
metadata: m_meta,
|
|
});
|
|
}
|
|
}
|
|
|
|
// Sort
|
|
entries.sort_by(|a, b| {
|
|
let cmp = match sort_by {
|
|
ModelSortBy::Name => a.metadata.name.to_lowercase().cmp(&b.metadata.name.to_lowercase()),
|
|
ModelSortBy::Id => a.model_key.cmp(b.model_key),
|
|
ModelSortBy::Provider => a.provider_id.cmp(b.provider_id),
|
|
ModelSortBy::ContextLimit => {
|
|
let a_ctx = a.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0);
|
|
let b_ctx = b.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0);
|
|
a_ctx.cmp(&b_ctx)
|
|
}
|
|
ModelSortBy::InputCost => {
|
|
let a_cost = a.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
|
let b_cost = b.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
|
a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
|
|
}
|
|
ModelSortBy::OutputCost => {
|
|
let a_cost = a.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
|
let b_cost = b.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
|
a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
|
|
}
|
|
};
|
|
|
|
match sort_order {
|
|
SortOrder::Asc => cmp,
|
|
SortOrder::Desc => cmp.reverse(),
|
|
}
|
|
});
|
|
|
|
entries
|
|
}
|
|
}
|