335 lines
11 KiB
Rust
335 lines
11 KiB
Rust
use anyhow::Result;
|
|
use async_trait::async_trait;
|
|
use futures::stream::BoxStream;
|
|
use sqlx::Row;
|
|
use std::sync::Arc;
|
|
|
|
use crate::errors::AppError;
|
|
use crate::models::UnifiedRequest;
|
|
|
|
pub mod deepseek;
|
|
pub mod gemini;
|
|
pub mod grok;
|
|
pub mod helpers;
|
|
pub mod ollama;
|
|
pub mod openai;
|
|
|
|
#[async_trait]
|
|
pub trait Provider: Send + Sync {
|
|
/// Get provider name (e.g., "openai", "gemini")
|
|
fn name(&self) -> &str;
|
|
|
|
/// Check if provider supports a specific model
|
|
fn supports_model(&self, model: &str) -> bool;
|
|
|
|
/// Check if provider supports multimodal (images, etc.)
|
|
fn supports_multimodal(&self) -> bool;
|
|
|
|
/// Process a chat completion request
|
|
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError>;
|
|
|
|
/// Process a chat request using provider-specific "responses" style endpoint
|
|
/// Default implementation falls back to `chat_completion` for providers
|
|
/// that do not implement a dedicated responses endpoint.
|
|
async fn chat_responses(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
|
self.chat_completion(request).await
|
|
}
|
|
|
|
/// Process a streaming chat completion request
|
|
async fn chat_completion_stream(
|
|
&self,
|
|
request: UnifiedRequest,
|
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError>;
|
|
|
|
/// Estimate token count for a request (for cost calculation)
|
|
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32>;
|
|
|
|
/// Calculate cost based on token usage and model using the registry.
|
|
/// `cache_read_tokens` / `cache_write_tokens` allow cache-aware pricing
|
|
/// when the registry provides `cache_read` / `cache_write` rates.
|
|
fn calculate_cost(
|
|
&self,
|
|
model: &str,
|
|
prompt_tokens: u32,
|
|
completion_tokens: u32,
|
|
cache_read_tokens: u32,
|
|
cache_write_tokens: u32,
|
|
registry: &crate::models::registry::ModelRegistry,
|
|
) -> f64;
|
|
}
|
|
|
|
pub struct ProviderResponse {
|
|
pub content: String,
|
|
pub reasoning_content: Option<String>,
|
|
pub tool_calls: Option<Vec<crate::models::ToolCall>>,
|
|
pub prompt_tokens: u32,
|
|
pub completion_tokens: u32,
|
|
pub total_tokens: u32,
|
|
pub cache_read_tokens: u32,
|
|
pub cache_write_tokens: u32,
|
|
pub model: String,
|
|
}
|
|
|
|
/// Usage data from the final streaming chunk (when providers report real token counts).
|
|
#[derive(Debug, Clone, Default)]
|
|
pub struct StreamUsage {
|
|
pub prompt_tokens: u32,
|
|
pub completion_tokens: u32,
|
|
pub total_tokens: u32,
|
|
pub cache_read_tokens: u32,
|
|
pub cache_write_tokens: u32,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ProviderStreamChunk {
|
|
pub content: String,
|
|
pub reasoning_content: Option<String>,
|
|
pub finish_reason: Option<String>,
|
|
pub tool_calls: Option<Vec<crate::models::ToolCallDelta>>,
|
|
pub model: String,
|
|
/// Populated only on the final chunk when providers report usage (e.g. stream_options.include_usage).
|
|
pub usage: Option<StreamUsage>,
|
|
}
|
|
|
|
use tokio::sync::RwLock;
|
|
|
|
use crate::config::AppConfig;
|
|
use crate::providers::{
|
|
deepseek::DeepSeekProvider, gemini::GeminiProvider, grok::GrokProvider, ollama::OllamaProvider,
|
|
openai::OpenAIProvider,
|
|
};
|
|
|
|
#[derive(Clone)]
|
|
pub struct ProviderManager {
|
|
providers: Arc<RwLock<Vec<Arc<dyn Provider>>>>,
|
|
}
|
|
|
|
impl Default for ProviderManager {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl ProviderManager {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
providers: Arc::new(RwLock::new(Vec::new())),
|
|
}
|
|
}
|
|
|
|
/// Initialize a provider by name using config and database overrides
|
|
pub async fn initialize_provider(
|
|
&self,
|
|
name: &str,
|
|
app_config: &AppConfig,
|
|
db_pool: &crate::database::DbPool,
|
|
) -> Result<()> {
|
|
// Load override from database
|
|
let db_config = sqlx::query("SELECT enabled, base_url, api_key FROM provider_configs WHERE id = ?")
|
|
.bind(name)
|
|
.fetch_optional(db_pool)
|
|
.await?;
|
|
|
|
let (enabled, base_url, api_key) = if let Some(row) = db_config {
|
|
(
|
|
row.get::<bool, _>("enabled"),
|
|
row.get::<Option<String>, _>("base_url"),
|
|
row.get::<Option<String>, _>("api_key"),
|
|
)
|
|
} else {
|
|
// No database override, use defaults from AppConfig
|
|
match name {
|
|
"openai" => (
|
|
app_config.providers.openai.enabled,
|
|
Some(app_config.providers.openai.base_url.clone()),
|
|
None,
|
|
),
|
|
"gemini" => (
|
|
app_config.providers.gemini.enabled,
|
|
Some(app_config.providers.gemini.base_url.clone()),
|
|
None,
|
|
),
|
|
"deepseek" => (
|
|
app_config.providers.deepseek.enabled,
|
|
Some(app_config.providers.deepseek.base_url.clone()),
|
|
None,
|
|
),
|
|
"grok" => (
|
|
app_config.providers.grok.enabled,
|
|
Some(app_config.providers.grok.base_url.clone()),
|
|
None,
|
|
),
|
|
"ollama" => (
|
|
app_config.providers.ollama.enabled,
|
|
Some(app_config.providers.ollama.base_url.clone()),
|
|
None,
|
|
),
|
|
_ => (false, None, None),
|
|
}
|
|
};
|
|
|
|
if !enabled {
|
|
self.remove_provider(name).await;
|
|
return Ok(());
|
|
}
|
|
|
|
// Create provider instance with merged config
|
|
let provider: Arc<dyn Provider> = match name {
|
|
"openai" => {
|
|
let mut cfg = app_config.providers.openai.clone();
|
|
if let Some(url) = base_url {
|
|
cfg.base_url = url;
|
|
}
|
|
// Handle API key override if present
|
|
let p = if let Some(key) = api_key {
|
|
// We need a way to create a provider with an explicit key
|
|
// Let's modify the providers to allow this
|
|
OpenAIProvider::new_with_key(&cfg, app_config, key)?
|
|
} else {
|
|
OpenAIProvider::new(&cfg, app_config)?
|
|
};
|
|
Arc::new(p)
|
|
}
|
|
"ollama" => {
|
|
let mut cfg = app_config.providers.ollama.clone();
|
|
if let Some(url) = base_url {
|
|
cfg.base_url = url;
|
|
}
|
|
Arc::new(OllamaProvider::new(&cfg, app_config)?)
|
|
}
|
|
"gemini" => {
|
|
let mut cfg = app_config.providers.gemini.clone();
|
|
if let Some(url) = base_url {
|
|
cfg.base_url = url;
|
|
}
|
|
let p = if let Some(key) = api_key {
|
|
GeminiProvider::new_with_key(&cfg, app_config, key)?
|
|
} else {
|
|
GeminiProvider::new(&cfg, app_config)?
|
|
};
|
|
Arc::new(p)
|
|
}
|
|
"deepseek" => {
|
|
let mut cfg = app_config.providers.deepseek.clone();
|
|
if let Some(url) = base_url {
|
|
cfg.base_url = url;
|
|
}
|
|
let p = if let Some(key) = api_key {
|
|
DeepSeekProvider::new_with_key(&cfg, app_config, key)?
|
|
} else {
|
|
DeepSeekProvider::new(&cfg, app_config)?
|
|
};
|
|
Arc::new(p)
|
|
}
|
|
"grok" => {
|
|
let mut cfg = app_config.providers.grok.clone();
|
|
if let Some(url) = base_url {
|
|
cfg.base_url = url;
|
|
}
|
|
let p = if let Some(key) = api_key {
|
|
GrokProvider::new_with_key(&cfg, app_config, key)?
|
|
} else {
|
|
GrokProvider::new(&cfg, app_config)?
|
|
};
|
|
Arc::new(p)
|
|
}
|
|
_ => return Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
|
};
|
|
|
|
self.add_provider(provider).await;
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn add_provider(&self, provider: Arc<dyn Provider>) {
|
|
let mut providers = self.providers.write().await;
|
|
// If provider with same name exists, replace it
|
|
if let Some(index) = providers.iter().position(|p| p.name() == provider.name()) {
|
|
providers[index] = provider;
|
|
} else {
|
|
providers.push(provider);
|
|
}
|
|
}
|
|
|
|
pub async fn remove_provider(&self, name: &str) {
|
|
let mut providers = self.providers.write().await;
|
|
providers.retain(|p| p.name() != name);
|
|
}
|
|
|
|
pub async fn get_provider_for_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
|
|
let providers = self.providers.read().await;
|
|
providers.iter().find(|p| p.supports_model(model)).map(Arc::clone)
|
|
}
|
|
|
|
pub async fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
|
|
let providers = self.providers.read().await;
|
|
providers.iter().find(|p| p.name() == name).map(Arc::clone)
|
|
}
|
|
|
|
pub async fn get_all_providers(&self) -> Vec<Arc<dyn Provider>> {
|
|
let providers = self.providers.read().await;
|
|
providers.clone()
|
|
}
|
|
}
|
|
|
|
// Create placeholder provider implementations
|
|
pub mod placeholder {
|
|
use super::*;
|
|
|
|
pub struct PlaceholderProvider {
|
|
name: String,
|
|
}
|
|
|
|
impl PlaceholderProvider {
|
|
pub fn new(name: &str) -> Self {
|
|
Self { name: name.to_string() }
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Provider for PlaceholderProvider {
|
|
fn name(&self) -> &str {
|
|
&self.name
|
|
}
|
|
|
|
fn supports_model(&self, _model: &str) -> bool {
|
|
false
|
|
}
|
|
|
|
fn supports_multimodal(&self) -> bool {
|
|
false
|
|
}
|
|
|
|
async fn chat_completion_stream(
|
|
&self,
|
|
_request: UnifiedRequest,
|
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
|
Err(AppError::ProviderError(
|
|
"Streaming not supported for placeholder provider".to_string(),
|
|
))
|
|
}
|
|
|
|
async fn chat_completion(&self, _request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
|
Err(AppError::ProviderError(format!(
|
|
"Provider {} not implemented",
|
|
self.name
|
|
)))
|
|
}
|
|
|
|
fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result<u32> {
|
|
Ok(0)
|
|
}
|
|
|
|
fn calculate_cost(
|
|
&self,
|
|
_model: &str,
|
|
_prompt_tokens: u32,
|
|
_completion_tokens: u32,
|
|
_cache_read_tokens: u32,
|
|
_cache_write_tokens: u32,
|
|
_registry: &crate::models::registry::ModelRegistry,
|
|
) -> f64 {
|
|
0.0
|
|
}
|
|
}
|
|
}
|