use chrono::{DateTime, Utc}; use serde::Serialize; use sqlx::SqlitePool; use tokio::sync::broadcast; use tracing::{info, warn}; use crate::errors::AppError; /// Request log entry for database storage #[derive(Debug, Clone, Serialize)] pub struct RequestLog { pub timestamp: DateTime, pub client_id: String, pub provider: String, pub model: String, pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, pub cache_read_tokens: u32, pub cache_write_tokens: u32, pub cost: f64, pub has_images: bool, pub status: String, // "success", "error" pub error_message: Option, pub duration_ms: u64, } /// Database operations for request logging pub struct RequestLogger { db_pool: SqlitePool, dashboard_tx: broadcast::Sender, } impl RequestLogger { pub fn new(db_pool: SqlitePool, dashboard_tx: broadcast::Sender) -> Self { Self { db_pool, dashboard_tx } } /// Log a request to the database (async, spawns a task) pub fn log_request(&self, log: RequestLog) { let pool = self.db_pool.clone(); let tx = self.dashboard_tx.clone(); // Spawn async task to log without blocking response tokio::spawn(async move { // Broadcast to dashboard let broadcast_result = tx.send(serde_json::json!({ "type": "request", "channel": "requests", "payload": log })); match broadcast_result { Ok(receivers) => info!("Broadcast request log to {} dashboard listeners", receivers), Err(_) => {} // No active WebSocket clients — expected when dashboard isn't open } match Self::insert_log(&pool, log).await { Ok(()) => info!("Request logged to database successfully"), Err(e) => warn!("Failed to log request to database: {}", e), } }); } /// Insert a log entry into the database async fn insert_log(pool: &SqlitePool, log: RequestLog) -> Result<(), sqlx::Error> { let mut tx = pool.begin().await?; // Ensure the client row exists (FK constraint requires it) sqlx::query( "INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')", ) .bind(&log.client_id) .bind(&log.client_id) .execute(&mut *tx) .await?; sqlx::query( r#" INSERT INTO llm_requests (timestamp, client_id, provider, model, prompt_tokens, completion_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status, error_message, duration_ms, request_body, response_body) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) "#, ) .bind(log.timestamp) .bind(log.client_id) .bind(&log.provider) .bind(log.model) .bind(log.prompt_tokens as i64) .bind(log.completion_tokens as i64) .bind(log.total_tokens as i64) .bind(log.cache_read_tokens as i64) .bind(log.cache_write_tokens as i64) .bind(log.cost) .bind(log.has_images) .bind(log.status) .bind(log.error_message) .bind(log.duration_ms as i64) .bind(None::) // request_body - optional, not stored to save disk space .bind(None::) // response_body - optional, not stored to save disk space .execute(&mut *tx) .await?; // Deduct from provider balance if successful. // Providers configured with billing_mode = 'postpaid' will not have their // credit_balance decremented. Use a conditional UPDATE so we don't need // a prior SELECT and avoid extra round-trips. if log.cost > 0.0 { sqlx::query( "UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')", ) .bind(log.cost) .bind(&log.provider) .execute(&mut *tx) .await?; } tx.commit().await?; Ok(()) } } // /// Middleware to log LLM API requests // /// TODO: Implement proper middleware that can extract response body details // pub async fn request_logging_middleware( // // Extract the authenticated client (if available) // auth_result: Result, // request: Request, // next: Next, // ) -> Response { // let start_time = std::time::Instant::now(); // // // Extract client_id from auth or use "unknown" // let client_id = match auth_result { // Ok(auth) => auth.client_id, // Err(_) => "unknown".to_string(), // }; // // // Try to extract request details // let (request_parts, request_body) = request.into_parts(); // // // Clone request parts for logging // let path = request_parts.uri.path().to_string(); // // // Check if this is a chat completion request // let is_chat_completion = path == "/v1/chat/completions"; // // // Reconstruct request for downstream handlers // let request = Request::from_parts(request_parts, request_body); // // // Process request and get response // let response = next.run(request).await; // // // Calculate duration // let duration = start_time.elapsed(); // let duration_ms = duration.as_millis() as u64; // // // Log basic request info // info!( // "Request from {} to {} - Status: {} - Duration: {}ms", // client_id, // path, // response.status().as_u16(), // duration_ms // ); // // // TODO: Extract more details from request/response for logging // // For now, we'll need to modify the server handler to pass additional context // // response // } /// Context for request logging that can be passed through extensions #[derive(Clone)] pub struct LoggingContext { pub client_id: String, pub provider_name: String, pub model: String, pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, pub cost: f64, pub has_images: bool, pub error: Option, } impl LoggingContext { pub fn new(client_id: String, provider_name: String, model: String) -> Self { Self { client_id, provider_name, model, prompt_tokens: 0, completion_tokens: 0, total_tokens: 0, cost: 0.0, has_images: false, error: None, } } pub fn with_token_counts(mut self, prompt_tokens: u32, completion_tokens: u32) -> Self { self.prompt_tokens = prompt_tokens; self.completion_tokens = completion_tokens; self.total_tokens = prompt_tokens + completion_tokens; self } pub fn with_cost(mut self, cost: f64) -> Self { self.cost = cost; self } pub fn with_images(mut self, has_images: bool) -> Self { self.has_images = has_images; self } pub fn with_error(mut self, error: AppError) -> Self { self.error = Some(error); self } }