chore: initial clean commit

This commit is contained in:
2026-02-26 13:56:21 -05:00
commit 1755075657
53 changed files with 18068 additions and 0 deletions

46
src/auth/mod.rs Normal file
View File

@@ -0,0 +1,46 @@
use axum::{extract::FromRequestParts, http::request::Parts};
use axum_extra::headers::Authorization;
use axum_extra::TypedHeader;
use headers::authorization::Bearer;
use crate::errors::AppError;
pub struct AuthenticatedClient {
pub token: String,
pub client_id: String,
}
impl<S> FromRequestParts<S> for AuthenticatedClient
where
S: Send + Sync,
{
type Rejection = AppError;
fn from_request_parts(parts: &mut Parts, state: &S) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
async move {
// Extract bearer token from Authorization header
let TypedHeader(Authorization(bearer)) =
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
.await
.map_err(|_| AppError::AuthError("Missing or invalid bearer token".to_string()))?;
let token = bearer.token().to_string();
// In a real implementation, we would:
// 1. Validate token against database or config
// 2. Look up client_id associated with token
// 3. Check token permissions/rate limits
// For now, use token hash as client_id
let client_id = format!("client_{}", &token[..8]);
Ok(AuthenticatedClient { token, client_id })
}
}
}
pub fn validate_token(token: &str, valid_tokens: &[String]) -> bool {
// Simple validation against list of tokens
// In production, use proper token validation (JWT, database lookup, etc.)
valid_tokens.contains(&token.to_string())
}

310
src/client/mod.rs Normal file
View File

@@ -0,0 +1,310 @@
//! Client management for LLM proxy
//!
//! This module handles:
//! 1. Client registration and management
//! 2. Client usage tracking
//! 3. Client rate limit configuration
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{SqlitePool, Row};
use anyhow::Result;
use tracing::{info, warn};
/// Client information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Client {
pub id: i64,
pub client_id: String,
pub name: String,
pub description: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub is_active: bool,
pub rate_limit_per_minute: i64,
pub total_requests: i64,
pub total_tokens: i64,
pub total_cost: f64,
}
/// Client creation request
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateClientRequest {
pub client_id: String,
pub name: String,
pub description: String,
pub rate_limit_per_minute: Option<i64>,
}
/// Client update request
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateClientRequest {
pub name: Option<String>,
pub description: Option<String>,
pub is_active: Option<bool>,
pub rate_limit_per_minute: Option<i64>,
}
/// Client manager for database operations
pub struct ClientManager {
db_pool: SqlitePool,
}
impl ClientManager {
pub fn new(db_pool: SqlitePool) -> Self {
Self { db_pool }
}
/// Create a new client
pub async fn create_client(&self, request: CreateClientRequest) -> Result<Client> {
let rate_limit = request.rate_limit_per_minute.unwrap_or(60);
// First insert the client
sqlx::query(
r#"
INSERT INTO clients (client_id, name, description, rate_limit_per_minute)
VALUES (?, ?, ?, ?)
"#,
)
.bind(&request.client_id)
.bind(&request.name)
.bind(&request.description)
.bind(rate_limit)
.execute(&self.db_pool)
.await?;
// Then fetch the created client
let client = self.get_client(&request.client_id).await?
.ok_or_else(|| anyhow::anyhow!("Failed to retrieve created client"))?;
info!("Created client: {} ({})", client.name, client.client_id);
Ok(client)
}
/// Get a client by ID
pub async fn get_client(&self, client_id: &str) -> Result<Option<Client>> {
let row = sqlx::query(
r#"
SELECT
id, client_id, name, description,
created_at, updated_at, is_active,
rate_limit_per_minute, total_requests, total_tokens, total_cost
FROM clients
WHERE client_id = ?
"#,
)
.bind(client_id)
.fetch_optional(&self.db_pool)
.await?;
if let Some(row) = row {
let client = Client {
id: row.get("id"),
client_id: row.get("client_id"),
name: row.get("name"),
description: row.get("description"),
created_at: row.get("created_at"),
updated_at: row.get("updated_at"),
is_active: row.get("is_active"),
rate_limit_per_minute: row.get("rate_limit_per_minute"),
total_requests: row.get("total_requests"),
total_tokens: row.get("total_tokens"),
total_cost: row.get("total_cost"),
};
Ok(Some(client))
} else {
Ok(None)
}
}
/// Update a client
pub async fn update_client(&self, client_id: &str, request: UpdateClientRequest) -> Result<Option<Client>> {
// First, get the current client to check if it exists
let current_client = self.get_client(client_id).await?;
if current_client.is_none() {
return Ok(None);
}
// Build update query dynamically based on provided fields
let mut updates = Vec::new();
let mut query_builder = sqlx::QueryBuilder::new("UPDATE clients SET ");
let mut has_updates = false;
if let Some(name) = &request.name {
updates.push("name = ");
query_builder.push_bind(name);
has_updates = true;
}
if let Some(description) = &request.description {
if has_updates {
query_builder.push(", ");
}
updates.push("description = ");
query_builder.push_bind(description);
has_updates = true;
}
if let Some(is_active) = request.is_active {
if has_updates {
query_builder.push(", ");
}
updates.push("is_active = ");
query_builder.push_bind(is_active);
has_updates = true;
}
if let Some(rate_limit) = request.rate_limit_per_minute {
if has_updates {
query_builder.push(", ");
}
updates.push("rate_limit_per_minute = ");
query_builder.push_bind(rate_limit);
has_updates = true;
}
// Always update the updated_at timestamp
if has_updates {
query_builder.push(", ");
}
query_builder.push("updated_at = CURRENT_TIMESTAMP");
if !has_updates {
// No updates to make
return self.get_client(client_id).await;
}
query_builder.push(" WHERE client_id = ");
query_builder.push_bind(client_id);
let query = query_builder.build();
query.execute(&self.db_pool).await?;
// Fetch the updated client
let updated_client = self.get_client(client_id).await?;
if updated_client.is_some() {
info!("Updated client: {}", client_id);
}
Ok(updated_client)
}
/// List all clients
pub async fn list_clients(&self, limit: Option<i64>, offset: Option<i64>) -> Result<Vec<Client>> {
let limit = limit.unwrap_or(100);
let offset = offset.unwrap_or(0);
let rows = sqlx::query(
r#"
SELECT
id, client_id, name, description,
created_at, updated_at, is_active,
rate_limit_per_minute, total_requests, total_tokens, total_cost
FROM clients
ORDER BY created_at DESC
LIMIT ? OFFSET ?
"#
)
.bind(limit)
.bind(offset)
.fetch_all(&self.db_pool)
.await?;
let mut clients = Vec::new();
for row in rows {
let client = Client {
id: row.get("id"),
client_id: row.get("client_id"),
name: row.get("name"),
description: row.get("description"),
created_at: row.get("created_at"),
updated_at: row.get("updated_at"),
is_active: row.get("is_active"),
rate_limit_per_minute: row.get("rate_limit_per_minute"),
total_requests: row.get("total_requests"),
total_tokens: row.get("total_tokens"),
total_cost: row.get("total_cost"),
};
clients.push(client);
}
Ok(clients)
}
/// Delete a client
pub async fn delete_client(&self, client_id: &str) -> Result<bool> {
let result = sqlx::query(
"DELETE FROM clients WHERE client_id = ?"
)
.bind(client_id)
.execute(&self.db_pool)
.await?;
let deleted = result.rows_affected() > 0;
if deleted {
info!("Deleted client: {}", client_id);
} else {
warn!("Client not found for deletion: {}", client_id);
}
Ok(deleted)
}
/// Update client usage statistics after a request
pub async fn update_client_usage(
&self,
client_id: &str,
tokens: i64,
cost: f64,
) -> Result<()> {
sqlx::query(
r#"
UPDATE clients
SET
total_requests = total_requests + 1,
total_tokens = total_tokens + ?,
total_cost = total_cost + ?,
updated_at = CURRENT_TIMESTAMP
WHERE client_id = ?
"#
)
.bind(tokens)
.bind(cost)
.bind(client_id)
.execute(&self.db_pool)
.await?;
Ok(())
}
/// Get client usage statistics
pub async fn get_client_usage(&self, client_id: &str) -> Result<Option<(i64, i64, f64)>> {
let row = sqlx::query(
r#"
SELECT total_requests, total_tokens, total_cost
FROM clients
WHERE client_id = ?
"#
)
.bind(client_id)
.fetch_optional(&self.db_pool)
.await?;
if let Some(row) = row {
let total_requests: i64 = row.get("total_requests");
let total_tokens: i64 = row.get("total_tokens");
let total_cost: f64 = row.get("total_cost");
Ok(Some((total_requests, total_tokens, total_cost)))
} else {
Ok(None)
}
}
/// Check if a client exists and is active
pub async fn validate_client(&self, client_id: &str) -> Result<bool> {
let client = self.get_client(client_id).await?;
Ok(client.map(|c| c.is_active).unwrap_or(false))
}
}

191
src/config/mod.rs Normal file
View File

