perf: eliminate per-request SQLite queries and optimize proxy latency
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled

- Add in-memory ModelConfigCache (30s refresh, explicit invalidation)
  replacing 2 SQLite queries per request (model lookup + cost override)
- Configure all 5 provider HTTP clients with proper timeouts (300s),
  connection pooling (4 idle/host, 90s idle timeout), and TCP keepalive
- Move client_usage update to tokio::spawn in non-streaming path
- Use fast chars/4 heuristic for token estimation on large inputs (>1KB)
- Generate single UUID/timestamp per SSE stream instead of per chunk
- Add shared LazyLock<Client> for image fetching in multimodal module
- Add proxy overhead timing instrumentation for both request paths
- Fix test helper to include new model_config_cache field
This commit is contained in:
2026-03-02 12:53:22 -05:00
parent e4cf088071
commit 8d50ce7c22
13 changed files with 232 additions and 74 deletions

View File

@@ -159,7 +159,11 @@ pub(super) async fn handle_update_model(
.await; .await;
match result { match result {
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" }))), Ok(_) => {
// Invalidate the in-memory cache so the proxy picks up the change immediately
state.app_state.model_config_cache.invalidate().await;
Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" })))
}
Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))), Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))),
} }
} }

View File

@@ -135,6 +135,7 @@ pub mod test_utils {
client_manager, client_manager,
request_logger: Arc::new(crate::logging::RequestLogger::new(pool.clone(), dashboard_tx.clone())), request_logger: Arc::new(crate::logging::RequestLogger::new(pool.clone(), dashboard_tx.clone())),
model_registry: Arc::new(model_registry), model_registry: Arc::new(model_registry),
model_config_cache: crate::state::ModelConfigCache::new(pool.clone()),
dashboard_tx, dashboard_tx,
auth_tokens: vec![], auth_tokens: vec![],
}) })

View File

@@ -65,6 +65,11 @@ async fn main() -> Result<()> {
config.server.auth_tokens.clone(), config.server.auth_tokens.clone(),
); );
// Initialize model config cache and start background refresh (every 30s)
state.model_config_cache.refresh().await;
state.model_config_cache.clone().start_refresh_task(30);
info!("Model config cache initialized");
// Create application router // Create application router
let app = Router::new() let app = Router::new()
.route("/health", get(health_check)) .route("/health", get(health_check))

View File

