- Add in-memory ModelConfigCache (30s refresh, explicit invalidation) replacing 2 SQLite queries per request (model lookup + cost override) - Configure all 5 provider HTTP clients with proper timeouts (300s), connection pooling (4 idle/host, 90s idle timeout), and TCP keepalive - Move client_usage update to tokio::spawn in non-streaming path - Use fast chars/4 heuristic for token estimation on large inputs (>1KB) - Generate single UUID/timestamp per SSE stream instead of per chunk - Add shared LazyLock<Client> for image fetching in multimodal module - Add proxy overhead timing instrumentation for both request paths - Fix test helper to include new model_config_cache field
130 lines
4.2 KiB
Rust
130 lines
4.2 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use tokio::sync::{broadcast, RwLock};
|
|
use tracing::warn;
|
|
|
|
use crate::{
|
|
client::ClientManager, config::AppConfig, database::DbPool, logging::RequestLogger,
|
|
models::registry::ModelRegistry, providers::ProviderManager, rate_limiting::RateLimitManager,
|
|
};
|
|
|
|
/// Cached model configuration entry
|
|
#[derive(Debug, Clone)]
|
|
pub struct CachedModelConfig {
|
|
pub enabled: bool,
|
|
pub mapping: Option<String>,
|
|
pub prompt_cost_per_m: Option<f64>,
|
|
pub completion_cost_per_m: Option<f64>,
|
|
}
|
|
|
|
/// In-memory cache for model_configs table.
|
|
/// Refreshes periodically to avoid hitting SQLite on every request.
|
|
#[derive(Clone)]
|
|
pub struct ModelConfigCache {
|
|
cache: Arc<RwLock<HashMap<String, CachedModelConfig>>>,
|
|
db_pool: DbPool,
|
|
}
|
|
|
|
impl ModelConfigCache {
|
|
pub fn new(db_pool: DbPool) -> Self {
|
|
Self {
|
|
cache: Arc::new(RwLock::new(HashMap::new())),
|
|
db_pool,
|
|
}
|
|
}
|
|
|
|
/// Load all model configs from the database into cache
|
|
pub async fn refresh(&self) {
|
|
match sqlx::query_as::<_, (String, bool, Option<String>, Option<f64>, Option<f64>)>(
|
|
"SELECT id, enabled, mapping, prompt_cost_per_m, completion_cost_per_m FROM model_configs",
|
|
)
|
|
.fetch_all(&self.db_pool)
|
|
.await
|
|
{
|
|
Ok(rows) => {
|
|
let mut map = HashMap::with_capacity(rows.len());
|
|
for (id, enabled, mapping, prompt_cost, completion_cost) in rows {
|
|
map.insert(
|
|
id,
|
|
CachedModelConfig {
|
|
enabled,
|
|
mapping,
|
|
prompt_cost_per_m: prompt_cost,
|
|
completion_cost_per_m: completion_cost,
|
|
},
|
|
);
|
|
}
|
|
*self.cache.write().await = map;
|
|
}
|
|
Err(e) => {
|
|
warn!("Failed to refresh model config cache: {}", e);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Get a cached model config. Returns None if not in cache (model is unconfigured).
|
|
pub async fn get(&self, model: &str) -> Option<CachedModelConfig> {
|
|
self.cache.read().await.get(model).cloned()
|
|
}
|
|
|
|
/// Invalidate cache — call this after dashboard writes to model_configs
|
|
pub async fn invalidate(&self) {
|
|
self.refresh().await;
|
|
}
|
|
|
|
/// Start a background task that refreshes the cache every `interval` seconds
|
|
pub fn start_refresh_task(self, interval_secs: u64) {
|
|
tokio::spawn(async move {
|
|
let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
|
|
loop {
|
|
interval.tick().await;
|
|
self.refresh().await;
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
/// Shared application state
|
|
#[derive(Clone)]
|
|
pub struct AppState {
|
|
pub config: Arc<AppConfig>,
|
|
pub provider_manager: ProviderManager,
|
|
pub db_pool: DbPool,
|
|
pub rate_limit_manager: Arc<RateLimitManager>,
|
|
pub client_manager: Arc<ClientManager>,
|
|
pub request_logger: Arc<RequestLogger>,
|
|
pub model_registry: Arc<ModelRegistry>,
|
|
pub model_config_cache: ModelConfigCache,
|
|
pub dashboard_tx: broadcast::Sender<serde_json::Value>,
|
|
pub auth_tokens: Vec<String>,
|
|
}
|
|
|
|
impl AppState {
|
|
pub fn new(
|
|
config: Arc<AppConfig>,
|
|
provider_manager: ProviderManager,
|
|
db_pool: DbPool,
|
|
rate_limit_manager: RateLimitManager,
|
|
model_registry: ModelRegistry,
|
|
auth_tokens: Vec<String>,
|
|
) -> Self {
|
|
let client_manager = Arc::new(ClientManager::new(db_pool.clone()));
|
|
let (dashboard_tx, _) = broadcast::channel(100);
|
|
let request_logger = Arc::new(RequestLogger::new(db_pool.clone(), dashboard_tx.clone()));
|
|
let model_config_cache = ModelConfigCache::new(db_pool.clone());
|
|
|
|
Self {
|
|
config,
|
|
provider_manager,
|
|
db_pool,
|
|
rate_limit_manager: Arc::new(rate_limit_manager),
|
|
client_manager,
|
|
request_logger,
|
|
model_registry: Arc::new(model_registry),
|
|
model_config_cache,
|
|
dashboard_tx,
|
|
auth_tokens,
|
|
}
|
|
}
|
|
}
|