@@ -0,0 +1,191 @@
use anyhow::Result;
use config::{Config, File, FileFormat};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
pub port: u16,
pub host: String,
pub auth_tokens: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
pub path: PathBuf,
pub max_connections: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub openai: OpenAIConfig,
pub gemini: GeminiConfig,
pub deepseek: DeepSeekConfig,
pub grok: GrokConfig,
pub ollama: OllamaConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIConfig {
pub api_key_env: String,
pub base_url: String,
pub default_model: String,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeminiConfig {
pub api_key_env: String,
pub base_url: String,
pub default_model: String,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeepSeekConfig {
pub api_key_env: String,
pub base_url: String,
pub default_model: String,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrokConfig {
pub api_key_env: String,
pub base_url: String,
pub default_model: String,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaConfig {
pub base_url: String,
pub enabled: bool,
pub models: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMappingConfig {
pub patterns: Vec<(String, String)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PricingConfig {
pub openai: Vec<ModelPricing>,
pub gemini: Vec<ModelPricing>,
pub deepseek: Vec<ModelPricing>,
pub grok: Vec<ModelPricing>,
pub ollama: Vec<ModelPricing>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelPricing {
pub model: String,
pub prompt_tokens_per_million: f64,
pub completion_tokens_per_million: f64,
}
#[derive(Debug, Clone)]
pub struct AppConfig {
pub server: ServerConfig,
pub database: DatabaseConfig,
pub providers: ProviderConfig,
pub model_mapping: ModelMappingConfig,
pub pricing: PricingConfig,
pub config_path: PathBuf,
}
impl AppConfig {
pub async fn load() -> Result<Arc<Self>> {
Self::load_from_path(None).await
}
/// Load configuration from a specific path (for testing)
pub async fn load_from_path(config_path: Option<PathBuf>) -> Result<Arc<Self>> {
// Load configuration from multiple sources
let mut config_builder = Config::builder();
// Default configuration
config_builder = config_builder
.set_default("server.port", 8080)?
.set_default("server.host", "0.0.0.0")?
.set_default("server.auth_tokens", Vec::<String>::new())?
.set_default("database.path", "./data/llm_proxy.db")?
.set_default("database.max_connections", 10)?
.set_default("providers.openai.api_key_env", "OPENAI_API_KEY")?
.set_default("providers.openai.base_url", "https://api.openai.com/v1")?
.set_default("providers.openai.default_model", "gpt-4o")?
.set_default("providers.openai.enabled", true)?
.set_default("providers.gemini.api_key_env", "GEMINI_API_KEY")?
.set_default("providers.gemini.base_url", "https://generativelanguage.googleapis.com/v1")?
.set_default("providers.gemini.default_model", "gemini-2.0-flash")?
.set_default("providers.gemini.enabled", true)?
.set_default("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")?
.set_default("providers.deepseek.base_url", "https://api.deepseek.com")?
.set_default("providers.deepseek.default_model", "deepseek-reasoner")?
.set_default("providers.deepseek.enabled", true)?
.set_default("providers.grok.api_key_env", "GROK_API_KEY")?
.set_default("providers.grok.base_url", "https://api.x.ai/v1")?
.set_default("providers.grok.default_model", "grok-beta")?
.set_default("providers.grok.enabled", false)?
.set_default("providers.ollama.base_url", "http://localhost:11434/v1")?
.set_default("providers.ollama.enabled", false)?
.set_default("providers.ollama.models", Vec::<String>::new())?;
// Load from config file if exists
let config_path = config_path.unwrap_or_else(|| std::env::current_dir().unwrap().join("config.toml"));
if config_path.exists() {
config_builder = config_builder.add_source(File::from(config_path.clone()).format(FileFormat::Toml));
}
// Load from .env file
dotenvy::dotenv().ok();
// Load from environment variables (with prefix "LLM_PROXY_")
config_builder = config_builder.add_source(
config::Environment::with_prefix("LLM_PROXY")
.separator("__")
.try_parsing(true),
);
let config = config_builder.build()?;
// Deserialize configuration
let server: ServerConfig = config.get("server")?;
let database: DatabaseConfig = config.get("database")?;
let providers: ProviderConfig = config.get("providers")?;
// For now, use empty model mapping and pricing (will be populated later)
let model_mapping = ModelMappingConfig { patterns: vec![] };
let pricing = PricingConfig {
openai: vec![],
gemini: vec![],
deepseek: vec![],
grok: vec![],
ollama: vec![],
};
Ok(Arc::new(AppConfig {
server,
database,
providers,
model_mapping,
pricing,
config_path,
}))
}
pub fn get_api_key(&self, provider: &str) -> Result<String> {
let env_var = match provider {
"openai" => &self.providers.openai.api_key_env,
"gemini" => &self.providers.gemini.api_key_env,
"deepseek" => &self.providers.deepseek.api_key_env,
"grok" => &self.providers.grok.api_key_env,
_ => return Err(anyhow::anyhow!("Unknown provider: {}", provider)),
};
std::env::var(env_var)
.map_err(|_| anyhow::anyhow!("Environment variable {} not set for {}", env_var, provider))
}
}

642
src/dashboard/mod.rs Normal file
View File

@@ -0,0 +1,642 @@
// Dashboard module for LLM Proxy Gateway
use axum::{
extract::{ws::{Message, WebSocket, WebSocketUpgrade}, State},
response::{IntoResponse, Json},
routing::{get, post},
Router,
};
use serde::Serialize;
use sqlx::Row;
use std::collections::HashMap;
use tracing::{info, warn};
use crate::state::AppState;
// Dashboard state
#[derive(Clone)]
struct DashboardState {
app_state: AppState,
}
// API Response types
#[derive(Serialize)]
struct ApiResponse<T> {
success: bool,
data: Option<T>,
error: Option<String>,
}
impl<T> ApiResponse<T> {
fn success(data: T) -> Self {
Self {
success: true,
data: Some(data),
error: None,
}
}
fn error(error: String) -> Self {
Self {
success: false,
data: None,
error: Some(error),
}
}
}
// ... (keep routes as they are)
// Dashboard routes
pub fn router(state: AppState) -> Router {
let dashboard_state = DashboardState {
app_state: state,
};
Router::new()
// Static file serving
.nest_service("/", tower_http::services::ServeDir::new("static"))
.fallback_service(tower_http::services::ServeDir::new("static"))
// WebSocket endpoint
.route("/ws", get(handle_websocket))
// API endpoints
.route("/api/auth/login", post(handle_login))
.route("/api/auth/status", get(handle_auth_status))
.route("/api/usage/summary", get(handle_usage_summary))
.route("/api/usage/time-series", get(handle_time_series))
.route("/api/usage/clients", get(handle_clients_usage))
.route("/api/usage/providers", get(handle_providers_usage))
.route("/api/clients", get(handle_get_clients).post(handle_create_client))
.route("/api/clients/:id", get(handle_get_client).delete(handle_delete_client))
.route("/api/clients/:id/usage", get(handle_client_usage))
.route("/api/providers", get(handle_get_providers))
.route("/api/providers/:name", get(handle_get_provider).put(handle_update_provider))
.route("/api/providers/:name/test", post(handle_test_provider))
.route("/api/system/health", get(handle_system_health))
.route("/api/system/logs", get(handle_system_logs))
.route("/api/system/backup", post(handle_system_backup))
.with_state(dashboard_state)
}
// WebSocket handler
async fn handle_websocket(
ws: WebSocketUpgrade,
State(state): State<DashboardState>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| handle_websocket_connection(socket, state))
}
async fn handle_websocket_connection(mut socket: WebSocket, state: DashboardState) {
info!("WebSocket connection established");
// Subscribe to events from the global bus
let mut rx = state.app_state.dashboard_tx.subscribe();
// Send initial connection message
let _ = socket.send(Message::Text(
serde_json::json!({
"type": "connected",
"message": "Connected to LLM Proxy Dashboard"
}).to_string().into(),
)).await;
// Handle incoming messages and broadcast events
loop {
tokio::select! {
// Receive broadcast events
Ok(event) = rx.recv() => {
let message = Message::Text(serde_json::to_string(&event).unwrap().into());
if socket.send(message).await.is_err() {
break;
}
}
// Receive WebSocket messages
result = socket.recv() => {
match result {
Some(Ok(Message::Text(text))) => {
handle_websocket_message(&text, &state).await;
}
_ => break,
}
}
}
}
info!("WebSocket connection closed");
}
async fn handle_websocket_message(text: &str, state: &DashboardState) {
// Parse and handle WebSocket messages
if let Ok(data) = serde_json::from_str::<serde_json::Value>(text) {
if let Some("ping") = data.get("type").and_then(|v| v.as_str()) {
let _ = state.app_state.dashboard_tx.send(serde_json::json!({
"event_type": "pong",
"data": {}
}));
}
}
}
// Authentication handlers
async fn handle_login() -> Json<ApiResponse<serde_json::Value>> {
// Simple authentication for demo
// In production, this would validate credentials against a database
Json(ApiResponse::success(serde_json::json!({
"token": "demo-token-123456",
"user": {
"username": "admin",
"name": "Administrator",
"role": "Super Admin"
}
})))
}
async fn handle_auth_status() -> Json<ApiResponse<serde_json::Value>> {
Json(ApiResponse::success(serde_json::json!({
"authenticated": true,
"user": {
"username": "admin",
"name": "Administrator",
"role": "Super Admin"
}
})))
}
// Usage handlers
async fn handle_usage_summary(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
// Total stats
let total_stats = sqlx::query(
r#"
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(total_tokens), 0) as total_tokens,
COALESCE(SUM(cost), 0.0) as total_cost,
COUNT(DISTINCT client_id) as active_clients
FROM llm_requests
"#
)
.fetch_one(pool);
// Today's stats
let today = chrono::Utc::now().format("%Y-%m-%d").to_string();
let today_stats = sqlx::query(
r#"
SELECT
COUNT(*) as today_requests,
COALESCE(SUM(total_tokens), 0) as today_tokens,
COALESCE(SUM(cost), 0.0) as today_cost
FROM llm_requests
WHERE strftime('%Y-%m-%d', timestamp) = ?
"#
)
.bind(today)
.fetch_one(pool);
// Error stats
let error_stats = sqlx::query(
r#"
SELECT
COUNT(*) as total,
SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as errors
FROM llm_requests
"#
)
.fetch_one(pool);
// Average response time
let avg_response = sqlx::query(
r#"
SELECT COALESCE(AVG(duration_ms), 0.0) as avg_duration
FROM llm_requests
WHERE status = 'success'
"#
)
.fetch_one(pool);
match tokio::join!(total_stats, today_stats, error_stats, avg_response) {
(Ok(t), Ok(d), Ok(e), Ok(a)) => {
let total_requests: i64 = t.get("total_requests");
let total_tokens: i64 = t.get("total_tokens");
let total_cost: f64 = t.get("total_cost");
let active_clients: i64 = t.get("active_clients");
let today_requests: i64 = d.get("today_requests");
let today_cost: f64 = d.get("today_cost");
let total_count: i64 = e.get("total");
let error_count: i64 = e.get("errors");
let error_rate = if total_count > 0 {
(error_count as f64 / total_count as f64) * 100.0
} else {
0.0
};
let avg_response_time: f64 = a.get("avg_duration");
Json(ApiResponse::success(serde_json::json!({
"total_requests": total_requests,
"total_tokens": total_tokens,
"total_cost": total_cost,
"active_clients": active_clients,
"today_requests": today_requests,
"today_cost": today_cost,
"error_rate": error_rate,
"avg_response_time": avg_response_time,
})))
}
_ => Json(ApiResponse::error("Failed to fetch usage statistics".to_string()))
}
}
async fn handle_time_series(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let now = chrono::Utc::now();
let twenty_four_hours_ago = now - chrono::Duration::hours(24);
let result = sqlx::query(
r#"
SELECT
strftime('%H:00', timestamp) as hour,
COUNT(*) as requests,
SUM(total_tokens) as tokens,
SUM(cost) as cost
FROM llm_requests
WHERE timestamp >= ?
GROUP BY hour
ORDER BY hour
"#
)
.bind(twenty_four_hours_ago)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let mut series = Vec::new();
for row in rows {
let hour: String = row.get("hour");
let requests: i64 = row.get("requests");
let tokens: i64 = row.get("tokens");
let cost: f64 = row.get("cost");
series.push(serde_json::json!({
"time": hour,
"requests": requests,
"tokens": tokens,
"cost": cost,
}));
}
Json(ApiResponse::success(serde_json::json!({
"series": series,
"period": "24h"
})))
}
Err(e) => {
warn!("Failed to fetch time series data: {}", e);
Json(ApiResponse::error("Failed to fetch time series data".to_string()))
}
}
}
async fn handle_clients_usage(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
// Query database for client usage statistics
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
client_id,
COUNT(*) as requests,
SUM(total_tokens) as tokens,
SUM(cost) as cost,
MAX(timestamp) as last_request
FROM llm_requests
GROUP BY client_id
ORDER BY requests DESC
"#
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let mut client_usage = Vec::new();
for row in rows {
let client_id: String = row.get("client_id");
let requests: i64 = row.get("requests");
let tokens: i64 = row.get("tokens");
let cost: f64 = row.get("cost");
let last_request: Option<chrono::DateTime<chrono::Utc>> = row.get("last_request");
client_usage.push(serde_json::json!({
"client_id": client_id,
"client_name": client_id,
"requests": requests,
"tokens": tokens,
"cost": cost,
"last_request": last_request,
}));
}
Json(ApiResponse::success(serde_json::json!(client_usage)))
}
Err(e) => {
warn!("Failed to fetch client usage data: {}", e);
Json(ApiResponse::error("Failed to fetch client usage data".to_string()))
}
}
}
async fn handle_providers_usage(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
// Query database for provider usage statistics
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
provider,
COUNT(*) as requests,
COALESCE(SUM(total_tokens), 0) as tokens,
COALESCE(SUM(cost), 0.0) as cost
FROM llm_requests
GROUP BY provider
ORDER BY requests DESC
"#
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let mut provider_usage = Vec::new();
for row in rows {
let provider: String = row.get("provider");
let requests: i64 = row.get("requests");
let tokens: i64 = row.get("tokens");
let cost: f64 = row.get("cost");
provider_usage.push(serde_json::json!({
"provider": provider,
"requests": requests,
"tokens": tokens,
"cost": cost,
}));
}
Json(ApiResponse::success(serde_json::json!(provider_usage)))
}
Err(e) => {
warn!("Failed to fetch provider usage data: {}", e);
Json(ApiResponse::error("Failed to fetch provider usage data".to_string()))
}
}
}
// Client handlers
async fn handle_get_clients(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
client_id as id,
name,
created_at,
total_requests,
total_tokens,
total_cost,
is_active
FROM clients
ORDER BY created_at DESC
"#
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let clients: Vec<serde_json::Value> = rows.into_iter().map(|row| {
serde_json::json!({
"id": row.get::<String, _>("id"),
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
"requests_count": row.get::<i64, _>("total_requests"),
"total_tokens": row.get::<i64, _>("total_tokens"),
"total_cost": row.get::<f64, _>("total_cost"),
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
})
}).collect();
Json(ApiResponse::success(serde_json::json!(clients)))
}
Err(e) => {
warn!("Failed to fetch clients: {}", e);
Json(ApiResponse::error("Failed to fetch clients".to_string()))
}
}
}
async fn handle_create_client() -> Json<ApiResponse<serde_json::Value>> {
// In production, this would create a real client
Json(ApiResponse::success(serde_json::json!({
"id": format!("client-{}", rand::random::<u32>()),
"name": "New Client",
"token": format!("sk-demo-{}", rand::random::<u32>()),
"created_at": chrono::Utc::now().to_rfc3339(),
"last_used": None::<String>,
"requests_count": 0,
"status": "active",
})))
}
async fn handle_get_client() -> Json<ApiResponse<serde_json::Value>> {
Json(ApiResponse::error("Not implemented".to_string()))
}
async fn handle_delete_client() -> Json<ApiResponse<serde_json::Value>> {
Json(ApiResponse::success(serde_json::json!({
"success": true,
"message": "Client deleted"
})))
}
async fn handle_client_usage() -> Json<ApiResponse<serde_json::Value>> {
Json(ApiResponse::error("Not implemented".to_string()))
}
// Provider handlers
async fn handle_get_providers(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let registry = &state.app_state.model_registry;
let mut providers_json = Vec::new();
for (p_id, p_info) in &registry.providers {
let models: Vec<String> = p_info.models.keys().cloned().collect();
// Check if provider is healthy via circuit breaker
let status = if state.app_state.rate_limit_manager.check_provider_request(p_id).await.unwrap_or(true) {
"online"
} else {
"degraded"
};
providers_json.push(serde_json::json!({
"id": p_id,
"name": p_info.name,
"enabled": true,
"status": status,
"models": models,
"last_used": null, // TODO: track last used
}));
}
// Add Ollama explicitly
providers_json.push(serde_json::json!({
"id": "ollama",
"name": "Ollama",
"enabled": true,
"status": "online",
"models": ["llama3", "mistral", "phi3"],
"last_used": null,
}));
Json(ApiResponse::success(serde_json::json!(providers_json)))
}
async fn handle_get_provider() -> Json<ApiResponse<serde_json::Value>> {
Json(ApiResponse::error("Not implemented".to_string()))
}
async fn handle_update_provider() -> Json<ApiResponse<serde_json::Value>> {
Json(ApiResponse::success(serde_json::json!({
"success": true,
"message": "Provider updated"
})))
}
async fn handle_test_provider() -> Json<ApiResponse<serde_json::Value>> {
Json(ApiResponse::success(serde_json::json!({
"success": true,
"latency": rand::random::<u32>() % 500 + 100,
"message": "Connection test successful"
})))
}
// System handlers
async fn handle_system_health(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let mut components = HashMap::new();
components.insert("api_server", "online");
components.insert("database", "online");
// Check provider health via circuit breakers
for p_id in state.app_state.model_registry.providers.keys() {
if state.app_state.rate_limit_manager.check_provider_request(p_id).await.unwrap_or(true) {
components.insert(p_id.as_str(), "online");
} else {
components.insert(p_id.as_str(), "degraded");
}
}
// Check Ollama health
if state.app_state.rate_limit_manager.check_provider_request("ollama").await.unwrap_or(true) {
components.insert("ollama", "online");
} else {
components.insert("ollama", "degraded");
}
Json(ApiResponse::success(serde_json::json!({
"status": "healthy",
"timestamp": chrono::Utc::now().to_rfc3339(),
"components": components,
"metrics": {
"cpu_usage": rand::random::<f64>() * 10.0 + 5.0,
"memory_usage": rand::random::<f64>() * 20.0 + 40.0,
"active_connections": rand::random::<u32>() % 20 + 5,
}
})))
}
async fn handle_system_logs(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
id,
timestamp,
client_id,
provider,
model,
prompt_tokens,
completion_tokens,
total_tokens,
cost,
status,
error_message,
duration_ms
FROM llm_requests
ORDER BY timestamp DESC
LIMIT 100
"#
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let logs: Vec<serde_json::Value> = rows.into_iter().map(|row| {
serde_json::json!({
"id": row.get::<i64, _>("id"),
"timestamp": row.get::<chrono::DateTime<chrono::Utc>, _>("timestamp"),
"client_id": row.get::<String, _>("client_id"),
"provider": row.get::<String, _>("provider"),
"model": row.get::<String, _>("model"),
"tokens": row.get::<i64, _>("total_tokens"),
"cost": row.get::<f64, _>("cost"),
"status": row.get::<String, _>("status"),
"error": row.get::<Option<String>, _>("error_message"),
"duration": row.get::<i64, _>("duration_ms"),
})
}).collect();
Json(ApiResponse::success(serde_json::json!(logs)))
}
Err(e) => {
warn!("Failed to fetch system logs: {}", e);
Json(ApiResponse::error("Failed to fetch system logs".to_string()))
}
}
}
async fn handle_system_backup() -> Json<ApiResponse<serde_json::Value>> {
Json(ApiResponse::success(serde_json::json!({
"success": true,
"message": "Backup initiated",
"backup_id": format!("backup-{}", chrono::Utc::now().timestamp()),
})))
}
// Helper functions
#[allow(dead_code)]
fn mask_token(token: &str) -> String {
if token.len() <= 8 {
return "*****".to_string();
}
let masked_len = token.len().min(12);
let visible_len = 4;
let mask_len = masked_len - visible_len;
format!("{}{}", "*".repeat(mask_len), &token[token.len() - visible_len..])
}

