Init repo

This commit is contained in:
2026-02-26 11:51:36 -05:00
commit 5400d82acd
50 changed files with 17748 additions and 0 deletions

303
src/providers/deepseek.rs Normal file
View File

@@ -0,0 +1,303 @@
use async_trait::async_trait;
use anyhow::Result;
use async_openai::{Client, config::OpenAIConfig};
use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent};
use futures::stream::{BoxStream, StreamExt};
use crate::{
models::UnifiedRequest,
errors::AppError,
config::AppConfig,
};
use super::{ProviderResponse, ProviderStreamChunk};
pub struct DeepSeekProvider {
client: Client<OpenAIConfig>, // DeepSeek uses OpenAI-compatible API
_config: crate::config::DeepSeekConfig,
pricing: Vec<crate::config::ModelPricing>,
}
impl DeepSeekProvider {
pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("deepseek")?;
// Create OpenAIConfig with api key and base url
let openai_config = OpenAIConfig::default()
.with_api_key(api_key)
.with_api_base(&config.base_url);
let client = Client::with_config(openai_config);
Ok(Self {
client,
_config: config.clone(),
pricing: app_config.pricing.deepseek.clone(),
})
}
}
#[async_trait]
impl super::Provider for DeepSeekProvider {
fn name(&self) -> &str {
"deepseek"
}
fn supports_model(&self, model: &str) -> bool {
model.starts_with("deepseek-") || model.contains("deepseek")
}
fn supports_multimodal(&self) -> bool {
false // DeepSeek doesn't support general vision (only OCR)
}
async fn chat_completion(
&self,
request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
// Convert UnifiedRequest messages to OpenAI-compatible messages
let mut messages = Vec::with_capacity(request.messages.len());
for msg in request.messages {
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
crate::models::ContentPart::Text { text } => {
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
text,
}));
}
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
}
}));
}
}
}
let message = match msg.role.as_str() {
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(request.model.clone());
builder.messages(messages);
// Add optional parameters
if let Some(temp) = request.temperature {
builder.temperature(temp as f32);
}
if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16);
}
// Execute API call
let response = self.client
.chat()
.create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Extract content from response
let content = response
.choices
.first()
.and_then(|choice| choice.message.content.clone())
.unwrap_or_default();
// Extract token usage
let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32;
let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32;
let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32;
Ok(ProviderResponse {
content,
prompt_tokens,
completion_tokens,
total_tokens,
model: request.model,
})
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
if let Some(metadata) = registry.find_model(model) {
if let Some(cost) = &metadata.cost {
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
(completion_tokens as f64 * cost.output / 1_000_000.0);
}
}
let (prompt_rate, completion_rate) = self.pricing.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((0.14, 0.28)); // Default to DeepSeek V3 price if not found
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
}
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
// Convert UnifiedRequest messages to OpenAI-compatible messages
let mut messages = Vec::with_capacity(request.messages.len());
for msg in request.messages {
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
crate::models::ContentPart::Text { text } => {
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
text,
}));
}
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
}
}));
}
}
}
let message = match msg.role.as_str() {
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(request.model.clone());
builder.messages(messages);
builder.stream(true); // Enable streaming
// Add optional parameters
if let Some(temp) = request.temperature {
builder.temperature(temp as f32);
}
if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16);
}
// Execute streaming API call
let stream = self.client
.chat()
.create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Convert OpenAI stream to our stream format
let model = request.model.clone();
let stream = stream.map(move |chunk_result| {
match chunk_result {
Ok(chunk) => {
// Extract content from chunk
let content = chunk.choices.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
let finish_reason = chunk.choices.first()
.and_then(|choice| choice.finish_reason.clone())
.map(|reason| format!("{:?}", reason));
Ok(ProviderStreamChunk {
content,
finish_reason,
model: model.clone(),
})
}
Err(e) => Err(AppError::ProviderError(e.to_string())),
}
});
Ok(Box::pin(stream))
}
}

342
src/providers/gemini.rs Normal file
View File

