Files
GopherGate/src/providers/mod.rs
hobokenchicken 5a8510bf1e
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled
feat(providers): model-registry routing + Responses API support and streaming fallbacks for OpenAI/Gemini
2026-03-04 13:36:03 -05:00

335 lines
11 KiB
Rust

use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use sqlx::Row;
use std::sync::Arc;
use crate::errors::AppError;
use crate::models::UnifiedRequest;
pub mod deepseek;
pub mod gemini;
pub mod grok;
pub mod helpers;
pub mod ollama;
pub mod openai;
#[async_trait]
pub trait Provider: Send + Sync {
/// Get provider name (e.g., "openai", "gemini")
fn name(&self) -> &str;
/// Check if provider supports a specific model
fn supports_model(&self, model: &str) -> bool;
/// Check if provider supports multimodal (images, etc.)
fn supports_multimodal(&self) -> bool;
/// Process a chat completion request
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError>;
/// Process a chat request using provider-specific "responses" style endpoint
/// Default implementation falls back to `chat_completion` for providers
/// that do not implement a dedicated responses endpoint.
async fn chat_responses(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
self.chat_completion(request).await
}
/// Process a streaming chat completion request
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError>;
/// Estimate token count for a request (for cost calculation)
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32>;
/// Calculate cost based on token usage and model using the registry.
/// `cache_read_tokens` / `cache_write_tokens` allow cache-aware pricing
/// when the registry provides `cache_read` / `cache_write` rates.
fn calculate_cost(
&self,
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
cache_read_tokens: u32,
cache_write_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
) -> f64;
}
pub struct ProviderResponse {
pub content: String,
pub reasoning_content: Option<String>,
pub tool_calls: Option<Vec<crate::models::ToolCall>>,
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub cache_read_tokens: u32,
pub cache_write_tokens: u32,
pub model: String,
}
/// Usage data from the final streaming chunk (when providers report real token counts).
#[derive(Debug, Clone, Default)]
pub struct StreamUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub cache_read_tokens: u32,
pub cache_write_tokens: u32,
}
#[derive(Debug, Clone)]
pub struct ProviderStreamChunk {
pub content: String,
pub reasoning_content: Option<String>,
pub finish_reason: Option<String>,
pub tool_calls: Option<Vec<crate::models::ToolCallDelta>>,
pub model: String,
/// Populated only on the final chunk when providers report usage (e.g. stream_options.include_usage).
pub usage: Option<StreamUsage>,
}
use tokio::sync::RwLock;
use crate::config::AppConfig;
use crate::providers::{
deepseek::DeepSeekProvider, gemini::GeminiProvider, grok::GrokProvider, ollama::OllamaProvider,
openai::OpenAIProvider,
};
#[derive(Clone)]
pub struct ProviderManager {
providers: Arc<RwLock<Vec<Arc<dyn Provider>>>>,
}
impl Default for ProviderManager {
fn default() -> Self {
Self::new()
}
}
impl ProviderManager {
pub fn new() -> Self {
Self {
providers: Arc::new(RwLock::new(Vec::new())),
}
}
/// Initialize a provider by name using config and database overrides
pub async fn initialize_provider(
&self,
name: &str,
app_config: &AppConfig,
db_pool: &crate::database::DbPool,
) -> Result<()> {
// Load override from database
let db_config = sqlx::query("SELECT enabled, base_url, api_key FROM provider_configs WHERE id = ?")
.bind(name)
.fetch_optional(db_pool)
.await?;
let (enabled, base_url, api_key) = if let Some(row) = db_config {
(
row.get::<bool, _>("enabled"),
row.get::<Option<String>, _>("base_url"),
row.get::<Option<String>, _>("api_key"),
)
} else {
// No database override, use defaults from AppConfig
match name {
"openai" => (
app_config.providers.openai.enabled,
Some(app_config.providers.openai.base_url.clone()),
None,
),
"gemini" => (
app_config.providers.gemini.enabled,
Some(app_config.providers.gemini.base_url.clone()),
None,
),
"deepseek" => (
app_config.providers.deepseek.enabled,
Some(app_config.providers.deepseek.base_url.clone()),
None,
),
"grok" => (
app_config.providers.grok.enabled,
Some(app_config.providers.grok.base_url.clone()),
None,
),
"ollama" => (
app_config.providers.ollama.enabled,
Some(app_config.providers.ollama.base_url.clone()),
None,
),
_ => (false, None, None),
}
};
if !enabled {
self.remove_provider(name).await;
return Ok(());
}
// Create provider instance with merged config
let provider: Arc<dyn Provider> = match name {
"openai" => {
let mut cfg = app_config.providers.openai.clone();
if let Some(url) = base_url {
cfg.base_url = url;
}
// Handle API key override if present
let p = if let Some(key) = api_key {
// We need a way to create a provider with an explicit key
// Let's modify the providers to allow this
OpenAIProvider::new_with_key(&cfg, app_config, key)?
} else {
OpenAIProvider::new(&cfg, app_config)?
};
Arc::new(p)
}
"ollama" => {
let mut cfg = app_config.providers.ollama.clone();
if let Some(url) = base_url {
cfg.base_url = url;
}
Arc::new(OllamaProvider::new(&cfg, app_config)?)
}
"gemini" => {
let mut cfg = app_config.providers.gemini.clone();
if let Some(url) = base_url {
cfg.base_url = url;
}
let p = if let Some(key) = api_key {
GeminiProvider::new_with_key(&cfg, app_config, key)?
} else {
GeminiProvider::new(&cfg, app_config)?
};
Arc::new(p)
}
"deepseek" => {
let mut cfg = app_config.providers.deepseek.clone();
if let Some(url) = base_url {
cfg.base_url = url;
}
let p = if let Some(key) = api_key {
DeepSeekProvider::new_with_key(&cfg, app_config, key)?
} else {
DeepSeekProvider::new(&cfg, app_config)?
};
Arc::new(p)
}
"grok" => {
let mut cfg = app_config.providers.grok.clone();
if let Some(url) = base_url {
cfg.base_url = url;
}
let p = if let Some(key) = api_key {
GrokProvider::new_with_key(&cfg, app_config, key)?
} else {
GrokProvider::new(&cfg, app_config)?
};
Arc::new(p)
}
_ => return Err(anyhow::anyhow!("Unknown provider: {}", name)),
};
self.add_provider(provider).await;
Ok(())
}
pub async fn add_provider(&self, provider: Arc<dyn Provider>) {
let mut providers = self.providers.write().await;
// If provider with same name exists, replace it
if let Some(index) = providers.iter().position(|p| p.name() == provider.name()) {
providers[index] = provider;
} else {
providers.push(provider);
}
}
pub async fn remove_provider(&self, name: &str) {
let mut providers = self.providers.write().await;
providers.retain(|p| p.name() != name);
}
pub async fn get_provider_for_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
let providers = self.providers.read().await;
providers.iter().find(|p| p.supports_model(model)).map(Arc::clone)
}
pub async fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
let providers = self.providers.read().await;
providers.iter().find(|p| p.name() == name).map(Arc::clone)
}
pub async fn get_all_providers(&self) -> Vec<Arc<dyn Provider>> {
let providers = self.providers.read().await;
providers.clone()
}
}
// Create placeholder provider implementations
pub mod placeholder {
use super::*;
pub struct PlaceholderProvider {
name: String,
}
impl PlaceholderProvider {
pub fn new(name: &str) -> Self {
Self { name: name.to_string() }
}
}
#[async_trait]
impl Provider for PlaceholderProvider {
fn name(&self) -> &str {
&self.name
}
fn supports_model(&self, _model: &str) -> bool {
false
}
fn supports_multimodal(&self) -> bool {
false
}
async fn chat_completion_stream(
&self,
_request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
Err(AppError::ProviderError(
"Streaming not supported for placeholder provider".to_string(),
))
}
async fn chat_completion(&self, _request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
Err(AppError::ProviderError(format!(
"Provider {} not implemented",
self.name
)))
}
fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result<u32> {
Ok(0)
}
fn calculate_cost(
&self,
_model: &str,
_prompt_tokens: u32,
_completion_tokens: u32,
_cache_read_tokens: u32,
_cache_write_tokens: u32,
_registry: &crate::models::registry::ModelRegistry,
) -> f64 {
0.0
}
}
}