128
src/database/mod.rs Normal file
View File

@@ -0,0 +1,128 @@
use anyhow::Result;
use sqlx::SqlitePool;
use tracing::info;
use crate::config::DatabaseConfig;
pub type DbPool = SqlitePool;
pub async fn init(config: &DatabaseConfig) -> Result<DbPool> {
// Ensure the database directory exists
if let Some(parent) = config.path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
let database_url = format!("sqlite:{}", config.path.display());
info!("Connecting to database at {}", database_url);
let pool = SqlitePool::connect(&database_url).await?;
// Run migrations
run_migrations(&pool).await?;
info!("Database migrations completed");
Ok(pool)
}
async fn run_migrations(pool: &DbPool) -> Result<()> {
// Create clients table if it doesn't exist
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS clients (
id INTEGER PRIMARY KEY AUTOINCREMENT,
client_id TEXT UNIQUE NOT NULL,
name TEXT,
description TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
is_active BOOLEAN DEFAULT TRUE,
rate_limit_per_minute INTEGER DEFAULT 60,
total_requests INTEGER DEFAULT 0,
total_tokens INTEGER DEFAULT 0,
total_cost REAL DEFAULT 0.0
)
"#,
)
.execute(pool)
.await?;
// Create llm_requests table if it doesn't exist
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS llm_requests (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
client_id TEXT,
provider TEXT,
model TEXT,
prompt_tokens INTEGER,
completion_tokens INTEGER,
total_tokens INTEGER,
cost REAL,
has_images BOOLEAN DEFAULT FALSE,
status TEXT DEFAULT 'success',
error_message TEXT,
duration_ms INTEGER,
request_body TEXT,
response_body TEXT,
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE SET NULL
)
"#,
)
.execute(pool)
.await?;
// Create indices
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)"
)
.execute(pool)
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)"
)
.execute(pool)
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)"
)
.execute(pool)
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)"
)
.execute(pool)
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)"
)
.execute(pool)
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)"
)
.execute(pool)
.await?;
// Insert default client if none exists
sqlx::query(
r#"
INSERT OR IGNORE INTO clients (client_id, name, description)
VALUES ('default', 'Default Client', 'Default client for anonymous requests')
"#,
)
.execute(pool)
.await?;
Ok(())
}
pub async fn test_connection(pool: &DbPool) -> Result<()> {
sqlx::query("SELECT 1").execute(pool).await?;
Ok(())
}

58
src/errors/mod.rs Normal file
View File

@@ -0,0 +1,58 @@
use thiserror::Error;
#[derive(Error, Debug, Clone)]
pub enum AppError {
#[error("Authentication failed: {0}")]
AuthError(String),
#[error("Configuration error: {0}")]
ConfigError(String),
#[error("Database error: {0}")]
DatabaseError(String),
#[error("Provider error: {0}")]
ProviderError(String),
#[error("Validation error: {0}")]
ValidationError(String),
#[error("Multimodal processing error: {0}")]
MultimodalError(String),
#[error("Rate limit exceeded: {0}")]
RateLimitError(String),
#[error("Internal server error: {0}")]
InternalError(String),
}
impl From<sqlx::Error> for AppError {
fn from(err: sqlx::Error) -> Self {
AppError::DatabaseError(err.to_string())
}
}
impl From<anyhow::Error> for AppError {
fn from(err: anyhow::Error) -> Self {
AppError::InternalError(err.to_string())
}
}
impl axum::response::IntoResponse for AppError {
fn into_response(self) -> axum::response::Response {
let status = match self {
AppError::AuthError(_) => axum::http::StatusCode::UNAUTHORIZED,
AppError::RateLimitError(_) => axum::http::StatusCode::TOO_MANY_REQUESTS,
AppError::ValidationError(_) => axum::http::StatusCode::BAD_REQUEST,
_ => axum::http::StatusCode::INTERNAL_SERVER_ERROR,
};
let body = axum::Json(serde_json::json!({
"error": self.to_string(),
"type": format!("{:?}", self)
}));
(status, body).into_response()
}
}

89
src/lib.rs Normal file
View File

@@ -0,0 +1,89 @@
//! LLM Proxy Library
//!
//! This library provides the core functionality for the LLM proxy gateway,
//! including provider integration, token tracking, and API endpoints.
pub mod auth;
pub mod client;
pub mod config;
pub mod database;
pub mod dashboard;
pub mod errors;
pub mod logging;
pub mod models;
pub mod multimodal;
pub mod providers;
pub mod rate_limiting;
pub mod server;
pub mod state;
pub mod utils;
// Re-exports for convenience
pub use auth::*;
pub use config::*;
pub use database::*;
pub use errors::*;
pub use logging::*;
pub use models::*;
pub use providers::*;
pub use server::*;
pub use state::*;
/// Test utilities for integration testing
#[cfg(test)]
pub mod test_utils {
use std::sync::Arc;
use crate::{
state::AppState,
rate_limiting::RateLimitManager,
client::ClientManager,
providers::ProviderManager,
};
use sqlx::sqlite::SqlitePool;
/// Create a test application state
pub async fn create_test_state() -> Arc<AppState> {
// Create in-memory database
let pool = SqlitePool::connect("sqlite::memory:")
.await
.expect("Failed to create test database");
// Run migrations
crate::database::init(&crate::config::DatabaseConfig {
path: std::path::PathBuf::from(":memory:"),
max_connections: 5,
}).await.expect("Failed to initialize test database");
let rate_limit_manager = RateLimitManager::new(
crate::rate_limiting::RateLimiterConfig::default(),
crate::rate_limiting::CircuitBreakerConfig::default(),
);
let client_manager = Arc::new(ClientManager::new(pool.clone()));
// Create provider manager
let provider_manager = ProviderManager::new();
let model_registry = crate::models::registry::ModelRegistry {
providers: std::collections::HashMap::new(),
};
Arc::new(AppState {
provider_manager,
db_pool: pool.clone(),
rate_limit_manager: Arc::new(rate_limit_manager),
client_manager,
request_logger: Arc::new(crate::logging::RequestLogger::new(pool.clone())),
model_registry: Arc::new(model_registry),
})
}
/// Create a test HTTP client
pub fn create_test_client() -> reqwest::Client {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("Failed to create test HTTP client")
}
}

186
src/logging/mod.rs Normal file
View File

