use crate::client::ClientManager; use crate::errors::AppError; use crate::logging::{RequestLog, RequestLogger}; use crate::models::ToolCall; use crate::providers::{Provider, ProviderStreamChunk, StreamUsage}; use crate::state::ModelConfigCache; use crate::utils::tokens::estimate_completion_tokens; use futures::stream::Stream; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; /// Configuration for creating an AggregatingStream. pub struct StreamConfig { pub client_id: String, pub provider: Arc, pub model: String, pub prompt_tokens: u32, pub has_images: bool, pub logger: Arc, pub client_manager: Arc, pub model_registry: Arc, pub model_config_cache: ModelConfigCache, } pub struct AggregatingStream { inner: S, client_id: String, provider: Arc, model: String, prompt_tokens: u32, has_images: bool, accumulated_content: String, accumulated_reasoning: String, accumulated_tool_calls: Vec, /// Real usage data from the provider's final stream chunk (when available). real_usage: Option, logger: Arc, client_manager: Arc, model_registry: Arc, model_config_cache: ModelConfigCache, start_time: std::time::Instant, has_logged: bool, } impl AggregatingStream where S: Stream> + Unpin, { pub fn new(inner: S, config: StreamConfig) -> Self { Self { inner, client_id: config.client_id, provider: config.provider, model: config.model, prompt_tokens: config.prompt_tokens, has_images: config.has_images, accumulated_content: String::new(), accumulated_reasoning: String::new(), accumulated_tool_calls: Vec::new(), real_usage: None, logger: config.logger, client_manager: config.client_manager, model_registry: config.model_registry, model_config_cache: config.model_config_cache, start_time: std::time::Instant::now(), has_logged: false, } } fn finalize(&mut self) { if self.has_logged { return; } self.has_logged = true; let duration = self.start_time.elapsed(); let client_id = self.client_id.clone(); let provider_name = self.provider.name().to_string(); let model = self.model.clone(); let logger = self.logger.clone(); let client_manager = self.client_manager.clone(); let provider = self.provider.clone(); let estimated_prompt_tokens = self.prompt_tokens; let has_images = self.has_images; let registry = self.model_registry.clone(); let config_cache = self.model_config_cache.clone(); let real_usage = self.real_usage.take(); // Estimate completion tokens (including reasoning if present) let estimated_content_tokens = estimate_completion_tokens(&self.accumulated_content, &model); let estimated_reasoning_tokens = if !self.accumulated_reasoning.is_empty() { estimate_completion_tokens(&self.accumulated_reasoning, &model) } else { 0 }; let estimated_completion = estimated_content_tokens + estimated_reasoning_tokens; // Spawn a background task to log the completion tokio::spawn(async move { // Use real usage from the provider when available, otherwise fall back to estimates let (prompt_tokens, completion_tokens, total_tokens, cache_read_tokens, cache_write_tokens) = if let Some(usage) = &real_usage { ( usage.prompt_tokens, usage.completion_tokens, usage.total_tokens, usage.cache_read_tokens, usage.cache_write_tokens, ) } else { ( estimated_prompt_tokens, estimated_completion, estimated_prompt_tokens + estimated_completion, 0u32, 0u32, ) }; // Check in-memory cache for cost overrides (no SQLite hit) let cost = if let Some(cached) = config_cache.get(&model).await { if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) { // Cost override doesn't have cache-aware pricing, use simple formula (prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0) } else { provider.calculate_cost( &model, prompt_tokens, completion_tokens, cache_read_tokens, cache_write_tokens, ®istry, ) } } else { provider.calculate_cost( &model, prompt_tokens, completion_tokens, cache_read_tokens, cache_write_tokens, ®istry, ) }; // Log to database logger.log_request(RequestLog { timestamp: chrono::Utc::now(), client_id: client_id.clone(), provider: provider_name, model, prompt_tokens, completion_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status: "success".to_string(), error_message: None, duration_ms: duration.as_millis() as u64, }); // Update client usage let _ = client_manager .update_client_usage(&client_id, total_tokens as i64, cost) .await; }); } } impl Stream for AggregatingStream where S: Stream> + Unpin, { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let result = Pin::new(&mut self.inner).poll_next(cx); match &result { Poll::Ready(Some(Ok(chunk))) => { self.accumulated_content.push_str(&chunk.content); if let Some(reasoning) = &chunk.reasoning_content { self.accumulated_reasoning.push_str(reasoning); } // Capture real usage from the provider when present (typically on the final chunk) if let Some(usage) = &chunk.usage { self.real_usage = Some(usage.clone()); } // Accumulate tool call deltas into complete tool calls if let Some(deltas) = &chunk.tool_calls { for delta in deltas { let idx = delta.index as usize; // Grow the accumulated_tool_calls vec if needed while self.accumulated_tool_calls.len() <= idx { self.accumulated_tool_calls.push(ToolCall { id: String::new(), call_type: "function".to_string(), function: crate::models::FunctionCall { name: String::new(), arguments: String::new(), }, }); } let tc = &mut self.accumulated_tool_calls[idx]; if let Some(id) = &delta.id { tc.id.clone_from(id); } if let Some(ct) = &delta.call_type { tc.call_type.clone_from(ct); } if let Some(f) = &delta.function { if let Some(name) = &f.name { tc.function.name.push_str(name); } if let Some(args) = &f.arguments { tc.function.arguments.push_str(args); } } } } } Poll::Ready(Some(Err(_))) => { // If there's an error, we might still want to log what we got so far? // For now, just finalize if we have content if !self.accumulated_content.is_empty() { self.finalize(); } } Poll::Ready(None) => { self.finalize(); } Poll::Pending => {} } result } } #[cfg(test)] mod tests { use super::*; use anyhow::Result; use futures::stream::{self, StreamExt}; // Simple mock provider for testing struct MockProvider; #[async_trait::async_trait] impl Provider for MockProvider { fn name(&self) -> &str { "mock" } fn supports_model(&self, _model: &str) -> bool { true } fn supports_multimodal(&self) -> bool { false } async fn chat_completion( &self, _req: crate::models::UnifiedRequest, ) -> Result { unimplemented!() } async fn chat_completion_stream( &self, _req: crate::models::UnifiedRequest, ) -> Result>, AppError> { unimplemented!() } fn estimate_tokens(&self, _req: &crate::models::UnifiedRequest) -> Result { Ok(10) } fn calculate_cost(&self, _model: &str, _p: u32, _c: u32, _cr: u32, _cw: u32, _r: &crate::models::registry::ModelRegistry) -> f64 { 0.05 } } #[tokio::test] async fn test_aggregating_stream() { let chunks = vec![ Ok(ProviderStreamChunk { content: "Hello".to_string(), reasoning_content: None, finish_reason: None, tool_calls: None, model: "test".to_string(), usage: None, }), Ok(ProviderStreamChunk { content: " World".to_string(), reasoning_content: None, finish_reason: Some("stop".to_string()), tool_calls: None, model: "test".to_string(), usage: None, }), ]; let inner_stream = stream::iter(chunks); let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); let (dashboard_tx, _) = tokio::sync::broadcast::channel(16); let logger = Arc::new(RequestLogger::new(pool.clone(), dashboard_tx)); let client_manager = Arc::new(ClientManager::new(pool.clone())); let registry = Arc::new(crate::models::registry::ModelRegistry { providers: std::collections::HashMap::new(), }); let mut agg_stream = AggregatingStream::new( inner_stream, StreamConfig { client_id: "client_1".to_string(), provider: Arc::new(MockProvider), model: "test".to_string(), prompt_tokens: 10, has_images: false, logger, client_manager, model_registry: registry, model_config_cache: ModelConfigCache::new(pool.clone()), }, ); while let Some(item) = agg_stream.next().await { assert!(item.is_ok()); } assert_eq!(agg_stream.accumulated_content, "Hello World"); assert!(agg_stream.has_logged); } }