refactor: comprehensive audit — fix bugs, harden security, deduplicate providers, add CI/Docker
Phase 1: Fix compilation (config_path Option<PathBuf>, streaming test, stale test cleanup) Phase 2: Fix critical bugs (remove block_on deadlocks in 4 providers, fix broken SQL query builder) Phase 3: Security hardening (session manager, real auth, token masking, Gemini key to header, password policy) Phase 4: Implement stubs (real provider test, /proc health metrics, client/provider/backup endpoints, has_images) Phase 5: Code quality (shared provider helpers, explicit re-exports, all Clippy warnings fixed, unwrap removal, 6 unused deps removed, dashboard split into 7 sub-modules) Phase 6: Infrastructure (GitHub Actions CI, multi-stage Dockerfile, rustfmt.toml, clippy.toml, script fixes)
This commit is contained in:
@@ -1,14 +1,10 @@
|
||||
use async_trait::async_trait;
|
||||
use anyhow::Result;
|
||||
use futures::stream::{BoxStream, StreamExt};
|
||||
use serde_json::Value;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
|
||||
use crate::{
|
||||
models::UnifiedRequest,
|
||||
errors::AppError,
|
||||
config::AppConfig,
|
||||
};
|
||||
use super::helpers;
|
||||
use super::{ProviderResponse, ProviderStreamChunk};
|
||||
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
||||
|
||||
pub struct DeepSeekProvider {
|
||||
client: reqwest::Client,
|
||||
@@ -23,7 +19,11 @@ impl DeepSeekProvider {
|
||||
Self::new_with_key(config, app_config, api_key)
|
||||
}
|
||||
|
||||
pub fn new_with_key(config: &crate::config::DeepSeekConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
||||
pub fn new_with_key(
|
||||
config: &crate::config::DeepSeekConfig,
|
||||
app_config: &AppConfig,
|
||||
api_key: String,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
client: reqwest::Client::new(),
|
||||
config: config.clone(),
|
||||
@@ -47,42 +47,13 @@ impl super::Provider for DeepSeekProvider {
|
||||
false
|
||||
}
|
||||
|
||||
async fn chat_completion(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<ProviderResponse, AppError> {
|
||||
// Build the OpenAI-compatible body
|
||||
let mut body = serde_json::json!({
|
||||
"model": request.model,
|
||||
"messages": request.messages.iter().map(|m| {
|
||||
serde_json::json!({
|
||||
"role": m.role,
|
||||
"content": m.content.iter().map(|p| {
|
||||
match p {
|
||||
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
||||
crate::models::ContentPart::Image(image_input) => {
|
||||
// DeepSeek currently doesn't support images in the same way, but we'll try to be standard
|
||||
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
||||
serde_json::json!({
|
||||
"type": "image_url",
|
||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||
})
|
||||
}
|
||||
}
|
||||
}).collect::<Vec<_>>()
|
||||
})
|
||||
}).collect::<Vec<_>>(),
|
||||
"stream": false,
|
||||
});
|
||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, false);
|
||||
|
||||
if let Some(temp) = request.temperature {
|
||||
body["temperature"] = serde_json::json!(temp);
|
||||
}
|
||||
if let Some(max_tokens) = request.max_tokens {
|
||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||
}
|
||||
|
||||
let response = self.client.post(format!("{}/chat/completions", self.config.base_url))
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body)
|
||||
.send()
|
||||
@@ -94,119 +65,52 @@ impl super::Provider for DeepSeekProvider {
|
||||
return Err(AppError::ProviderError(format!("DeepSeek API error: {}", error_text)));
|
||||
}
|
||||
|
||||
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
||||
let message = &choice["message"];
|
||||
|
||||
let content = message["content"].as_str().unwrap_or_default().to_string();
|
||||
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
|
||||
|
||||
let usage = &resp_json["usage"];
|
||||
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let resp_json: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
Ok(ProviderResponse {
|
||||
content,
|
||||
reasoning_content,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
model: request.model,
|
||||
})
|
||||
helpers::parse_openai_response(&resp_json, request.model)
|
||||
}
|
||||
|
||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||
}
|
||||
|
||||
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
|
||||
if let Some(metadata) = registry.find_model(model) {
|
||||
if let Some(cost) = &metadata.cost {
|
||||
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
|
||||
(completion_tokens as f64 * cost.output / 1_000_000.0);
|
||||
}
|
||||
}
|
||||
|
||||
let (prompt_rate, completion_rate) = self.pricing.iter()
|
||||
.find(|p| model.contains(&p.model))
|
||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||
.unwrap_or((0.14, 0.28));
|
||||
|
||||
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
helpers::calculate_cost_with_registry(
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
registry,
|
||||
&self.pricing,
|
||||
0.14,
|
||||
0.28,
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_completion_stream(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
let mut body = serde_json::json!({
|
||||
"model": request.model,
|
||||
"messages": request.messages.iter().map(|m| {
|
||||
serde_json::json!({
|
||||
"role": m.role,
|
||||
"content": m.content.iter().map(|p| {
|
||||
match p {
|
||||
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
||||
crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }),
|
||||
}
|
||||
}).collect::<Vec<_>>()
|
||||
})
|
||||
}).collect::<Vec<_>>(),
|
||||
"stream": true,
|
||||
});
|
||||
// DeepSeek doesn't support images in streaming, use text-only
|
||||
let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, true);
|
||||
|
||||
if let Some(temp) = request.temperature {
|
||||
body["temperature"] = serde_json::json!(temp);
|
||||
}
|
||||
if let Some(max_tokens) = request.max_tokens {
|
||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||
}
|
||||
let es = reqwest_eventsource::EventSource::new(
|
||||
self.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body),
|
||||
)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
// Create eventsource stream
|
||||
use reqwest_eventsource::{EventSource, Event};
|
||||
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body))
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
let model = request.model.clone();
|
||||
|
||||
let stream = async_stream::try_stream! {
|
||||
let mut es = es;
|
||||
while let Some(event) = es.next().await {
|
||||
match event {
|
||||
Ok(Event::Message(msg)) => {
|
||||
if msg.data == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
|
||||
let chunk: Value = serde_json::from_str(&msg.data)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||
|
||||
if let Some(choice) = chunk["choices"].get(0) {
|
||||
let delta = &choice["delta"];
|
||||
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
||||
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
|
||||
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
||||
|
||||
yield ProviderStreamChunk {
|
||||
content,
|
||||
reasoning_content,
|
||||
finish_reason,
|
||||
model: model.clone(),
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok(_) => continue,
|
||||
Err(e) => {
|
||||
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Box::pin(stream))
|
||||
Ok(helpers::create_openai_stream(es, request.model, None))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
use async_trait::async_trait;
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
models::UnifiedRequest,
|
||||
errors::AppError,
|
||||
config::AppConfig,
|
||||
};
|
||||
use super::{ProviderResponse, ProviderStreamChunk};
|
||||
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct GeminiRequest {
|
||||
@@ -61,8 +57,6 @@ struct GeminiResponse {
|
||||
usage_metadata: Option<GeminiUsageMetadata>,
|
||||
}
|
||||
|
||||
|
||||
|
||||
pub struct GeminiProvider {
|
||||
client: reqwest::Client,
|
||||
config: crate::config::GeminiConfig,
|
||||
@@ -80,7 +74,7 @@ impl GeminiProvider {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
config: config.clone(),
|
||||
@@ -101,19 +95,16 @@ impl super::Provider for GeminiProvider {
|
||||
}
|
||||
|
||||
fn supports_multimodal(&self) -> bool {
|
||||
true // Gemini supports vision
|
||||
true // Gemini supports vision
|
||||
}
|
||||
|
||||
async fn chat_completion(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<ProviderResponse, AppError> {
|
||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
// Convert UnifiedRequest to Gemini request
|
||||
let mut contents = Vec::with_capacity(request.messages.len());
|
||||
|
||||
|
||||
for msg in request.messages {
|
||||
let mut parts = Vec::with_capacity(msg.content.len());
|
||||
|
||||
|
||||
for part in msg.content {
|
||||
match part {
|
||||
crate::models::ContentPart::Text { text } => {
|
||||
@@ -123,9 +114,11 @@ impl super::Provider for GeminiProvider {
|
||||
});
|
||||
}
|
||||
crate::models::ContentPart::Image(image_input) => {
|
||||
let (base64_data, mime_type) = image_input.to_base64().await
|
||||
let (base64_data, mime_type) = image_input
|
||||
.to_base64()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
||||
|
||||
|
||||
parts.push(GeminiPart {
|
||||
text: None,
|
||||
inline_data: Some(GeminiInlineData {
|
||||
@@ -136,23 +129,20 @@ impl super::Provider for GeminiProvider {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Map role: "user" -> "user", "assistant" -> "model", "system" -> "user"
|
||||
let role = match msg.role.as_str() {
|
||||
"assistant" => "model".to_string(),
|
||||
_ => "user".to_string(),
|
||||
};
|
||||
|
||||
contents.push(GeminiContent {
|
||||
parts,
|
||||
role,
|
||||
});
|
||||
|
||||
contents.push(GeminiContent { parts, role });
|
||||
}
|
||||
|
||||
|
||||
if contents.is_empty() {
|
||||
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
|
||||
}
|
||||
|
||||
|
||||
// Build generation config
|
||||
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
|
||||
Some(GeminiGenerationConfig {
|
||||
@@ -162,51 +152,65 @@ impl super::Provider for GeminiProvider {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
|
||||
let gemini_request = GeminiRequest {
|
||||
contents,
|
||||
generation_config,
|
||||
};
|
||||
|
||||
|
||||
// Build URL
|
||||
let url = format!("{}/models/{}:generateContent?key={}",
|
||||
self.config.base_url,
|
||||
request.model,
|
||||
self.api_key
|
||||
);
|
||||
|
||||
let url = format!("{}/models/{}:generateContent", self.config.base_url, request.model,);
|
||||
|
||||
// Send request
|
||||
let response = self.client
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("x-goog-api-key", &self.api_key)
|
||||
.json(&gemini_request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(format!("HTTP request failed: {}", e)))?;
|
||||
|
||||
|
||||
// Check status
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, error_text)));
|
||||
return Err(AppError::ProviderError(format!(
|
||||
"Gemini API error ({}): {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
|
||||
let gemini_response: GeminiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?;
|
||||
|
||||
|
||||
// Extract content from first candidate
|
||||
let content = gemini_response.candidates
|
||||
let content = gemini_response
|
||||
.candidates
|
||||
.first()
|
||||
.and_then(|c| c.content.parts.first())
|
||||
.and_then(|p| p.text.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
|
||||
// Extract token usage
|
||||
let prompt_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.prompt_token_count).unwrap_or(0);
|
||||
let completion_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.candidates_token_count).unwrap_or(0);
|
||||
let total_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.total_token_count).unwrap_or(0);
|
||||
|
||||
let prompt_tokens = gemini_response
|
||||
.usage_metadata
|
||||
.as_ref()
|
||||
.map(|u| u.prompt_token_count)
|
||||
.unwrap_or(0);
|
||||
let completion_tokens = gemini_response
|
||||
.usage_metadata
|
||||
.as_ref()
|
||||
.map(|u| u.candidates_token_count)
|
||||
.unwrap_or(0);
|
||||
let total_tokens = gemini_response
|
||||
.usage_metadata
|
||||
.as_ref()
|
||||
.map(|u| u.total_token_count)
|
||||
.unwrap_or(0);
|
||||
|
||||
Ok(ProviderResponse {
|
||||
content,
|
||||
reasoning_content: None, // Gemini doesn't use this field name
|
||||
@@ -221,20 +225,22 @@ impl super::Provider for GeminiProvider {
|
||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||
}
|
||||
|
||||
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
|
||||
if let Some(metadata) = registry.find_model(model) {
|
||||
if let Some(cost) = &metadata.cost {
|
||||
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
|
||||
(completion_tokens as f64 * cost.output / 1_000_000.0);
|
||||
}
|
||||
}
|
||||
|
||||
let (prompt_rate, completion_rate) = self.pricing.iter()
|
||||
.find(|p| model.contains(&p.model))
|
||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||
.unwrap_or((0.075, 0.30)); // Default to Gemini 2.0 Flash price if not found
|
||||
|
||||
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
super::helpers::calculate_cost_with_registry(
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
registry,
|
||||
&self.pricing,
|
||||
0.075,
|
||||
0.30,
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_completion_stream(
|
||||
@@ -243,10 +249,10 @@ impl super::Provider for GeminiProvider {
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
// Convert UnifiedRequest to Gemini request
|
||||
let mut contents = Vec::with_capacity(request.messages.len());
|
||||
|
||||
|
||||
for msg in request.messages {
|
||||
let mut parts = Vec::with_capacity(msg.content.len());
|
||||
|
||||
|
||||
for part in msg.content {
|
||||
match part {
|
||||
crate::models::ContentPart::Text { text } => {
|
||||
@@ -256,9 +262,11 @@ impl super::Provider for GeminiProvider {
|
||||
});
|
||||
}
|
||||
crate::models::ContentPart::Image(image_input) => {
|
||||
let (base64_data, mime_type) = image_input.to_base64().await
|
||||
let (base64_data, mime_type) = image_input
|
||||
.to_base64()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
||||
|
||||
|
||||
parts.push(GeminiPart {
|
||||
text: None,
|
||||
inline_data: Some(GeminiInlineData {
|
||||
@@ -269,19 +277,16 @@ impl super::Provider for GeminiProvider {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Map role
|
||||
let role = match msg.role.as_str() {
|
||||
"assistant" => "model".to_string(),
|
||||
_ => "user".to_string(),
|
||||
};
|
||||
|
||||
contents.push(GeminiContent {
|
||||
parts,
|
||||
role,
|
||||
});
|
||||
|
||||
contents.push(GeminiContent { parts, role });
|
||||
}
|
||||
|
||||
|
||||
// Build generation config
|
||||
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
|
||||
Some(GeminiGenerationConfig {
|
||||
@@ -291,28 +296,32 @@ impl super::Provider for GeminiProvider {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
|
||||
let gemini_request = GeminiRequest {
|
||||
contents,
|
||||
generation_config,
|
||||
};
|
||||
|
||||
// Build URL for streaming
|
||||
let url = format!("{}/models/{}:streamGenerateContent?alt=sse&key={}",
|
||||
self.config.base_url,
|
||||
request.model,
|
||||
self.api_key
|
||||
);
|
||||
|
||||
// Create eventsource stream
|
||||
use reqwest_eventsource::{EventSource, Event};
|
||||
use futures::StreamExt;
|
||||
|
||||
let es = EventSource::new(self.client.post(&url).json(&gemini_request))
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
// Build URL for streaming
|
||||
let url = format!(
|
||||
"{}/models/{}:streamGenerateContent?alt=sse",
|
||||
self.config.base_url, request.model,
|
||||
);
|
||||
|
||||
// Create eventsource stream
|
||||
use futures::StreamExt;
|
||||
use reqwest_eventsource::{Event, EventSource};
|
||||
|
||||
let es = EventSource::new(
|
||||
self.client
|
||||
.post(&url)
|
||||
.header("x-goog-api-key", &self.api_key)
|
||||
.json(&gemini_request),
|
||||
)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
let model = request.model.clone();
|
||||
|
||||
|
||||
let stream = async_stream::try_stream! {
|
||||
let mut es = es;
|
||||
while let Some(event) = es.next().await {
|
||||
@@ -320,12 +329,12 @@ impl super::Provider for GeminiProvider {
|
||||
Ok(Event::Message(msg)) => {
|
||||
let gemini_response: GeminiResponse = serde_json::from_str(&msg.data)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||
|
||||
|
||||
if let Some(candidate) = gemini_response.candidates.first() {
|
||||
let content = candidate.content.parts.first()
|
||||
.and_then(|p| p.text.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
|
||||
yield ProviderStreamChunk {
|
||||
content,
|
||||
reasoning_content: None,
|
||||
@@ -341,7 +350,7 @@ impl super::Provider for GeminiProvider {
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
use async_trait::async_trait;
|
||||
use anyhow::Result;
|
||||
use futures::stream::{BoxStream, StreamExt};
|
||||
use serde_json::Value;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
|
||||
use crate::{
|
||||
models::UnifiedRequest,
|
||||
errors::AppError,
|
||||
config::AppConfig,
|
||||
};
|
||||
use super::helpers;
|
||||
use super::{ProviderResponse, ProviderStreamChunk};
|
||||
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
||||
|
||||
pub struct GrokProvider {
|
||||
client: reqwest::Client,
|
||||
_config: crate::config::GrokConfig,
|
||||
config: crate::config::GrokConfig,
|
||||
api_key: String,
|
||||
pricing: Vec<crate::config::ModelPricing>,
|
||||
}
|
||||
@@ -26,7 +22,7 @@ impl GrokProvider {
|
||||
pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
||||
Ok(Self {
|
||||
client: reqwest::Client::new(),
|
||||
_config: config.clone(),
|
||||
config: config.clone(),
|
||||
api_key,
|
||||
pricing: app_config.pricing.grok.clone(),
|
||||
})
|
||||
@@ -47,40 +43,13 @@ impl super::Provider for GrokProvider {
|
||||
true
|
||||
}
|
||||
|
||||
async fn chat_completion(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<ProviderResponse, AppError> {
|
||||
let mut body = serde_json::json!({
|
||||
"model": request.model,
|
||||
"messages": request.messages.iter().map(|m| {
|
||||
serde_json::json!({
|
||||
"role": m.role,
|
||||
"content": m.content.iter().map(|p| {
|
||||
match p {
|
||||
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
||||
crate::models::ContentPart::Image(image_input) => {
|
||||
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
||||
serde_json::json!({
|
||||
"type": "image_url",
|
||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||
})
|
||||
}
|
||||
}
|
||||
}).collect::<Vec<_>>()
|
||||
})
|
||||
}).collect::<Vec<_>>(),
|
||||
"stream": false,
|
||||
});
|
||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, false);
|
||||
|
||||
if let Some(temp) = request.temperature {
|
||||
body["temperature"] = serde_json::json!(temp);
|
||||
}
|
||||
if let Some(max_tokens) = request.max_tokens {
|
||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||
}
|
||||
|
||||
let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body)
|
||||
.send()
|
||||
@@ -92,125 +61,51 @@ impl super::Provider for GrokProvider {
|
||||
return Err(AppError::ProviderError(format!("Grok API error: {}", error_text)));
|
||||
}
|
||||
|
||||
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
||||
let message = &choice["message"];
|
||||
|
||||
let content = message["content"].as_str().unwrap_or_default().to_string();
|
||||
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
|
||||
|
||||
let usage = &resp_json["usage"];
|
||||
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let resp_json: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
Ok(ProviderResponse {
|
||||
content,
|
||||
reasoning_content,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
model: request.model,
|
||||
})
|
||||
helpers::parse_openai_response(&resp_json, request.model)
|
||||
}
|
||||
|
||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||
}
|
||||
|
||||
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
|
||||
if let Some(metadata) = registry.find_model(model) {
|
||||
if let Some(cost) = &metadata.cost {
|
||||
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
|
||||
(completion_tokens as f64 * cost.output / 1_000_000.0);
|
||||
}
|
||||
}
|
||||
|
||||
let (prompt_rate, completion_rate) = self.pricing.iter()
|
||||
.find(|p| model.contains(&p.model))
|
||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||
.unwrap_or((5.0, 15.0));
|
||||
|
||||
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
helpers::calculate_cost_with_registry(
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
registry,
|
||||
&self.pricing,
|
||||
5.0,
|
||||
15.0,
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_completion_stream(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
let mut body = serde_json::json!({
|
||||
"model": request.model,
|
||||
"messages": request.messages.iter().map(|m| {
|
||||
serde_json::json!({
|
||||
"role": m.role,
|
||||
"content": m.content.iter().map(|p| {
|
||||
match p {
|
||||
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
||||
crate::models::ContentPart::Image(image_input) => {
|
||||
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
||||
serde_json::json!({
|
||||
"type": "image_url",
|
||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||
})
|
||||
}
|
||||
}
|
||||
}).collect::<Vec<_>>()
|
||||
})
|
||||
}).collect::<Vec<_>>(),
|
||||
"stream": true,
|
||||
});
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, true);
|
||||
|
||||
if let Some(temp) = request.temperature {
|
||||
body["temperature"] = serde_json::json!(temp);
|
||||
}
|
||||
if let Some(max_tokens) = request.max_tokens {
|
||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||
}
|
||||
let es = reqwest_eventsource::EventSource::new(
|
||||
self.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body),
|
||||
)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
// Create eventsource stream
|
||||
use reqwest_eventsource::{EventSource, Event};
|
||||
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body))
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
let model = request.model.clone();
|
||||
|
||||
let stream = async_stream::try_stream! {
|
||||
let mut es = es;
|
||||
while let Some(event) = es.next().await {
|
||||
match event {
|
||||
Ok(Event::Message(msg)) => {
|
||||
if msg.data == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
|
||||
let chunk: Value = serde_json::from_str(&msg.data)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||
|
||||
if let Some(choice) = chunk["choices"].get(0) {
|
||||
let delta = &choice["delta"];
|
||||
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
||||
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
|
||||
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
||||
|
||||
yield ProviderStreamChunk {
|
||||
content,
|
||||
reasoning_content,
|
||||
finish_reason,
|
||||
model: model.clone(),
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok(_) => continue,
|
||||
Err(e) => {
|
||||
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Box::pin(stream))
|
||||
Ok(helpers::create_openai_stream(es, request.model, None))
|
||||
}
|
||||
}
|
||||
|
||||
189
src/providers/helpers.rs
Normal file
189
src/providers/helpers.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
use super::{ProviderResponse, ProviderStreamChunk};
|
||||
use crate::errors::AppError;
|
||||
use crate::models::{ContentPart, UnifiedMessage, UnifiedRequest};
|
||||
use futures::stream::{BoxStream, StreamExt};
|
||||
use serde_json::Value;
|
||||
|
||||
/// Convert messages to OpenAI-compatible JSON, resolving images asynchronously.
|
||||
///
|
||||
/// This avoids the deadlock caused by `futures::executor::block_on` inside a
|
||||
/// Tokio async context. All image base64 conversions are awaited properly.
|
||||
pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result<Vec<serde_json::Value>, AppError> {
|
||||
let mut result = Vec::new();
|
||||
for m in messages {
|
||||
let mut parts = Vec::new();
|
||||
for p in &m.content {
|
||||
match p {
|
||||
ContentPart::Text { text } => {
|
||||
parts.push(serde_json::json!({ "type": "text", "text": text }));
|
||||
}
|
||||
ContentPart::Image(image_input) => {
|
||||
let (base64_data, mime_type) = image_input
|
||||
.to_base64()
|
||||
.await
|
||||
.map_err(|e| AppError::MultimodalError(e.to_string()))?;
|
||||
parts.push(serde_json::json!({
|
||||
"type": "image_url",
|
||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
result.push(serde_json::json!({
|
||||
"role": m.role,
|
||||
"content": parts
|
||||
}));
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Convert messages to OpenAI-compatible JSON, but replace images with a
|
||||
/// text placeholder "[Image]". Useful for providers that don't support
|
||||
/// multimodal in streaming mode or at all.
|
||||
pub async fn messages_to_openai_json_text_only(
|
||||
messages: &[UnifiedMessage],
|
||||
) -> Result<Vec<serde_json::Value>, AppError> {
|
||||
let mut result = Vec::new();
|
||||
for m in messages {
|
||||
let mut parts = Vec::new();
|
||||
for p in &m.content {
|
||||
match p {
|
||||
ContentPart::Text { text } => {
|
||||
parts.push(serde_json::json!({ "type": "text", "text": text }));
|
||||
}
|
||||
ContentPart::Image(_) => {
|
||||
parts.push(serde_json::json!({ "type": "text", "text": "[Image]" }));
|
||||
}
|
||||
}
|
||||
}
|
||||
result.push(serde_json::json!({
|
||||
"role": m.role,
|
||||
"content": parts
|
||||
}));
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Build an OpenAI-compatible request body from a UnifiedRequest and pre-converted messages.
|
||||
pub fn build_openai_body(
|
||||
request: &UnifiedRequest,
|
||||
messages_json: Vec<serde_json::Value>,
|
||||
stream: bool,
|
||||
) -> serde_json::Value {
|
||||
let mut body = serde_json::json!({
|
||||
"model": request.model,
|
||||
"messages": messages_json,
|
||||
"stream": stream,
|
||||
});
|
||||
|
||||
if let Some(temp) = request.temperature {
|
||||
body["temperature"] = serde_json::json!(temp);
|
||||
}
|
||||
if let Some(max_tokens) = request.max_tokens {
|
||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||
}
|
||||
|
||||
body
|
||||
}
|
||||
|
||||
/// Parse an OpenAI-compatible chat completion response JSON into a ProviderResponse.
|
||||
pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<ProviderResponse, AppError> {
|
||||
let choice = resp_json["choices"]
|
||||
.get(0)
|
||||
.ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
||||
let message = &choice["message"];
|
||||
|
||||
let content = message["content"].as_str().unwrap_or_default().to_string();
|
||||
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
|
||||
|
||||
let usage = &resp_json["usage"];
|
||||
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
|
||||
Ok(ProviderResponse {
|
||||
content,
|
||||
reasoning_content,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
model,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create an SSE stream that parses OpenAI-compatible streaming chunks.
|
||||
///
|
||||
/// The optional `reasoning_field` allows overriding the field name for
|
||||
/// reasoning content (e.g., "thought" for Ollama).
|
||||
pub fn create_openai_stream(
|
||||
es: reqwest_eventsource::EventSource,
|
||||
model: String,
|
||||
reasoning_field: Option<&'static str>,
|
||||
) -> BoxStream<'static, Result<ProviderStreamChunk, AppError>> {
|
||||
use reqwest_eventsource::Event;
|
||||
|
||||
let stream = async_stream::try_stream! {
|
||||
let mut es = es;
|
||||
while let Some(event) = es.next().await {
|
||||
match event {
|
||||
Ok(Event::Message(msg)) => {
|
||||
if msg.data == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
|
||||
let chunk: Value = serde_json::from_str(&msg.data)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||
|
||||
if let Some(choice) = chunk["choices"].get(0) {
|
||||
let delta = &choice["delta"];
|
||||
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
||||
let reasoning_content = delta["reasoning_content"]
|
||||
.as_str()
|
||||
.or_else(|| reasoning_field.and_then(|f| delta[f].as_str()))
|
||||
.map(|s| s.to_string());
|
||||
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
||||
|
||||
yield ProviderStreamChunk {
|
||||
content,
|
||||
reasoning_content,
|
||||
finish_reason,
|
||||
model: model.clone(),
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok(_) => continue,
|
||||
Err(e) => {
|
||||
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Box::pin(stream)
|
||||
}
|
||||
|
||||
/// Calculate cost using the model registry first, then falling back to provider pricing config.
|
||||
pub fn calculate_cost_with_registry(
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
pricing: &[crate::config::ModelPricing],
|
||||
default_prompt_rate: f64,
|
||||
default_completion_rate: f64,
|
||||
) -> f64 {
|
||||
if let Some(metadata) = registry.find_model(model)
|
||||
&& let Some(cost) = &metadata.cost
|
||||
{
|
||||
return (prompt_tokens as f64 * cost.input / 1_000_000.0)
|
||||
+ (completion_tokens as f64 * cost.output / 1_000_000.0);
|
||||
}
|
||||
|
||||
let (prompt_rate, completion_rate) = pricing
|
||||
.iter()
|
||||
.find(|p| model.contains(&p.model))
|
||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||
.unwrap_or((default_prompt_rate, default_completion_rate));
|
||||
|
||||
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
||||
}
|
||||
@@ -1,17 +1,18 @@
|
||||
use async_trait::async_trait;
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use sqlx::Row;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::models::UnifiedRequest;
|
||||
use crate::errors::AppError;
|
||||
use crate::models::UnifiedRequest;
|
||||
|
||||
pub mod openai;
|
||||
pub mod gemini;
|
||||
pub mod deepseek;
|
||||
pub mod gemini;
|
||||
pub mod grok;
|
||||
pub mod helpers;
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
@@ -25,10 +26,7 @@ pub trait Provider: Send + Sync {
|
||||
fn supports_multimodal(&self) -> bool;
|
||||
|
||||
/// Process a chat completion request
|
||||
async fn chat_completion(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<ProviderResponse, AppError>;
|
||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError>;
|
||||
|
||||
/// Process a streaming chat completion request
|
||||
async fn chat_completion_stream(
|
||||
@@ -40,7 +38,13 @@ pub trait Provider: Send + Sync {
|
||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32>;
|
||||
|
||||
/// Calculate cost based on token usage and model using the registry
|
||||
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64;
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64;
|
||||
}
|
||||
|
||||
pub struct ProviderResponse {
|
||||
@@ -64,11 +68,8 @@ use tokio::sync::RwLock;
|
||||
|
||||
use crate::config::AppConfig;
|
||||
use crate::providers::{
|
||||
deepseek::DeepSeekProvider, gemini::GeminiProvider, grok::GrokProvider, ollama::OllamaProvider,
|
||||
openai::OpenAIProvider,
|
||||
gemini::GeminiProvider,
|
||||
deepseek::DeepSeekProvider,
|
||||
grok::GrokProvider,
|
||||
ollama::OllamaProvider,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -76,6 +77,12 @@ pub struct ProviderManager {
|
||||
providers: Arc<RwLock<Vec<Arc<dyn Provider>>>>,
|
||||
}
|
||||
|
||||
impl Default for ProviderManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
@@ -84,7 +91,12 @@ impl ProviderManager {
|
||||
}
|
||||
|
||||
/// Initialize a provider by name using config and database overrides
|
||||
pub async fn initialize_provider(&self, name: &str, app_config: &AppConfig, db_pool: &crate::database::DbPool) -> Result<()> {
|
||||
pub async fn initialize_provider(
|
||||
&self,
|
||||
name: &str,
|
||||
app_config: &AppConfig,
|
||||
db_pool: &crate::database::DbPool,
|
||||
) -> Result<()> {
|
||||
// Load override from database
|
||||
let db_config = sqlx::query("SELECT enabled, base_url, api_key FROM provider_configs WHERE id = ?")
|
||||
.bind(name)
|
||||
@@ -100,11 +112,31 @@ impl ProviderManager {
|
||||
} else {
|
||||
// No database override, use defaults from AppConfig
|
||||
match name {
|
||||
"openai" => (app_config.providers.openai.enabled, Some(app_config.providers.openai.base_url.clone()), None),
|
||||
"gemini" => (app_config.providers.gemini.enabled, Some(app_config.providers.gemini.base_url.clone()), None),
|
||||
"deepseek" => (app_config.providers.deepseek.enabled, Some(app_config.providers.deepseek.base_url.clone()), None),
|
||||
"grok" => (app_config.providers.grok.enabled, Some(app_config.providers.grok.base_url.clone()), None),
|
||||
"ollama" => (app_config.providers.ollama.enabled, Some(app_config.providers.ollama.base_url.clone()), None),
|
||||
"openai" => (
|
||||
app_config.providers.openai.enabled,
|
||||
Some(app_config.providers.openai.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
"gemini" => (
|
||||
app_config.providers.gemini.enabled,
|
||||
Some(app_config.providers.gemini.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
"deepseek" => (
|
||||
app_config.providers.deepseek.enabled,
|
||||
Some(app_config.providers.deepseek.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
"grok" => (
|
||||
app_config.providers.grok.enabled,
|
||||
Some(app_config.providers.grok.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
"ollama" => (
|
||||
app_config.providers.ollama.enabled,
|
||||
Some(app_config.providers.ollama.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
_ => (false, None, None),
|
||||
}
|
||||
};
|
||||
@@ -118,7 +150,9 @@ impl ProviderManager {
|
||||
let provider: Arc<dyn Provider> = match name {
|
||||
"openai" => {
|
||||
let mut cfg = app_config.providers.openai.clone();
|
||||
if let Some(url) = base_url { cfg.base_url = url; }
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
// Handle API key override if present
|
||||
let p = if let Some(key) = api_key {
|
||||
// We need a way to create a provider with an explicit key
|
||||
@@ -128,42 +162,50 @@ impl ProviderManager {
|
||||
OpenAIProvider::new(&cfg, app_config)?
|
||||
};
|
||||
Arc::new(p)
|
||||
},
|
||||
}
|
||||
"ollama" => {
|
||||
let mut cfg = app_config.providers.ollama.clone();
|
||||
if let Some(url) = base_url { cfg.base_url = url; }
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
Arc::new(OllamaProvider::new(&cfg, app_config)?)
|
||||
},
|
||||
}
|
||||
"gemini" => {
|
||||
let mut cfg = app_config.providers.gemini.clone();
|
||||
if let Some(url) = base_url { cfg.base_url = url; }
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
let p = if let Some(key) = api_key {
|
||||
GeminiProvider::new_with_key(&cfg, app_config, key)?
|
||||
} else {
|
||||
GeminiProvider::new(&cfg, app_config)?
|
||||
};
|
||||
Arc::new(p)
|
||||
},
|
||||
}
|
||||
"deepseek" => {
|
||||
let mut cfg = app_config.providers.deepseek.clone();
|
||||
if let Some(url) = base_url { cfg.base_url = url; }
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
let p = if let Some(key) = api_key {
|
||||
DeepSeekProvider::new_with_key(&cfg, app_config, key)?
|
||||
} else {
|
||||
DeepSeekProvider::new(&cfg, app_config)?
|
||||
};
|
||||
Arc::new(p)
|
||||
},
|
||||
}
|
||||
"grok" => {
|
||||
let mut cfg = app_config.providers.grok.clone();
|
||||
if let Some(url) = base_url { cfg.base_url = url; }
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
let p = if let Some(key) = api_key {
|
||||
GrokProvider::new_with_key(&cfg, app_config, key)?
|
||||
} else {
|
||||
GrokProvider::new(&cfg, app_config)?
|
||||
};
|
||||
Arc::new(p)
|
||||
},
|
||||
}
|
||||
_ => return Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
||||
};
|
||||
|
||||
@@ -188,16 +230,12 @@ impl ProviderManager {
|
||||
|
||||
pub async fn get_provider_for_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
|
||||
let providers = self.providers.read().await;
|
||||
providers.iter()
|
||||
.find(|p| p.supports_model(model))
|
||||
.map(|p| Arc::clone(p))
|
||||
providers.iter().find(|p| p.supports_model(model)).map(Arc::clone)
|
||||
}
|
||||
|
||||
pub async fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
|
||||
let providers = self.providers.read().await;
|
||||
providers.iter()
|
||||
.find(|p| p.name() == name)
|
||||
.map(|p| Arc::clone(p))
|
||||
providers.iter().find(|p| p.name() == name).map(Arc::clone)
|
||||
}
|
||||
|
||||
pub async fn get_all_providers(&self) -> Vec<Arc<dyn Provider>> {
|
||||
@@ -238,22 +276,30 @@ pub mod placeholder {
|
||||
&self,
|
||||
_request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
Err(AppError::ProviderError("Streaming not supported for placeholder provider".to_string()))
|
||||
Err(AppError::ProviderError(
|
||||
"Streaming not supported for placeholder provider".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn chat_completion(
|
||||
&self,
|
||||
_request: UnifiedRequest,
|
||||
) -> Result<ProviderResponse, AppError> {
|
||||
Err(AppError::ProviderError(format!("Provider {} not implemented", self.name)))
|
||||
async fn chat_completion(&self, _request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
Err(AppError::ProviderError(format!(
|
||||
"Provider {} not implemented",
|
||||
self.name
|
||||
)))
|
||||
}
|
||||
|
||||
fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result<u32> {
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
fn calculate_cost(&self, _model: &str, _prompt_tokens: u32, _completion_tokens: u32, _registry: &crate::models::registry::ModelRegistry) -> f64 {
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
_model: &str,
|
||||
_prompt_tokens: u32,
|
||||
_completion_tokens: u32,
|
||||
_registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
use async_trait::async_trait;
|
||||
use anyhow::Result;
|
||||
use futures::stream::{BoxStream, StreamExt};
|
||||
use serde_json::Value;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
|
||||
use crate::{
|
||||
models::UnifiedRequest,
|
||||
errors::AppError,
|
||||
config::AppConfig,
|
||||
};
|
||||
use super::helpers;
|
||||
use super::{ProviderResponse, ProviderStreamChunk};
|
||||
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
||||
|
||||
pub struct OllamaProvider {
|
||||
client: reqwest::Client,
|
||||
_config: crate::config::OllamaConfig,
|
||||
config: crate::config::OllamaConfig,
|
||||
pricing: Vec<crate::config::ModelPricing>,
|
||||
}
|
||||
|
||||
@@ -20,7 +16,7 @@ impl OllamaProvider {
|
||||
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
|
||||
Ok(Self {
|
||||
client: reqwest::Client::new(),
|
||||
_config: config.clone(),
|
||||
config: config.clone(),
|
||||
pricing: app_config.pricing.ollama.clone(),
|
||||
})
|
||||
}
|
||||
@@ -33,49 +29,29 @@ impl super::Provider for OllamaProvider {
|
||||
}
|
||||
|
||||
fn supports_model(&self, model: &str) -> bool {
|
||||
self._config.models.iter().any(|m| m == model) || model.starts_with("ollama/")
|
||||
self.config.models.iter().any(|m| m == model) || model.starts_with("ollama/")
|
||||
}
|
||||
|
||||
fn supports_multimodal(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn chat_completion(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<ProviderResponse, AppError> {
|
||||
let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string();
|
||||
async fn chat_completion(&self, mut request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
// Strip "ollama/" prefix if present for the API call
|
||||
let api_model = request
|
||||
.model
|
||||
.strip_prefix("ollama/")
|
||||
.unwrap_or(&request.model)
|
||||
.to_string();
|
||||
let original_model = request.model.clone();
|
||||
request.model = api_model;
|
||||
|
||||
let mut body = serde_json::json!({
|
||||
"model": model,
|
||||
"messages": request.messages.iter().map(|m| {
|
||||
serde_json::json!({
|
||||
"role": m.role,
|
||||
"content": m.content.iter().map(|p| {
|
||||
match p {
|
||||
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
||||
crate::models::ContentPart::Image(image_input) => {
|
||||
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
||||
serde_json::json!({
|
||||
"type": "image_url",
|
||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||
})
|
||||
}
|
||||
}
|
||||
}).collect::<Vec<_>>()
|
||||
})
|
||||
}).collect::<Vec<_>>(),
|
||||
"stream": false,
|
||||
});
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, false);
|
||||
|
||||
if let Some(temp) = request.temperature {
|
||||
body["temperature"] = serde_json::json!(temp);
|
||||
}
|
||||
if let Some(max_tokens) = request.max_tokens {
|
||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||
}
|
||||
|
||||
let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
@@ -86,120 +62,67 @@ impl super::Provider for OllamaProvider {
|
||||
return Err(AppError::ProviderError(format!("Ollama API error: {}", error_text)));
|
||||
}
|
||||
|
||||
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
||||
let message = &choice["message"];
|
||||
|
||||
let content = message["content"].as_str().unwrap_or_default().to_string();
|
||||
let reasoning_content = message["reasoning_content"].as_str().or_else(|| message["thought"].as_str()).map(|s| s.to_string());
|
||||
|
||||
let usage = &resp_json["usage"];
|
||||
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let resp_json: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
Ok(ProviderResponse {
|
||||
content,
|
||||
reasoning_content,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
model: request.model,
|
||||
})
|
||||
// Ollama also supports "thought" as an alias for reasoning_content
|
||||
let mut result = helpers::parse_openai_response(&resp_json, original_model)?;
|
||||
if result.reasoning_content.is_none() {
|
||||
result.reasoning_content = resp_json["choices"]
|
||||
.get(0)
|
||||
.and_then(|c| c["message"]["thought"].as_str())
|
||||
.map(|s| s.to_string());
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||
}
|
||||
|
||||
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
|
||||
if let Some(metadata) = registry.find_model(model) {
|
||||
if let Some(cost) = &metadata.cost {
|
||||
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
|
||||
(completion_tokens as f64 * cost.output / 1_000_000.0);
|
||||
}
|
||||
}
|
||||
|
||||
let (prompt_rate, completion_rate) = self.pricing.iter()
|
||||
.find(|p| model.contains(&p.model))
|
||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||
.unwrap_or((0.0, 0.0));
|
||||
|
||||
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
helpers::calculate_cost_with_registry(
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
registry,
|
||||
&self.pricing,
|
||||
0.0,
|
||||
0.0,
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_completion_stream(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
mut request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string();
|
||||
let api_model = request
|
||||
.model
|
||||
.strip_prefix("ollama/")
|
||||
.unwrap_or(&request.model)
|
||||
.to_string();
|
||||
let original_model = request.model.clone();
|
||||
request.model = api_model;
|
||||
|
||||
let mut body = serde_json::json!({
|
||||
"model": model,
|
||||
"messages": request.messages.iter().map(|m| {
|
||||
serde_json::json!({
|
||||
"role": m.role,
|
||||
"content": m.content.iter().map(|p| {
|
||||
match p {
|
||||
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
||||
crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }),
|
||||
}
|
||||
}).collect::<Vec<_>>()
|
||||
})
|
||||
}).collect::<Vec<_>>(),
|
||||
"stream": true,
|
||||
});
|
||||
let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, true);
|
||||
|
||||
if let Some(temp) = request.temperature {
|
||||
body["temperature"] = serde_json::json!(temp);
|
||||
}
|
||||
if let Some(max_tokens) = request.max_tokens {
|
||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||
}
|
||||
let es = reqwest_eventsource::EventSource::new(
|
||||
self.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.json(&body),
|
||||
)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
// Create eventsource stream
|
||||
use reqwest_eventsource::{EventSource, Event};
|
||||
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
|
||||
.json(&body))
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
let model_name = request.model.clone();
|
||||
|
||||
let stream = async_stream::try_stream! {
|
||||
let mut es = es;
|
||||
while let Some(event) = es.next().await {
|
||||
match event {
|
||||
Ok(Event::Message(msg)) => {
|
||||
if msg.data == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
|
||||
let chunk: Value = serde_json::from_str(&msg.data)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||
|
||||
if let Some(choice) = chunk["choices"].get(0) {
|
||||
let delta = &choice["delta"];
|
||||
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
||||
let reasoning_content = delta["reasoning_content"].as_str().or_else(|| delta["thought"].as_str()).map(|s| s.to_string());
|
||||
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
||||
|
||||
yield ProviderStreamChunk {
|
||||
content,
|
||||
reasoning_content,
|
||||
finish_reason,
|
||||
model: model_name.clone(),
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok(_) => continue,
|
||||
Err(e) => {
|
||||
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Box::pin(stream))
|
||||
// Ollama uses "thought" as an alternative field for reasoning content
|
||||
Ok(helpers::create_openai_stream(es, original_model, Some("thought")))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,14 @@
|
||||
use async_trait::async_trait;
|
||||
use anyhow::Result;
|
||||
use futures::stream::{BoxStream, StreamExt};
|
||||
use serde_json::Value;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
|
||||
use crate::{
|
||||
models::UnifiedRequest,
|
||||
errors::AppError,
|
||||
config::AppConfig,
|
||||
};
|
||||
use super::helpers;
|
||||
use super::{ProviderResponse, ProviderStreamChunk};
|
||||
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
||||
|
||||
pub struct OpenAIProvider {
|
||||
client: reqwest::Client,
|
||||
_config: crate::config::OpenAIConfig,
|
||||
config: crate::config::OpenAIConfig,
|
||||
api_key: String,
|
||||
pricing: Vec<crate::config::ModelPricing>,
|
||||
}
|
||||
@@ -26,7 +22,7 @@ impl OpenAIProvider {
|
||||
pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
||||
Ok(Self {
|
||||
client: reqwest::Client::new(),
|
||||
_config: config.clone(),
|
||||
config: config.clone(),
|
||||
api_key,
|
||||
pricing: app_config.pricing.openai.clone(),
|
||||
})
|
||||
@@ -47,40 +43,13 @@ impl super::Provider for OpenAIProvider {
|
||||
true
|
||||
}
|
||||
|
||||
async fn chat_completion(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<ProviderResponse, AppError> {
|
||||
let mut body = serde_json::json!({
|
||||
"model": request.model,
|
||||
"messages": request.messages.iter().map(|m| {
|
||||
serde_json::json!({
|
||||
"role": m.role,
|
||||
"content": m.content.iter().map(|p| {
|
||||
match p {
|
||||
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
||||
crate::models::ContentPart::Image(image_input) => {
|
||||
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
||||
serde_json::json!({
|
||||
"type": "image_url",
|
||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||
})
|
||||
}
|
||||
}
|
||||
}).collect::<Vec<_>>()
|
||||
})
|
||||
}).collect::<Vec<_>>(),
|
||||
"stream": false,
|
||||
});
|
||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, false);
|
||||
|
||||
if let Some(temp) = request.temperature {
|
||||
body["temperature"] = serde_json::json!(temp);
|
||||
}
|
||||
if let Some(max_tokens) = request.max_tokens {
|
||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||
}
|
||||
|
||||
let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body)
|
||||
.send()
|
||||
@@ -92,125 +61,51 @@ impl super::Provider for OpenAIProvider {
|
||||
return Err(AppError::ProviderError(format!("OpenAI API error: {}", error_text)));
|
||||
}
|
||||
|
||||
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
||||
let message = &choice["message"];
|
||||
|
||||
let content = message["content"].as_str().unwrap_or_default().to_string();
|
||||
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
|
||||
|
||||
let usage = &resp_json["usage"];
|
||||
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let resp_json: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
Ok(ProviderResponse {
|
||||
content,
|
||||
reasoning_content,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
model: request.model,
|
||||
})
|
||||
helpers::parse_openai_response(&resp_json, request.model)
|
||||
}
|
||||
|
||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||
}
|
||||
|
||||
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
|
||||
if let Some(metadata) = registry.find_model(model) {
|
||||
if let Some(cost) = &metadata.cost {
|
||||
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
|
||||
(completion_tokens as f64 * cost.output / 1_000_000.0);
|
||||
}
|
||||
}
|
||||
|
||||
let (prompt_rate, completion_rate) = self.pricing.iter()
|
||||
.find(|p| model.contains(&p.model))
|
||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||
.unwrap_or((0.15, 0.60));
|
||||
|
||||
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
helpers::calculate_cost_with_registry(
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
registry,
|
||||
&self.pricing,
|
||||
0.15,
|
||||
0.60,
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_completion_stream(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
let mut body = serde_json::json!({
|
||||
"model": request.model,
|
||||
"messages": request.messages.iter().map(|m| {
|
||||
serde_json::json!({
|
||||
"role": m.role,
|
||||
"content": m.content.iter().map(|p| {
|
||||
match p {
|
||||
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
||||
crate::models::ContentPart::Image(image_input) => {
|
||||
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
||||
serde_json::json!({
|
||||
"type": "image_url",
|
||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||
})
|
||||
}
|
||||
}
|
||||
}).collect::<Vec<_>>()
|
||||
})
|
||||
}).collect::<Vec<_>>(),
|
||||
"stream": true,
|
||||
});
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, true);
|
||||
|
||||
if let Some(temp) = request.temperature {
|
||||
body["temperature"] = serde_json::json!(temp);
|
||||
}
|
||||
if let Some(max_tokens) = request.max_tokens {
|
||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||
}
|
||||
let es = reqwest_eventsource::EventSource::new(
|
||||
self.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body),
|
||||
)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
// Create eventsource stream
|
||||
use reqwest_eventsource::{EventSource, Event};
|
||||
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body))
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
let model = request.model.clone();
|
||||
|
||||
let stream = async_stream::try_stream! {
|
||||
let mut es = es;
|
||||
while let Some(event) = es.next().await {
|
||||
match event {
|
||||
Ok(Event::Message(msg)) => {
|
||||
if msg.data == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
|
||||
let chunk: Value = serde_json::from_str(&msg.data)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||
|
||||
if let Some(choice) = chunk["choices"].get(0) {
|
||||
let delta = &choice["delta"];
|
||||
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
||||
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
|
||||
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
||||
|
||||
yield ProviderStreamChunk {
|
||||
content,
|
||||
reasoning_content,
|
||||
finish_reason,
|
||||
model: model.clone(),
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok(_) => continue,
|
||||
Err(e) => {
|
||||
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Box::pin(stream))
|
||||
Ok(helpers::create_openai_stream(es, request.model, None))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user