@@ -0,0 +1,186 @@
use chrono::{DateTime, Utc};
use sqlx::SqlitePool;
use tokio::sync::broadcast;
use tracing::warn;
use serde::Serialize;
use crate::errors::AppError;
/// Request log entry for database storage
#[derive(Debug, Clone, Serialize)]
pub struct RequestLog {
pub timestamp: DateTime<Utc>,
pub client_id: String,
pub provider: String,
pub model: String,
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub cost: f64,
pub has_images: bool,
pub status: String, // "success", "error"
pub error_message: Option<String>,
pub duration_ms: u64,
}
/// Database operations for request logging
pub struct RequestLogger {
db_pool: SqlitePool,
dashboard_tx: broadcast::Sender<serde_json::Value>,
}
impl RequestLogger {
pub fn new(db_pool: SqlitePool, dashboard_tx: broadcast::Sender<serde_json::Value>) -> Self {
Self { db_pool, dashboard_tx }
}
/// Log a request to the database (async, spawns a task)
pub fn log_request(&self, log: RequestLog) {
let pool = self.db_pool.clone();
let tx = self.dashboard_tx.clone();
// Spawn async task to log without blocking response
tokio::spawn(async move {
// Broadcast to dashboard
let _ = tx.send(serde_json::json!({
"event_type": "request",
"data": log
}));
if let Err(e) = Self::insert_log(&pool, log).await {
warn!("Failed to log request to database: {}", e);
}
});
}
/// Insert a log entry into the database
async fn insert_log(pool: &SqlitePool, log: RequestLog) -> Result<(), sqlx::Error> {
sqlx::query(
r#"
INSERT INTO llm_requests
(timestamp, client_id, provider, model, prompt_tokens, completion_tokens, total_tokens, cost, has_images, status, error_message, duration_ms, request_body, response_body)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(log.timestamp)
.bind(log.client_id)
.bind(log.provider)
.bind(log.model)
.bind(log.prompt_tokens as i64)
.bind(log.completion_tokens as i64)
.bind(log.total_tokens as i64)
.bind(log.cost)
.bind(log.has_images)
.bind(log.status)
.bind(log.error_message)
.bind(log.duration_ms as i64)
.bind(None::<String>) // request_body - TODO: store serialized request
.bind(None::<String>) // response_body - TODO: store serialized response or error
.execute(pool)
.await?;
Ok(())
}
}
// /// Middleware to log LLM API requests
// /// TODO: Implement proper middleware that can extract response body details
// pub async fn request_logging_middleware(
// // Extract the authenticated client (if available)
// auth_result: Result<AuthenticatedClient, AppError>,
// request: Request,
// next: Next,
// ) -> Response {
// let start_time = std::time::Instant::now();
//
// // Extract client_id from auth or use "unknown"
// let client_id = match auth_result {
// Ok(auth) => auth.client_id,
// Err(_) => "unknown".to_string(),
// };
//
// // Try to extract request details
// let (request_parts, request_body) = request.into_parts();
//
// // Clone request parts for logging
// let path = request_parts.uri.path().to_string();
//
// // Check if this is a chat completion request
// let is_chat_completion = path == "/v1/chat/completions";
//
// // Reconstruct request for downstream handlers
// let request = Request::from_parts(request_parts, request_body);
//
// // Process request and get response
// let response = next.run(request).await;
//
// // Calculate duration
// let duration = start_time.elapsed();
// let duration_ms = duration.as_millis() as u64;
//
// // Log basic request info
// info!(
// "Request from {} to {} - Status: {} - Duration: {}ms",
// client_id,
// path,
// response.status().as_u16(),
// duration_ms
// );
//
// // TODO: Extract more details from request/response for logging
// // For now, we'll need to modify the server handler to pass additional context
//
// response
// }
/// Context for request logging that can be passed through extensions
#[derive(Clone)]
pub struct LoggingContext {
pub client_id: String,
pub provider_name: String,
pub model: String,
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub cost: f64,
pub has_images: bool,
pub error: Option<AppError>,
}
impl LoggingContext {
pub fn new(client_id: String, provider_name: String, model: String) -> Self {
Self {
client_id,
provider_name,
model,
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
cost: 0.0,
has_images: false,
error: None,
}
}
pub fn with_token_counts(mut self, prompt_tokens: u32, completion_tokens: u32) -> Self {
self.prompt_tokens = prompt_tokens;
self.completion_tokens = completion_tokens;
self.total_tokens = prompt_tokens + completion_tokens;
self
}
pub fn with_cost(mut self, cost: f64) -> Self {
self.cost = cost;
self
}
pub fn with_images(mut self, has_images: bool) -> Self {
self.has_images = has_images;
self
}
pub fn with_error(mut self, error: AppError) -> Self {
self.error = Some(error);
self
}
}

141
src/main.rs Normal file
View File

@@ -0,0 +1,141 @@
use anyhow::Result;
use axum::{Router, routing::get};
use std::net::SocketAddr;
use std::sync::Arc;
use tracing::{info, error};
use llm_proxy::{
config::AppConfig,
state::AppState,
providers::{
ProviderManager,
openai::OpenAIProvider,
gemini::GeminiProvider,
deepseek::DeepSeekProvider,
grok::GrokProvider,
ollama::OllamaProvider,
},
database,
server,
dashboard,
rate_limiting::{RateLimitManager, RateLimiterConfig, CircuitBreakerConfig},
};
#[tokio::main]
async fn main() -> Result<()> {
// Initialize tracing (logging)
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.with_target(false)
.init();
info!("Starting LLM Proxy Gateway v{}", env!("CARGO_PKG_VERSION"));
// Load configuration
let config = AppConfig::load().await?;
info!("Configuration loaded from {:?}", config.config_path);
// Initialize database connection pool
let db_pool = database::init(&config.database).await?;
info!("Database initialized at {:?}", config.database.path);
// Initialize provider manager with configured providers
let mut provider_manager = ProviderManager::new();
// Initialize OpenAI
if config.providers.openai.enabled {
match OpenAIProvider::new(&config.providers.openai, &config) {
Ok(p) => {
provider_manager.add_provider(Arc::new(p));
info!("OpenAI provider initialized");
}
Err(e) => error!("Failed to initialize OpenAI provider: {}", e),
}
}
// Initialize Gemini
if config.providers.gemini.enabled {
match GeminiProvider::new(&config.providers.gemini, &config) {
Ok(p) => {
provider_manager.add_provider(Arc::new(p));
info!("Gemini provider initialized");
}
Err(e) => error!("Failed to initialize Gemini provider: {}", e),
}
}
// Initialize DeepSeek
if config.providers.deepseek.enabled {
match DeepSeekProvider::new(&config.providers.deepseek, &config) {
Ok(p) => {
provider_manager.add_provider(Arc::new(p));
info!("DeepSeek provider initialized");
}
Err(e) => error!("Failed to initialize DeepSeek provider: {}", e),
}
}
// Initialize Grok
if config.providers.grok.enabled {
match GrokProvider::new(&config.providers.grok, &config) {
Ok(p) => {
provider_manager.add_provider(Arc::new(p));
info!("Grok provider initialized");
}
Err(e) => error!("Failed to initialize Grok provider: {}", e),
}
}
// Initialize Ollama
if config.providers.ollama.enabled {
match OllamaProvider::new(&config.providers.ollama, &config) {
Ok(p) => {
provider_manager.add_provider(Arc::new(p));
info!("Ollama provider initialized at {}", config.providers.ollama.base_url);
}
Err(e) => error!("Failed to initialize Ollama provider: {}", e),
}
}
// Create rate limit manager
let rate_limit_manager = RateLimitManager::new(
RateLimiterConfig::default(),
CircuitBreakerConfig::default(),
);
// Fetch model registry from models.dev
let model_registry = match llm_proxy::utils::registry::fetch_registry().await {
Ok(registry) => registry,
Err(e) => {
error!("Failed to fetch model registry: {}. Using empty registry.", e);
llm_proxy::models::registry::ModelRegistry { providers: std::collections::HashMap::new() }
}
};
// Create application state
let state = AppState::new(provider_manager, db_pool, rate_limit_manager, model_registry);
// Create application router
let app = Router::new()
.route("/health", get(health_check))
.route("/", get(root))
.merge(server::router(state.clone()))
.merge(dashboard::router(state.clone()));
// Start server
let addr = SocketAddr::from(([0, 0, 0, 0], config.server.port));
info!("Server listening on http://{}", addr);
let listener = tokio::net::TcpListener::bind(&addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
async fn health_check() -> &'static str {
"OK"
}
async fn root() -> &'static str {
"LLM Proxy Gateway - Unified interface for OpenAI, Gemini, DeepSeek, and Grok"
}

247
src/models/mod.rs Normal file
View File

@@ -0,0 +1,247 @@
use serde::{Deserialize, Serialize};
pub mod registry;
// ========== OpenAI-compatible Request/Response Structs ==========
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(default)]
pub temperature: Option<f64>,
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub stream: Option<bool>,
// Add other OpenAI-compatible fields as needed
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String, // "system", "user", "assistant"
#[serde(flatten)]
pub content: MessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text { content: String },
Parts { content: Vec<ContentPartValue> },
None, // Handle cases where content might be null but reasoning is present
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPartValue {
Text { text: String },
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatChoice>,
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
// ========== Streaming Response Structs ==========
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionStreamResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatStreamChoice>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatStreamChoice {
pub index: u32,
pub delta: ChatStreamDelta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatStreamDelta {
pub role: Option<String>,
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
}
// ========== Unified Request Format (for internal use) ==========
#[derive(Debug, Clone)]
pub struct UnifiedRequest {
pub client_id: String,
pub model: String,
pub messages: Vec<UnifiedMessage>,
pub temperature: Option<f64>,
pub max_tokens: Option<u32>,
pub stream: bool,
pub has_images: bool,
}
#[derive(Debug, Clone)]
pub struct UnifiedMessage {
pub role: String,
pub content: Vec<ContentPart>,
}
#[derive(Debug, Clone)]
pub enum ContentPart {
Text { text: String },
Image(crate::multimodal::ImageInput),
}
// ========== Provider-specific Structs ==========
#[derive(Debug, Clone, Serialize)]
pub struct OpenAIRequest {
pub model: String,
pub messages: Vec<OpenAIMessage>,
pub temperature: Option<f64>,
pub max_tokens: Option<u32>,
pub stream: Option<bool>,
}
#[derive(Debug, Clone, Serialize)]
pub struct OpenAIMessage {
pub role: String,
pub content: Vec<OpenAIContentPart>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum OpenAIContentPart {
Text { text: String },
ImageUrl { image_url: ImageUrl },
}
// Note: ImageUrl struct is defined earlier in the file
// ========== Conversion Traits ==========
pub trait ToOpenAI {
fn to_openai(&self) -> Result<OpenAIRequest, anyhow::Error>;
}
pub trait FromOpenAI {
fn from_openai(request: &OpenAIRequest) -> Result<Self, anyhow::Error>
where
Self: Sized;
}
impl UnifiedRequest {
/// Hydrate all image content by fetching URLs and converting to base64/bytes
pub async fn hydrate_images(&mut self) -> anyhow::Result<()> {
if !self.has_images {
return Ok(());
}
for msg in &mut self.messages {
for part in &mut msg.content {
if let ContentPart::Image(image_input) = part {
// Pre-fetch and validate if it's a URL
if let crate::multimodal::ImageInput::Url(_url) = image_input {
let (base64_data, mime_type) = image_input.to_base64().await?;
*image_input = crate::multimodal::ImageInput::Base64 {
data: base64_data,
mime_type,
};
}
}
}
}
Ok(())
}
}
impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
type Error = anyhow::Error;
fn try_from(req: ChatCompletionRequest) -> Result<Self, Self::Error> {
let mut has_images = false;
// Convert OpenAI-compatible request to unified format
let messages = req
.messages
.into_iter()
.map(|msg| {
let (content, _images_in_message) = match msg.content {
MessageContent::Text { content } => {
(vec![ContentPart::Text { text: content }], false)
}
MessageContent::Parts { content } => {
let mut unified_content = Vec::new();
let mut has_images_in_msg = false;
for part in content {
match part {
ContentPartValue::Text { text } => {
unified_content.push(ContentPart::Text { text });
}
ContentPartValue::ImageUrl { image_url } => {
has_images_in_msg = true;
has_images = true;
unified_content.push(ContentPart::Image(
crate::multimodal::ImageInput::from_url(image_url.url)
));
}
}
}
(unified_content, has_images_in_msg)
}
MessageContent::None => {
(vec![], false)
}
};
UnifiedMessage {
role: msg.role,
content,
}
})
.collect();
Ok(UnifiedRequest {
client_id: String::new(), // Will be populated by auth middleware
model: req.model,
messages,
temperature: req.temperature,
max_tokens: req.max_tokens,
stream: req.stream.unwrap_or(false),
has_images,
})
}
}

69
src/models/registry.rs Normal file
View File

@@ -0,0 +1,69 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelRegistry {
#[serde(flatten)]
pub providers: HashMap<String, ProviderInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderInfo {
pub id: String,
pub name: String,
pub models: HashMap<String, ModelMetadata>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetadata {
pub id: String,
pub name: String,
pub cost: Option<ModelCost>,
pub limit: Option<ModelLimit>,
pub modalities: Option<ModelModalities>,
pub tool_call: Option<bool>,
pub reasoning: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCost {
pub input: f64,
pub output: f64,
pub cache_read: Option<f64>,
pub cache_write: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelLimit {
pub context: u32,
pub output: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelModalities {
pub input: Vec<String>,
pub output: Vec<String>,
}
impl ModelRegistry {
/// Find a model by its ID (searching across all providers)
pub fn find_model(&self, model_id: &str) -> Option<&ModelMetadata> {
// First try exact match if the key in models map matches the ID
for provider in self.providers.values() {
if let Some(model) = provider.models.get(model_id) {
return Some(model);
}
}
// Try searching for the model ID inside the metadata if the key was different
for provider in self.providers.values() {
for model in provider.models.values() {
if model.id == model_id {
return Some(model);
}
}
}
None
}
}

285
src/multimodal/mod.rs Normal file
View File

@@ -0,0 +1,285 @@
//! Multimodal support for image processing and conversion
//!
//! This module handles:
//! 1. Image format detection and conversion
//! 2. Base64 encoding/decoding
//! 3. URL fetching for images
//! 4. Provider-specific image format conversion
use anyhow::{Context, Result};
use base64::{engine::general_purpose, Engine as _};
use tracing::{info, warn};
/// Supported image formats for multimodal input
#[derive(Debug, Clone)]
pub enum ImageInput {
/// Base64-encoded image data with MIME type
Base64 {
data: String,
mime_type: String,
},
/// URL to fetch image from
Url(String),
/// Raw bytes with MIME type
Bytes {
data: Vec<u8>,
mime_type: String,
},
}
impl ImageInput {
/// Create ImageInput from base64 string
pub fn from_base64(data: String, mime_type: String) -> Self {
Self::Base64 { data, mime_type }
}
/// Create ImageInput from URL
pub fn from_url(url: String) -> Self {
Self::Url(url)
}
/// Create ImageInput from raw bytes
pub fn from_bytes(data: Vec<u8>, mime_type: String) -> Self {
Self::Bytes { data, mime_type }
}
/// Get MIME type if available
pub fn mime_type(&self) -> Option<&str> {
match self {
Self::Base64 { mime_type, .. } => Some(mime_type),
Self::Bytes { mime_type, .. } => Some(mime_type),
Self::Url(_) => None,
}
}
/// Convert to base64 if not already
pub async fn to_base64(&self) -> Result<(String, String)> {
match self {
Self::Base64 { data, mime_type } => Ok((data.clone(), mime_type.clone())),
Self::Bytes { data, mime_type } => {
let base64_data = general_purpose::STANDARD.encode(data);
Ok((base64_data, mime_type.clone()))
}
Self::Url(url) => {
// Fetch image from URL
info!("Fetching image from URL: {}", url);
let response = reqwest::get(url)
.await
.context("Failed to fetch image from URL")?;
if !response.status().is_success() {
anyhow::bail!("Failed to fetch image: HTTP {}", response.status());
}
let mime_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|h| h.to_str().ok())
.unwrap_or("image/jpeg")
.to_string();
let bytes = response.bytes().await.context("Failed to read image bytes")?;
let base64_data = general_purpose::STANDARD.encode(&bytes);
Ok((base64_data, mime_type))
}
}
}
/// Get image dimensions (width, height)
pub async fn get_dimensions(&self) -> Result<(u32, u32)> {
let bytes = match self {
Self::Base64 { data, .. } => {
general_purpose::STANDARD.decode(data).context("Failed to decode base64")?
}
Self::Bytes { data, .. } => data.clone(),
Self::Url(_) => {
let (base64_data, _) = self.to_base64().await?;
general_purpose::STANDARD.decode(&base64_data).context("Failed to decode base64")?
}
};
let img = image::load_from_memory(&bytes).context("Failed to load image from bytes")?;
Ok((img.width(), img.height()))
}
/// Validate image size and format
pub async fn validate(&self, max_size_mb: f64) -> Result<()> {
let (width, height) = self.get_dimensions().await?;
// Check dimensions
if width > 4096 || height > 4096 {
warn!("Image dimensions too large: {}x{}", width, height);
// Continue anyway, but log warning
}
// Check file size
let size_bytes = match self {
Self::Base64 { data, .. } => {
// Base64 size is ~4/3 of original
(data.len() as f64 * 0.75) as usize
}
Self::Bytes { data, .. } => data.len(),
Self::Url(_) => {
// For URLs, we'd need to fetch to check size
// Skip size check for URLs for now
return Ok(());
}
};
let size_mb = size_bytes as f64 / (1024.0 * 1024.0);
if size_mb > max_size_mb {
anyhow::bail!("Image too large: {:.2}MB > {:.2}MB limit", size_mb, max_size_mb);
}
Ok(())
}
}
/// Provider-specific image format conversion
pub struct ImageConverter;
impl ImageConverter {
/// Convert image to OpenAI-compatible format
pub async fn to_openai_format(image: &ImageInput) -> Result<serde_json::Value> {
let (base64_data, mime_type) = image.to_base64().await?;
// OpenAI expects data URL format: "data:image/jpeg;base64,{data}"
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
Ok(serde_json::json!({
"type": "image_url",
"image_url": {
"url": data_url,
"detail": "auto" // Can be "low", "high", or "auto"
}
}))
}
/// Convert image to Gemini-compatible format
pub async fn to_gemini_format(image: &ImageInput) -> Result<serde_json::Value> {
let (base64_data, mime_type) = image.to_base64().await?;
// Gemini expects inline data format
Ok(serde_json::json!({
"inline_data": {
"mime_type": mime_type,
"data": base64_data
}
}))
}
/// Convert image to DeepSeek-compatible format
pub async fn to_deepseek_format(image: &ImageInput) -> Result<serde_json::Value> {
// DeepSeek uses OpenAI-compatible format for vision models
Self::to_openai_format(image).await
}
/// Detect if a model supports multimodal input
pub fn model_supports_multimodal(model: &str) -> bool {
// OpenAI vision models
if (model.starts_with("gpt-4") && (model.contains("vision") || model.contains("-v") || model.contains("4o"))) ||
model.starts_with("o1-") || model.starts_with("o3-") {
return true;
}
// Gemini vision models
if model.starts_with("gemini") {
// Most Gemini models support vision
return true;
}
// DeepSeek vision models
if model.starts_with("deepseek-vl") {
return true;
}
false
}
}
/// Parse OpenAI-compatible multimodal message content
pub fn parse_openai_content(content: &serde_json::Value) -> Result<Vec<(String, Option<ImageInput>)>> {
let mut parts = Vec::new();
if let Some(content_str) = content.as_str() {
// Simple text content
parts.push((content_str.to_string(), None));
} else if let Some(content_array) = content.as_array() {
// Array of content parts (text and/or images)
for part in content_array {
if let Some(part_obj) = part.as_object() {
if let Some(part_type) = part_obj.get("type").and_then(|t| t.as_str()) {
match part_type {
"text" => {
if let Some(text) = part_obj.get("text").and_then(|t| t.as_str()) {
parts.push((text.to_string(), None));
}
}
"image_url" => {
if let Some(image_url_obj) = part_obj.get("image_url").and_then(|o| o.as_object()) {
if let Some(url) = image_url_obj.get("url").and_then(|u| u.as_str()) {
if url.starts_with("data:") {
// Parse data URL
if let Some((mime_type, data)) = parse_data_url(url) {
let image_input = ImageInput::from_base64(data, mime_type);
parts.push(("".to_string(), Some(image_input)));
}
} else {
// Regular URL
let image_input = ImageInput::from_url(url.to_string());
parts.push(("".to_string(), Some(image_input)));
}
}
}
}
_ => {
warn!("Unknown content part type: {}", part_type);
}
}
}
}
}
}
Ok(parts)
}
/// Parse data URL (data:image/jpeg;base64,{data})
fn parse_data_url(data_url: &str) -> Option<(String, String)> {
if !data_url.starts_with("data:") {
return None;
}
let parts: Vec<&str> = data_url[5..].split(";base64,").collect();
if parts.len() != 2 {
return None;
}
let mime_type = parts[0].to_string();
let data = parts[1].to_string();
Some((mime_type, data))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_data_url() {
let test_url = "data:image/jpeg;base64,SGVsbG8gV29ybGQ="; // "Hello World" in base64
let (mime_type, data) = parse_data_url(test_url).unwrap();
assert_eq!(mime_type, "image/jpeg");
assert_eq!(data, "SGVsbG8gV29ybGQ=");
}
#[tokio::test]
async fn test_model_supports_multimodal() {
assert!(ImageConverter::model_supports_multimodal("gpt-4-vision-preview"));
assert!(ImageConverter::model_supports_multimodal("gemini-pro-vision"));
assert!(!ImageConverter::model_supports_multimodal("gpt-3.5-turbo"));
assert!(!ImageConverter::model_supports_multimodal("gemini-pro"));
}
}

209
src/providers/deepseek.rs Normal file
View File

@@ -0,0 +1,209 @@
use async_trait::async_trait;
use anyhow::Result;
use futures::stream::{BoxStream, StreamExt};
use serde_json::Value;
use crate::{
models::UnifiedRequest,
errors::AppError,
config::AppConfig,
};
use super::{ProviderResponse, ProviderStreamChunk};
pub struct DeepSeekProvider {
client: reqwest::Client,
config: crate::config::DeepSeekConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>,
}
impl DeepSeekProvider {
pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("deepseek")?;
Ok(Self {
client: reqwest::Client::new(),
config: config.clone(),
api_key,
pricing: app_config.pricing.deepseek.clone(),
})
}
}
#[async_trait]
impl super::Provider for DeepSeekProvider {
fn name(&self) -> &str {
"deepseek"
}
fn supports_model(&self, model: &str) -> bool {
model.starts_with("deepseek-") || model.contains("deepseek")
}
fn supports_multimodal(&self) -> bool {
false
}
async fn chat_completion(
&self,
request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> {
// Build the OpenAI-compatible body
let mut body = serde_json::json!({
"model": request.model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(image_input) => {
// DeepSeek currently doesn't support images in the same way, but we'll try to be standard
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})
}
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": false,
});
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
let response = self.client.post(format!("{}/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("DeepSeek API error: {}", error_text)));
}
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
let message = &choice["message"];
let content = message["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
Ok(ProviderResponse {
content,
reasoning_content,
prompt_tokens,
completion_tokens,
total_tokens,
model: request.model,
})
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
if let Some(metadata) = registry.find_model(model) {
if let Some(cost) = &metadata.cost {
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
(completion_tokens as f64 * cost.output / 1_000_000.0);
}
}
let (prompt_rate, completion_rate) = self.pricing.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((0.14, 0.28));
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
}
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
let mut body = serde_json::json!({
"model": request.model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }),
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": true,
});
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body))
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let model = request.model.clone();
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
model: model.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream))
}
}

