use anyhow::Result; use async_trait::async_trait; use futures::stream::BoxStream; use sqlx::Row; use std::sync::Arc; use crate::errors::AppError; use crate::models::UnifiedRequest; pub mod deepseek; pub mod gemini; pub mod grok; pub mod helpers; pub mod ollama; pub mod openai; #[async_trait] pub trait Provider: Send + Sync { /// Get provider name (e.g., "openai", "gemini") fn name(&self) -> &str; /// Check if provider supports a specific model fn supports_model(&self, model: &str) -> bool; /// Check if provider supports multimodal (images, etc.) fn supports_multimodal(&self) -> bool; /// Process a chat completion request async fn chat_completion(&self, request: UnifiedRequest) -> Result; /// Process a chat request using provider-specific "responses" style endpoint /// Default implementation falls back to `chat_completion` for providers /// that do not implement a dedicated responses endpoint. async fn chat_responses(&self, request: UnifiedRequest) -> Result { self.chat_completion(request).await } /// Process a streaming chat completion request async fn chat_completion_stream( &self, request: UnifiedRequest, ) -> Result>, AppError>; /// Process a streaming chat request using provider-specific "responses" style endpoint /// Default implementation falls back to `chat_completion_stream` for providers /// that do not implement a dedicated responses endpoint. async fn chat_responses_stream( &self, request: UnifiedRequest, ) -> Result>, AppError> { self.chat_completion_stream(request).await } /// Estimate token count for a request (for cost calculation) fn estimate_tokens(&self, request: &UnifiedRequest) -> Result; /// Calculate cost based on token usage and model using the registry. /// `cache_read_tokens` / `cache_write_tokens` allow cache-aware pricing /// when the registry provides `cache_read` / `cache_write` rates. fn calculate_cost( &self, model: &str, prompt_tokens: u32, completion_tokens: u32, cache_read_tokens: u32, cache_write_tokens: u32, registry: &crate::models::registry::ModelRegistry, ) -> f64; } pub struct ProviderResponse { pub content: String, pub reasoning_content: Option, pub tool_calls: Option>, pub prompt_tokens: u32, pub completion_tokens: u32, pub reasoning_tokens: u32, pub total_tokens: u32, pub cache_read_tokens: u32, pub cache_write_tokens: u32, pub model: String, } /// Usage data from the final streaming chunk (when providers report real token counts). #[derive(Debug, Clone, Default)] pub struct StreamUsage { pub prompt_tokens: u32, pub completion_tokens: u32, pub reasoning_tokens: u32, pub total_tokens: u32, pub cache_read_tokens: u32, pub cache_write_tokens: u32, } #[derive(Debug, Clone)] pub struct ProviderStreamChunk { pub content: String, pub reasoning_content: Option, pub finish_reason: Option, pub tool_calls: Option>, pub model: String, /// Populated only on the final chunk when providers report usage (e.g. stream_options.include_usage). pub usage: Option, } use tokio::sync::RwLock; use crate::config::AppConfig; use crate::providers::{ deepseek::DeepSeekProvider, gemini::GeminiProvider, grok::GrokProvider, ollama::OllamaProvider, openai::OpenAIProvider, }; #[derive(Clone)] pub struct ProviderManager { providers: Arc>>>, } impl Default for ProviderManager { fn default() -> Self { Self::new() } } impl ProviderManager { pub fn new() -> Self { Self { providers: Arc::new(RwLock::new(Vec::new())), } } /// Initialize a provider by name using config and database overrides pub async fn initialize_provider( &self, name: &str, app_config: &AppConfig, db_pool: &crate::database::DbPool, ) -> Result<()> { // Load override from database let db_config = sqlx::query("SELECT enabled, base_url, api_key, api_key_encrypted FROM provider_configs WHERE id = ?") .bind(name) .fetch_optional(db_pool) .await?; let (enabled, base_url, api_key) = if let Some(row) = db_config { let enabled = row.get::("enabled"); let base_url = row.get::, _>("base_url"); let api_key_encrypted = row.get::("api_key_encrypted"); let api_key = row.get::, _>("api_key"); // Decrypt API key if encrypted let api_key = match (api_key, api_key_encrypted) { (Some(key), true) => { match crate::utils::crypto::decrypt(&key) { Ok(decrypted) => Some(decrypted), Err(e) => { tracing::error!("Failed to decrypt API key for provider {}: {}", name, e); None } } } (Some(key), false) => { // Plaintext key - optionally encrypt and update database (lazy migration) // For now, just use plaintext Some(key) } (None, _) => None, }; (enabled, base_url, api_key) } else { // No database override, use defaults from AppConfig match name { "openai" => ( app_config.providers.openai.enabled, Some(app_config.providers.openai.base_url.clone()), None, ), "gemini" => ( app_config.providers.gemini.enabled, Some(app_config.providers.gemini.base_url.clone()), None, ), "deepseek" => ( app_config.providers.deepseek.enabled, Some(app_config.providers.deepseek.base_url.clone()), None, ), "grok" => ( app_config.providers.grok.enabled, Some(app_config.providers.grok.base_url.clone()), None, ), "ollama" => ( app_config.providers.ollama.enabled, Some(app_config.providers.ollama.base_url.clone()), None, ), _ => (false, None, None), } }; if !enabled { self.remove_provider(name).await; return Ok(()); } // Create provider instance with merged config let provider: Arc = match name { "openai" => { let mut cfg = app_config.providers.openai.clone(); if let Some(url) = base_url { cfg.base_url = url; } // Handle API key override if present let p = if let Some(key) = api_key { // We need a way to create a provider with an explicit key // Let's modify the providers to allow this OpenAIProvider::new_with_key(&cfg, app_config, key)? } else { OpenAIProvider::new(&cfg, app_config)? }; Arc::new(p) } "ollama" => { let mut cfg = app_config.providers.ollama.clone(); if let Some(url) = base_url { cfg.base_url = url; } Arc::new(OllamaProvider::new(&cfg, app_config)?) } "gemini" => { let mut cfg = app_config.providers.gemini.clone(); if let Some(url) = base_url { cfg.base_url = url; } let p = if let Some(key) = api_key { GeminiProvider::new_with_key(&cfg, app_config, key)? } else { GeminiProvider::new(&cfg, app_config)? }; Arc::new(p) } "deepseek" => { let mut cfg = app_config.providers.deepseek.clone(); if let Some(url) = base_url { cfg.base_url = url; } let p = if let Some(key) = api_key { DeepSeekProvider::new_with_key(&cfg, app_config, key)? } else { DeepSeekProvider::new(&cfg, app_config)? }; Arc::new(p) } "grok" => { let mut cfg = app_config.providers.grok.clone(); if let Some(url) = base_url { cfg.base_url = url; } let p = if let Some(key) = api_key { GrokProvider::new_with_key(&cfg, app_config, key)? } else { GrokProvider::new(&cfg, app_config)? }; Arc::new(p) } _ => return Err(anyhow::anyhow!("Unknown provider: {}", name)), }; self.add_provider(provider).await; Ok(()) } pub async fn add_provider(&self, provider: Arc) { let mut providers = self.providers.write().await; // If provider with same name exists, replace it if let Some(index) = providers.iter().position(|p| p.name() == provider.name()) { providers[index] = provider; } else { providers.push(provider); } } pub async fn remove_provider(&self, name: &str) { let mut providers = self.providers.write().await; providers.retain(|p| p.name() != name); } pub async fn get_provider_for_model(&self, model: &str) -> Option> { let providers = self.providers.read().await; providers.iter().find(|p| p.supports_model(model)).map(Arc::clone) } pub async fn get_provider(&self, name: &str) -> Option> { let providers = self.providers.read().await; providers.iter().find(|p| p.name() == name).map(Arc::clone) } pub async fn get_all_providers(&self) -> Vec> { let providers = self.providers.read().await; providers.clone() } } // Create placeholder provider implementations pub mod placeholder { use super::*; pub struct PlaceholderProvider { name: String, } impl PlaceholderProvider { pub fn new(name: &str) -> Self { Self { name: name.to_string() } } } #[async_trait] impl Provider for PlaceholderProvider { fn name(&self) -> &str { &self.name } fn supports_model(&self, _model: &str) -> bool { false } fn supports_multimodal(&self) -> bool { false } async fn chat_completion_stream( &self, _request: UnifiedRequest, ) -> Result>, AppError> { Err(AppError::ProviderError( "Streaming not supported for placeholder provider".to_string(), )) } async fn chat_completion(&self, _request: UnifiedRequest) -> Result { Err(AppError::ProviderError(format!( "Provider {} not implemented", self.name ))) } fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result { Ok(0) } fn calculate_cost( &self, _model: &str, _prompt_tokens: u32, _completion_tokens: u32, _cache_read_tokens: u32, _cache_write_tokens: u32, _registry: &crate::models::registry::ModelRegistry, ) -> f64 { 0.0 } } }