use axum::{ Json, Router, extract::State, response::IntoResponse, response::sse::{Event, Sse}, routing::{get, post}, }; use futures::stream::StreamExt; use std::time::Duration; use sqlx; use std::sync::Arc; use tracing::{info, warn}; use uuid::Uuid; use crate::{ auth::AuthenticatedClient, errors::AppError, models::{ ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage, ChatStreamChoice, ChatStreamDelta, Usage, }, rate_limiting, state::AppState, }; pub fn router(state: AppState) -> Router { Router::new() .route("/v1/chat/completions", post(chat_completions)) .route("/v1/models", get(list_models)) .layer(axum::middleware::from_fn_with_state( state.clone(), rate_limiting::middleware::rate_limit_middleware, )) .with_state(state) } /// GET /v1/models — OpenAI-compatible model listing. /// Returns all models from enabled providers so clients like Open WebUI can /// discover which models are available through the proxy. async fn list_models( State(state): State, _auth: AuthenticatedClient, ) -> Result, AppError> { let registry = &state.model_registry; let providers = state.provider_manager.get_all_providers().await; let mut models = Vec::new(); for provider in &providers { let provider_name = provider.name(); // Map internal provider names to registry provider IDs let registry_key = match provider_name { "gemini" => "google", "grok" => "xai", _ => provider_name, }; // Find this provider's models in the registry if let Some(provider_info) = registry.providers.get(registry_key) { for (model_id, meta) in &provider_info.models { // Skip disabled models via the config cache if let Some(cfg) = state.model_config_cache.get(model_id).await { if !cfg.enabled { continue; } } models.push(serde_json::json!({ "id": model_id, "object": "model", "created": 0, "owned_by": provider_name, "name": meta.name, })); } } // For Ollama, models are configured in the TOML, not the registry if provider_name == "ollama" { for model_id in &state.config.providers.ollama.models { models.push(serde_json::json!({ "id": model_id, "object": "model", "created": 0, "owned_by": "ollama", })); } } } Ok(Json(serde_json::json!({ "object": "list", "data": models }))) } async fn get_model_cost( model: &str, prompt_tokens: u32, completion_tokens: u32, cache_read_tokens: u32, cache_write_tokens: u32, provider: &Arc, state: &AppState, ) -> f64 { // Check in-memory cache for cost overrides (no SQLite hit) if let Some(cached) = state.model_config_cache.get(model).await { if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) { // Manual overrides don't have cache-specific rates, so use simple formula return (prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0); } } // Fallback to provider's registry-based calculation (cache-aware) provider.calculate_cost(model, prompt_tokens, completion_tokens, cache_read_tokens, cache_write_tokens, &state.model_registry) } async fn chat_completions( State(state): State, auth: AuthenticatedClient, Json(mut request): Json, ) -> Result { // Resolve client_id: try DB token first, then env tokens, then permissive fallback let db_client_id: Option = sqlx::query_scalar::<_, String>( "SELECT client_id FROM client_tokens WHERE token = ? AND is_active = TRUE", ) .bind(&auth.token) .fetch_optional(&state.db_pool) .await .unwrap_or(None); let client_id = if let Some(cid) = db_client_id { // Update last_used_at in background (fire-and-forget) let pool = state.db_pool.clone(); let token = auth.token.clone(); tokio::spawn(async move { let _ = sqlx::query("UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ?") .bind(&token) .execute(&pool) .await; }); cid } else if state.auth_tokens.is_empty() || state.auth_tokens.contains(&auth.token) { // Env token match or permissive mode (no env tokens configured) auth.client_id.clone() } else { return Err(AppError::AuthError("Invalid authentication token".to_string())); }; let start_time = std::time::Instant::now(); let model = request.model.clone(); info!("Chat completion request from client {} for model {}", client_id, model); // Check if model is enabled via in-memory cache (no SQLite hit) let cached_config = state.model_config_cache.get(&model).await; let (model_enabled, model_mapping) = match cached_config { Some(cfg) => (cfg.enabled, cfg.mapping), None => (true, None), }; if !model_enabled { return Err(AppError::ValidationError(format!( "Model {} is currently disabled", model ))); } // Apply mapping if present if let Some(target_model) = model_mapping { info!("Mapping model {} to {}", model, target_model); request.model = target_model; } // Find appropriate provider for the model let provider = state .provider_manager .get_provider_for_model(&request.model) .await .ok_or_else(|| AppError::ProviderError(format!("No provider found for model: {}", request.model)))?; let provider_name = provider.name().to_string(); // Check circuit breaker for this provider rate_limiting::middleware::circuit_breaker_middleware(&provider_name, &state).await?; // Convert to unified request format let mut unified_request = crate::models::UnifiedRequest::try_from(request).map_err(|e| AppError::ValidationError(e.to_string()))?; // Set client_id from authentication unified_request.client_id = client_id.clone(); // Hydrate images if present if unified_request.has_images { unified_request .hydrate_images() .await .map_err(|e| AppError::ValidationError(format!("Failed to process images: {}", e)))?; } let has_images = unified_request.has_images; // Measure proxy overhead (time spent before sending to upstream provider) let proxy_overhead = start_time.elapsed(); // Check if streaming is requested if unified_request.stream { // Estimate prompt tokens for logging later let prompt_tokens = crate::utils::tokens::estimate_request_tokens(&model, &unified_request); // Handle streaming response let stream_result = provider.chat_completion_stream(unified_request).await; match stream_result { Ok(stream) => { // Record provider success state.rate_limit_manager.record_provider_success(&provider_name).await; info!( "Streaming started for {} (proxy overhead: {}ms)", model, proxy_overhead.as_millis() ); // Wrap with AggregatingStream for token counting and database logging let aggregating_stream = crate::utils::streaming::AggregatingStream::new( stream, crate::utils::streaming::StreamConfig { client_id: client_id.clone(), provider: provider.clone(), model: model.clone(), prompt_tokens, has_images, logger: state.request_logger.clone(), client_manager: state.client_manager.clone(), model_registry: state.model_registry.clone(), model_config_cache: state.model_config_cache.clone(), }, ); // Create SSE stream from aggregating stream let stream_id = format!("chatcmpl-{}", Uuid::new_v4()); let stream_created = chrono::Utc::now().timestamp() as u64; let sse_stream = aggregating_stream.map(move |chunk_result| { match chunk_result { Ok(chunk) => { // Convert provider chunk to OpenAI-compatible SSE event let response = ChatCompletionStreamResponse { id: stream_id.clone(), object: "chat.completion.chunk".to_string(), created: stream_created, model: chunk.model.clone(), choices: vec![ChatStreamChoice { index: 0, delta: ChatStreamDelta { role: None, content: Some(chunk.content), reasoning_content: chunk.reasoning_content, tool_calls: chunk.tool_calls, }, finish_reason: chunk.finish_reason, }], }; match Event::default().json_data(response) { Ok(event) => Ok(event), Err(e) => { warn!("Failed to serialize SSE event: {}", e); Err(AppError::InternalError("SSE serialization failed".to_string())) } } } Err(e) => { warn!("Error in streaming response: {}", e); Err(e) } } }); // Many OpenAI-compatible clients expect a terminal [DONE] marker. // Emit it when the upstream stream ends to avoid clients treating // the response as incomplete. <<<<<<< HEAD // Convert to a Vec first, then append [DONE], then stream it ======= >>>>>>> refs/remotes/origin/main let done_event = Ok::(Event::default().data("[DONE]")); let done_stream = futures::stream::iter(vec![done_event]); let out = sse_stream.chain(done_stream); Ok(Sse::new(out).into_response()) } Err(e) => { // Record provider failure state.rate_limit_manager.record_provider_failure(&provider_name).await; // Log failed request let duration = start_time.elapsed(); warn!("Streaming request failed after {:?}: {}", duration, e); Err(e) } } } else { // Handle non-streaming response let result = provider.chat_completion(unified_request).await; match result { Ok(response) => { // Record provider success state.rate_limit_manager.record_provider_success(&provider_name).await; let duration = start_time.elapsed(); let cost = get_model_cost( &response.model, response.prompt_tokens, response.completion_tokens, response.cache_read_tokens, response.cache_write_tokens, &provider, &state, ) .await; // Log request to database state.request_logger.log_request(crate::logging::RequestLog { timestamp: chrono::Utc::now(), client_id: client_id.clone(), provider: provider_name.clone(), model: response.model.clone(), prompt_tokens: response.prompt_tokens, completion_tokens: response.completion_tokens, total_tokens: response.total_tokens, cache_read_tokens: response.cache_read_tokens, cache_write_tokens: response.cache_write_tokens, cost, has_images, status: "success".to_string(), error_message: None, duration_ms: duration.as_millis() as u64, }); // Update client usage (fire-and-forget, don't block response) { let cm = state.client_manager.clone(); let cid = client_id.clone(); tokio::spawn(async move { let _ = cm.update_client_usage(&cid, response.total_tokens as i64, cost).await; }); } // Convert ProviderResponse to ChatCompletionResponse let finish_reason = if response.tool_calls.is_some() { "tool_calls".to_string() } else { "stop".to_string() }; let chat_response = ChatCompletionResponse { id: format!("chatcmpl-{}", Uuid::new_v4()), object: "chat.completion".to_string(), created: chrono::Utc::now().timestamp() as u64, model: response.model, choices: vec![ChatChoice { index: 0, message: ChatMessage { role: "assistant".to_string(), content: crate::models::MessageContent::Text { content: response.content, }, reasoning_content: response.reasoning_content, tool_calls: response.tool_calls, name: None, tool_call_id: None, }, finish_reason: Some(finish_reason), }], usage: Some(Usage { prompt_tokens: response.prompt_tokens, completion_tokens: response.completion_tokens, total_tokens: response.total_tokens, cache_read_tokens: if response.cache_read_tokens > 0 { Some(response.cache_read_tokens) } else { None }, cache_write_tokens: if response.cache_write_tokens > 0 { Some(response.cache_write_tokens) } else { None }, }), }; // Log successful request with proxy overhead breakdown let upstream_ms = duration.as_millis() as u64 - proxy_overhead.as_millis() as u64; info!( "Request completed in {:?} (proxy: {}ms, upstream: {}ms)", duration, proxy_overhead.as_millis(), upstream_ms ); Ok(Json(chat_response).into_response()) } Err(e) => { // Record provider failure state.rate_limit_manager.record_provider_failure(&provider_name).await; // Log failed request to database let duration = start_time.elapsed(); state.request_logger.log_request(crate::logging::RequestLog { timestamp: chrono::Utc::now(), client_id: client_id.clone(), provider: provider_name.clone(), model: model.clone(), prompt_tokens: 0, completion_tokens: 0, total_tokens: 0, cache_read_tokens: 0, cache_write_tokens: 0, cost: 0.0, has_images: false, status: "error".to_string(), error_message: Some(e.to_string()), duration_ms: duration.as_millis() as u64, }); warn!("Request failed after {:?}: {}", duration, e); Err(e) } } } }