use axum::{ Json, Router, extract::State, response::IntoResponse, response::sse::{Event, Sse}, routing::{get, post}, }; use axum::http::{header, HeaderValue}; use tower_http::{ limit::RequestBodyLimitLayer, set_header::SetResponseHeaderLayer, }; use futures::StreamExt; use std::sync::Arc; use uuid::Uuid; use tracing::{info, warn}; use crate::{ auth::AuthenticatedClient, errors::AppError, models::{ ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage, ChatStreamChoice, ChatStreamDelta, Usage, }, rate_limiting, state::AppState, }; pub fn router(state: AppState) -> Router { // Security headers let csp_header: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( header::CONTENT_SECURITY_POLICY, "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws:;" .parse() .unwrap(), ); let x_frame_options: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( header::X_FRAME_OPTIONS, "DENY".parse().unwrap(), ); let x_content_type_options: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( header::X_CONTENT_TYPE_OPTIONS, "nosniff".parse().unwrap(), ); let strict_transport_security: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( header::STRICT_TRANSPORT_SECURITY, "max-age=31536000; includeSubDomains".parse().unwrap(), ); Router::new() .route("/v1/chat/completions", post(chat_completions)) .route("/v1/models", get(list_models)) .layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit .layer(csp_header) .layer(x_frame_options) .layer(x_content_type_options) .layer(strict_transport_security) .layer(axum::middleware::from_fn_with_state( state.clone(), rate_limiting::middleware::rate_limit_middleware, )) .with_state(state) } /// GET /v1/models — OpenAI-compatible model listing. /// Returns all models from enabled providers so clients like Open WebUI can /// discover which models are available through the proxy. async fn list_models( State(state): State, _auth: AuthenticatedClient, ) -> Result, AppError> { let registry = &state.model_registry; let providers = state.provider_manager.get_all_providers().await; let mut models = Vec::new(); for provider in &providers { let provider_name = provider.name(); // Map internal provider names to registry provider IDs let registry_key = match provider_name { "gemini" => "google", "grok" => "xai", _ => provider_name, }; // Find this provider's models in the registry if let Some(provider_info) = registry.providers.get(registry_key) { for (model_id, meta) in &provider_info.models { // Skip disabled models via the config cache if let Some(cfg) = state.model_config_cache.get(model_id).await { if !cfg.enabled { continue; } } models.push(serde_json::json!({ "id": model_id, "object": "model", "created": 0, "owned_by": provider_name, "name": meta.name, })); } } // For Ollama, models are configured in the TOML, not the registry if provider_name == "ollama" { for model_id in &state.config.providers.ollama.models { models.push(serde_json::json!({ "id": model_id, "object": "model", "created": 0, "owned_by": "ollama", })); } } } Ok(Json(serde_json::json!({ "object": "list", "data": models }))) } async fn get_model_cost( model: &str, prompt_tokens: u32, completion_tokens: u32, cache_read_tokens: u32, cache_write_tokens: u32, provider: &Arc, state: &AppState, ) -> f64 { // Check in-memory cache for cost overrides (no SQLite hit) if let Some(cached) = state.model_config_cache.get(model).await { if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) { // Manual overrides logic: if cache rates are provided, use cache-aware formula. // Formula: (non_cached_prompt * input_rate) + (cache_read * read_rate) + (cache_write * write_rate) + (completion * output_rate) let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens); let mut total = (non_cached_prompt as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0); if let Some(cr) = cached.cache_read_cost_per_m { total += cache_read_tokens as f64 * cr / 1_000_000.0; } else { // No manual cache_read rate — charge cached tokens at full input rate (backwards compatibility) total += cache_read_tokens as f64 * p / 1_000_000.0; } if let Some(cw) = cached.cache_write_cost_per_m { total += cache_write_tokens as f64 * cw / 1_000_000.0; } return total; } } // Fallback to provider's registry-based calculation (cache-aware) provider.calculate_cost(model, prompt_tokens, completion_tokens, cache_read_tokens, cache_write_tokens, &state.model_registry) } async fn chat_completions( State(state): State, auth: AuthenticatedClient, Json(mut request): Json, ) -> Result { let client_id = auth.client_id.clone(); let token = auth.token.clone(); // Verify token if env tokens are configured if !state.auth_tokens.is_empty() && !state.auth_tokens.contains(&token) { // If not in env tokens, check if it was a DB token (client_id wouldn't be client_XXXX prefix) if client_id.starts_with("client_") { return Err(AppError::AuthError("Invalid authentication token".to_string())); } } let start_time = std::time::Instant::now(); let model = request.model.clone(); info!("Chat completion request from client {} for model {}", client_id, model); // Check if model is enabled via in-memory cache (no SQLite hit) let cached_config = state.model_config_cache.get(&model).await; let (model_enabled, model_mapping) = match cached_config { Some(cfg) => (cfg.enabled, cfg.mapping), None => (true, None), }; if !model_enabled { return Err(AppError::ValidationError(format!( "Model {} is currently disabled", model ))); } // Apply mapping if present if let Some(target_model) = model_mapping { info!("Mapping model {} to {}", model, target_model); request.model = target_model; } // Find appropriate provider for the model let provider = state .provider_manager .get_provider_for_model(&request.model) .await .ok_or_else(|| AppError::ProviderError(format!("No provider found for model: {}", request.model)))?; let provider_name = provider.name().to_string(); // Check circuit breaker for this provider rate_limiting::middleware::circuit_breaker_middleware(&provider_name, &state).await?; // Convert to unified request format let mut unified_request = crate::models::UnifiedRequest::try_from(request).map_err(|e| AppError::ValidationError(e.to_string()))?; // Set client_id from authentication unified_request.client_id = client_id.clone(); // Hydrate images if present if unified_request.has_images { unified_request .hydrate_images() .await .map_err(|e| AppError::ValidationError(format!("Failed to process images: {}", e)))?; } let has_images = unified_request.has_images; // Measure proxy overhead (time spent before sending to upstream provider) let proxy_overhead = start_time.elapsed(); // Check if streaming is requested if unified_request.stream { // Estimate prompt tokens for logging later let prompt_tokens = crate::utils::tokens::estimate_request_tokens(&model, &unified_request); // Handle streaming response // Allow provider-specific routing for streaming too let use_responses = provider.name() == "openai" && crate::utils::registry::model_prefers_responses(&state.model_registry, &unified_request.model); let stream_result = if use_responses { provider.chat_responses_stream(unified_request).await } else { provider.chat_completion_stream(unified_request).await }; match stream_result { Ok(stream) => { // Record provider success state.rate_limit_manager.record_provider_success(&provider_name).await; info!( "Streaming started for {} (proxy overhead: {}ms)", model, proxy_overhead.as_millis() ); // Wrap with AggregatingStream for token counting and database logging let aggregating_stream = crate::utils::streaming::AggregatingStream::new( stream, crate::utils::streaming::StreamConfig { client_id: client_id.clone(), provider: provider.clone(), model: model.clone(), prompt_tokens, has_images, logger: state.request_logger.clone(), model_registry: state.model_registry.clone(), model_config_cache: state.model_config_cache.clone(), }, ); // Create SSE stream - simpler approach that works let stream_id = format!("chatcmpl-{}", Uuid::new_v4()); let stream_created = chrono::Utc::now().timestamp() as u64; let stream_id_sse = stream_id.clone(); // Build stream that yields events wrapped in Result let stream = async_stream::stream! { let mut aggregator = Box::pin(aggregating_stream); let mut first_chunk = true; while let Some(chunk_result) = aggregator.next().await { match chunk_result { Ok(chunk) => { let role = if first_chunk { first_chunk = false; Some("assistant".to_string()) } else { None }; let response = ChatCompletionStreamResponse { id: stream_id_sse.clone(), object: "chat.completion.chunk".to_string(), created: stream_created, model: chunk.model.clone(), choices: vec![ChatStreamChoice { index: 0, delta: ChatStreamDelta { role, content: Some(chunk.content), reasoning_content: chunk.reasoning_content, tool_calls: chunk.tool_calls, }, finish_reason: chunk.finish_reason, }], usage: chunk.usage.as_ref().map(|u| crate::models::Usage { prompt_tokens: u.prompt_tokens, completion_tokens: u.completion_tokens, total_tokens: u.total_tokens, reasoning_tokens: if u.reasoning_tokens > 0 { Some(u.reasoning_tokens) } else { None }, cache_read_tokens: if u.cache_read_tokens > 0 { Some(u.cache_read_tokens) } else { None }, cache_write_tokens: if u.cache_write_tokens > 0 { Some(u.cache_write_tokens) } else { None }, }), }; // Use axum's Event directly, wrap in Ok match Event::default().json_data(response) { Ok(event) => yield Ok::<_, crate::errors::AppError>(event), Err(e) => { warn!("Failed to serialize SSE: {}", e); } } } Err(e) => { warn!("Stream error: {}", e); } } } // Yield [DONE] at the end yield Ok::<_, crate::errors::AppError>(Event::default().data("[DONE]")); }; Ok(Sse::new(stream).into_response()) } Err(e) => { // Record provider failure state.rate_limit_manager.record_provider_failure(&provider_name).await; // Log failed request let duration = start_time.elapsed(); warn!("Streaming request failed after {:?}: {}", duration, e); Err(e) } } } else { // Handle non-streaming response // Allow provider-specific routing: for OpenAI, some models prefer the // Responses API (/v1/responses). Use the model registry heuristic to // choose chat_responses vs chat_completion automatically. let use_responses = provider.name() == "openai" && crate::utils::registry::model_prefers_responses(&state.model_registry, &unified_request.model); let result = if use_responses { provider.chat_responses(unified_request).await } else { provider.chat_completion(unified_request).await }; match result { Ok(response) => { // Record provider success state.rate_limit_manager.record_provider_success(&provider_name).await; let duration = start_time.elapsed(); let cost = get_model_cost( &response.model, response.prompt_tokens, response.completion_tokens, response.cache_read_tokens, response.cache_write_tokens, &provider, &state, ) .await; // Log request to database state.request_logger.log_request(crate::logging::RequestLog { timestamp: chrono::Utc::now(), client_id: client_id.clone(), provider: provider_name.clone(), model: response.model.clone(), prompt_tokens: response.prompt_tokens, completion_tokens: response.completion_tokens, reasoning_tokens: response.reasoning_tokens, total_tokens: response.total_tokens, cache_read_tokens: response.cache_read_tokens, cache_write_tokens: response.cache_write_tokens, cost, has_images, status: "success".to_string(), error_message: None, duration_ms: duration.as_millis() as u64, }); // Convert ProviderResponse to ChatCompletionResponse let finish_reason = if response.tool_calls.is_some() { "tool_calls".to_string() } else { "stop".to_string() }; let chat_response = ChatCompletionResponse { id: format!("chatcmpl-{}", Uuid::new_v4()), object: "chat.completion".to_string(), created: chrono::Utc::now().timestamp() as u64, model: response.model, choices: vec![ChatChoice { index: 0, message: ChatMessage { role: "assistant".to_string(), content: crate::models::MessageContent::Text { content: response.content, }, reasoning_content: response.reasoning_content, tool_calls: response.tool_calls, name: None, tool_call_id: None, }, finish_reason: Some(finish_reason), }], usage: Some(Usage { prompt_tokens: response.prompt_tokens, completion_tokens: response.completion_tokens, total_tokens: response.total_tokens, reasoning_tokens: if response.reasoning_tokens > 0 { Some(response.reasoning_tokens) } else { None }, cache_read_tokens: if response.cache_read_tokens > 0 { Some(response.cache_read_tokens) } else { None }, cache_write_tokens: if response.cache_write_tokens > 0 { Some(response.cache_write_tokens) } else { None }, }), }; // Log successful request with proxy overhead breakdown let upstream_ms = duration.as_millis() as u64 - proxy_overhead.as_millis() as u64; info!( "Request completed in {:?} (proxy: {}ms, upstream: {}ms)", duration, proxy_overhead.as_millis(), upstream_ms ); Ok(Json(chat_response).into_response()) } Err(e) => { // Record provider failure state.rate_limit_manager.record_provider_failure(&provider_name).await; // Log failed request to database let duration = start_time.elapsed(); state.request_logger.log_request(crate::logging::RequestLog { timestamp: chrono::Utc::now(), client_id: client_id.clone(), provider: provider_name.clone(), model: model.clone(), prompt_tokens: 0, completion_tokens: 0, reasoning_tokens: 0, total_tokens: 0, cache_read_tokens: 0, cache_write_tokens: 0, cost: 0.0, has_images: false, status: "error".to_string(), error_message: Some(e.to_string()), duration_ms: duration.as_millis() as u64, }); warn!("Request failed after {:?}: {}", duration, e); Err(e) } } } }