use std::sync::Arc; use sqlx::Row; use uuid::Uuid; use axum::{ extract::State, routing::post, Json, Router, response::sse::{Event, Sse}, response::IntoResponse, }; use futures::stream::StreamExt; use tracing::{info, warn}; use crate::{ auth::AuthenticatedClient, errors::AppError, models::{ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatStreamChoice, ChatStreamDelta, ChatMessage, ChatChoice, Usage}, state::AppState, rate_limiting, }; pub fn router(state: AppState) -> Router { Router::new() .route("/v1/chat/completions", post(chat_completions)) .layer(axum::middleware::from_fn_with_state( state.clone(), rate_limiting::middleware::rate_limit_middleware, )) .with_state(state) } async fn get_model_cost( model: &str, prompt_tokens: u32, completion_tokens: u32, provider: &Arc, state: &AppState, ) -> f64 { // Check database for cost overrides let db_cost = sqlx::query("SELECT prompt_cost_per_m, completion_cost_per_m FROM model_configs WHERE id = ?") .bind(model) .fetch_optional(&state.db_pool) .await .unwrap_or(None); if let Some(row) = db_cost { let prompt_rate = row.get::, _>("prompt_cost_per_m"); let completion_rate = row.get::, _>("completion_cost_per_m"); if let (Some(p), Some(c)) = (prompt_rate, completion_rate) { return (prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0); } } // Fallback to provider's registry-based calculation provider.calculate_cost(model, prompt_tokens, completion_tokens, &state.model_registry) } async fn chat_completions( State(state): State, auth: AuthenticatedClient, Json(mut request): Json, ) -> Result { // Validate token against configured auth tokens if !state.auth_tokens.is_empty() && !state.auth_tokens.contains(&auth.token) { return Err(AppError::AuthError("Invalid authentication token".to_string())); } let start_time = std::time::Instant::now(); let client_id = auth.client_id.clone(); let model = request.model.clone(); info!("Chat completion request from client {} for model {}", client_id, model); // Check if model is enabled in database and get potential mapping let model_config = sqlx::query("SELECT enabled, mapping FROM model_configs WHERE id = ?") .bind(&model) .fetch_optional(&state.db_pool) .await .unwrap_or(None); let (model_enabled, model_mapping) = match model_config { Some(row) => (row.get::("enabled"), row.get::, _>("mapping")), None => (true, None), }; if !model_enabled { return Err(AppError::ValidationError(format!("Model {} is currently disabled", model))); } // Apply mapping if present if let Some(target_model) = model_mapping { info!("Mapping model {} to {}", model, target_model); request.model = target_model; } // Find appropriate provider for the model let provider = state.provider_manager.get_provider_for_model(&request.model).await .ok_or_else(|| AppError::ProviderError(format!("No provider found for model: {}", request.model)))?; let provider_name = provider.name().to_string(); // Check circuit breaker for this provider rate_limiting::middleware::circuit_breaker_middleware(&provider_name, &state).await?; // Convert to unified request format let mut unified_request = crate::models::UnifiedRequest::try_from(request) .map_err(|e| AppError::ValidationError(e.to_string()))?; // Set client_id from authentication unified_request.client_id = client_id.clone(); // Hydrate images if present if unified_request.has_images { unified_request.hydrate_images().await .map_err(|e| AppError::ValidationError(format!("Failed to process images: {}", e)))?; } // Check if streaming is requested if unified_request.stream { // Estimate prompt tokens for logging later let prompt_tokens = crate::utils::tokens::estimate_request_tokens(&model, &unified_request); let has_images = unified_request.has_images; // Handle streaming response let stream_result = provider.chat_completion_stream(unified_request).await; match stream_result { Ok(stream) => { // Record provider success state.rate_limit_manager.record_provider_success(&provider_name).await; // Wrap with AggregatingStream for token counting and database logging let aggregating_stream = crate::utils::streaming::AggregatingStream::new( stream, client_id.clone(), provider.clone(), model.clone(), prompt_tokens, has_images, state.request_logger.clone(), state.client_manager.clone(), state.model_registry.clone(), state.db_pool.clone(), ); // Create SSE stream from aggregating stream let sse_stream = aggregating_stream.map(move |chunk_result| { match chunk_result { Ok(chunk) => { // Convert provider chunk to OpenAI-compatible SSE event let response = ChatCompletionStreamResponse { id: format!("chatcmpl-{}", Uuid::new_v4()), object: "chat.completion.chunk".to_string(), created: chrono::Utc::now().timestamp() as u64, model: chunk.model.clone(), choices: vec![ChatStreamChoice { index: 0, delta: ChatStreamDelta { role: None, content: Some(chunk.content), reasoning_content: chunk.reasoning_content, }, finish_reason: chunk.finish_reason, }], }; Ok(Event::default().json_data(response).unwrap()) } Err(e) => { warn!("Error in streaming response: {}", e); Err(e) } } }); Ok(Sse::new(sse_stream).into_response()) } Err(e) => { // Record provider failure state.rate_limit_manager.record_provider_failure(&provider_name).await; // Log failed request let duration = start_time.elapsed(); warn!("Streaming request failed after {:?}: {}", duration, e); Err(e) } } } else { // Handle non-streaming response let result = provider.chat_completion(unified_request).await; match result { Ok(response) => { // Record provider success state.rate_limit_manager.record_provider_success(&provider_name).await; let duration = start_time.elapsed(); let cost = get_model_cost(&response.model, response.prompt_tokens, response.completion_tokens, &provider, &state).await; // Log request to database state.request_logger.log_request(crate::logging::RequestLog { timestamp: chrono::Utc::now(), client_id: client_id.clone(), provider: provider_name.clone(), model: response.model.clone(), prompt_tokens: response.prompt_tokens, completion_tokens: response.completion_tokens, total_tokens: response.total_tokens, cost, has_images: false, // TODO: check images status: "success".to_string(), error_message: None, duration_ms: duration.as_millis() as u64, }); // Update client usage let _ = state.client_manager.update_client_usage( &client_id, response.total_tokens as i64, cost, ).await; // Convert ProviderResponse to ChatCompletionResponse let chat_response = ChatCompletionResponse { id: format!("chatcmpl-{}", Uuid::new_v4()), object: "chat.completion".to_string(), created: chrono::Utc::now().timestamp() as u64, model: response.model, choices: vec![ChatChoice { index: 0, message: ChatMessage { role: "assistant".to_string(), content: crate::models::MessageContent::Text { content: response.content }, reasoning_content: response.reasoning_content, }, finish_reason: Some("stop".to_string()), }], usage: Some(Usage { prompt_tokens: response.prompt_tokens, completion_tokens: response.completion_tokens, total_tokens: response.total_tokens, }), }; // Log successful request info!("Request completed successfully in {:?}", duration); Ok(Json(chat_response).into_response()) } Err(e) => { // Record provider failure state.rate_limit_manager.record_provider_failure(&provider_name).await; // Log failed request to database let duration = start_time.elapsed(); state.request_logger.log_request(crate::logging::RequestLog { timestamp: chrono::Utc::now(), client_id: client_id.clone(), provider: provider_name.clone(), model: model.clone(), prompt_tokens: 0, completion_tokens: 0, total_tokens: 0, cost: 0.0, has_images: false, status: "error".to_string(), error_message: Some(e.to_string()), duration_ms: duration.as_millis() as u64, }); warn!("Request failed after {:?}: {}", duration, e); Err(e) } } } }