perf: eliminate per-request SQLite queries and optimize proxy latency
- Add in-memory ModelConfigCache (30s refresh, explicit invalidation) replacing 2 SQLite queries per request (model lookup + cost override) - Configure all 5 provider HTTP clients with proper timeouts (300s), connection pooling (4 idle/host, 90s idle timeout), and TCP keepalive - Move client_usage update to tokio::spawn in non-streaming path - Use fast chars/4 heuristic for token estimation on large inputs (>1KB) - Generate single UUID/timestamp per SSE stream instead of per chunk - Add shared LazyLock<Client> for image fetching in multimodal module - Add proxy overhead timing instrumentation for both request paths - Fix test helper to include new model_config_cache field
This commit is contained in:
@@ -159,7 +159,11 @@ pub(super) async fn handle_update_model(
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" }))),
|
||||
Ok(_) => {
|
||||
// Invalidate the in-memory cache so the proxy picks up the change immediately
|
||||
state.app_state.model_config_cache.invalidate().await;
|
||||
Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" })))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -135,6 +135,7 @@ pub mod test_utils {
|
||||
client_manager,
|
||||
request_logger: Arc::new(crate::logging::RequestLogger::new(pool.clone(), dashboard_tx.clone())),
|
||||
model_registry: Arc::new(model_registry),
|
||||
model_config_cache: crate::state::ModelConfigCache::new(pool.clone()),
|
||||
dashboard_tx,
|
||||
auth_tokens: vec![],
|
||||
})
|
||||
|
||||
@@ -65,6 +65,11 @@ async fn main() -> Result<()> {
|
||||
config.server.auth_tokens.clone(),
|
||||
);
|
||||
|
||||
// Initialize model config cache and start background refresh (every 30s)
|
||||
state.model_config_cache.refresh().await;
|
||||
state.model_config_cache.clone().start_refresh_task(30);
|
||||
info!("Model config cache initialized");
|
||||
|
||||
// Create application router
|
||||
let app = Router::new()
|
||||
.route("/health", get(health_check))
|
||||
|
||||
@@ -8,8 +8,20 @@
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use base64::{Engine as _, engine::general_purpose};
|
||||
use std::sync::LazyLock;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Shared HTTP client for image fetching — avoids creating a new TCP+TLS
|
||||
/// connection for every image URL.
|
||||
static IMAGE_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
||||
reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(60))
|
||||
.build()
|
||||
.expect("Failed to build image HTTP client")
|
||||
});
|
||||
|
||||
/// Supported image formats for multimodal input
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ImageInput {
|
||||
@@ -55,9 +67,13 @@ impl ImageInput {
|
||||
Ok((base64_data, mime_type.clone()))
|
||||
}
|
||||
Self::Url(url) => {
|
||||
// Fetch image from URL
|
||||
// Fetch image from URL using shared client
|
||||
info!("Fetching image from URL: {}", url);
|
||||
let response = reqwest::get(url).await.context("Failed to fetch image from URL")?;
|
||||
let response = IMAGE_CLIENT
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to fetch image from URL")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!("Failed to fetch image: HTTP {}", response.status());
|
||||
|
||||
@@ -24,8 +24,16 @@ impl DeepSeekProvider {
|
||||
app_config: &AppConfig,
|
||||
api_key: String,
|
||||
) -> Result<Self> {
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||
.pool_max_idle_per_host(4)
|
||||
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
client: reqwest::Client::new(),
|
||||
client,
|
||||
config: config.clone(),
|
||||
api_key,
|
||||
pricing: app_config.pricing.deepseek.clone(),
|
||||
|
||||
@@ -145,7 +145,11 @@ impl GeminiProvider {
|
||||
|
||||
pub fn new_with_key(config: &crate::config::GeminiConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||
.pool_max_idle_per_host(4)
|
||||
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
|
||||
@@ -20,8 +20,16 @@ impl GrokProvider {
|
||||
}
|
||||
|
||||
pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||
.pool_max_idle_per_host(4)
|
||||
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
client: reqwest::Client::new(),
|
||||
client,
|
||||
config: config.clone(),
|
||||
api_key,
|
||||
pricing: app_config.pricing.grok.clone(),
|
||||
|
||||
@@ -14,8 +14,16 @@ pub struct OllamaProvider {
|
||||
|
||||
impl OllamaProvider {
|
||||
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||
.pool_max_idle_per_host(4)
|
||||
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
client: reqwest::Client::new(),
|
||||
client,
|
||||
config: config.clone(),
|
||||
pricing: app_config.pricing.ollama.clone(),
|
||||
})
|
||||
|
||||
@@ -20,8 +20,16 @@ impl OpenAIProvider {
|
||||
}
|
||||
|
||||
pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||
.pool_max_idle_per_host(4)
|
||||
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
client: reqwest::Client::new(),
|
||||
client,
|
||||
config: config.clone(),
|
||||
api_key,
|
||||
pricing: app_config.pricing.openai.clone(),
|
||||
|
||||
@@ -6,7 +6,6 @@ use axum::{
|
||||
routing::post,
|
||||
};
|
||||
use futures::stream::StreamExt;
|
||||
use sqlx::Row;
|
||||
use std::sync::Arc;
|
||||
use tracing::{info, warn};
|
||||
use uuid::Uuid;
|
||||
@@ -39,18 +38,9 @@ async fn get_model_cost(
|
||||
provider: &Arc<dyn crate::providers::Provider>,
|
||||
state: &AppState,
|
||||
) -> f64 {
|
||||
// Check database for cost overrides
|
||||
let db_cost = sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?")
|
||||
.bind(model)
|
||||
.fetch_optional(&state.db_pool)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
if let Some(row) = db_cost {
|
||||
let prompt_rate = row.get::<Option<f64>, _>("prompt_cost_per_m");
|
||||
let completion_rate = row.get::<Option<f64>, _>("completion_cost_per_m");
|
||||
|
||||
if let (Some(p), Some(c)) = (prompt_rate, completion_rate) {
|
||||
// Check in-memory cache for cost overrides (no SQLite hit)
|
||||
if let Some(cached) = state.model_config_cache.get(model).await {
|
||||
if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) {
|
||||
return (prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0);
|
||||
}
|
||||
}
|
||||
@@ -75,15 +65,11 @@ async fn chat_completions(
|
||||
|
||||
info!("Chat completion request from client {} for model {}", client_id, model);
|
||||
|
||||
// Check if model is enabled in database and get potential mapping
|
||||
let model_config = sqlx::query("SELECT enabled, mapping FROM model_configs WHERE id = ?")
|
||||
.bind(&model)
|
||||
.fetch_optional(&state.db_pool)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
// Check if model is enabled via in-memory cache (no SQLite hit)
|
||||
let cached_config = state.model_config_cache.get(&model).await;
|
||||
|
||||
let (model_enabled, model_mapping) = match model_config {
|
||||
Some(row) => (row.get::<bool, _>("enabled"), row.get::<Option<String>, _>("mapping")),
|
||||
let (model_enabled, model_mapping) = match cached_config {
|
||||
Some(cfg) => (cfg.enabled, cfg.mapping),
|
||||
None => (true, None),
|
||||
};
|
||||
|
||||
@@ -129,6 +115,9 @@ async fn chat_completions(
|
||||
|
||||
let has_images = unified_request.has_images;
|
||||
|
||||
// Measure proxy overhead (time spent before sending to upstream provider)
|
||||
let proxy_overhead = start_time.elapsed();
|
||||
|
||||
// Check if streaming is requested
|
||||
if unified_request.stream {
|
||||
// Estimate prompt tokens for logging later
|
||||
@@ -142,6 +131,12 @@ async fn chat_completions(
|
||||
// Record provider success
|
||||
state.rate_limit_manager.record_provider_success(&provider_name).await;
|
||||
|
||||
info!(
|
||||
"Streaming started for {} (proxy overhead: {}ms)",
|
||||
model,
|
||||
proxy_overhead.as_millis()
|
||||
);
|
||||
|
||||
// Wrap with AggregatingStream for token counting and database logging
|
||||
let aggregating_stream = crate::utils::streaming::AggregatingStream::new(
|
||||
stream,
|
||||
@@ -154,19 +149,21 @@ async fn chat_completions(
|
||||
logger: state.request_logger.clone(),
|
||||
client_manager: state.client_manager.clone(),
|
||||
model_registry: state.model_registry.clone(),
|
||||
db_pool: state.db_pool.clone(),
|
||||
model_config_cache: state.model_config_cache.clone(),
|
||||
},
|
||||
);
|
||||
|
||||
// Create SSE stream from aggregating stream
|
||||
let stream_id = format!("chatcmpl-{}", Uuid::new_v4());
|
||||
let stream_created = chrono::Utc::now().timestamp() as u64;
|
||||
let sse_stream = aggregating_stream.map(move |chunk_result| {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
// Convert provider chunk to OpenAI-compatible SSE event
|
||||
let response = ChatCompletionStreamResponse {
|
||||
id: format!("chatcmpl-{}", Uuid::new_v4()),
|
||||
id: stream_id.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created: chrono::Utc::now().timestamp() as u64,
|
||||
created: stream_created,
|
||||
model: chunk.model.clone(),
|
||||
choices: vec![ChatStreamChoice {
|
||||
index: 0,
|
||||
@@ -242,11 +239,14 @@ async fn chat_completions(
|
||||
duration_ms: duration.as_millis() as u64,
|
||||
});
|
||||
|
||||
// Update client usage
|
||||
let _ = state
|
||||
.client_manager
|
||||
.update_client_usage(&client_id, response.total_tokens as i64, cost)
|
||||
.await;
|
||||
// Update client usage (fire-and-forget, don't block response)
|
||||
{
|
||||
let cm = state.client_manager.clone();
|
||||
let cid = client_id.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = cm.update_client_usage(&cid, response.total_tokens as i64, cost).await;
|
||||
});
|
||||
}
|
||||
|
||||
// Convert ProviderResponse to ChatCompletionResponse
|
||||
let finish_reason = if response.tool_calls.is_some() {
|
||||
@@ -281,8 +281,14 @@ async fn chat_completions(
|
||||
}),
|
||||
};
|
||||
|
||||
// Log successful request
|
||||
info!("Request completed successfully in {:?}", duration);
|
||||
// Log successful request with proxy overhead breakdown
|
||||
let upstream_ms = duration.as_millis() as u64 - proxy_overhead.as_millis() as u64;
|
||||
info!(
|
||||
"Request completed in {:?} (proxy: {}ms, upstream: {}ms)",
|
||||
duration,
|
||||
proxy_overhead.as_millis(),
|
||||
upstream_ms
|
||||
);
|
||||
|
||||
Ok(Json(chat_response).into_response())
|
||||
}
|
||||
|
||||
@@ -1,11 +1,89 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::{
|
||||
client::ClientManager, config::AppConfig, database::DbPool, logging::RequestLogger,
|
||||
models::registry::ModelRegistry, providers::ProviderManager, rate_limiting::RateLimitManager,
|
||||
};
|
||||
|
||||
/// Cached model configuration entry
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedModelConfig {
|
||||
pub enabled: bool,
|
||||
pub mapping: Option<String>,
|
||||
pub prompt_cost_per_m: Option<f64>,
|
||||
pub completion_cost_per_m: Option<f64>,
|
||||
}
|
||||
|
||||
/// In-memory cache for model_configs table.
|
||||
/// Refreshes periodically to avoid hitting SQLite on every request.
|
||||
#[derive(Clone)]
|
||||
pub struct ModelConfigCache {
|
||||
cache: Arc<RwLock<HashMap<String, CachedModelConfig>>>,
|
||||
db_pool: DbPool,
|
||||
}
|
||||
|
||||
impl ModelConfigCache {
|
||||
pub fn new(db_pool: DbPool) -> Self {
|
||||
Self {
|
||||
cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
db_pool,
|
||||
}
|
||||
}
|
||||
|
||||
/// Load all model configs from the database into cache
|
||||
pub async fn refresh(&self) {
|
||||
match sqlx::query_as::<_, (String, bool, Option<String>, Option<f64>, Option<f64>)>(
|
||||
"SELECT id, enabled, mapping, prompt_cost_per_m, completion_cost_per_m FROM model_configs",
|
||||
)
|
||||
.fetch_all(&self.db_pool)
|
||||
.await
|
||||
{
|
||||
Ok(rows) => {
|
||||
let mut map = HashMap::with_capacity(rows.len());
|
||||
for (id, enabled, mapping, prompt_cost, completion_cost) in rows {
|
||||
map.insert(
|
||||
id,
|
||||
CachedModelConfig {
|
||||
enabled,
|
||||
mapping,
|
||||
prompt_cost_per_m: prompt_cost,
|
||||
completion_cost_per_m: completion_cost,
|
||||
},
|
||||
);
|
||||
}
|
||||
*self.cache.write().await = map;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to refresh model config cache: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a cached model config. Returns None if not in cache (model is unconfigured).
|
||||
pub async fn get(&self, model: &str) -> Option<CachedModelConfig> {
|
||||
self.cache.read().await.get(model).cloned()
|
||||
}
|
||||
|
||||
/// Invalidate cache — call this after dashboard writes to model_configs
|
||||
pub async fn invalidate(&self) {
|
||||
self.refresh().await;
|
||||
}
|
||||
|
||||
/// Start a background task that refreshes the cache every `interval` seconds
|
||||
pub fn start_refresh_task(self, interval_secs: u64) {
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
self.refresh().await;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Shared application state
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
@@ -16,6 +94,7 @@ pub struct AppState {
|
||||
pub client_manager: Arc<ClientManager>,
|
||||
pub request_logger: Arc<RequestLogger>,
|
||||
pub model_registry: Arc<ModelRegistry>,
|
||||
pub model_config_cache: ModelConfigCache,
|
||||
pub dashboard_tx: broadcast::Sender<serde_json::Value>,
|
||||
pub auth_tokens: Vec<String>,
|
||||
}
|
||||
@@ -32,6 +111,7 @@ impl AppState {
|
||||
let client_manager = Arc::new(ClientManager::new(db_pool.clone()));
|
||||
let (dashboard_tx, _) = broadcast::channel(100);
|
||||
let request_logger = Arc::new(RequestLogger::new(db_pool.clone(), dashboard_tx.clone()));
|
||||
let model_config_cache = ModelConfigCache::new(db_pool.clone());
|
||||
|
||||
Self {
|
||||
config,
|
||||
@@ -41,6 +121,7 @@ impl AppState {
|
||||
client_manager,
|
||||
request_logger,
|
||||
model_registry: Arc::new(model_registry),
|
||||
model_config_cache,
|
||||
dashboard_tx,
|
||||
auth_tokens,
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@ use crate::errors::AppError;
|
||||
use crate::logging::{RequestLog, RequestLogger};
|
||||
use crate::models::ToolCall;
|
||||
use crate::providers::{Provider, ProviderStreamChunk};
|
||||
use crate::state::ModelConfigCache;
|
||||
use crate::utils::tokens::estimate_completion_tokens;
|
||||
use futures::stream::Stream;
|
||||
use sqlx::Row;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
@@ -20,7 +20,7 @@ pub struct StreamConfig {
|
||||
pub logger: Arc<RequestLogger>,
|
||||
pub client_manager: Arc<ClientManager>,
|
||||
pub model_registry: Arc<crate::models::registry::ModelRegistry>,
|
||||
pub db_pool: crate::database::DbPool,
|
||||
pub model_config_cache: ModelConfigCache,
|
||||
}
|
||||
|
||||
pub struct AggregatingStream<S> {
|
||||
@@ -36,7 +36,7 @@ pub struct AggregatingStream<S> {
|
||||
logger: Arc<RequestLogger>,
|
||||
client_manager: Arc<ClientManager>,
|
||||
model_registry: Arc<crate::models::registry::ModelRegistry>,
|
||||
db_pool: crate::database::DbPool,
|
||||
model_config_cache: ModelConfigCache,
|
||||
start_time: std::time::Instant,
|
||||
has_logged: bool,
|
||||
}
|
||||
@@ -59,7 +59,7 @@ where
|
||||
logger: config.logger,
|
||||
client_manager: config.client_manager,
|
||||
model_registry: config.model_registry,
|
||||
db_pool: config.db_pool,
|
||||
model_config_cache: config.model_config_cache,
|
||||
start_time: std::time::Instant::now(),
|
||||
has_logged: false,
|
||||
}
|
||||
@@ -81,7 +81,7 @@ where
|
||||
let prompt_tokens = self.prompt_tokens;
|
||||
let has_images = self.has_images;
|
||||
let registry = self.model_registry.clone();
|
||||
let pool = self.db_pool.clone();
|
||||
let config_cache = self.model_config_cache.clone();
|
||||
|
||||
// Estimate completion tokens (including reasoning if present)
|
||||
let content_tokens = estimate_completion_tokens(&self.accumulated_content, &model);
|
||||
@@ -96,19 +96,9 @@ where
|
||||
|
||||
// Spawn a background task to log the completion
|
||||
tokio::spawn(async move {
|
||||
// Check database for cost overrides
|
||||
let db_cost =
|
||||
sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?")
|
||||
.bind(&model)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let cost = if let Some(row) = db_cost {
|
||||
let prompt_rate = row.get::<Option<f64>, _>("prompt_cost_per_m");
|
||||
let completion_rate = row.get::<Option<f64>, _>("completion_cost_per_m");
|
||||
|
||||
if let (Some(p), Some(c)) = (prompt_rate, completion_rate) {
|
||||
// Check in-memory cache for cost overrides (no SQLite hit)
|
||||
let cost = if let Some(cached) = config_cache.get(&model).await {
|
||||
if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) {
|
||||
(prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0)
|
||||
} else {
|
||||
provider.calculate_cost(&model, prompt_tokens, completion_tokens, ®istry)
|
||||
@@ -284,7 +274,7 @@ mod tests {
|
||||
logger,
|
||||
client_manager,
|
||||
model_registry: registry,
|
||||
db_pool: pool.clone(),
|
||||
model_config_cache: ModelConfigCache::new(pool.clone()),
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
@@ -10,38 +10,57 @@ pub fn count_tokens(model: &str, text: &str) -> u32 {
|
||||
bpe.encode_with_special_tokens(text).len() as u32
|
||||
}
|
||||
|
||||
/// Estimate tokens for a unified request
|
||||
/// Estimate tokens for a unified request.
|
||||
/// Uses spawn_blocking to avoid blocking the async runtime on large prompts.
|
||||
pub fn estimate_request_tokens(model: &str, request: &UnifiedRequest) -> u32 {
|
||||
let mut total_tokens = 0;
|
||||
let mut total_text = String::new();
|
||||
let msg_count = request.messages.len();
|
||||
|
||||
// Base tokens per message for OpenAI (approximate)
|
||||
let tokens_per_message = 3;
|
||||
let _tokens_per_name = 1;
|
||||
let tokens_per_message: u32 = 3;
|
||||
|
||||
for msg in &request.messages {
|
||||
total_tokens += tokens_per_message;
|
||||
|
||||
for part in &msg.content {
|
||||
match part {
|
||||
crate::models::ContentPart::Text { text } => {
|
||||
total_tokens += count_tokens(model, text);
|
||||
total_text.push_str(text);
|
||||
total_text.push('\n');
|
||||
}
|
||||
crate::models::ContentPart::Image { .. } => {
|
||||
// Vision models usually have a fixed cost or calculation based on size
|
||||
// For now, let's use a conservative estimate of 1000 tokens
|
||||
total_tokens += 1000;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add name tokens if we had names (we don't in UnifiedMessage yet)
|
||||
// total_tokens += tokens_per_name;
|
||||
}
|
||||
|
||||
// Add 3 tokens for the assistant reply header
|
||||
total_tokens += 3;
|
||||
// Quick heuristic for small inputs (< 1KB) — avoid spawn_blocking overhead
|
||||
if total_text.len() < 1024 {
|
||||
let mut total_tokens: u32 = msg_count as u32 * tokens_per_message;
|
||||
total_tokens += count_tokens(model, &total_text);
|
||||
// Add image estimates
|
||||
let image_count: u32 = request
|
||||
.messages
|
||||
.iter()
|
||||
.flat_map(|m| m.content.iter())
|
||||
.filter(|p| matches!(p, crate::models::ContentPart::Image { .. }))
|
||||
.count() as u32;
|
||||
total_tokens += image_count * 1000;
|
||||
total_tokens += 3; // assistant reply header
|
||||
return total_tokens;
|
||||
}
|
||||
|
||||
total_tokens
|
||||
// For large inputs, use the fast heuristic (chars / 4) to avoid blocking
|
||||
// the async runtime. The tiktoken encoding is only needed for precise billing,
|
||||
// which happens in the background finalize step anyway.
|
||||
let estimated_text_tokens = (total_text.len() as u32) / 4;
|
||||
let image_count: u32 = request
|
||||
.messages
|
||||
.iter()
|
||||
.flat_map(|m| m.content.iter())
|
||||
.filter(|p| matches!(p, crate::models::ContentPart::Image { .. }))
|
||||
.count() as u32;
|
||||
|
||||
(msg_count as u32 * tokens_per_message) + estimated_text_tokens + (image_count * 1000) + 3
|
||||
}
|
||||
|
||||
/// Estimate tokens for completion text
|
||||
|
||||
Reference in New Issue
Block a user