344
src/providers/gemini.rs Normal file
View File

@@ -0,0 +1,344 @@
use async_trait::async_trait;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use futures::stream::BoxStream;
use crate::{
models::UnifiedRequest,
errors::AppError,
config::AppConfig,
};
use super::{ProviderResponse, ProviderStreamChunk};
#[derive(Debug, Serialize)]
struct GeminiRequest {
contents: Vec<GeminiContent>,
generation_config: Option<GeminiGenerationConfig>,
}
#[derive(Debug, Serialize, Deserialize)]
struct GeminiContent {
parts: Vec<GeminiPart>,
role: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct GeminiPart {
#[serde(skip_serializing_if = "Option::is_none")]
text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
inline_data: Option<GeminiInlineData>,
}
#[derive(Debug, Serialize, Deserialize)]
struct GeminiInlineData {
mime_type: String,
data: String,
}
#[derive(Debug, Serialize)]
struct GeminiGenerationConfig {
temperature: Option<f64>,
max_output_tokens: Option<u32>,
}
#[derive(Debug, Deserialize)]
struct GeminiCandidate {
content: GeminiContent,
_finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct GeminiUsageMetadata {
prompt_token_count: u32,
candidates_token_count: u32,
total_token_count: u32,
}
#[derive(Debug, Deserialize)]
struct GeminiResponse {
candidates: Vec<GeminiCandidate>,
usage_metadata: Option<GeminiUsageMetadata>,
}
pub struct GeminiProvider {
client: reqwest::Client,
config: crate::config::GeminiConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>,
}
impl GeminiProvider {
pub fn new(config: &crate::config::GeminiConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("gemini")?;
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()?;
Ok(Self {
client,
config: config.clone(),
api_key,
pricing: app_config.pricing.gemini.clone(),
})
}
}
#[async_trait]
impl super::Provider for GeminiProvider {
fn name(&self) -> &str {
"gemini"
}
fn supports_model(&self, model: &str) -> bool {
model.starts_with("gemini-")
}
fn supports_multimodal(&self) -> bool {
true // Gemini supports vision
}
async fn chat_completion(
&self,
request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> {
// Convert UnifiedRequest to Gemini request
let mut contents = Vec::with_capacity(request.messages.len());
for msg in request.messages {
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
crate::models::ContentPart::Text { text } => {
parts.push(GeminiPart {
text: Some(text),
inline_data: None,
});
}
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
parts.push(GeminiPart {
text: None,
inline_data: Some(GeminiInlineData {
mime_type,
data: base64_data,
}),
});
}
}
}
// Map role: "user" -> "user", "assistant" -> "model", "system" -> "user"
let role = match msg.role.as_str() {
"assistant" => "model".to_string(),
_ => "user".to_string(),
};
contents.push(GeminiContent {
parts,
role,
});
}
if contents.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build generation config
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
Some(GeminiGenerationConfig {
temperature: request.temperature,
max_output_tokens: request.max_tokens,
})
} else {
None
};
let gemini_request = GeminiRequest {
contents,
generation_config,
};
// Build URL
let url = format!("{}/models/{}:generateContent?key={}",
self.config.base_url,
request.model,
self.api_key
);
// Send request
let response = self.client
.post(&url)
.json(&gemini_request)
.send()
.await
.map_err(|e| AppError::ProviderError(format!("HTTP request failed: {}", e)))?;
// Check status
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, error_text)));
}
let gemini_response: GeminiResponse = response
.json()
.await
.map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?;
// Extract content from first candidate
let content = gemini_response.candidates
.first()
.and_then(|c| c.content.parts.first())
.and_then(|p| p.text.clone())
.unwrap_or_default();
// Extract token usage
let prompt_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.prompt_token_count).unwrap_or(0);
let completion_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.candidates_token_count).unwrap_or(0);
let total_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.total_token_count).unwrap_or(0);
Ok(ProviderResponse {
content,
reasoning_content: None, // Gemini doesn't use this field name
prompt_tokens,
completion_tokens,
total_tokens,
model: request.model,
})
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
if let Some(metadata) = registry.find_model(model) {
if let Some(cost) = &metadata.cost {
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
(completion_tokens as f64 * cost.output / 1_000_000.0);
}
}
let (prompt_rate, completion_rate) = self.pricing.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((0.075, 0.30)); // Default to Gemini 2.0 Flash price if not found
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
}
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
// Convert UnifiedRequest to Gemini request
let mut contents = Vec::with_capacity(request.messages.len());
for msg in request.messages {
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
crate::models::ContentPart::Text { text } => {
parts.push(GeminiPart {
text: Some(text),
inline_data: None,
});
}
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
parts.push(GeminiPart {
text: None,
inline_data: Some(GeminiInlineData {
mime_type,
data: base64_data,
}),
});
}
}
}
// Map role
let role = match msg.role.as_str() {
"assistant" => "model".to_string(),
_ => "user".to_string(),
};
contents.push(GeminiContent {
parts,
role,
});
}
// Build generation config
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
Some(GeminiGenerationConfig {
temperature: request.temperature,
max_output_tokens: request.max_tokens,
})
} else {
None
};
let gemini_request = GeminiRequest {
contents,
generation_config,
};
// Build URL for streaming
let url = format!("{}/models/{}:streamGenerateContent?alt=sse&key={}",
self.config.base_url,
request.model,
self.api_key
);
// Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
use futures::StreamExt;
let es = EventSource::new(self.client.post(&url).json(&gemini_request))
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let model = request.model.clone();
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
let gemini_response: GeminiResponse = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(candidate) = gemini_response.candidates.first() {
let content = candidate.content.parts.first()
.and_then(|p| p.text.clone())
.unwrap_or_default();
yield ProviderStreamChunk {
content,
reasoning_content: None,
finish_reason: None, // Will be set in the last chunk
model: model.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream))
}
}

