refactor: comprehensive audit — fix bugs, harden security, deduplicate providers, add CI/Docker
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled

Phase 1: Fix compilation (config_path Option<PathBuf>, streaming test, stale test cleanup)
Phase 2: Fix critical bugs (remove block_on deadlocks in 4 providers, fix broken SQL query builder)
Phase 3: Security hardening (session manager, real auth, token masking, Gemini key to header, password policy)
Phase 4: Implement stubs (real provider test, /proc health metrics, client/provider/backup endpoints, has_images)
Phase 5: Code quality (shared provider helpers, explicit re-exports, all Clippy warnings fixed, unwrap removal, 6 unused deps removed, dashboard split into 7 sub-modules)
Phase 6: Infrastructure (GitHub Actions CI, multi-stage Dockerfile, rustfmt.toml, clippy.toml, script fixes)
This commit is contained in:
2026-03-02 00:35:45 -05:00
parent ba643dd2b0
commit 2cdc49d7f2
42 changed files with 2800 additions and 2747 deletions

View File

@@ -1,17 +1,18 @@
use async_trait::async_trait;
use anyhow::Result;
use std::sync::Arc;
use async_trait::async_trait;
use futures::stream::BoxStream;
use sqlx::Row;
use std::sync::Arc;
use crate::models::UnifiedRequest;
use crate::errors::AppError;
use crate::models::UnifiedRequest;
pub mod openai;
pub mod gemini;
pub mod deepseek;
pub mod gemini;
pub mod grok;
pub mod helpers;
pub mod ollama;
pub mod openai;
#[async_trait]
pub trait Provider: Send + Sync {
@@ -25,10 +26,7 @@ pub trait Provider: Send + Sync {
fn supports_multimodal(&self) -> bool;
/// Process a chat completion request
async fn chat_completion(
&self,
request: UnifiedRequest,
) -> Result<ProviderResponse, AppError>;
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError>;
/// Process a streaming chat completion request
async fn chat_completion_stream(
@@ -40,7 +38,13 @@ pub trait Provider: Send + Sync {
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32>;
/// Calculate cost based on token usage and model using the registry
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64;
fn calculate_cost(
&self,
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
) -> f64;
}
pub struct ProviderResponse {
@@ -64,11 +68,8 @@ use tokio::sync::RwLock;
use crate::config::AppConfig;
use crate::providers::{
deepseek::DeepSeekProvider, gemini::GeminiProvider, grok::GrokProvider, ollama::OllamaProvider,
openai::OpenAIProvider,
gemini::GeminiProvider,
deepseek::DeepSeekProvider,
grok::GrokProvider,
ollama::OllamaProvider,
};
#[derive(Clone)]
@@ -76,6 +77,12 @@ pub struct ProviderManager {
providers: Arc<RwLock<Vec<Arc<dyn Provider>>>>,
}
impl Default for ProviderManager {
fn default() -> Self {
Self::new()
}
}
impl ProviderManager {
pub fn new() -> Self {
Self {
@@ -84,7 +91,12 @@ impl ProviderManager {
}
/// Initialize a provider by name using config and database overrides
pub async fn initialize_provider(&self, name: &str, app_config: &AppConfig, db_pool: &crate::database::DbPool) -> Result<()> {
pub async fn initialize_provider(
&self,
name: &str,
app_config: &AppConfig,
db_pool: &crate::database::DbPool,
) -> Result<()> {
// Load override from database
let db_config = sqlx::query("SELECT enabled, base_url, api_key FROM provider_configs WHERE id = ?")
.bind(name)
@@ -100,11 +112,31 @@ impl ProviderManager {
} else {
// No database override, use defaults from AppConfig
match name {
"openai" => (app_config.providers.openai.enabled, Some(app_config.providers.openai.base_url.clone()), None),
"gemini" => (app_config.providers.gemini.enabled, Some(app_config.providers.gemini.base_url.clone()), None),
"deepseek" => (app_config.providers.deepseek.enabled, Some(app_config.providers.deepseek.base_url.clone()), None),
"grok" => (app_config.providers.grok.enabled, Some(app_config.providers.grok.base_url.clone()), None),
"ollama" => (app_config.providers.ollama.enabled, Some(app_config.providers.ollama.base_url.clone()), None),
"openai" => (
app_config.providers.openai.enabled,
Some(app_config.providers.openai.base_url.clone()),
None,
),
"gemini" => (
app_config.providers.gemini.enabled,
Some(app_config.providers.gemini.base_url.clone()),
None,
),
"deepseek" => (
app_config.providers.deepseek.enabled,
Some(app_config.providers.deepseek.base_url.clone()),
None,
),
"grok" => (
app_config.providers.grok.enabled,
Some(app_config.providers.grok.base_url.clone()),
None,
),
"ollama" => (
app_config.providers.ollama.enabled,
Some(app_config.providers.ollama.base_url.clone()),
None,
),
_ => (false, None, None),
}
};
@@ -118,7 +150,9 @@ impl ProviderManager {
let provider: Arc<dyn Provider> = match name {
"openai" => {
let mut cfg = app_config.providers.openai.clone();
if let Some(url) = base_url { cfg.base_url = url; }
if let Some(url) = base_url {
cfg.base_url = url;
}
// Handle API key override if present
let p = if let Some(key) = api_key {
// We need a way to create a provider with an explicit key
@@ -128,42 +162,50 @@ impl ProviderManager {
OpenAIProvider::new(&cfg, app_config)?
};
Arc::new(p)
},
}
"ollama" => {
let mut cfg = app_config.providers.ollama.clone();
if let Some(url) = base_url { cfg.base_url = url; }
if let Some(url) = base_url {
cfg.base_url = url;
}
Arc::new(OllamaProvider::new(&cfg, app_config)?)
},
}
"gemini" => {
let mut cfg = app_config.providers.gemini.clone();
if let Some(url) = base_url { cfg.base_url = url; }
if let Some(url) = base_url {
cfg.base_url = url;
}
let p = if let Some(key) = api_key {
GeminiProvider::new_with_key(&cfg, app_config, key)?
} else {
GeminiProvider::new(&cfg, app_config)?
};
Arc::new(p)
},
}
"deepseek" => {
let mut cfg = app_config.providers.deepseek.clone();
if let Some(url) = base_url { cfg.base_url = url; }
if let Some(url) = base_url {
cfg.base_url = url;
}
let p = if let Some(key) = api_key {
DeepSeekProvider::new_with_key(&cfg, app_config, key)?
} else {
DeepSeekProvider::new(&cfg, app_config)?
};
Arc::new(p)
},
}
"grok" => {
let mut cfg = app_config.providers.grok.clone();
if let Some(url) = base_url { cfg.base_url = url; }
if let Some(url) = base_url {
cfg.base_url = url;
}
let p = if let Some(key) = api_key {
GrokProvider::new_with_key(&cfg, app_config, key)?
} else {
GrokProvider::new(&cfg, app_config)?
};
Arc::new(p)
},
}
_ => return Err(anyhow::anyhow!("Unknown provider: {}", name)),
};
@@ -188,16 +230,12 @@ impl ProviderManager {
pub async fn get_provider_for_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
let providers = self.providers.read().await;
providers.iter()
.find(|p| p.supports_model(model))
.map(|p| Arc::clone(p))
providers.iter().find(|p| p.supports_model(model)).map(Arc::clone)
}
pub async fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
let providers = self.providers.read().await;
providers.iter()
.find(|p| p.name() == name)
.map(|p| Arc::clone(p))
providers.iter().find(|p| p.name() == name).map(Arc::clone)
}
pub async fn get_all_providers(&self) -> Vec<Arc<dyn Provider>> {
@@ -238,22 +276,30 @@ pub mod placeholder {
&self,
_request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
Err(AppError::ProviderError("Streaming not supported for placeholder provider".to_string()))
Err(AppError::ProviderError(
"Streaming not supported for placeholder provider".to_string(),
))
}
async fn chat_completion(
&self,
_request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> {
Err(AppError::ProviderError(format!("Provider {} not implemented", self.name)))
async fn chat_completion(&self, _request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
Err(AppError::ProviderError(format!(
"Provider {} not implemented",
self.name
)))
}
fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result<u32> {
Ok(0)
}
fn calculate_cost(&self, _model: &str, _prompt_tokens: u32, _completion_tokens: u32, _registry: &crate::models::registry::ModelRegistry) -> f64 {
fn calculate_cost(
&self,
_model: &str,
_prompt_tokens: u32,
_completion_tokens: u32,
_registry: &crate::models::registry::ModelRegistry,
) -> f64 {
0.0
}
}
}
}