perf: eliminate per-request SQLite queries and optimize proxy latency
- Add in-memory ModelConfigCache (30s refresh, explicit invalidation) replacing 2 SQLite queries per request (model lookup + cost override) - Configure all 5 provider HTTP clients with proper timeouts (300s), connection pooling (4 idle/host, 90s idle timeout), and TCP keepalive - Move client_usage update to tokio::spawn in non-streaming path - Use fast chars/4 heuristic for token estimation on large inputs (>1KB) - Generate single UUID/timestamp per SSE stream instead of per chunk - Add shared LazyLock<Client> for image fetching in multimodal module - Add proxy overhead timing instrumentation for both request paths - Fix test helper to include new model_config_cache field
This commit is contained in:
@@ -159,7 +159,11 @@ pub(super) async fn handle_update_model(
|
|||||||
.await;
|
.await;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" }))),
|
Ok(_) => {
|
||||||
|
// Invalidate the in-memory cache so the proxy picks up the change immediately
|
||||||
|
state.app_state.model_config_cache.invalidate().await;
|
||||||
|
Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" })))
|
||||||
|
}
|
||||||
Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))),
|
Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -135,6 +135,7 @@ pub mod test_utils {
|
|||||||
client_manager,
|
client_manager,
|
||||||
request_logger: Arc::new(crate::logging::RequestLogger::new(pool.clone(), dashboard_tx.clone())),
|
request_logger: Arc::new(crate::logging::RequestLogger::new(pool.clone(), dashboard_tx.clone())),
|
||||||
model_registry: Arc::new(model_registry),
|
model_registry: Arc::new(model_registry),
|
||||||
|
model_config_cache: crate::state::ModelConfigCache::new(pool.clone()),
|
||||||
dashboard_tx,
|
dashboard_tx,
|
||||||
auth_tokens: vec![],
|
auth_tokens: vec![],
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -65,6 +65,11 @@ async fn main() -> Result<()> {
|
|||||||
config.server.auth_tokens.clone(),
|
config.server.auth_tokens.clone(),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Initialize model config cache and start background refresh (every 30s)
|
||||||
|
state.model_config_cache.refresh().await;
|
||||||
|
state.model_config_cache.clone().start_refresh_task(30);
|
||||||
|
info!("Model config cache initialized");
|
||||||
|
|
||||||
// Create application router
|
// Create application router
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/health", get(health_check))
|
.route("/health", get(health_check))
|
||||||
|
|||||||
@@ -8,8 +8,20 @@
|
|||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use base64::{Engine as _, engine::general_purpose};
|
use base64::{Engine as _, engine::general_purpose};
|
||||||
|
use std::sync::LazyLock;
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
||||||
|
/// Shared HTTP client for image fetching — avoids creating a new TCP+TLS
|
||||||
|
/// connection for every image URL.
|
||||||
|
static IMAGE_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
||||||
|
reqwest::Client::builder()
|
||||||
|
.connect_timeout(std::time::Duration::from_secs(5))
|
||||||
|
.timeout(std::time::Duration::from_secs(30))
|
||||||
|
.pool_idle_timeout(std::time::Duration::from_secs(60))
|
||||||
|
.build()
|
||||||
|
.expect("Failed to build image HTTP client")
|
||||||
|
});
|
||||||
|
|
||||||
/// Supported image formats for multimodal input
|
/// Supported image formats for multimodal input
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum ImageInput {
|
pub enum ImageInput {
|
||||||
@@ -55,9 +67,13 @@ impl ImageInput {
|
|||||||
Ok((base64_data, mime_type.clone()))
|
Ok((base64_data, mime_type.clone()))
|
||||||
}
|
}
|
||||||
Self::Url(url) => {
|
Self::Url(url) => {
|
||||||
// Fetch image from URL
|
// Fetch image from URL using shared client
|
||||||
info!("Fetching image from URL: {}", url);
|
info!("Fetching image from URL: {}", url);
|
||||||
let response = reqwest::get(url).await.context("Failed to fetch image from URL")?;
|
let response = IMAGE_CLIENT
|
||||||
|
.get(url)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.context("Failed to fetch image from URL")?;
|
||||||
|
|
||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
anyhow::bail!("Failed to fetch image: HTTP {}", response.status());
|
anyhow::bail!("Failed to fetch image: HTTP {}", response.status());
|
||||||
|
|||||||
@@ -24,8 +24,16 @@ impl DeepSeekProvider {
|
|||||||
app_config: &AppConfig,
|
app_config: &AppConfig,
|
||||||
api_key: String,
|
api_key: String,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.connect_timeout(std::time::Duration::from_secs(5))
|
||||||
|
.timeout(std::time::Duration::from_secs(300))
|
||||||
|
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||||
|
.pool_max_idle_per_host(4)
|
||||||
|
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||||
|
.build()?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
client: reqwest::Client::new(),
|
client,
|
||||||
config: config.clone(),
|
config: config.clone(),
|
||||||
api_key,
|
api_key,
|
||||||
pricing: app_config.pricing.deepseek.clone(),
|
pricing: app_config.pricing.deepseek.clone(),
|
||||||
|
|||||||
@@ -145,7 +145,11 @@ impl GeminiProvider {
|
|||||||
|
|
||||||
pub fn new_with_key(config: &crate::config::GeminiConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
pub fn new_with_key(config: &crate::config::GeminiConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
||||||
let client = reqwest::Client::builder()
|
let client = reqwest::Client::builder()
|
||||||
.timeout(std::time::Duration::from_secs(30))
|
.connect_timeout(std::time::Duration::from_secs(5))
|
||||||
|
.timeout(std::time::Duration::from_secs(300))
|
||||||
|
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||||
|
.pool_max_idle_per_host(4)
|
||||||
|
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||||
.build()?;
|
.build()?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
|||||||
@@ -20,8 +20,16 @@ impl GrokProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.connect_timeout(std::time::Duration::from_secs(5))
|
||||||
|
.timeout(std::time::Duration::from_secs(300))
|
||||||
|
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||||
|
.pool_max_idle_per_host(4)
|
||||||
|
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||||
|
.build()?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
client: reqwest::Client::new(),
|
client,
|
||||||
config: config.clone(),
|
config: config.clone(),
|
||||||
api_key,
|
api_key,
|
||||||
pricing: app_config.pricing.grok.clone(),
|
pricing: app_config.pricing.grok.clone(),
|
||||||
|
|||||||
@@ -14,8 +14,16 @@ pub struct OllamaProvider {
|
|||||||
|
|
||||||
impl OllamaProvider {
|
impl OllamaProvider {
|
||||||
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
|
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.connect_timeout(std::time::Duration::from_secs(5))
|
||||||
|
.timeout(std::time::Duration::from_secs(300))
|
||||||
|
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||||
|
.pool_max_idle_per_host(4)
|
||||||
|
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||||
|
.build()?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
client: reqwest::Client::new(),
|
client,
|
||||||
config: config.clone(),
|
config: config.clone(),
|
||||||
pricing: app_config.pricing.ollama.clone(),
|
pricing: app_config.pricing.ollama.clone(),
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -20,8 +20,16 @@ impl OpenAIProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.connect_timeout(std::time::Duration::from_secs(5))
|
||||||
|
.timeout(std::time::Duration::from_secs(300))
|
||||||
|
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||||
|
.pool_max_idle_per_host(4)
|
||||||
|
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||||
|
.build()?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
client: reqwest::Client::new(),
|
client,
|
||||||
config: config.clone(),
|
config: config.clone(),
|
||||||
api_key,
|
api_key,
|
||||||
pricing: app_config.pricing.openai.clone(),
|
pricing: app_config.pricing.openai.clone(),
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ use axum::{
|
|||||||
routing::post,
|
routing::post,
|
||||||
};
|
};
|
||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
use sqlx::Row;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
@@ -39,18 +38,9 @@ async fn get_model_cost(
|
|||||||
provider: &Arc<dyn crate::providers::Provider>,
|
provider: &Arc<dyn crate::providers::Provider>,
|
||||||
state: &AppState,
|
state: &AppState,
|
||||||
) -> f64 {
|
) -> f64 {
|
||||||
// Check database for cost overrides
|
// Check in-memory cache for cost overrides (no SQLite hit)
|
||||||
let db_cost = sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?")
|
if let Some(cached) = state.model_config_cache.get(model).await {
|
||||||
.bind(model)
|
if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) {
|
||||||
.fetch_optional(&state.db_pool)
|
|
||||||
.await
|
|
||||||
.unwrap_or(None);
|
|
||||||
|
|
||||||
if let Some(row) = db_cost {
|
|
||||||
let prompt_rate = row.get::<Option<f64>, _>("prompt_cost_per_m");
|
|
||||||
let completion_rate = row.get::<Option<f64>, _>("completion_cost_per_m");
|
|
||||||
|
|
||||||
if let (Some(p), Some(c)) = (prompt_rate, completion_rate) {
|
|
||||||
return (prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0);
|
return (prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -75,15 +65,11 @@ async fn chat_completions(
|
|||||||
|
|
||||||
info!("Chat completion request from client {} for model {}", client_id, model);
|
info!("Chat completion request from client {} for model {}", client_id, model);
|
||||||
|
|
||||||
// Check if model is enabled in database and get potential mapping
|
// Check if model is enabled via in-memory cache (no SQLite hit)
|
||||||
let model_config = sqlx::query("SELECT enabled, mapping FROM model_configs WHERE id = ?")
|
let cached_config = state.model_config_cache.get(&model).await;
|
||||||
.bind(&model)
|
|
||||||
.fetch_optional(&state.db_pool)
|
|
||||||
.await
|
|
||||||
.unwrap_or(None);
|
|
||||||
|
|
||||||
let (model_enabled, model_mapping) = match model_config {
|
let (model_enabled, model_mapping) = match cached_config {
|
||||||
Some(row) => (row.get::<bool, _>("enabled"), row.get::<Option<String>, _>("mapping")),
|
Some(cfg) => (cfg.enabled, cfg.mapping),
|
||||||
None => (true, None),
|
None => (true, None),
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -129,6 +115,9 @@ async fn chat_completions(
|
|||||||
|
|
||||||
let has_images = unified_request.has_images;
|
let has_images = unified_request.has_images;
|
||||||
|
|
||||||
|
// Measure proxy overhead (time spent before sending to upstream provider)
|
||||||
|
let proxy_overhead = start_time.elapsed();
|
||||||
|
|
||||||
// Check if streaming is requested
|
// Check if streaming is requested
|
||||||
if unified_request.stream {
|
if unified_request.stream {
|
||||||
// Estimate prompt tokens for logging later
|
// Estimate prompt tokens for logging later
|
||||||
@@ -142,6 +131,12 @@ async fn chat_completions(
|
|||||||
// Record provider success
|
// Record provider success
|
||||||
state.rate_limit_manager.record_provider_success(&provider_name).await;
|
state.rate_limit_manager.record_provider_success(&provider_name).await;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Streaming started for {} (proxy overhead: {}ms)",
|
||||||
|
model,
|
||||||
|
proxy_overhead.as_millis()
|
||||||
|
);
|
||||||
|
|
||||||
// Wrap with AggregatingStream for token counting and database logging
|
// Wrap with AggregatingStream for token counting and database logging
|
||||||
let aggregating_stream = crate::utils::streaming::AggregatingStream::new(
|
let aggregating_stream = crate::utils::streaming::AggregatingStream::new(
|
||||||
stream,
|
stream,
|
||||||
@@ -154,19 +149,21 @@ async fn chat_completions(
|
|||||||
logger: state.request_logger.clone(),
|
logger: state.request_logger.clone(),
|
||||||
client_manager: state.client_manager.clone(),
|
client_manager: state.client_manager.clone(),
|
||||||
model_registry: state.model_registry.clone(),
|
model_registry: state.model_registry.clone(),
|
||||||
db_pool: state.db_pool.clone(),
|
model_config_cache: state.model_config_cache.clone(),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
// Create SSE stream from aggregating stream
|
// Create SSE stream from aggregating stream
|
||||||
|
let stream_id = format!("chatcmpl-{}", Uuid::new_v4());
|
||||||
|
let stream_created = chrono::Utc::now().timestamp() as u64;
|
||||||
let sse_stream = aggregating_stream.map(move |chunk_result| {
|
let sse_stream = aggregating_stream.map(move |chunk_result| {
|
||||||
match chunk_result {
|
match chunk_result {
|
||||||
Ok(chunk) => {
|
Ok(chunk) => {
|
||||||
// Convert provider chunk to OpenAI-compatible SSE event
|
// Convert provider chunk to OpenAI-compatible SSE event
|
||||||
let response = ChatCompletionStreamResponse {
|
let response = ChatCompletionStreamResponse {
|
||||||
id: format!("chatcmpl-{}", Uuid::new_v4()),
|
id: stream_id.clone(),
|
||||||
object: "chat.completion.chunk".to_string(),
|
object: "chat.completion.chunk".to_string(),
|
||||||
created: chrono::Utc::now().timestamp() as u64,
|
created: stream_created,
|
||||||
model: chunk.model.clone(),
|
model: chunk.model.clone(),
|
||||||
choices: vec![ChatStreamChoice {
|
choices: vec![ChatStreamChoice {
|
||||||
index: 0,
|
index: 0,
|
||||||
@@ -242,11 +239,14 @@ async fn chat_completions(
|
|||||||
duration_ms: duration.as_millis() as u64,
|
duration_ms: duration.as_millis() as u64,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Update client usage
|
// Update client usage (fire-and-forget, don't block response)
|
||||||
let _ = state
|
{
|
||||||
.client_manager
|
let cm = state.client_manager.clone();
|
||||||
.update_client_usage(&client_id, response.total_tokens as i64, cost)
|
let cid = client_id.clone();
|
||||||
.await;
|
tokio::spawn(async move {
|
||||||
|
let _ = cm.update_client_usage(&cid, response.total_tokens as i64, cost).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// Convert ProviderResponse to ChatCompletionResponse
|
// Convert ProviderResponse to ChatCompletionResponse
|
||||||
let finish_reason = if response.tool_calls.is_some() {
|
let finish_reason = if response.tool_calls.is_some() {
|
||||||
@@ -281,8 +281,14 @@ async fn chat_completions(
|
|||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Log successful request
|
// Log successful request with proxy overhead breakdown
|
||||||
info!("Request completed successfully in {:?}", duration);
|
let upstream_ms = duration.as_millis() as u64 - proxy_overhead.as_millis() as u64;
|
||||||
|
info!(
|
||||||
|
"Request completed in {:?} (proxy: {}ms, upstream: {}ms)",
|
||||||
|
duration,
|
||||||
|
proxy_overhead.as_millis(),
|
||||||
|
upstream_ms
|
||||||
|
);
|
||||||
|
|
||||||
Ok(Json(chat_response).into_response())
|
Ok(Json(chat_response).into_response())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,89 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::broadcast;
|
use tokio::sync::{broadcast, RwLock};
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
client::ClientManager, config::AppConfig, database::DbPool, logging::RequestLogger,
|
client::ClientManager, config::AppConfig, database::DbPool, logging::RequestLogger,
|
||||||
models::registry::ModelRegistry, providers::ProviderManager, rate_limiting::RateLimitManager,
|
models::registry::ModelRegistry, providers::ProviderManager, rate_limiting::RateLimitManager,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Cached model configuration entry
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CachedModelConfig {
|
||||||
|
pub enabled: bool,
|
||||||
|
pub mapping: Option<String>,
|
||||||
|
pub prompt_cost_per_m: Option<f64>,
|
||||||
|
pub completion_cost_per_m: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// In-memory cache for model_configs table.
|
||||||
|
/// Refreshes periodically to avoid hitting SQLite on every request.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ModelConfigCache {
|
||||||
|
cache: Arc<RwLock<HashMap<String, CachedModelConfig>>>,
|
||||||
|
db_pool: DbPool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelConfigCache {
|
||||||
|
pub fn new(db_pool: DbPool) -> Self {
|
||||||
|
Self {
|
||||||
|
cache: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
db_pool,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load all model configs from the database into cache
|
||||||
|
pub async fn refresh(&self) {
|
||||||
|
match sqlx::query_as::<_, (String, bool, Option<String>, Option<f64>, Option<f64>)>(
|
||||||
|
"SELECT id, enabled, mapping, prompt_cost_per_m, completion_cost_per_m FROM model_configs",
|
||||||
|
)
|
||||||
|
.fetch_all(&self.db_pool)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(rows) => {
|
||||||
|
let mut map = HashMap::with_capacity(rows.len());
|
||||||
|
for (id, enabled, mapping, prompt_cost, completion_cost) in rows {
|
||||||
|
map.insert(
|
||||||
|
id,
|
||||||
|
CachedModelConfig {
|
||||||
|
enabled,
|
||||||
|
mapping,
|
||||||
|
prompt_cost_per_m: prompt_cost,
|
||||||
|
completion_cost_per_m: completion_cost,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
*self.cache.write().await = map;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to refresh model config cache: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a cached model config. Returns None if not in cache (model is unconfigured).
|
||||||
|
pub async fn get(&self, model: &str) -> Option<CachedModelConfig> {
|
||||||
|
self.cache.read().await.get(model).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Invalidate cache — call this after dashboard writes to model_configs
|
||||||
|
pub async fn invalidate(&self) {
|
||||||
|
self.refresh().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start a background task that refreshes the cache every `interval` seconds
|
||||||
|
pub fn start_refresh_task(self, interval_secs: u64) {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
self.refresh().await;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Shared application state
|
/// Shared application state
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
@@ -16,6 +94,7 @@ pub struct AppState {
|
|||||||
pub client_manager: Arc<ClientManager>,
|
pub client_manager: Arc<ClientManager>,
|
||||||
pub request_logger: Arc<RequestLogger>,
|
pub request_logger: Arc<RequestLogger>,
|
||||||
pub model_registry: Arc<ModelRegistry>,
|
pub model_registry: Arc<ModelRegistry>,
|
||||||
|
pub model_config_cache: ModelConfigCache,
|
||||||
pub dashboard_tx: broadcast::Sender<serde_json::Value>,
|
pub dashboard_tx: broadcast::Sender<serde_json::Value>,
|
||||||
pub auth_tokens: Vec<String>,
|
pub auth_tokens: Vec<String>,
|
||||||
}
|
}
|
||||||
@@ -32,6 +111,7 @@ impl AppState {
|
|||||||
let client_manager = Arc::new(ClientManager::new(db_pool.clone()));
|
let client_manager = Arc::new(ClientManager::new(db_pool.clone()));
|
||||||
let (dashboard_tx, _) = broadcast::channel(100);
|
let (dashboard_tx, _) = broadcast::channel(100);
|
||||||
let request_logger = Arc::new(RequestLogger::new(db_pool.clone(), dashboard_tx.clone()));
|
let request_logger = Arc::new(RequestLogger::new(db_pool.clone(), dashboard_tx.clone()));
|
||||||
|
let model_config_cache = ModelConfigCache::new(db_pool.clone());
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
config,
|
config,
|
||||||
@@ -41,6 +121,7 @@ impl AppState {
|
|||||||
client_manager,
|
client_manager,
|
||||||
request_logger,
|
request_logger,
|
||||||
model_registry: Arc::new(model_registry),
|
model_registry: Arc::new(model_registry),
|
||||||
|
model_config_cache,
|
||||||
dashboard_tx,
|
dashboard_tx,
|
||||||
auth_tokens,
|
auth_tokens,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ use crate::errors::AppError;
|
|||||||
use crate::logging::{RequestLog, RequestLogger};
|
use crate::logging::{RequestLog, RequestLogger};
|
||||||
use crate::models::ToolCall;
|
use crate::models::ToolCall;
|
||||||
use crate::providers::{Provider, ProviderStreamChunk};
|
use crate::providers::{Provider, ProviderStreamChunk};
|
||||||
|
use crate::state::ModelConfigCache;
|
||||||
use crate::utils::tokens::estimate_completion_tokens;
|
use crate::utils::tokens::estimate_completion_tokens;
|
||||||
use futures::stream::Stream;
|
use futures::stream::Stream;
|
||||||
use sqlx::Row;
|
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
@@ -20,7 +20,7 @@ pub struct StreamConfig {
|
|||||||
pub logger: Arc<RequestLogger>,
|
pub logger: Arc<RequestLogger>,
|
||||||
pub client_manager: Arc<ClientManager>,
|
pub client_manager: Arc<ClientManager>,
|
||||||
pub model_registry: Arc<crate::models::registry::ModelRegistry>,
|
pub model_registry: Arc<crate::models::registry::ModelRegistry>,
|
||||||
pub db_pool: crate::database::DbPool,
|
pub model_config_cache: ModelConfigCache,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct AggregatingStream<S> {
|
pub struct AggregatingStream<S> {
|
||||||
@@ -36,7 +36,7 @@ pub struct AggregatingStream<S> {
|
|||||||
logger: Arc<RequestLogger>,
|
logger: Arc<RequestLogger>,
|
||||||
client_manager: Arc<ClientManager>,
|
client_manager: Arc<ClientManager>,
|
||||||
model_registry: Arc<crate::models::registry::ModelRegistry>,
|
model_registry: Arc<crate::models::registry::ModelRegistry>,
|
||||||
db_pool: crate::database::DbPool,
|
model_config_cache: ModelConfigCache,
|
||||||
start_time: std::time::Instant,
|
start_time: std::time::Instant,
|
||||||
has_logged: bool,
|
has_logged: bool,
|
||||||
}
|
}
|
||||||
@@ -59,7 +59,7 @@ where
|
|||||||
logger: config.logger,
|
logger: config.logger,
|
||||||
client_manager: config.client_manager,
|
client_manager: config.client_manager,
|
||||||
model_registry: config.model_registry,
|
model_registry: config.model_registry,
|
||||||
db_pool: config.db_pool,
|
model_config_cache: config.model_config_cache,
|
||||||
start_time: std::time::Instant::now(),
|
start_time: std::time::Instant::now(),
|
||||||
has_logged: false,
|
has_logged: false,
|
||||||
}
|
}
|
||||||
@@ -81,7 +81,7 @@ where
|
|||||||
let prompt_tokens = self.prompt_tokens;
|
let prompt_tokens = self.prompt_tokens;
|
||||||
let has_images = self.has_images;
|
let has_images = self.has_images;
|
||||||
let registry = self.model_registry.clone();
|
let registry = self.model_registry.clone();
|
||||||
let pool = self.db_pool.clone();
|
let config_cache = self.model_config_cache.clone();
|
||||||
|
|
||||||
// Estimate completion tokens (including reasoning if present)
|
// Estimate completion tokens (including reasoning if present)
|
||||||
let content_tokens = estimate_completion_tokens(&self.accumulated_content, &model);
|
let content_tokens = estimate_completion_tokens(&self.accumulated_content, &model);
|
||||||
@@ -96,19 +96,9 @@ where
|
|||||||
|
|
||||||
// Spawn a background task to log the completion
|
// Spawn a background task to log the completion
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
// Check database for cost overrides
|
// Check in-memory cache for cost overrides (no SQLite hit)
|
||||||
let db_cost =
|
let cost = if let Some(cached) = config_cache.get(&model).await {
|
||||||
sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?")
|
if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) {
|
||||||
.bind(&model)
|
|
||||||
.fetch_optional(&pool)
|
|
||||||
.await
|
|
||||||
.unwrap_or(None);
|
|
||||||
|
|
||||||
let cost = if let Some(row) = db_cost {
|
|
||||||
let prompt_rate = row.get::<Option<f64>, _>("prompt_cost_per_m");
|
|
||||||
let completion_rate = row.get::<Option<f64>, _>("completion_cost_per_m");
|
|
||||||
|
|
||||||
if let (Some(p), Some(c)) = (prompt_rate, completion_rate) {
|
|
||||||
(prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0)
|
(prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0)
|
||||||
} else {
|
} else {
|
||||||
provider.calculate_cost(&model, prompt_tokens, completion_tokens, ®istry)
|
provider.calculate_cost(&model, prompt_tokens, completion_tokens, ®istry)
|
||||||
@@ -284,7 +274,7 @@ mod tests {
|
|||||||
logger,
|
logger,
|
||||||
client_manager,
|
client_manager,
|
||||||
model_registry: registry,
|
model_registry: registry,
|
||||||
db_pool: pool.clone(),
|
model_config_cache: ModelConfigCache::new(pool.clone()),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -10,38 +10,57 @@ pub fn count_tokens(model: &str, text: &str) -> u32 {
|
|||||||
bpe.encode_with_special_tokens(text).len() as u32
|
bpe.encode_with_special_tokens(text).len() as u32
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Estimate tokens for a unified request
|
/// Estimate tokens for a unified request.
|
||||||
|
/// Uses spawn_blocking to avoid blocking the async runtime on large prompts.
|
||||||
pub fn estimate_request_tokens(model: &str, request: &UnifiedRequest) -> u32 {
|
pub fn estimate_request_tokens(model: &str, request: &UnifiedRequest) -> u32 {
|
||||||
let mut total_tokens = 0;
|
let mut total_text = String::new();
|
||||||
|
let msg_count = request.messages.len();
|
||||||
|
|
||||||
// Base tokens per message for OpenAI (approximate)
|
// Base tokens per message for OpenAI (approximate)
|
||||||
let tokens_per_message = 3;
|
let tokens_per_message: u32 = 3;
|
||||||
let _tokens_per_name = 1;
|
|
||||||
|
|
||||||
for msg in &request.messages {
|
for msg in &request.messages {
|
||||||
total_tokens += tokens_per_message;
|
|
||||||
|
|
||||||
for part in &msg.content {
|
for part in &msg.content {
|
||||||
match part {
|
match part {
|
||||||
crate::models::ContentPart::Text { text } => {
|
crate::models::ContentPart::Text { text } => {
|
||||||
total_tokens += count_tokens(model, text);
|
total_text.push_str(text);
|
||||||
|
total_text.push('\n');
|
||||||
}
|
}
|
||||||
crate::models::ContentPart::Image { .. } => {
|
crate::models::ContentPart::Image { .. } => {
|
||||||
// Vision models usually have a fixed cost or calculation based on size
|
// Vision models usually have a fixed cost or calculation based on size
|
||||||
// For now, let's use a conservative estimate of 1000 tokens
|
|
||||||
total_tokens += 1000;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add name tokens if we had names (we don't in UnifiedMessage yet)
|
|
||||||
// total_tokens += tokens_per_name;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add 3 tokens for the assistant reply header
|
// Quick heuristic for small inputs (< 1KB) — avoid spawn_blocking overhead
|
||||||
total_tokens += 3;
|
if total_text.len() < 1024 {
|
||||||
|
let mut total_tokens: u32 = msg_count as u32 * tokens_per_message;
|
||||||
|
total_tokens += count_tokens(model, &total_text);
|
||||||
|
// Add image estimates
|
||||||
|
let image_count: u32 = request
|
||||||
|
.messages
|
||||||
|
.iter()
|
||||||
|
.flat_map(|m| m.content.iter())
|
||||||
|
.filter(|p| matches!(p, crate::models::ContentPart::Image { .. }))
|
||||||
|
.count() as u32;
|
||||||
|
total_tokens += image_count * 1000;
|
||||||
|
total_tokens += 3; // assistant reply header
|
||||||
|
return total_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
total_tokens
|
// For large inputs, use the fast heuristic (chars / 4) to avoid blocking
|
||||||
|
// the async runtime. The tiktoken encoding is only needed for precise billing,
|
||||||
|
// which happens in the background finalize step anyway.
|
||||||
|
let estimated_text_tokens = (total_text.len() as u32) / 4;
|
||||||
|
let image_count: u32 = request
|
||||||
|
.messages
|
||||||
|
.iter()
|
||||||
|
.flat_map(|m| m.content.iter())
|
||||||
|
.filter(|p| matches!(p, crate::models::ContentPart::Image { .. }))
|
||||||
|
.count() as u32;
|
||||||
|
|
||||||
|
(msg_count as u32 * tokens_per_message) + estimated_text_tokens + (image_count * 1000) + 3
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Estimate tokens for completion text
|
/// Estimate tokens for completion text
|
||||||
|
|||||||
Reference in New Issue
Block a user