213
src/providers/grok.rs Normal file
View File

@@ -0,0 +1,213 @@
use async_trait::async_trait;
use anyhow::Result;
use futures::stream::{BoxStream, StreamExt};
use serde_json::Value;
use crate::{
models::UnifiedRequest,
errors::AppError,
config::AppConfig,
};
use super::{ProviderResponse, ProviderStreamChunk};
pub struct GrokProvider {
client: reqwest::Client,
_config: crate::config::GrokConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>,
}
impl GrokProvider {
pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("grok")?;
Ok(Self {
client: reqwest::Client::new(),
_config: config.clone(),
api_key,
pricing: app_config.pricing.grok.clone(),
})
}
}
#[async_trait]
impl super::Provider for GrokProvider {
fn name(&self) -> &str {
"grok"
}
fn supports_model(&self, model: &str) -> bool {
model.starts_with("grok-")
}
fn supports_multimodal(&self) -> bool {
true
}
async fn chat_completion(
&self,
request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> {
let mut body = serde_json::json!({
"model": request.model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})
}
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": false,
});
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("Grok API error: {}", error_text)));
}
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
let message = &choice["message"];
let content = message["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
Ok(ProviderResponse {
content,
reasoning_content,
prompt_tokens,
completion_tokens,
total_tokens,
model: request.model,
})
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
if let Some(metadata) = registry.find_model(model) {
if let Some(cost) = &metadata.cost {
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
(completion_tokens as f64 * cost.output / 1_000_000.0);
}
}
let (prompt_rate, completion_rate) = self.pricing.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((5.0, 15.0));
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
}
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
let mut body = serde_json::json!({
"model": request.model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})
}
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": true,
});
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body))
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let model = request.model.clone();
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
model: model.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream))
}
}

141
src/providers/mod.rs Normal file
View File

@@ -0,0 +1,141 @@
use async_trait::async_trait;
use anyhow::Result;
use std::sync::Arc;
use futures::stream::BoxStream;
use crate::models::UnifiedRequest;
use crate::errors::AppError;
pub mod openai;
pub mod gemini;
pub mod deepseek;
pub mod grok;
pub mod ollama;
#[async_trait]
pub trait Provider: Send + Sync {
/// Get provider name (e.g., "openai", "gemini")
fn name(&self) -> &str;
/// Check if provider supports a specific model
fn supports_model(&self, model: &str) -> bool;
/// Check if provider supports multimodal (images, etc.)
fn supports_multimodal(&self) -> bool;
/// Process a chat completion request
async fn chat_completion(
&self,
request: UnifiedRequest,
) -> Result<ProviderResponse, AppError>;
/// Process a streaming chat completion request
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError>;
/// Estimate token count for a request (for cost calculation)
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32>;
/// Calculate cost based on token usage and model using the registry
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64;
}
pub struct ProviderResponse {
pub content: String,
pub reasoning_content: Option<String>,
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub model: String,
}
#[derive(Debug, Clone)]
pub struct ProviderStreamChunk {
pub content: String,
pub reasoning_content: Option<String>,
pub finish_reason: Option<String>,
pub model: String,
}
#[derive(Clone)]
pub struct ProviderManager {
providers: Vec<Arc<dyn Provider>>,
}
impl ProviderManager {
pub fn new() -> Self {
Self {
providers: Vec::new(),
}
}
pub fn add_provider(&mut self, provider: Arc<dyn Provider>) {
self.providers.push(provider);
}
pub fn get_provider_for_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
self.providers.iter()
.find(|p| p.supports_model(model))
.map(|p| Arc::clone(p))
}
pub fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
self.providers.iter()
.find(|p| p.name() == name)
.map(|p| Arc::clone(p))
}
}
// Create placeholder provider implementations
pub mod placeholder {
use super::*;
pub struct PlaceholderProvider {
name: String,
}
impl PlaceholderProvider {
pub fn new(name: &str) -> Self {
Self { name: name.to_string() }
}
}
#[async_trait]
impl Provider for PlaceholderProvider {
fn name(&self) -> &str {
&self.name
}
fn supports_model(&self, _model: &str) -> bool {
false
}
fn supports_multimodal(&self) -> bool {
false
}
async fn chat_completion_stream(
&self,
_request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
Err(AppError::ProviderError("Streaming not supported for placeholder provider".to_string()))
}
async fn chat_completion(
&self,
_request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> {
Err(AppError::ProviderError(format!("Provider {} not implemented", self.name)))
}
fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result<u32> {
Ok(0)
}
fn calculate_cost(&self, _model: &str, _prompt_tokens: u32, _completion_tokens: u32, _registry: &crate::models::registry::ModelRegistry) -> f64 {
0.0
}
}
}

205
src/providers/ollama.rs Normal file
View File

@@ -0,0 +1,205 @@
use async_trait::async_trait;
use anyhow::Result;
use futures::stream::{BoxStream, StreamExt};
use serde_json::Value;
use crate::{
models::UnifiedRequest,
errors::AppError,
config::AppConfig,
};
use super::{ProviderResponse, ProviderStreamChunk};
pub struct OllamaProvider {
client: reqwest::Client,
_config: crate::config::OllamaConfig,
pricing: Vec<crate::config::ModelPricing>,
}
impl OllamaProvider {
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
Ok(Self {
client: reqwest::Client::new(),
_config: config.clone(),
pricing: app_config.pricing.ollama.clone(),
})
}
}
#[async_trait]
impl super::Provider for OllamaProvider {
fn name(&self) -> &str {
"ollama"
}
fn supports_model(&self, model: &str) -> bool {
self._config.models.iter().any(|m| m == model) || model.starts_with("ollama/")
}
fn supports_multimodal(&self) -> bool {
true
}
async fn chat_completion(
&self,
request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> {
let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string();
let mut body = serde_json::json!({
"model": model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})
}
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": false,
});
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
.json(&body)
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("Ollama API error: {}", error_text)));
}
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
let message = &choice["message"];
let content = message["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = message["reasoning_content"].as_str().or_else(|| message["thought"].as_str()).map(|s| s.to_string());
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
Ok(ProviderResponse {
content,
reasoning_content,
prompt_tokens,
completion_tokens,
total_tokens,
model: request.model,
})
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
if let Some(metadata) = registry.find_model(model) {
if let Some(cost) = &metadata.cost {
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
(completion_tokens as f64 * cost.output / 1_000_000.0);
}
}
let (prompt_rate, completion_rate) = self.pricing.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((0.0, 0.0));
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
}
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string();
let mut body = serde_json::json!({
"model": model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }),
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": true,
});
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
.json(&body))
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let model_name = request.model.clone();
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"].as_str().or_else(|| delta["thought"].as_str()).map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
model: model_name.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream))
}
}

213
src/providers/openai.rs Normal file
View File

@@ -0,0 +1,213 @@
use async_trait::async_trait;
use anyhow::Result;
use futures::stream::{BoxStream, StreamExt};
use serde_json::Value;
use crate::{
models::UnifiedRequest,
errors::AppError,
config::AppConfig,
};
use super::{ProviderResponse, ProviderStreamChunk};
pub struct OpenAIProvider {
client: reqwest::Client,
_config: crate::config::OpenAIConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>,
}
impl OpenAIProvider {
pub fn new(config: &crate::config::OpenAIConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("openai")?;
Ok(Self {
client: reqwest::Client::new(),
_config: config.clone(),
api_key,
pricing: app_config.pricing.openai.clone(),
})
}
}
#[async_trait]
impl super::Provider for OpenAIProvider {
fn name(&self) -> &str {
"openai"
}
fn supports_model(&self, model: &str) -> bool {
model.starts_with("gpt-") || model.starts_with("o1-") || model.starts_with("o3-")
}
fn supports_multimodal(&self) -> bool {
true
}
async fn chat_completion(
&self,
request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> {
let mut body = serde_json::json!({
"model": request.model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})
}
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": false,
});
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("OpenAI API error: {}", error_text)));
}
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
let message = &choice["message"];
let content = message["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
Ok(ProviderResponse {
content,
reasoning_content,
prompt_tokens,
completion_tokens,
total_tokens,
model: request.model,
})
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
if let Some(metadata) = registry.find_model(model) {
if let Some(cost) = &metadata.cost {
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
(completion_tokens as f64 * cost.output / 1_000_000.0);
}
}
let (prompt_rate, completion_rate) = self.pricing.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((0.15, 0.60));
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
}
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
let mut body = serde_json::json!({
"model": request.model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})
}
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": true,
});
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body))
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let model = request.model.clone();
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
model: model.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream))
}
}

