use serde::{Deserialize, Serialize}; use std::collections::HashMap; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelRegistry { #[serde(flatten)] pub providers: HashMap, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ProviderInfo { pub id: String, pub name: String, pub models: HashMap, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelMetadata { pub id: String, pub name: String, pub cost: Option, pub limit: Option, pub modalities: Option, pub tool_call: Option, pub reasoning: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelCost { pub input: f64, pub output: f64, pub cache_read: Option, pub cache_write: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelLimit { pub context: u32, pub output: u32, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelModalities { pub input: Vec, pub output: Vec, } /// A model entry paired with its provider ID, returned by listing/filtering methods. #[derive(Debug, Clone)] pub struct ModelEntry<'a> { pub model_key: &'a str, pub provider_id: &'a str, pub provider_name: &'a str, pub metadata: &'a ModelMetadata, } /// Filter criteria for listing models. All fields are optional; `None` means no filter. #[derive(Debug, Default, Clone, Deserialize)] pub struct ModelFilter { /// Filter by provider ID (exact match). pub provider: Option, /// Text search on model ID or name (case-insensitive substring). pub search: Option, /// Filter by input modality (e.g. "image", "text"). pub modality: Option, /// Only models that support tool calling. pub tool_call: Option, /// Only models that support reasoning. pub reasoning: Option, /// Only models that have pricing data. pub has_cost: Option, } /// Sort field for model listings. #[derive(Debug, Clone, Deserialize, Default, PartialEq)] #[serde(rename_all = "snake_case")] pub enum ModelSortBy { #[default] Name, Id, Provider, ContextLimit, InputCost, OutputCost, } /// Sort direction. #[derive(Debug, Clone, Deserialize, Default, PartialEq)] #[serde(rename_all = "snake_case")] pub enum SortOrder { #[default] Asc, Desc, } impl ModelRegistry { /// Find a model by its ID (searching across all providers) pub fn find_model(&self, model_id: &str) -> Option<&ModelMetadata> { // First try exact match if the key in models map matches the ID for provider in self.providers.values() { if let Some(model) = provider.models.get(model_id) { return Some(model); } } // Try searching for the model ID inside the metadata if the key was different for provider in self.providers.values() { for model in provider.models.values() { if model.id == model_id { return Some(model); } } } None } /// List all models with optional filtering and sorting. pub fn list_models( &self, filter: &ModelFilter, sort_by: &ModelSortBy, sort_order: &SortOrder, ) -> Vec> { let mut entries: Vec> = Vec::new(); for (p_id, p_info) in &self.providers { // Provider filter if let Some(ref prov) = filter.provider { if p_id != prov { continue; } } for (m_key, m_meta) in &p_info.models { // Text search filter if let Some(ref search) = filter.search { let search_lower = search.to_lowercase(); if !m_meta.id.to_lowercase().contains(&search_lower) && !m_meta.name.to_lowercase().contains(&search_lower) && !m_key.to_lowercase().contains(&search_lower) { continue; } } // Modality filter if let Some(ref modality) = filter.modality { let has_modality = m_meta .modalities .as_ref() .is_some_and(|m| m.input.iter().any(|i| i.eq_ignore_ascii_case(modality))); if !has_modality { continue; } } // Tool call filter if let Some(tc) = filter.tool_call { if m_meta.tool_call.unwrap_or(false) != tc { continue; } } // Reasoning filter if let Some(r) = filter.reasoning { if m_meta.reasoning.unwrap_or(false) != r { continue; } } // Has cost filter if let Some(hc) = filter.has_cost { if hc != m_meta.cost.is_some() { continue; } } entries.push(ModelEntry { model_key: m_key, provider_id: p_id, provider_name: &p_info.name, metadata: m_meta, }); } } // Sort entries.sort_by(|a, b| { let cmp = match sort_by { ModelSortBy::Name => a.metadata.name.to_lowercase().cmp(&b.metadata.name.to_lowercase()), ModelSortBy::Id => a.model_key.cmp(b.model_key), ModelSortBy::Provider => a.provider_id.cmp(b.provider_id), ModelSortBy::ContextLimit => { let a_ctx = a.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0); let b_ctx = b.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0); a_ctx.cmp(&b_ctx) } ModelSortBy::InputCost => { let a_cost = a.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0); let b_cost = b.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0); a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal) } ModelSortBy::OutputCost => { let a_cost = a.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0); let b_cost = b.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0); a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal) } }; match sort_order { SortOrder::Asc => cmp, SortOrder::Desc => cmp.reverse(), } }); entries } }