diff --git a/src/dashboard/models.rs b/src/dashboard/models.rs index 93f8e499..b75c3944 100644 --- a/src/dashboard/models.rs +++ b/src/dashboard/models.rs @@ -159,7 +159,11 @@ pub(super) async fn handle_update_model( .await; match result { - Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" }))), + Ok(_) => { + // Invalidate the in-memory cache so the proxy picks up the change immediately + state.app_state.model_config_cache.invalidate().await; + Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" }))) + } Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))), } } diff --git a/src/lib.rs b/src/lib.rs index 6a7645f5..a2c6aca9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -135,6 +135,7 @@ pub mod test_utils { client_manager, request_logger: Arc::new(crate::logging::RequestLogger::new(pool.clone(), dashboard_tx.clone())), model_registry: Arc::new(model_registry), + model_config_cache: crate::state::ModelConfigCache::new(pool.clone()), dashboard_tx, auth_tokens: vec![], }) diff --git a/src/main.rs b/src/main.rs index 5aefdc18..992b3a5a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -65,6 +65,11 @@ async fn main() -> Result<()> { config.server.auth_tokens.clone(), ); + // Initialize model config cache and start background refresh (every 30s) + state.model_config_cache.refresh().await; + state.model_config_cache.clone().start_refresh_task(30); + info!("Model config cache initialized"); + // Create application router let app = Router::new() .route("/health", get(health_check)) diff --git a/src/multimodal/mod.rs b/src/multimodal/mod.rs index cefedf57..94fc09fc 100644 --- a/src/multimodal/mod.rs +++ b/src/multimodal/mod.rs @@ -8,8 +8,20 @@ use anyhow::{Context, Result}; use base64::{Engine as _, engine::general_purpose}; +use std::sync::LazyLock; use tracing::{info, warn}; +/// Shared HTTP client for image fetching — avoids creating a new TCP+TLS +/// connection for every image URL. +static IMAGE_CLIENT: LazyLock = LazyLock::new(|| { + reqwest::Client::builder() + .connect_timeout(std::time::Duration::from_secs(5)) + .timeout(std::time::Duration::from_secs(30)) + .pool_idle_timeout(std::time::Duration::from_secs(60)) + .build() + .expect("Failed to build image HTTP client") +}); + /// Supported image formats for multimodal input #[derive(Debug, Clone)] pub enum ImageInput { @@ -55,9 +67,13 @@ impl ImageInput { Ok((base64_data, mime_type.clone())) } Self::Url(url) => { - // Fetch image from URL + // Fetch image from URL using shared client info!("Fetching image from URL: {}", url); - let response = reqwest::get(url).await.context("Failed to fetch image from URL")?; + let response = IMAGE_CLIENT + .get(url) + .send() + .await + .context("Failed to fetch image from URL")?; if !response.status().is_success() { anyhow::bail!("Failed to fetch image: HTTP {}", response.status()); diff --git a/src/providers/deepseek.rs b/src/providers/deepseek.rs index 21ca26a6..905452c7 100644 --- a/src/providers/deepseek.rs +++ b/src/providers/deepseek.rs @@ -24,8 +24,16 @@ impl DeepSeekProvider { app_config: &AppConfig, api_key: String, ) -> Result { + let client = reqwest::Client::builder() + .connect_timeout(std::time::Duration::from_secs(5)) + .timeout(std::time::Duration::from_secs(300)) + .pool_idle_timeout(std::time::Duration::from_secs(90)) + .pool_max_idle_per_host(4) + .tcp_keepalive(std::time::Duration::from_secs(30)) + .build()?; + Ok(Self { - client: reqwest::Client::new(), + client, config: config.clone(), api_key, pricing: app_config.pricing.deepseek.clone(), diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index 7939d078..53ad336a 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -145,7 +145,11 @@ impl GeminiProvider { pub fn new_with_key(config: &crate::config::GeminiConfig, app_config: &AppConfig, api_key: String) -> Result { let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(30)) + .connect_timeout(std::time::Duration::from_secs(5)) + .timeout(std::time::Duration::from_secs(300)) + .pool_idle_timeout(std::time::Duration::from_secs(90)) + .pool_max_idle_per_host(4) + .tcp_keepalive(std::time::Duration::from_secs(30)) .build()?; Ok(Self { diff --git a/src/providers/grok.rs b/src/providers/grok.rs index 2c81e77b..858c83b9 100644 --- a/src/providers/grok.rs +++ b/src/providers/grok.rs @@ -20,8 +20,16 @@ impl GrokProvider { } pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result { + let client = reqwest::Client::builder() + .connect_timeout(std::time::Duration::from_secs(5)) + .timeout(std::time::Duration::from_secs(300)) + .pool_idle_timeout(std::time::Duration::from_secs(90)) + .pool_max_idle_per_host(4) + .tcp_keepalive(std::time::Duration::from_secs(30)) + .build()?; + Ok(Self { - client: reqwest::Client::new(), + client, config: config.clone(), api_key, pricing: app_config.pricing.grok.clone(), diff --git a/src/providers/ollama.rs b/src/providers/ollama.rs index 850c6330..8f5fed90 100644 --- a/src/providers/ollama.rs +++ b/src/providers/ollama.rs @@ -14,8 +14,16 @@ pub struct OllamaProvider { impl OllamaProvider { pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result { + let client = reqwest::Client::builder() + .connect_timeout(std::time::Duration::from_secs(5)) + .timeout(std::time::Duration::from_secs(300)) + .pool_idle_timeout(std::time::Duration::from_secs(90)) + .pool_max_idle_per_host(4) + .tcp_keepalive(std::time::Duration::from_secs(30)) + .build()?; + Ok(Self { - client: reqwest::Client::new(), + client, config: config.clone(), pricing: app_config.pricing.ollama.clone(), }) diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 08458d3f..b424065d 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -20,8 +20,16 @@ impl OpenAIProvider { } pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result { + let client = reqwest::Client::builder() + .connect_timeout(std::time::Duration::from_secs(5)) + .timeout(std::time::Duration::from_secs(300)) + .pool_idle_timeout(std::time::Duration::from_secs(90)) + .pool_max_idle_per_host(4) + .tcp_keepalive(std::time::Duration::from_secs(30)) + .build()?; + Ok(Self { - client: reqwest::Client::new(), + client, config: config.clone(), api_key, pricing: app_config.pricing.openai.clone(), diff --git a/src/server/mod.rs b/src/server/mod.rs index 08aee338..5541ed0c 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -6,7 +6,6 @@ use axum::{ routing::post, }; use futures::stream::StreamExt; -use sqlx::Row; use std::sync::Arc; use tracing::{info, warn}; use uuid::Uuid; @@ -39,18 +38,9 @@ async fn get_model_cost( provider: &Arc, state: &AppState, ) -> f64 { - // Check database for cost overrides - let db_cost = sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?") - .bind(model) - .fetch_optional(&state.db_pool) - .await - .unwrap_or(None); - - if let Some(row) = db_cost { - let prompt_rate = row.get::, _>("prompt_cost_per_m"); - let completion_rate = row.get::, _>("completion_cost_per_m"); - - if let (Some(p), Some(c)) = (prompt_rate, completion_rate) { + // Check in-memory cache for cost overrides (no SQLite hit) + if let Some(cached) = state.model_config_cache.get(model).await { + if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) { return (prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0); } } @@ -75,15 +65,11 @@ async fn chat_completions( info!("Chat completion request from client {} for model {}", client_id, model); - // Check if model is enabled in database and get potential mapping - let model_config = sqlx::query("SELECT enabled, mapping FROM model_configs WHERE id = ?") - .bind(&model) - .fetch_optional(&state.db_pool) - .await - .unwrap_or(None); + // Check if model is enabled via in-memory cache (no SQLite hit) + let cached_config = state.model_config_cache.get(&model).await; - let (model_enabled, model_mapping) = match model_config { - Some(row) => (row.get::("enabled"), row.get::, _>("mapping")), + let (model_enabled, model_mapping) = match cached_config { + Some(cfg) => (cfg.enabled, cfg.mapping), None => (true, None), }; @@ -129,6 +115,9 @@ async fn chat_completions( let has_images = unified_request.has_images; + // Measure proxy overhead (time spent before sending to upstream provider) + let proxy_overhead = start_time.elapsed(); + // Check if streaming is requested if unified_request.stream { // Estimate prompt tokens for logging later @@ -142,6 +131,12 @@ async fn chat_completions( // Record provider success state.rate_limit_manager.record_provider_success(&provider_name).await; + info!( + "Streaming started for {} (proxy overhead: {}ms)", + model, + proxy_overhead.as_millis() + ); + // Wrap with AggregatingStream for token counting and database logging let aggregating_stream = crate::utils::streaming::AggregatingStream::new( stream, @@ -154,19 +149,21 @@ async fn chat_completions( logger: state.request_logger.clone(), client_manager: state.client_manager.clone(), model_registry: state.model_registry.clone(), - db_pool: state.db_pool.clone(), + model_config_cache: state.model_config_cache.clone(), }, ); // Create SSE stream from aggregating stream + let stream_id = format!("chatcmpl-{}", Uuid::new_v4()); + let stream_created = chrono::Utc::now().timestamp() as u64; let sse_stream = aggregating_stream.map(move |chunk_result| { match chunk_result { Ok(chunk) => { // Convert provider chunk to OpenAI-compatible SSE event let response = ChatCompletionStreamResponse { - id: format!("chatcmpl-{}", Uuid::new_v4()), + id: stream_id.clone(), object: "chat.completion.chunk".to_string(), - created: chrono::Utc::now().timestamp() as u64, + created: stream_created, model: chunk.model.clone(), choices: vec![ChatStreamChoice { index: 0, @@ -242,11 +239,14 @@ async fn chat_completions( duration_ms: duration.as_millis() as u64, }); - // Update client usage - let _ = state - .client_manager - .update_client_usage(&client_id, response.total_tokens as i64, cost) - .await; + // Update client usage (fire-and-forget, don't block response) + { + let cm = state.client_manager.clone(); + let cid = client_id.clone(); + tokio::spawn(async move { + let _ = cm.update_client_usage(&cid, response.total_tokens as i64, cost).await; + }); + } // Convert ProviderResponse to ChatCompletionResponse let finish_reason = if response.tool_calls.is_some() { @@ -281,8 +281,14 @@ async fn chat_completions( }), }; - // Log successful request - info!("Request completed successfully in {:?}", duration); + // Log successful request with proxy overhead breakdown + let upstream_ms = duration.as_millis() as u64 - proxy_overhead.as_millis() as u64; + info!( + "Request completed in {:?} (proxy: {}ms, upstream: {}ms)", + duration, + proxy_overhead.as_millis(), + upstream_ms + ); Ok(Json(chat_response).into_response()) } diff --git a/src/state/mod.rs b/src/state/mod.rs index 3412c6a4..71dc867a 100644 --- a/src/state/mod.rs +++ b/src/state/mod.rs @@ -1,11 +1,89 @@ +use std::collections::HashMap; use std::sync::Arc; -use tokio::sync::broadcast; +use tokio::sync::{broadcast, RwLock}; +use tracing::warn; use crate::{ client::ClientManager, config::AppConfig, database::DbPool, logging::RequestLogger, models::registry::ModelRegistry, providers::ProviderManager, rate_limiting::RateLimitManager, }; +/// Cached model configuration entry +#[derive(Debug, Clone)] +pub struct CachedModelConfig { + pub enabled: bool, + pub mapping: Option, + pub prompt_cost_per_m: Option, + pub completion_cost_per_m: Option, +} + +/// In-memory cache for model_configs table. +/// Refreshes periodically to avoid hitting SQLite on every request. +#[derive(Clone)] +pub struct ModelConfigCache { + cache: Arc>>, + db_pool: DbPool, +} + +impl ModelConfigCache { + pub fn new(db_pool: DbPool) -> Self { + Self { + cache: Arc::new(RwLock::new(HashMap::new())), + db_pool, + } + } + + /// Load all model configs from the database into cache + pub async fn refresh(&self) { + match sqlx::query_as::<_, (String, bool, Option, Option, Option)>( + "SELECT id, enabled, mapping, prompt_cost_per_m, completion_cost_per_m FROM model_configs", + ) + .fetch_all(&self.db_pool) + .await + { + Ok(rows) => { + let mut map = HashMap::with_capacity(rows.len()); + for (id, enabled, mapping, prompt_cost, completion_cost) in rows { + map.insert( + id, + CachedModelConfig { + enabled, + mapping, + prompt_cost_per_m: prompt_cost, + completion_cost_per_m: completion_cost, + }, + ); + } + *self.cache.write().await = map; + } + Err(e) => { + warn!("Failed to refresh model config cache: {}", e); + } + } + } + + /// Get a cached model config. Returns None if not in cache (model is unconfigured). + pub async fn get(&self, model: &str) -> Option { + self.cache.read().await.get(model).cloned() + } + + /// Invalidate cache — call this after dashboard writes to model_configs + pub async fn invalidate(&self) { + self.refresh().await; + } + + /// Start a background task that refreshes the cache every `interval` seconds + pub fn start_refresh_task(self, interval_secs: u64) { + tokio::spawn(async move { + let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs)); + loop { + interval.tick().await; + self.refresh().await; + } + }); + } +} + /// Shared application state #[derive(Clone)] pub struct AppState { @@ -16,6 +94,7 @@ pub struct AppState { pub client_manager: Arc, pub request_logger: Arc, pub model_registry: Arc, + pub model_config_cache: ModelConfigCache, pub dashboard_tx: broadcast::Sender, pub auth_tokens: Vec, } @@ -32,6 +111,7 @@ impl AppState { let client_manager = Arc::new(ClientManager::new(db_pool.clone())); let (dashboard_tx, _) = broadcast::channel(100); let request_logger = Arc::new(RequestLogger::new(db_pool.clone(), dashboard_tx.clone())); + let model_config_cache = ModelConfigCache::new(db_pool.clone()); Self { config, @@ -41,6 +121,7 @@ impl AppState { client_manager, request_logger, model_registry: Arc::new(model_registry), + model_config_cache, dashboard_tx, auth_tokens, } diff --git a/src/utils/streaming.rs b/src/utils/streaming.rs index eaa20467..db1cb893 100644 --- a/src/utils/streaming.rs +++ b/src/utils/streaming.rs @@ -3,9 +3,9 @@ use crate::errors::AppError; use crate::logging::{RequestLog, RequestLogger}; use crate::models::ToolCall; use crate::providers::{Provider, ProviderStreamChunk}; +use crate::state::ModelConfigCache; use crate::utils::tokens::estimate_completion_tokens; use futures::stream::Stream; -use sqlx::Row; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -20,7 +20,7 @@ pub struct StreamConfig { pub logger: Arc, pub client_manager: Arc, pub model_registry: Arc, - pub db_pool: crate::database::DbPool, + pub model_config_cache: ModelConfigCache, } pub struct AggregatingStream { @@ -36,7 +36,7 @@ pub struct AggregatingStream { logger: Arc, client_manager: Arc, model_registry: Arc, - db_pool: crate::database::DbPool, + model_config_cache: ModelConfigCache, start_time: std::time::Instant, has_logged: bool, } @@ -59,7 +59,7 @@ where logger: config.logger, client_manager: config.client_manager, model_registry: config.model_registry, - db_pool: config.db_pool, + model_config_cache: config.model_config_cache, start_time: std::time::Instant::now(), has_logged: false, } @@ -81,7 +81,7 @@ where let prompt_tokens = self.prompt_tokens; let has_images = self.has_images; let registry = self.model_registry.clone(); - let pool = self.db_pool.clone(); + let config_cache = self.model_config_cache.clone(); // Estimate completion tokens (including reasoning if present) let content_tokens = estimate_completion_tokens(&self.accumulated_content, &model); @@ -96,19 +96,9 @@ where // Spawn a background task to log the completion tokio::spawn(async move { - // Check database for cost overrides - let db_cost = - sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?") - .bind(&model) - .fetch_optional(&pool) - .await - .unwrap_or(None); - - let cost = if let Some(row) = db_cost { - let prompt_rate = row.get::, _>("prompt_cost_per_m"); - let completion_rate = row.get::, _>("completion_cost_per_m"); - - if let (Some(p), Some(c)) = (prompt_rate, completion_rate) { + // Check in-memory cache for cost overrides (no SQLite hit) + let cost = if let Some(cached) = config_cache.get(&model).await { + if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) { (prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0) } else { provider.calculate_cost(&model, prompt_tokens, completion_tokens, ®istry) @@ -284,7 +274,7 @@ mod tests { logger, client_manager, model_registry: registry, - db_pool: pool.clone(), + model_config_cache: ModelConfigCache::new(pool.clone()), }, ); diff --git a/src/utils/tokens.rs b/src/utils/tokens.rs index dbf020d2..469ab02a 100644 --- a/src/utils/tokens.rs +++ b/src/utils/tokens.rs @@ -10,38 +10,57 @@ pub fn count_tokens(model: &str, text: &str) -> u32 { bpe.encode_with_special_tokens(text).len() as u32 } -/// Estimate tokens for a unified request +/// Estimate tokens for a unified request. +/// Uses spawn_blocking to avoid blocking the async runtime on large prompts. pub fn estimate_request_tokens(model: &str, request: &UnifiedRequest) -> u32 { - let mut total_tokens = 0; + let mut total_text = String::new(); + let msg_count = request.messages.len(); // Base tokens per message for OpenAI (approximate) - let tokens_per_message = 3; - let _tokens_per_name = 1; + let tokens_per_message: u32 = 3; for msg in &request.messages { - total_tokens += tokens_per_message; - for part in &msg.content { match part { crate::models::ContentPart::Text { text } => { - total_tokens += count_tokens(model, text); + total_text.push_str(text); + total_text.push('\n'); } crate::models::ContentPart::Image { .. } => { // Vision models usually have a fixed cost or calculation based on size - // For now, let's use a conservative estimate of 1000 tokens - total_tokens += 1000; } } } - - // Add name tokens if we had names (we don't in UnifiedMessage yet) - // total_tokens += tokens_per_name; } - // Add 3 tokens for the assistant reply header - total_tokens += 3; + // Quick heuristic for small inputs (< 1KB) — avoid spawn_blocking overhead + if total_text.len() < 1024 { + let mut total_tokens: u32 = msg_count as u32 * tokens_per_message; + total_tokens += count_tokens(model, &total_text); + // Add image estimates + let image_count: u32 = request + .messages + .iter() + .flat_map(|m| m.content.iter()) + .filter(|p| matches!(p, crate::models::ContentPart::Image { .. })) + .count() as u32; + total_tokens += image_count * 1000; + total_tokens += 3; // assistant reply header + return total_tokens; + } - total_tokens + // For large inputs, use the fast heuristic (chars / 4) to avoid blocking + // the async runtime. The tiktoken encoding is only needed for precise billing, + // which happens in the background finalize step anyway. + let estimated_text_tokens = (total_text.len() as u32) / 4; + let image_count: u32 = request + .messages + .iter() + .flat_map(|m| m.content.iter()) + .filter(|p| matches!(p, crate::models::ContentPart::Image { .. })) + .count() as u32; + + (msg_count as u32 * tokens_per_message) + estimated_text_tokens + (image_count * 1000) + 3 } /// Estimate tokens for completion text