359
src/rate_limiting/mod.rs Normal file
View File

@@ -0,0 +1,359 @@
//! Rate limiting and circuit breaking for LLM proxy
//!
//! This module provides:
//! 1. Per-client rate limiting using governor crate
//! 2. Provider circuit breaking to handle API failures
//! 3. Global rate limiting for overall system protection
use std::sync::Arc;
use std::collections::HashMap;
use std::time::Instant;
use tokio::sync::RwLock;
use tracing::{info, warn};
use anyhow::Result;
/// Rate limiter configuration
#[derive(Debug, Clone)]
pub struct RateLimiterConfig {
/// Requests per minute per client
pub requests_per_minute: u32,
/// Burst size (maximum burst capacity)
pub burst_size: u32,
/// Global requests per minute (across all clients)
pub global_requests_per_minute: u32,
}
impl Default for RateLimiterConfig {
fn default() -> Self {
Self {
requests_per_minute: 60, // 1 request per second per client
burst_size: 10, // Allow bursts of up to 10 requests
global_requests_per_minute: 600, // 10 requests per second globally
}
}
}
/// Circuit breaker state
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CircuitState {
Closed, // Normal operation
Open, // Circuit is open, requests fail fast
HalfOpen, // Testing if service has recovered
}
/// Circuit breaker configuration
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
/// Failure threshold to open circuit
pub failure_threshold: u32,
/// Time window for failure counting (seconds)
pub failure_window_secs: u64,
/// Time to wait before trying half-open state (seconds)
pub reset_timeout_secs: u64,
/// Success threshold to close circuit
pub success_threshold: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5, // 5 failures
failure_window_secs: 60, // within 60 seconds
reset_timeout_secs: 30, // wait 30 seconds before half-open
success_threshold: 3, // 3 successes to close circuit
}
}
}
/// Simple token bucket rate limiter for a single client
#[derive(Debug)]
struct TokenBucket {
tokens: f64,
capacity: f64,
refill_rate: f64, // tokens per second
last_refill: Instant,
}
impl TokenBucket {
fn new(capacity: f64, refill_rate: f64) -> Self {
Self {
tokens: capacity,
capacity,
refill_rate,
last_refill: Instant::now(),
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
let new_tokens = elapsed * self.refill_rate;
self.tokens = (self.tokens + new_tokens).min(self.capacity);
self.last_refill = now;
}
fn try_acquire(&mut self, tokens: f64) -> bool {
self.refill();
if self.tokens >= tokens {
self.tokens -= tokens;
true
} else {
false
}
}
}
/// Circuit breaker for a provider
#[derive(Debug)]
pub struct ProviderCircuitBreaker {
state: CircuitState,
failure_count: u32,
success_count: u32,
last_failure_time: Option<std::time::Instant>,
last_state_change: std::time::Instant,
config: CircuitBreakerConfig,
}
impl ProviderCircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
state: CircuitState::Closed,
failure_count: 0,
success_count: 0,
last_failure_time: None,
last_state_change: std::time::Instant::now(),
config,
}
}
/// Check if request is allowed
pub fn allow_request(&mut self) -> bool {
match self.state {
CircuitState::Closed => true,
CircuitState::Open => {
// Check if reset timeout has passed
let elapsed = self.last_state_change.elapsed();
if elapsed.as_secs() >= self.config.reset_timeout_secs {
self.state = CircuitState::HalfOpen;
self.last_state_change = std::time::Instant::now();
info!("Circuit breaker transitioning to half-open state");
true
} else {
false
}
}
CircuitState::HalfOpen => true,
}
}
/// Record a successful request
pub fn record_success(&mut self) {
match self.state {
CircuitState::Closed => {
// Reset failure count on success
self.failure_count = 0;
self.last_failure_time = None;
}
CircuitState::HalfOpen => {
self.success_count += 1;
if self.success_count >= self.config.success_threshold {
self.state = CircuitState::Closed;
self.success_count = 0;
self.failure_count = 0;
self.last_state_change = std::time::Instant::now();
info!("Circuit breaker closed after successful requests");
}
}
CircuitState::Open => {
// Should not happen, but handle gracefully
}
}
}
/// Record a failed request
pub fn record_failure(&mut self) {
let now = std::time::Instant::now();
// Check if failure window has expired
if let Some(last_failure) = self.last_failure_time {
if now.duration_since(last_failure).as_secs() > self.config.failure_window_secs {
// Reset failure count if window expired
self.failure_count = 0;
}
}
self.failure_count += 1;
self.last_failure_time = Some(now);
if self.failure_count >= self.config.failure_threshold && self.state == CircuitState::Closed {
self.state = CircuitState::Open;
self.last_state_change = now;
warn!("Circuit breaker opened due to {} failures", self.failure_count);
} else if self.state == CircuitState::HalfOpen {
// Failure in half-open state, go back to open
self.state = CircuitState::Open;
self.success_count = 0;
self.last_state_change = now;
warn!("Circuit breaker re-opened after failure in half-open state");
}
}
/// Get current state
pub fn state(&self) -> CircuitState {
self.state
}
}
/// Rate limiting and circuit breaking manager
#[derive(Debug)]
pub struct RateLimitManager {
client_buckets: Arc<RwLock<HashMap<String, TokenBucket>>>,
global_bucket: Arc<RwLock<TokenBucket>>,
circuit_breakers: Arc<RwLock<HashMap<String, ProviderCircuitBreaker>>>,
config: RateLimiterConfig,
circuit_config: CircuitBreakerConfig,
}
impl RateLimitManager {
pub fn new(config: RateLimiterConfig, circuit_config: CircuitBreakerConfig) -> Self {
// Convert requests per minute to tokens per second
let global_refill_rate = config.global_requests_per_minute as f64 / 60.0;
Self {
client_buckets: Arc::new(RwLock::new(HashMap::new())),
global_bucket: Arc::new(RwLock::new(TokenBucket::new(
config.burst_size as f64,
global_refill_rate,
))),
circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
config,
circuit_config,
}
}
/// Check if a client request is allowed
pub async fn check_client_request(&self, client_id: &str) -> Result<bool> {
// Check global rate limit first (1 token per request)
{
let mut global_bucket = self.global_bucket.write().await;
if !global_bucket.try_acquire(1.0) {
warn!("Global rate limit exceeded");
return Ok(false);
}
}
// Check client-specific rate limit
let mut buckets = self.client_buckets.write().await;
let bucket = buckets
.entry(client_id.to_string())
.or_insert_with(|| {
TokenBucket::new(
self.config.burst_size as f64,
self.config.requests_per_minute as f64 / 60.0,
)
});
Ok(bucket.try_acquire(1.0))
}
/// Check if provider requests are allowed (circuit breaker)
pub async fn check_provider_request(&self, provider_name: &str) -> Result<bool> {
let mut breakers = self.circuit_breakers.write().await;
let breaker = breakers
.entry(provider_name.to_string())
.or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone()));
Ok(breaker.allow_request())
}
/// Record provider success
pub async fn record_provider_success(&self, provider_name: &str) {
let mut breakers = self.circuit_breakers.write().await;
if let Some(breaker) = breakers.get_mut(provider_name) {
breaker.record_success();
}
}
/// Record provider failure
pub async fn record_provider_failure(&self, provider_name: &str) {
let mut breakers = self.circuit_breakers.write().await;
let breaker = breakers
.entry(provider_name.to_string())
.or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone()));
breaker.record_failure();
}
/// Get provider circuit state
pub async fn get_provider_state(&self, provider_name: &str) -> CircuitState {
let breakers = self.circuit_breakers.read().await;
breakers
.get(provider_name)
.map(|b| b.state())
.unwrap_or(CircuitState::Closed)
}
}
/// Axum middleware for rate limiting
pub mod middleware {
use super::*;
use axum::{
extract::{Request, State},
middleware::Next,
response::Response,
};
use crate::errors::AppError;
use crate::state::AppState;
/// Rate limiting middleware
pub async fn rate_limit_middleware(
State(state): State<AppState>,
request: Request,
next: Next,
) -> Result<Response, AppError> {
// Extract client ID from authentication header
let client_id = extract_client_id_from_request(&request);
// Check rate limits
if !state.rate_limit_manager.check_client_request(&client_id).await? {
return Err(AppError::RateLimitError(
"Rate limit exceeded".to_string()
));
}
Ok(next.run(request).await)
}
/// Extract client ID from request (helper function)
fn extract_client_id_from_request(request: &Request) -> String {
// Try to extract from Authorization header
if let Some(auth_header) = request.headers().get("Authorization") {
if let Ok(auth_str) = auth_header.to_str() {
if auth_str.starts_with("Bearer ") {
let token = &auth_str[7..];
// Use token hash as client ID (same logic as auth module)
return format!("client_{}", &token[..8.min(token.len())]);
}
}
}
// Fallback to anonymous
"anonymous".to_string()
}
/// Circuit breaker middleware for provider requests
pub async fn circuit_breaker_middleware(
provider_name: &str,
state: &AppState,
) -> Result<(), AppError> {
if !state.rate_limit_manager.check_provider_request(provider_name).await? {
return Err(AppError::ProviderError(
format!("Provider {} is currently unavailable (circuit breaker open)", provider_name)
));
}
Ok(())
}
}

224
src/server/mod.rs Normal file
View File

@@ -0,0 +1,224 @@
use uuid::Uuid;
use axum::{
extract::State,
routing::post,
Json, Router,
response::sse::{Event, Sse},
response::IntoResponse,
};
use futures::stream::StreamExt;
use tracing::{info, warn};
use crate::{
auth::AuthenticatedClient,
errors::AppError,
models::{ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatStreamChoice, ChatStreamDelta, ChatMessage, ChatChoice, Usage},
state::AppState,
rate_limiting,
};
pub fn router(state: AppState) -> Router {
Router::new()
.route("/v1/chat/completions", post(chat_completions))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
rate_limiting::middleware::rate_limit_middleware,
))
.with_state(state)
}
async fn chat_completions(
State(state): State<AppState>,
auth: AuthenticatedClient,
Json(request): Json<ChatCompletionRequest>,
) -> Result<axum::response::Response, AppError> {
let start_time = std::time::Instant::now();
let client_id = auth.client_id.clone();
let model = request.model.clone();
info!("Chat completion request from client {} for model {}", client_id, model);
// Find appropriate provider for the model
let provider = state.provider_manager.get_provider_for_model(&request.model)
.ok_or_else(|| AppError::ProviderError(format!("No provider found for model: {}", request.model)))?;
let provider_name = provider.name().to_string();
// Check circuit breaker for this provider
rate_limiting::middleware::circuit_breaker_middleware(&provider_name, &state).await?;
// Convert to unified request format
let mut unified_request = crate::models::UnifiedRequest::try_from(request)
.map_err(|e| AppError::ValidationError(e.to_string()))?;
// Set client_id from authentication
unified_request.client_id = client_id.clone();
// Hydrate images if present
if unified_request.has_images {
unified_request.hydrate_images().await
.map_err(|e| AppError::ValidationError(format!("Failed to process images: {}", e)))?;
}
// Check if streaming is requested
if unified_request.stream {
// Estimate prompt tokens for logging later
let prompt_tokens = crate::utils::tokens::estimate_request_tokens(&model, &unified_request);
let has_images = unified_request.has_images;
// Handle streaming response
let stream_result = provider.chat_completion_stream(unified_request).await;
match stream_result {
Ok(stream) => {
// Record provider success
state.rate_limit_manager.record_provider_success(&provider_name).await;
// Wrap with AggregatingStream for token counting and database logging
let aggregating_stream = crate::utils::streaming::AggregatingStream::new(
stream,
client_id.clone(),
provider.clone(),
model.clone(),
prompt_tokens,
has_images,
state.request_logger.clone(),
state.client_manager.clone(),
state.model_registry.clone(),
);
// Create SSE stream from aggregating stream
let sse_stream = aggregating_stream.map(move |chunk_result| {
match chunk_result {
Ok(chunk) => {
// Convert provider chunk to OpenAI-compatible SSE event
let response = ChatCompletionStreamResponse {
id: format!("chatcmpl-{}", Uuid::new_v4()),
object: "chat.completion.chunk".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: chunk.model.clone(),
choices: vec![ChatStreamChoice {
index: 0,
delta: ChatStreamDelta {
role: None,
content: Some(chunk.content),
reasoning_content: chunk.reasoning_content,
},
finish_reason: chunk.finish_reason,
}],
};
Ok(Event::default().json_data(response).unwrap())
}
Err(e) => {
warn!("Error in streaming response: {}", e);
Err(e)
}
}
});
Ok(Sse::new(sse_stream).into_response())
}
Err(e) => {
// Record provider failure
state.rate_limit_manager.record_provider_failure(&provider_name).await;
// Log failed request
let duration = start_time.elapsed();
warn!("Streaming request failed after {:?}: {}", duration, e);
Err(e)
}
}
} else {
// Handle non-streaming response
let result = provider.chat_completion(unified_request).await;
match result {
Ok(response) => {
// Record provider success
state.rate_limit_manager.record_provider_success(&provider_name).await;
let duration = start_time.elapsed();
let cost = provider.calculate_cost(&response.model, response.prompt_tokens, response.completion_tokens, &state.model_registry);
// Log request to database
state.request_logger.log_request(crate::logging::RequestLog {
timestamp: chrono::Utc::now(),
client_id: client_id.clone(),
provider: provider_name.clone(),
model: response.model.clone(),
prompt_tokens: response.prompt_tokens,
completion_tokens: response.completion_tokens,
total_tokens: response.total_tokens,
cost,
has_images: false, // TODO: check images
status: "success".to_string(),
error_message: None,
duration_ms: duration.as_millis() as u64,
});
// Update client usage
let _ = state.client_manager.update_client_usage(
&client_id,
response.total_tokens as i64,
cost,
).await;
// Convert ProviderResponse to ChatCompletionResponse
let chat_response = ChatCompletionResponse {
id: format!("chatcmpl-{}", Uuid::new_v4()),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: response.model,
choices: vec![ChatChoice {
index: 0,
message: ChatMessage {
role: "assistant".to_string(),
content: crate::models::MessageContent::Text {
content: response.content
},
reasoning_content: response.reasoning_content,
},
finish_reason: Some("stop".to_string()),
}],
usage: Some(Usage {
prompt_tokens: response.prompt_tokens,
completion_tokens: response.completion_tokens,
total_tokens: response.total_tokens,
}),
};
// Log successful request
info!("Request completed successfully in {:?}", duration);
Ok(Json(chat_response).into_response())
}
Err(e) => {
// Record provider failure
state.rate_limit_manager.record_provider_failure(&provider_name).await;
// Log failed request to database
let duration = start_time.elapsed();
state.request_logger.log_request(crate::logging::RequestLog {
timestamp: chrono::Utc::now(),
client_id: client_id.clone(),
provider: provider_name.clone(),
model: model.clone(),
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
cost: 0.0,
has_images: false,
status: "error".to_string(),
error_message: Some(e.to_string()),
duration_ms: duration.as_millis() as u64,
});
warn!("Request failed after {:?}: {}", duration, e);
Err(e)
}
}
}
}