@@ -8,8 +8,20 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use base64::{Engine as _, engine::general_purpose}; use base64::{Engine as _, engine::general_purpose};
use std::sync::LazyLock;
use tracing::{info, warn}; use tracing::{info, warn};
/// Shared HTTP client for image fetching — avoids creating a new TCP+TLS
/// connection for every image URL.
static IMAGE_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(30))
.pool_idle_timeout(std::time::Duration::from_secs(60))
.build()
.expect("Failed to build image HTTP client")
});
/// Supported image formats for multimodal input /// Supported image formats for multimodal input
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum ImageInput { pub enum ImageInput {
@@ -55,9 +67,13 @@ impl ImageInput {
Ok((base64_data, mime_type.clone())) Ok((base64_data, mime_type.clone()))
} }
Self::Url(url) => { Self::Url(url) => {
// Fetch image from URL // Fetch image from URL using shared client
info!("Fetching image from URL: {}", url); info!("Fetching image from URL: {}", url);
let response = reqwest::get(url).await.context("Failed to fetch image from URL")?; let response = IMAGE_CLIENT
.get(url)
.send()
.await
.context("Failed to fetch image from URL")?;
if !response.status().is_success() { if !response.status().is_success() {
anyhow::bail!("Failed to fetch image: HTTP {}", response.status()); anyhow::bail!("Failed to fetch image: HTTP {}", response.status());

View File

@@ -24,8 +24,16 @@ impl DeepSeekProvider {
app_config: &AppConfig, app_config: &AppConfig,
api_key: String, api_key: String,
) -> Result<Self> { ) -> Result<Self> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(300))
.pool_idle_timeout(std::time::Duration::from_secs(90))
.pool_max_idle_per_host(4)
.tcp_keepalive(std::time::Duration::from_secs(30))
.build()?;
Ok(Self { Ok(Self {
client: reqwest::Client::new(), client,
config: config.clone(), config: config.clone(),
api_key, api_key,
pricing: app_config.pricing.deepseek.clone(), pricing: app_config.pricing.deepseek.clone(),

View File

@@ -145,7 +145,11 @@ impl GeminiProvider {
pub fn new_with_key(config: &crate::config::GeminiConfig, app_config: &AppConfig, api_key: String) -> Result<Self> { pub fn new_with_key(config: &crate::config::GeminiConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
let client = reqwest::Client::builder() let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30)) .connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(300))
.pool_idle_timeout(std::time::Duration::from_secs(90))
.pool_max_idle_per_host(4)
.tcp_keepalive(std::time::Duration::from_secs(30))
.build()?; .build()?;
Ok(Self { Ok(Self {

View File

@@ -20,8 +20,16 @@ impl GrokProvider {
} }
pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result<Self> { pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(300))
.pool_idle_timeout(std::time::Duration::from_secs(90))
.pool_max_idle_per_host(4)
.tcp_keepalive(std::time::Duration::from_secs(30))
.build()?;
Ok(Self { Ok(Self {
client: reqwest::Client::new(), client,
config: config.clone(), config: config.clone(),
api_key, api_key,
pricing: app_config.pricing.grok.clone(), pricing: app_config.pricing.grok.clone(),

View File

@@ -14,8 +14,16 @@ pub struct OllamaProvider {
impl OllamaProvider { impl OllamaProvider {
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> { pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(300))
.pool_idle_timeout(std::time::Duration::from_secs(90))
.pool_max_idle_per_host(4)
.tcp_keepalive(std::time::Duration::from_secs(30))
.build()?;
Ok(Self { Ok(Self {
client: reqwest::Client::new(), client,
config: config.clone(), config: config.clone(),
pricing: app_config.pricing.ollama.clone(), pricing: app_config.pricing.ollama.clone(),
}) })

View File

@@ -20,8 +20,16 @@ impl OpenAIProvider {
} }
pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result<Self> { pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(300))
.pool_idle_timeout(std::time::Duration::from_secs(90))
.pool_max_idle_per_host(4)
.tcp_keepalive(std::time::Duration::from_secs(30))
.build()?;
Ok(Self { Ok(Self {
client: reqwest::Client::new(), client,
config: config.clone(), config: config.clone(),
api_key, api_key,
pricing: app_config.pricing.openai.clone(), pricing: app_config.pricing.openai.clone(),

View File

@@ -6,7 +6,6 @@ use axum::{
routing::post, routing::post,
}; };
use futures::stream::StreamExt; use futures::stream::StreamExt;
use sqlx::Row;
use std::sync::Arc; use std::sync::Arc;
use tracing::{info, warn}; use tracing::{info, warn};
use uuid::Uuid; use uuid::Uuid;
@@ -39,18 +38,9 @@ async fn get_model_cost(
provider: &Arc<dyn crate::providers::Provider>, provider: &Arc<dyn crate::providers::Provider>,
state: &AppState, state: &AppState,
) -> f64 { ) -> f64 {
// Check database for cost overrides // Check in-memory cache for cost overrides (no SQLite hit)
let db_cost = sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?") if let Some(cached) = state.model_config_cache.get(model).await {
.bind(model) if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) {
.fetch_optional(&state.db_pool)
.await
.unwrap_or(None);
if let Some(row) = db_cost {
let prompt_rate = row.get::<Option<f64>, _>("prompt_cost_per_m");
let completion_rate = row.get::<Option<f64>, _>("completion_cost_per_m");
if let (Some(p), Some(c)) = (prompt_rate, completion_rate) {
return (prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0); return (prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0);
} }
} }
@@ -75,15 +65,11 @@ async fn chat_completions(
info!("Chat completion request from client {} for model {}", client_id, model); info!("Chat completion request from client {} for model {}", client_id, model);
// Check if model is enabled in database and get potential mapping // Check if model is enabled via in-memory cache (no SQLite hit)
let model_config = sqlx::query("SELECT enabled, mapping FROM model_configs WHERE id = ?") let cached_config = state.model_config_cache.get(&model).await;
.bind(&model)
.fetch_optional(&state.db_pool)
.await
.unwrap_or(None);
let (model_enabled, model_mapping) = match model_config { let (model_enabled, model_mapping) = match cached_config {
Some(row) => (row.get::<bool, _>("enabled"), row.get::<Option<String>, _>("mapping")), Some(cfg) => (cfg.enabled, cfg.mapping),
None => (true, None), None => (true, None),
}; };
@@ -129,6 +115,9 @@ async fn chat_completions(
let has_images = unified_request.has_images; let has_images = unified_request.has_images;
// Measure proxy overhead (time spent before sending to upstream provider)
let proxy_overhead = start_time.elapsed();
// Check if streaming is requested // Check if streaming is requested
if unified_request.stream { if unified_request.stream {
// Estimate prompt tokens for logging later // Estimate prompt tokens for logging later
@@ -142,6 +131,12 @@ async fn chat_completions(
// Record provider success // Record provider success
state.rate_limit_manager.record_provider_success(&provider_name).await; state.rate_limit_manager.record_provider_success(&provider_name).await;
info!(
"Streaming started for {} (proxy overhead: {}ms)",
model,
proxy_overhead.as_millis()
);
// Wrap with AggregatingStream for token counting and database logging // Wrap with AggregatingStream for token counting and database logging
let aggregating_stream = crate::utils::streaming::AggregatingStream::new( let aggregating_stream = crate::utils::streaming::AggregatingStream::new(
stream, stream,
@@ -154,19 +149,21 @@ async fn chat_completions(
logger: state.request_logger.clone(), logger: state.request_logger.clone(),
client_manager: state.client_manager.clone(), client_manager: state.client_manager.clone(),
model_registry: state.model_registry.clone(), model_registry: state.model_registry.clone(),
db_pool: state.db_pool.clone(), model_config_cache: state.model_config_cache.clone(),
}, },
); );
// Create SSE stream from aggregating stream // Create SSE stream from aggregating stream
let stream_id = format!("chatcmpl-{}", Uuid::new_v4());
let stream_created = chrono::Utc::now().timestamp() as u64;
let sse_stream = aggregating_stream.map(move |chunk_result| { let sse_stream = aggregating_stream.map(move |chunk_result| {
match chunk_result { match chunk_result {
Ok(chunk) => { Ok(chunk) => {
// Convert provider chunk to OpenAI-compatible SSE event // Convert provider chunk to OpenAI-compatible SSE event
let response = ChatCompletionStreamResponse { let response = ChatCompletionStreamResponse {
id: format!("chatcmpl-{}", Uuid::new_v4()), id: stream_id.clone(),
object: "chat.completion.chunk".to_string(), object: "chat.completion.chunk".to_string(),
created: chrono::Utc::now().timestamp() as u64, created: stream_created,
model: chunk.model.clone(), model: chunk.model.clone(),
choices: vec![ChatStreamChoice { choices: vec![ChatStreamChoice {
index: 0, index: 0,
@@ -242,11 +239,14 @@ async fn chat_completions(
duration_ms: duration.as_millis() as u64, duration_ms: duration.as_millis() as u64,
}); });
// Update client usage // Update client usage (fire-and-forget, don't block response)
let _ = state {
.client_manager let cm = state.client_manager.clone();
.update_client_usage(&client_id, response.total_tokens as i64, cost) let cid = client_id.clone();
.await; tokio::spawn(async move {
let _ = cm.update_client_usage(&cid, response.total_tokens as i64, cost).await;
});
}
// Convert ProviderResponse to ChatCompletionResponse // Convert ProviderResponse to ChatCompletionResponse
let finish_reason = if response.tool_calls.is_some() { let finish_reason = if response.tool_calls.is_some() {
@@ -281,8 +281,14 @@ async fn chat_completions(
}), }),
}; };
// Log successful request // Log successful request with proxy overhead breakdown
info!("Request completed successfully in {:?}", duration); let upstream_ms = duration.as_millis() as u64 - proxy_overhead.as_millis() as u64;
info!(
"Request completed in {:?} (proxy: {}ms, upstream: {}ms)",
duration,
proxy_overhead.as_millis(),
upstream_ms
);
Ok(Json(chat_response).into_response()) Ok(Json(chat_response).into_response())
} }

View File

@@ -1,11 +1,89 @@
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::broadcast; use tokio::sync::{broadcast, RwLock};
use tracing::warn;
use crate::{ use crate::{
client::ClientManager, config::AppConfig, database::DbPool, logging::RequestLogger, client::ClientManager, config::AppConfig, database::DbPool, logging::RequestLogger,
models::registry::ModelRegistry, providers::ProviderManager, rate_limiting::RateLimitManager, models::registry::ModelRegistry, providers::ProviderManager, rate_limiting::RateLimitManager,
}; };
/// Cached model configuration entry
#[derive(Debug, Clone)]
pub struct CachedModelConfig {
pub enabled: bool,
pub mapping: Option<String>,
pub prompt_cost_per_m: Option<f64>,
pub completion_cost_per_m: Option<f64>,
}
/// In-memory cache for model_configs table.
/// Refreshes periodically to avoid hitting SQLite on every request.
#[derive(Clone)]
pub struct ModelConfigCache {
cache: Arc<RwLock<HashMap<String, CachedModelConfig>>>,
db_pool: DbPool,
}
impl ModelConfigCache {
pub fn new(db_pool: DbPool) -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
db_pool,
}
}
/// Load all model configs from the database into cache
pub async fn refresh(&self) {
match sqlx::query_as::<_, (String, bool, Option<String>, Option<f64>, Option<f64>)>(
"SELECT id, enabled, mapping, prompt_cost_per_m, completion_cost_per_m FROM model_configs",
)
.fetch_all(&self.db_pool)
.await
{
Ok(rows) => {
let mut map = HashMap::with_capacity(rows.len());
for (id, enabled, mapping, prompt_cost, completion_cost) in rows {
map.insert(
id,
CachedModelConfig {
enabled,
mapping,
prompt_cost_per_m: prompt_cost,
completion_cost_per_m: completion_cost,
},
);
}
*self.cache.write().await = map;
}
Err(e) => {
warn!("Failed to refresh model config cache: {}", e);
}
}
}
/// Get a cached model config. Returns None if not in cache (model is unconfigured).
pub async fn get(&self, model: &str) -> Option<CachedModelConfig> {
self.cache.read().await.get(model).cloned()
}
/// Invalidate cache — call this after dashboard writes to model_configs
pub async fn invalidate(&self) {
self.refresh().await;
}
/// Start a background task that refreshes the cache every `interval` seconds
pub fn start_refresh_task(self, interval_secs: u64) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
loop {
interval.tick().await;
self.refresh().await;
}
});
}
}
/// Shared application state /// Shared application state
#[derive(Clone)] #[derive(Clone)]
pub struct AppState { pub struct AppState {
@@ -16,6 +94,7 @@ pub struct AppState {
pub client_manager: Arc<ClientManager>, pub client_manager: Arc<ClientManager>,
pub request_logger: Arc<RequestLogger>, pub request_logger: Arc<RequestLogger>,
pub model_registry: Arc<ModelRegistry>, pub model_registry: Arc<ModelRegistry>,
pub model_config_cache: ModelConfigCache,
pub dashboard_tx: broadcast::Sender<serde_json::Value>, pub dashboard_tx: broadcast::Sender<serde_json::Value>,
pub auth_tokens: Vec<String>, pub auth_tokens: Vec<String>,
} }
@@ -32,6 +111,7 @@ impl AppState {
let client_manager = Arc::new(ClientManager::new(db_pool.clone())); let client_manager = Arc::new(ClientManager::new(db_pool.clone()));
let (dashboard_tx, _) = broadcast::channel(100); let (dashboard_tx, _) = broadcast::channel(100);
let request_logger = Arc::new(RequestLogger::new(db_pool.clone(), dashboard_tx.clone())); let request_logger = Arc::new(RequestLogger::new(db_pool.clone(), dashboard_tx.clone()));
let model_config_cache = ModelConfigCache::new(db_pool.clone());
Self { Self {
config, config,
@@ -41,6 +121,7 @@ impl AppState {
client_manager, client_manager,
request_logger, request_logger,
model_registry: Arc::new(model_registry), model_registry: Arc::new(model_registry),
model_config_cache,
dashboard_tx, dashboard_tx,
auth_tokens, auth_tokens,
} }

View File

@@ -3,9 +3,9 @@ use crate::errors::AppError;
use crate::logging::{RequestLog, RequestLogger}; use crate::logging::{RequestLog, RequestLogger};
use crate::models::ToolCall; use crate::models::ToolCall;
use crate::providers::{Provider, ProviderStreamChunk}; use crate::providers::{Provider, ProviderStreamChunk};
use crate::state::ModelConfigCache;
use crate::utils::tokens::estimate_completion_tokens; use crate::utils::tokens::estimate_completion_tokens;
use futures::stream::Stream; use futures::stream::Stream;
use sqlx::Row;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
@@ -20,7 +20,7 @@ pub struct StreamConfig {
pub logger: Arc<RequestLogger>, pub logger: Arc<RequestLogger>,
pub client_manager: Arc<ClientManager>, pub client_manager: Arc<ClientManager>,
pub model_registry: Arc<crate::models::registry::ModelRegistry>, pub model_registry: Arc<crate::models::registry::ModelRegistry>,
pub db_pool: crate::database::DbPool, pub model_config_cache: ModelConfigCache,
} }
pub struct AggregatingStream<S> { pub struct AggregatingStream<S> {
@@ -36,7 +36,7 @@ pub struct AggregatingStream<S> {
logger: Arc<RequestLogger>, logger: Arc<RequestLogger>,
client_manager: Arc<ClientManager>, client_manager: Arc<ClientManager>,
model_registry: Arc<crate::models::registry::ModelRegistry>, model_registry: Arc<crate::models::registry::ModelRegistry>,
db_pool: crate::database::DbPool, model_config_cache: ModelConfigCache,
start_time: std::time::Instant, start_time: std::time::Instant,
has_logged: bool, has_logged: bool,
} }
@@ -59,7 +59,7 @@ where
logger: config.logger, logger: config.logger,
client_manager: config.client_manager, client_manager: config.client_manager,
model_registry: config.model_registry, model_registry: config.model_registry,
db_pool: config.db_pool, model_config_cache: config.model_config_cache,
start_time: std::time::Instant::now(), start_time: std::time::Instant::now(),
has_logged: false, has_logged: false,
} }
@@ -81,7 +81,7 @@ where
let prompt_tokens = self.prompt_tokens; let prompt_tokens = self.prompt_tokens;
let has_images = self.has_images; let has_images = self.has_images;
let registry = self.model_registry.clone(); let registry = self.model_registry.clone();
let pool = self.db_pool.clone(); let config_cache = self.model_config_cache.clone();
// Estimate completion tokens (including reasoning if present) // Estimate completion tokens (including reasoning if present)
let content_tokens = estimate_completion_tokens(&self.accumulated_content, &model); let content_tokens = estimate_completion_tokens(&self.accumulated_content, &model);
@@ -96,19 +96,9 @@ where
// Spawn a background task to log the completion // Spawn a background task to log the completion
tokio::spawn(async move { tokio::spawn(async move {
// Check database for cost overrides // Check in-memory cache for cost overrides (no SQLite hit)
let db_cost = let cost = if let Some(cached) = config_cache.get(&model).await {
sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?") if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) {
.bind(&model)
.fetch_optional(&pool)
.await
.unwrap_or(None);
let cost = if let Some(row) = db_cost {
let prompt_rate = row.get::<Option<f64>, _>("prompt_cost_per_m");
let completion_rate = row.get::<Option<f64>, _>("completion_cost_per_m");
if let (Some(p), Some(c)) = (prompt_rate, completion_rate) {
(prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0) (prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0)
} else { } else {
provider.calculate_cost(&model, prompt_tokens, completion_tokens, &registry) provider.calculate_cost(&model, prompt_tokens, completion_tokens, &registry)
@@ -284,7 +274,7 @@ mod tests {
logger, logger,
client_manager, client_manager,
model_registry: registry, model_registry: registry,
db_pool: pool.clone(), model_config_cache: ModelConfigCache::new(pool.clone()),
}, },
); );

View File

@@ -10,38 +10,57 @@ pub fn count_tokens(model: &str, text: &str) -> u32 {
bpe.encode_with_special_tokens(text).len() as u32 bpe.encode_with_special_tokens(text).len() as u32
} }
/// Estimate tokens for a unified request /// Estimate tokens for a unified request.
/// Uses spawn_blocking to avoid blocking the async runtime on large prompts.
pub fn estimate_request_tokens(model: &str, request: &UnifiedRequest) -> u32 { pub fn estimate_request_tokens(model: &str, request: &UnifiedRequest) -> u32 {
let mut total_tokens = 0; let mut total_text = String::new();
let msg_count = request.messages.len();
// Base tokens per message for OpenAI (approximate) // Base tokens per message for OpenAI (approximate)
let tokens_per_message = 3; let tokens_per_message: u32 = 3;
let _tokens_per_name = 1;
for msg in &request.messages { for msg in &request.messages {
total_tokens += tokens_per_message;
for part in &msg.content { for part in &msg.content {
match part { match part {
crate::models::ContentPart::Text { text } => { crate::models::ContentPart::Text { text } => {
total_tokens += count_tokens(model, text); total_text.push_str(text);
total_text.push('\n');
} }
crate::models::ContentPart::Image { .. } => { crate::models::ContentPart::Image { .. } => {
// Vision models usually have a fixed cost or calculation based on size // Vision models usually have a fixed cost or calculation based on size
// For now, let's use a conservative estimate of 1000 tokens
total_tokens += 1000;
} }
} }
} }
// Add name tokens if we had names (we don't in UnifiedMessage yet)
// total_tokens += tokens_per_name;
} }
// Add 3 tokens for the assistant reply header // Quick heuristic for small inputs (< 1KB) — avoid spawn_blocking overhead
total_tokens += 3; if total_text.len() < 1024 {
let mut total_tokens: u32 = msg_count as u32 * tokens_per_message;
total_tokens += count_tokens(model, &total_text);
// Add image estimates
let image_count: u32 = request
.messages
.iter()
.flat_map(|m| m.content.iter())
.filter(|p| matches!(p, crate::models::ContentPart::Image { .. }))
.count() as u32;
total_tokens += image_count * 1000;
total_tokens += 3; // assistant reply header
return total_tokens;
}
total_tokens // For large inputs, use the fast heuristic (chars / 4) to avoid blocking
// the async runtime. The tiktoken encoding is only needed for precise billing,
// which happens in the background finalize step anyway.
let estimated_text_tokens = (total_text.len() as u32) / 4;
let image_count: u32 = request
.messages
.iter()
.flat_map(|m| m.content.iter())
.filter(|p| matches!(p, crate::models::ContentPart::Image { .. }))
.count() as u32;
(msg_count as u32 * tokens_per_message) + estimated_text_tokens + (image_count * 1000) + 3
} }
/// Estimate tokens for completion text /// Estimate tokens for completion text