use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{broadcast, RwLock}; use tracing::warn; use crate::{ client::ClientManager, config::AppConfig, database::DbPool, logging::RequestLogger, models::registry::ModelRegistry, providers::ProviderManager, rate_limiting::RateLimitManager, }; /// Cached model configuration entry #[derive(Debug, Clone)] pub struct CachedModelConfig { pub enabled: bool, pub mapping: Option, pub prompt_cost_per_m: Option, pub completion_cost_per_m: Option, } /// In-memory cache for model_configs table. /// Refreshes periodically to avoid hitting SQLite on every request. #[derive(Clone)] pub struct ModelConfigCache { cache: Arc>>, db_pool: DbPool, } impl ModelConfigCache { pub fn new(db_pool: DbPool) -> Self { Self { cache: Arc::new(RwLock::new(HashMap::new())), db_pool, } } /// Load all model configs from the database into cache pub async fn refresh(&self) { match sqlx::query_as::<_, (String, bool, Option, Option, Option)>( "SELECT id, enabled, mapping, prompt_cost_per_m, completion_cost_per_m FROM model_configs", ) .fetch_all(&self.db_pool) .await { Ok(rows) => { let mut map = HashMap::with_capacity(rows.len()); for (id, enabled, mapping, prompt_cost, completion_cost) in rows { map.insert( id, CachedModelConfig { enabled, mapping, prompt_cost_per_m: prompt_cost, completion_cost_per_m: completion_cost, }, ); } *self.cache.write().await = map; } Err(e) => { warn!("Failed to refresh model config cache: {}", e); } } } /// Get a cached model config. Returns None if not in cache (model is unconfigured). pub async fn get(&self, model: &str) -> Option { self.cache.read().await.get(model).cloned() } /// Invalidate cache — call this after dashboard writes to model_configs pub async fn invalidate(&self) { self.refresh().await; } /// Start a background task that refreshes the cache every `interval` seconds pub fn start_refresh_task(self, interval_secs: u64) { tokio::spawn(async move { let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs)); loop { interval.tick().await; self.refresh().await; } }); } } /// Shared application state #[derive(Clone)] pub struct AppState { pub config: Arc, pub provider_manager: ProviderManager, pub db_pool: DbPool, pub rate_limit_manager: Arc, pub client_manager: Arc, pub request_logger: Arc, pub model_registry: Arc, pub model_config_cache: ModelConfigCache, pub dashboard_tx: broadcast::Sender, pub auth_tokens: Vec, } impl AppState { pub fn new( config: Arc, provider_manager: ProviderManager, db_pool: DbPool, rate_limit_manager: RateLimitManager, model_registry: ModelRegistry, auth_tokens: Vec, ) -> Self { let client_manager = Arc::new(ClientManager::new(db_pool.clone())); let (dashboard_tx, _) = broadcast::channel(100); let request_logger = Arc::new(RequestLogger::new(db_pool.clone(), dashboard_tx.clone())); let model_config_cache = ModelConfigCache::new(db_pool.clone()); Self { config, provider_manager, db_pool, rate_limit_manager: Arc::new(rate_limit_manager), client_manager, request_logger, model_registry: Arc::new(model_registry), model_config_cache, dashboard_tx, auth_tokens, } } }