43
src/state/mod.rs Normal file
View File

@@ -0,0 +1,43 @@
use std::sync::Arc;
use tokio::sync::broadcast;
use crate::{
client::ClientManager, database::DbPool, providers::ProviderManager,
rate_limiting::RateLimitManager, logging::RequestLogger,
models::registry::ModelRegistry,
};
/// Shared application state
#[derive(Clone)]
pub struct AppState {
pub provider_manager: ProviderManager,
pub db_pool: DbPool,
pub rate_limit_manager: Arc<RateLimitManager>,
pub client_manager: Arc<ClientManager>,
pub request_logger: Arc<RequestLogger>,
pub model_registry: Arc<ModelRegistry>,
pub dashboard_tx: broadcast::Sender<serde_json::Value>,
}
impl AppState {
pub fn new(
provider_manager: ProviderManager,
db_pool: DbPool,
rate_limit_manager: RateLimitManager,
model_registry: ModelRegistry,
) -> Self {
let client_manager = Arc::new(ClientManager::new(db_pool.clone()));
let (dashboard_tx, _) = broadcast::channel(100);
let request_logger = Arc::new(RequestLogger::new(db_pool.clone(), dashboard_tx.clone()));
Self {
provider_manager,
db_pool,
rate_limit_manager: Arc::new(rate_limit_manager),
client_manager,
request_logger,
model_registry: Arc::new(model_registry),
dashboard_tx,
}
}
}

3
src/utils/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
pub mod tokens;
pub mod registry;
pub mod streaming;

24
src/utils/registry.rs Normal file
View File

@@ -0,0 +1,24 @@
use anyhow::Result;
use tracing::info;
use crate::models::registry::ModelRegistry;
const MODELS_DEV_URL: &str = "https://models.dev/api.json";
pub async fn fetch_registry() -> Result<ModelRegistry> {
info!("Fetching model registry from {}", MODELS_DEV_URL);
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()?;
let response = client.get(MODELS_DEV_URL).send().await?;
if !response.status().is_success() {
return Err(anyhow::anyhow!("Failed to fetch registry: HTTP {}", response.status()));
}
let registry: ModelRegistry = response.json().await?;
info!("Successfully loaded model registry");
Ok(registry)
}

200
src/utils/streaming.rs Normal file
View File

@@ -0,0 +1,200 @@
use futures::stream::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::sync::Arc;
use crate::logging::{RequestLogger, RequestLog};
use crate::client::ClientManager;
use crate::providers::{Provider, ProviderStreamChunk};
use crate::errors::AppError;
use crate::utils::tokens::estimate_completion_tokens;
pub struct AggregatingStream<S> {
inner: S,
client_id: String,
provider: Arc<dyn Provider>,
model: String,
prompt_tokens: u32,
has_images: bool,
accumulated_content: String,
accumulated_reasoning: String,
logger: Arc<RequestLogger>,
client_manager: Arc<ClientManager>,
model_registry: Arc<crate::models::registry::ModelRegistry>,
start_time: std::time::Instant,
has_logged: bool,
}
impl<S> AggregatingStream<S>
where
S: Stream<Item = Result<ProviderStreamChunk, AppError>> + Unpin
{
pub fn new(
inner: S,
client_id: String,
provider: Arc<dyn Provider>,
model: String,
prompt_tokens: u32,
has_images: bool,
logger: Arc<RequestLogger>,
client_manager: Arc<ClientManager>,
model_registry: Arc<crate::models::registry::ModelRegistry>,
) -> Self {
Self {
inner,
client_id,
provider,
model,
prompt_tokens,
has_images,
accumulated_content: String::new(),
accumulated_reasoning: String::new(),
logger,
client_manager,
model_registry,
start_time: std::time::Instant::now(),
has_logged: false,
}
}
fn finalize(&mut self) {
if self.has_logged {
return;
}
self.has_logged = true;
let duration = self.start_time.elapsed();
let client_id = self.client_id.clone();
let provider_name = self.provider.name().to_string();
let model = self.model.clone();
let logger = self.logger.clone();
let client_manager = self.client_manager.clone();
let provider = self.provider.clone();
let prompt_tokens = self.prompt_tokens;
let has_images = self.has_images;
let registry = self.model_registry.clone();
// Estimate completion tokens (including reasoning if present)
let content_tokens = estimate_completion_tokens(&self.accumulated_content, &model);
let reasoning_tokens = if !self.accumulated_reasoning.is_empty() {
estimate_completion_tokens(&self.accumulated_reasoning, &model)
} else {
0
};
let completion_tokens = content_tokens + reasoning_tokens;
let total_tokens = prompt_tokens + completion_tokens;
let cost = provider.calculate_cost(&model, prompt_tokens, completion_tokens, &registry);
// Spawn a background task to log the completion
tokio::spawn(async move {
// Log to database
logger.log_request(RequestLog {
timestamp: chrono::Utc::now(),
client_id: client_id.clone(),
provider: provider_name,
model,
prompt_tokens,
completion_tokens,
total_tokens,
cost,
has_images,
status: "success".to_string(),
error_message: None,
duration_ms: duration.as_millis() as u64,
});
// Update client usage
let _ = client_manager.update_client_usage(
&client_id,
total_tokens as i64,
cost,
).await;
});
}
}
impl<S> Stream for AggregatingStream<S>
where
S: Stream<Item = Result<ProviderStreamChunk, AppError>> + Unpin
{
type Item = Result<ProviderStreamChunk, AppError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let result = Pin::new(&mut self.inner).poll_next(cx);
match &result {
Poll::Ready(Some(Ok(chunk))) => {
self.accumulated_content.push_str(&chunk.content);
if let Some(reasoning) = &chunk.reasoning_content {
self.accumulated_reasoning.push_str(reasoning);
}
}
Poll::Ready(Some(Err(_))) => {
// If there's an error, we might still want to log what we got so far?
// For now, just finalize if we have content
if !self.accumulated_content.is_empty() {
self.finalize();
}
}
Poll::Ready(None) => {
self.finalize();
}
Poll::Pending => {}
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::stream::{self, StreamExt};
use anyhow::Result;
// Simple mock provider for testing
struct MockProvider;
#[async_trait::async_trait]
impl Provider for MockProvider {
fn name(&self) -> &str { "mock" }
fn supports_model(&self, _model: &str) -> bool { true }
fn supports_multimodal(&self) -> bool { false }
async fn chat_completion(&self, _req: crate::models::UnifiedRequest) -> Result<crate::providers::ProviderResponse, AppError> { unimplemented!() }
async fn chat_completion_stream(&self, _req: crate::models::UnifiedRequest) -> Result<futures::stream::BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> { unimplemented!() }
fn estimate_tokens(&self, _req: &crate::models::UnifiedRequest) -> Result<u32> { Ok(10) }
fn calculate_cost(&self, _model: &str, _p: u32, _c: u32, _r: &crate::models::registry::ModelRegistry) -> f64 { 0.05 }
}
#[tokio::test]
async fn test_aggregating_stream() {
let chunks = vec![
Ok(ProviderStreamChunk { content: "Hello".to_string(), finish_reason: None, model: "test".to_string() }),
Ok(ProviderStreamChunk { content: " World".to_string(), finish_reason: Some("stop".to_string()), model: "test".to_string() }),
];
let inner_stream = stream::iter(chunks);
let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
let logger = Arc::new(RequestLogger::new(pool.clone()));
let client_manager = Arc::new(ClientManager::new(pool.clone()));
let registry = Arc::new(crate::models::registry::ModelRegistry { providers: std::collections::HashMap::new() });
let mut agg_stream = AggregatingStream::new(
inner_stream,
"client_1".to_string(),
Arc::new(MockProvider),
"test".to_string(),
10,
false,
logger,
client_manager,
registry,
);
while let Some(item) = agg_stream.next().await {
assert!(item.is_ok());
}
assert_eq!(agg_stream.accumulated_content, "Hello World");
assert!(agg_stream.has_logged);
}
}

51
src/utils/tokens.rs Normal file
View File

@@ -0,0 +1,51 @@
use tiktoken_rs::get_bpe_from_model;
use crate::models::UnifiedRequest;
/// Count tokens for a given model and text
pub fn count_tokens(model: &str, text: &str) -> u32 {
// If we can't get the bpe for the model, fallback to a safe default (cl100k_base for GPT-4/o1)
let bpe = get_bpe_from_model(model).unwrap_or_else(|_| {
tiktoken_rs::cl100k_base().expect("Failed to get cl100k_base encoding")
});
bpe.encode_with_special_tokens(text).len() as u32
}
/// Estimate tokens for a unified request
pub fn estimate_request_tokens(model: &str, request: &UnifiedRequest) -> u32 {
let mut total_tokens = 0;
// Base tokens per message for OpenAI (approximate)
let tokens_per_message = 3;
let _tokens_per_name = 1;
for msg in &request.messages {
total_tokens += tokens_per_message;
for part in &msg.content {
match part {
crate::models::ContentPart::Text { text } => {
total_tokens += count_tokens(model, text);
}
crate::models::ContentPart::Image { .. } => {
// Vision models usually have a fixed cost or calculation based on size
// For now, let's use a conservative estimate of 1000 tokens
total_tokens += 1000;
}
}
}
// Add name tokens if we had names (we don't in UnifiedMessage yet)
// total_tokens += tokens_per_name;
}
// Add 3 tokens for the assistant reply header
total_tokens += 3;
total_tokens
}
/// Estimate tokens for completion text
pub fn estimate_completion_tokens(text: &str, model: &str) -> u32 {
count_tokens(model, text)
}