refactor: comprehensive audit — fix bugs, harden security, deduplicate providers, add CI/Docker
Phase 1: Fix compilation (config_path Option<PathBuf>, streaming test, stale test cleanup) Phase 2: Fix critical bugs (remove block_on deadlocks in 4 providers, fix broken SQL query builder) Phase 3: Security hardening (session manager, real auth, token masking, Gemini key to header, password policy) Phase 4: Implement stubs (real provider test, /proc health metrics, client/provider/backup endpoints, has_images) Phase 5: Code quality (shared provider helpers, explicit re-exports, all Clippy warnings fixed, unwrap removal, 6 unused deps removed, dashboard split into 7 sub-modules) Phase 6: Infrastructure (GitHub Actions CI, multi-stage Dockerfile, rustfmt.toml, clippy.toml, script fixes)
This commit is contained in:
@@ -1,17 +1,18 @@
|
||||
use async_trait::async_trait;
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use sqlx::Row;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::models::UnifiedRequest;
|
||||
use crate::errors::AppError;
|
||||
use crate::models::UnifiedRequest;
|
||||
|
||||
pub mod openai;
|
||||
pub mod gemini;
|
||||
pub mod deepseek;
|
||||
pub mod gemini;
|
||||
pub mod grok;
|
||||
pub mod helpers;
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
@@ -25,10 +26,7 @@ pub trait Provider: Send + Sync {
|
||||
fn supports_multimodal(&self) -> bool;
|
||||
|
||||
/// Process a chat completion request
|
||||
async fn chat_completion(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<ProviderResponse, AppError>;
|
||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError>;
|
||||
|
||||
/// Process a streaming chat completion request
|
||||
async fn chat_completion_stream(
|
||||
@@ -40,7 +38,13 @@ pub trait Provider: Send + Sync {
|
||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32>;
|
||||
|
||||
/// Calculate cost based on token usage and model using the registry
|
||||
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64;
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64;
|
||||
}
|
||||
|
||||
pub struct ProviderResponse {
|
||||
@@ -64,11 +68,8 @@ use tokio::sync::RwLock;
|
||||
|
||||
use crate::config::AppConfig;
|
||||
use crate::providers::{
|
||||
deepseek::DeepSeekProvider, gemini::GeminiProvider, grok::GrokProvider, ollama::OllamaProvider,
|
||||
openai::OpenAIProvider,
|
||||
gemini::GeminiProvider,
|
||||
deepseek::DeepSeekProvider,
|
||||
grok::GrokProvider,
|
||||
ollama::OllamaProvider,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -76,6 +77,12 @@ pub struct ProviderManager {
|
||||
providers: Arc<RwLock<Vec<Arc<dyn Provider>>>>,
|
||||
}
|
||||
|
||||
impl Default for ProviderManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
@@ -84,7 +91,12 @@ impl ProviderManager {
|
||||
}
|
||||
|
||||
/// Initialize a provider by name using config and database overrides
|
||||
pub async fn initialize_provider(&self, name: &str, app_config: &AppConfig, db_pool: &crate::database::DbPool) -> Result<()> {
|
||||
pub async fn initialize_provider(
|
||||
&self,
|
||||
name: &str,
|
||||
app_config: &AppConfig,
|
||||
db_pool: &crate::database::DbPool,
|
||||
) -> Result<()> {
|
||||
// Load override from database
|
||||
let db_config = sqlx::query("SELECT enabled, base_url, api_key FROM provider_configs WHERE id = ?")
|
||||
.bind(name)
|
||||
@@ -100,11 +112,31 @@ impl ProviderManager {
|
||||
} else {
|
||||
// No database override, use defaults from AppConfig
|
||||
match name {
|
||||
"openai" => (app_config.providers.openai.enabled, Some(app_config.providers.openai.base_url.clone()), None),
|
||||
"gemini" => (app_config.providers.gemini.enabled, Some(app_config.providers.gemini.base_url.clone()), None),
|
||||
"deepseek" => (app_config.providers.deepseek.enabled, Some(app_config.providers.deepseek.base_url.clone()), None),
|
||||
"grok" => (app_config.providers.grok.enabled, Some(app_config.providers.grok.base_url.clone()), None),
|
||||
"ollama" => (app_config.providers.ollama.enabled, Some(app_config.providers.ollama.base_url.clone()), None),
|
||||
"openai" => (
|
||||
app_config.providers.openai.enabled,
|
||||
Some(app_config.providers.openai.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
"gemini" => (
|
||||
app_config.providers.gemini.enabled,
|
||||
Some(app_config.providers.gemini.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
"deepseek" => (
|
||||
app_config.providers.deepseek.enabled,
|
||||
Some(app_config.providers.deepseek.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
"grok" => (
|
||||
app_config.providers.grok.enabled,
|
||||
Some(app_config.providers.grok.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
"ollama" => (
|
||||
app_config.providers.ollama.enabled,
|
||||
Some(app_config.providers.ollama.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
_ => (false, None, None),
|
||||
}
|
||||
};
|
||||
@@ -118,7 +150,9 @@ impl ProviderManager {
|
||||
let provider: Arc<dyn Provider> = match name {
|
||||
"openai" => {
|
||||
let mut cfg = app_config.providers.openai.clone();
|
||||
if let Some(url) = base_url { cfg.base_url = url; }
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
// Handle API key override if present
|
||||
let p = if let Some(key) = api_key {
|
||||
// We need a way to create a provider with an explicit key
|
||||
@@ -128,42 +162,50 @@ impl ProviderManager {
|
||||
OpenAIProvider::new(&cfg, app_config)?
|
||||
};
|
||||
Arc::new(p)
|
||||
},
|
||||
}
|
||||
"ollama" => {
|
||||
let mut cfg = app_config.providers.ollama.clone();
|
||||
if let Some(url) = base_url { cfg.base_url = url; }
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
Arc::new(OllamaProvider::new(&cfg, app_config)?)
|
||||
},
|
||||
}
|
||||
"gemini" => {
|
||||
let mut cfg = app_config.providers.gemini.clone();
|
||||
if let Some(url) = base_url { cfg.base_url = url; }
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
let p = if let Some(key) = api_key {
|
||||
GeminiProvider::new_with_key(&cfg, app_config, key)?
|
||||
} else {
|
||||
GeminiProvider::new(&cfg, app_config)?
|
||||
};
|
||||
Arc::new(p)
|
||||
},
|
||||
}
|
||||
"deepseek" => {
|
||||
let mut cfg = app_config.providers.deepseek.clone();
|
||||
if let Some(url) = base_url { cfg.base_url = url; }
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
let p = if let Some(key) = api_key {
|
||||
DeepSeekProvider::new_with_key(&cfg, app_config, key)?
|
||||
} else {
|
||||
DeepSeekProvider::new(&cfg, app_config)?
|
||||
};
|
||||
Arc::new(p)
|
||||
},
|
||||
}
|
||||
"grok" => {
|
||||
let mut cfg = app_config.providers.grok.clone();
|
||||
if let Some(url) = base_url { cfg.base_url = url; }
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
let p = if let Some(key) = api_key {
|
||||
GrokProvider::new_with_key(&cfg, app_config, key)?
|
||||
} else {
|
||||
GrokProvider::new(&cfg, app_config)?
|
||||
};
|
||||
Arc::new(p)
|
||||
},
|
||||
}
|
||||
_ => return Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
||||
};
|
||||
|
||||
@@ -188,16 +230,12 @@ impl ProviderManager {
|
||||
|
||||
pub async fn get_provider_for_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
|
||||
let providers = self.providers.read().await;
|
||||
providers.iter()
|
||||
.find(|p| p.supports_model(model))
|
||||
.map(|p| Arc::clone(p))
|
||||
providers.iter().find(|p| p.supports_model(model)).map(Arc::clone)
|
||||
}
|
||||
|
||||
pub async fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
|
||||
let providers = self.providers.read().await;
|
||||
providers.iter()
|
||||
.find(|p| p.name() == name)
|
||||
.map(|p| Arc::clone(p))
|
||||
providers.iter().find(|p| p.name() == name).map(Arc::clone)
|
||||
}
|
||||
|
||||
pub async fn get_all_providers(&self) -> Vec<Arc<dyn Provider>> {
|
||||
@@ -238,22 +276,30 @@ pub mod placeholder {
|
||||
&self,
|
||||
_request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
Err(AppError::ProviderError("Streaming not supported for placeholder provider".to_string()))
|
||||
Err(AppError::ProviderError(
|
||||
"Streaming not supported for placeholder provider".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn chat_completion(
|
||||
&self,
|
||||
_request: UnifiedRequest,
|
||||
) -> Result<ProviderResponse, AppError> {
|
||||
Err(AppError::ProviderError(format!("Provider {} not implemented", self.name)))
|
||||
async fn chat_completion(&self, _request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
Err(AppError::ProviderError(format!(
|
||||
"Provider {} not implemented",
|
||||
self.name
|
||||
)))
|
||||
}
|
||||
|
||||
fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result<u32> {
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
fn calculate_cost(&self, _model: &str, _prompt_tokens: u32, _completion_tokens: u32, _registry: &crate::models::registry::ModelRegistry) -> f64 {
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
_model: &str,
|
||||
_prompt_tokens: u32,
|
||||
_completion_tokens: u32,
|
||||
_registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user