@@ -0,0 +1,342 @@
use async_trait::async_trait;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use futures::stream::BoxStream;
use crate::{
models::UnifiedRequest,
errors::AppError,
config::AppConfig,
};
use super::{ProviderResponse, ProviderStreamChunk};
#[derive(Debug, Serialize)]
struct GeminiRequest {
contents: Vec<GeminiContent>,
generation_config: Option<GeminiGenerationConfig>,
}
#[derive(Debug, Serialize, Deserialize)]
struct GeminiContent {
parts: Vec<GeminiPart>,
role: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct GeminiPart {
#[serde(skip_serializing_if = "Option::is_none")]
text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
inline_data: Option<GeminiInlineData>,
}
#[derive(Debug, Serialize, Deserialize)]
struct GeminiInlineData {
mime_type: String,
data: String,
}
#[derive(Debug, Serialize)]
struct GeminiGenerationConfig {
temperature: Option<f64>,
max_output_tokens: Option<u32>,
}
#[derive(Debug, Deserialize)]
struct GeminiCandidate {
content: GeminiContent,
_finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct GeminiUsageMetadata {
prompt_token_count: u32,
candidates_token_count: u32,
total_token_count: u32,
}
#[derive(Debug, Deserialize)]
struct GeminiResponse {
candidates: Vec<GeminiCandidate>,
usage_metadata: Option<GeminiUsageMetadata>,
}
pub struct GeminiProvider {
client: reqwest::Client,
config: crate::config::GeminiConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>,
}
impl GeminiProvider {
pub fn new(config: &crate::config::GeminiConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("gemini")?;
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()?;
Ok(Self {
client,
config: config.clone(),
api_key,
pricing: app_config.pricing.gemini.clone(),
})
}
}
#[async_trait]
impl super::Provider for GeminiProvider {
fn name(&self) -> &str {
"gemini"
}
fn supports_model(&self, model: &str) -> bool {
model.starts_with("gemini-")
}
fn supports_multimodal(&self) -> bool {
true // Gemini supports vision
}
async fn chat_completion(
&self,
request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> {
// Convert UnifiedRequest to Gemini request
let mut contents = Vec::with_capacity(request.messages.len());
for msg in request.messages {
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
crate::models::ContentPart::Text { text } => {
parts.push(GeminiPart {
text: Some(text),
inline_data: None,
});
}
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
parts.push(GeminiPart {
text: None,
inline_data: Some(GeminiInlineData {
mime_type,
data: base64_data,
}),
});
}
}
}
// Map role: "user" -> "user", "assistant" -> "model", "system" -> "user"
let role = match msg.role.as_str() {
"assistant" => "model".to_string(),
_ => "user".to_string(),
};
contents.push(GeminiContent {
parts,
role,
});
}
if contents.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build generation config
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
Some(GeminiGenerationConfig {
temperature: request.temperature,
max_output_tokens: request.max_tokens,
})
} else {
None
};
let gemini_request = GeminiRequest {
contents,
generation_config,
};
// Build URL
let url = format!("{}/models/{}:generateContent?key={}",
self.config.base_url,
request.model,
self.api_key
);
// Send request
let response = self.client
.post(&url)
.json(&gemini_request)
.send()
.await
.map_err(|e| AppError::ProviderError(format!("HTTP request failed: {}", e)))?;
// Check status
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, error_text)));
}
let gemini_response: GeminiResponse = response
.json()
.await
.map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?;
// Extract content from first candidate
let content = gemini_response.candidates
.first()
.and_then(|c| c.content.parts.first())
.and_then(|p| p.text.clone())
.unwrap_or_default();
// Extract token usage
let prompt_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.prompt_token_count).unwrap_or(0);
let completion_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.candidates_token_count).unwrap_or(0);
let total_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.total_token_count).unwrap_or(0);
Ok(ProviderResponse {
content,
prompt_tokens,
completion_tokens,
total_tokens,
model: request.model,
})
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
if let Some(metadata) = registry.find_model(model) {
if let Some(cost) = &metadata.cost {
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
(completion_tokens as f64 * cost.output / 1_000_000.0);
}
}
let (prompt_rate, completion_rate) = self.pricing.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((0.075, 0.30)); // Default to Gemini 2.0 Flash price if not found
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
}
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
// Convert UnifiedRequest to Gemini request
let mut contents = Vec::with_capacity(request.messages.len());
for msg in request.messages {
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
crate::models::ContentPart::Text { text } => {
parts.push(GeminiPart {
text: Some(text),
inline_data: None,
});
}
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
parts.push(GeminiPart {
text: None,
inline_data: Some(GeminiInlineData {
mime_type,
data: base64_data,
}),
});
}
}
}
// Map role
let role = match msg.role.as_str() {
"assistant" => "model".to_string(),
_ => "user".to_string(),
};
contents.push(GeminiContent {
parts,
role,
});
}
// Build generation config
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
Some(GeminiGenerationConfig {
temperature: request.temperature,
max_output_tokens: request.max_tokens,
})
} else {
None
};
let gemini_request = GeminiRequest {
contents,
generation_config,
};
// Build URL for streaming
let url = format!("{}/models/{}:streamGenerateContent?alt=sse&key={}",
self.config.base_url,
request.model,
self.api_key
);
// Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
use futures::StreamExt;
let es = EventSource::new(self.client.post(&url).json(&gemini_request))
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let model = request.model.clone();
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
let gemini_response: GeminiResponse = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(candidate) = gemini_response.candidates.first() {
let content = candidate.content.parts.first()
.and_then(|p| p.text.clone())
.unwrap_or_default();
yield ProviderStreamChunk {
content,
finish_reason: None, // Will be set in the last chunk
model: model.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream))
}
}

92
src/providers/grok.rs Normal file
View File

@@ -0,0 +1,92 @@
use async_trait::async_trait;
use anyhow::Result;
use futures::stream::BoxStream;
use crate::{
models::UnifiedRequest,
errors::AppError,
config::AppConfig,
};
use super::{ProviderResponse, ProviderStreamChunk};
pub struct GrokProvider {
_client: reqwest::Client,
_config: crate::config::GrokConfig,
_api_key: String,
pricing: Vec<crate::config::ModelPricing>,
}
impl GrokProvider {
pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("grok")?;
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()?;
Ok(Self {
_client: client,
_config: config.clone(),
_api_key: api_key,
pricing: app_config.pricing.grok.clone(),
})
}
}
#[async_trait]
impl super::Provider for GrokProvider {
fn name(&self) -> &str {
"grok"
}
fn supports_model(&self, model: &str) -> bool {
model.starts_with("grok-") || model.contains("grok")
}
fn supports_multimodal(&self) -> bool {
false // Unknown - assume false until API is researched
}
async fn chat_completion(
&self,
request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> {
// TODO: Implement actual Grok API call (once API is available)
// For now, return placeholder response
Ok(ProviderResponse {
content: "Grok provider not yet implemented (API not researched)".to_string(),
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
model: request.model,
})
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
if let Some(metadata) = registry.find_model(model) {
if let Some(cost) = &metadata.cost {
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
(completion_tokens as f64 * cost.output / 1_000_000.0);
}
}
let (prompt_rate, completion_rate) = self.pricing.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((1.0, 3.0)); // Default to some reasonable Grok price if not found
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
}
async fn chat_completion_stream(
&self,
_request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
// Grok API not yet implemented
Err(AppError::ProviderError("Streaming not supported for Grok provider (API not implemented)".to_string()))
}
}

138
src/providers/mod.rs Normal file
View File

@@ -0,0 +1,138 @@
use async_trait::async_trait;
use anyhow::Result;
use std::sync::Arc;
use futures::stream::BoxStream;
use crate::models::UnifiedRequest;
use crate::errors::AppError;
pub mod openai;
pub mod gemini;
pub mod deepseek;
pub mod grok;
#[async_trait]
pub trait Provider: Send + Sync {
/// Get provider name (e.g., "openai", "gemini")
fn name(&self) -> &str;
/// Check if provider supports a specific model
fn supports_model(&self, model: &str) -> bool;
/// Check if provider supports multimodal (images, etc.)
fn supports_multimodal(&self) -> bool;
/// Process a chat completion request
async fn chat_completion(
&self,
request: UnifiedRequest,
) -> Result<ProviderResponse, AppError>;
/// Process a streaming chat completion request
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError>;
/// Estimate token count for a request (for cost calculation)
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32>;
/// Calculate cost based on token usage and model using the registry
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64;
}
pub struct ProviderResponse {
pub content: String,
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub model: String,
}
#[derive(Debug, Clone)]
pub struct ProviderStreamChunk {
pub content: String,
pub finish_reason: Option<String>,
pub model: String,
}
#[derive(Clone)]
pub struct ProviderManager {
providers: Vec<Arc<dyn Provider>>,
}
impl ProviderManager {
pub fn new() -> Self {
Self {
providers: Vec::new(),
}
}
pub fn add_provider(&mut self, provider: Arc<dyn Provider>) {
self.providers.push(provider);
}
pub fn get_provider_for_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
self.providers.iter()
.find(|p| p.supports_model(model))
.map(|p| Arc::clone(p))
}
pub fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
self.providers.iter()
.find(|p| p.name() == name)
.map(|p| Arc::clone(p))
}
}
// Create placeholder provider implementations
pub mod placeholder {
use super::*;
pub struct PlaceholderProvider {
name: String,
}
impl PlaceholderProvider {
pub fn new(name: &str) -> Self {
Self { name: name.to_string() }
}
}
#[async_trait]
impl Provider for PlaceholderProvider {
fn name(&self) -> &str {
&self.name
}
fn supports_model(&self, _model: &str) -> bool {
false
}
fn supports_multimodal(&self) -> bool {
false
}
async fn chat_completion_stream(
&self,
_request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
Err(AppError::ProviderError("Streaming not supported for placeholder provider".to_string()))
}
async fn chat_completion(
&self,
_request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> {
Err(AppError::ProviderError(format!("Provider {} not implemented", self.name)))
}
fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result<u32> {
Ok(0)
}
fn calculate_cost(&self, _model: &str, _prompt_tokens: u32, _completion_tokens: u32, _registry: &crate::models::registry::ModelRegistry) -> f64 {
0.0
}
}
}

304
src/providers/openai.rs Normal file
View File

@@ -0,0 +1,304 @@
use async_trait::async_trait;
use anyhow::Result;
use async_openai::{Client, config::OpenAIConfig};
use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent};
use futures::stream::{BoxStream, StreamExt};
use crate::{
models::UnifiedRequest,
errors::AppError,
config::AppConfig,
};
use super::{ProviderResponse, ProviderStreamChunk};
pub struct OpenAIProvider {
client: Client<OpenAIConfig>,
_config: crate::config::OpenAIConfig,
pricing: Vec<crate::config::ModelPricing>,
}
impl OpenAIProvider {
pub fn new(config: &crate::config::OpenAIConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("openai")?;
// Create OpenAIConfig with api key and base url
let openai_config = OpenAIConfig::default()
.with_api_key(api_key)
.with_api_base(&config.base_url);
let client = Client::with_config(openai_config);
Ok(Self {
client,
_config: config.clone(),
pricing: app_config.pricing.openai.clone(),
})
}
}
#[async_trait]
impl super::Provider for OpenAIProvider {
fn name(&self) -> &str {
"openai"
}
fn supports_model(&self, model: &str) -> bool {
model.starts_with("gpt-") || model.starts_with("o1-") || model.starts_with("o3-")
}
fn supports_multimodal(&self) -> bool {
true // OpenAI supports vision models
}
async fn chat_completion(
&self,
request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
// Convert UnifiedRequest messages to OpenAI messages
let mut messages = Vec::with_capacity(request.messages.len());
for msg in request.messages {
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
crate::models::ContentPart::Text { text } => {
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
text,
}));
}
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
}
}));
}
}
}
let message = match msg.role.as_str() {
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(request.model.clone());
builder.messages(messages);
// Add optional parameters
if let Some(temp) = request.temperature {
builder.temperature(temp as f32);
}
if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16);
}
// Execute API call
let response = self.client
.chat()
.create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Extract content from response
let content = response
.choices
.first()
.and_then(|choice| choice.message.content.clone())
.unwrap_or_default();
// Extract token usage
let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32;
let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32;
let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32;
Ok(ProviderResponse {
content,
prompt_tokens,
completion_tokens,
total_tokens,
model: request.model,
})
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
if let Some(metadata) = registry.find_model(model) {
if let Some(cost) = &metadata.cost {
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
(completion_tokens as f64 * cost.output / 1_000_000.0);
}
}
// Fallback to static pricing if not in registry
let (prompt_rate, completion_rate) = self.pricing.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((0.15, 0.60));
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
}
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
// Convert UnifiedRequest messages to OpenAI messages
let mut messages = Vec::with_capacity(request.messages.len());
for msg in request.messages {
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
crate::models::ContentPart::Text { text } => {
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
text,
}));
}
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
}
}));
}
}
}
let message = match msg.role.as_str() {
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(request.model.clone());
builder.messages(messages);
builder.stream(true); // Enable streaming
// Add optional parameters
if let Some(temp) = request.temperature {
builder.temperature(temp as f32);
}
if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16);
}
// Execute streaming API call
let stream = self.client
.chat()
.create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Convert OpenAI stream to our stream format
let model = request.model.clone();
let stream = stream.map(move |chunk_result| {
match chunk_result {
Ok(chunk) => {
// Extract content from chunk
let content = chunk.choices.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
let finish_reason = chunk.choices.first()
.and_then(|choice| choice.finish_reason.clone())
.map(|reason| format!("{:?}", reason));
Ok(ProviderStreamChunk {
content,
finish_reason,
model: model.clone(),
})
}
Err(e) => Err(AppError::ProviderError(e.to_string())),
}
});
Ok(Box::pin(stream))
}
}