feat: add support for reasoning (thinking) models across all providers

- Refactored OpenAI, DeepSeek, Grok, and Ollama to manual JSON parsing to capture 'reasoning_content' and 'thought' fields.
- Implemented real-time streaming of reasoning blocks.
- Added token aggregation and cost tracking for reasoning tokens.
- Updated unified models to include 'reasoning_content' in API responses.
This commit is contained in:
2026-02-26 13:50:22 -05:00
parent 6143b88eac
commit 70fef80051
9 changed files with 502 additions and 860 deletions

View File

@@ -22,6 +22,8 @@ pub struct ChatMessage {
pub role: String, // "system", "user", "assistant" pub role: String, // "system", "user", "assistant"
#[serde(flatten)] #[serde(flatten)]
pub content: MessageContent, pub content: MessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -29,6 +31,7 @@ pub struct ChatMessage {
pub enum MessageContent { pub enum MessageContent {
Text { content: String }, Text { content: String },
Parts { content: Vec<ContentPartValue> }, Parts { content: Vec<ContentPartValue> },
None, // Handle cases where content might be null but reasoning is present
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -91,6 +94,8 @@ pub struct ChatStreamChoice {
pub struct ChatStreamDelta { pub struct ChatStreamDelta {
pub role: Option<String>, pub role: Option<String>,
pub content: Option<String>, pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
} }
// ========== Unified Request Format (for internal use) ========== // ========== Unified Request Format (for internal use) ==========
@@ -217,6 +222,9 @@ impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
(unified_content, has_images_in_msg) (unified_content, has_images_in_msg)
} }
MessageContent::None => {
(vec![], false)
}
}; };
UnifiedMessage { UnifiedMessage {

View File

@@ -1,8 +1,7 @@
use async_trait::async_trait; use async_trait::async_trait;
use anyhow::Result; use anyhow::Result;
use async_openai::{Client, config::OpenAIConfig};
use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent};
use futures::stream::{BoxStream, StreamExt}; use futures::stream::{BoxStream, StreamExt};
use serde_json::Value;
use crate::{ use crate::{
models::UnifiedRequest, models::UnifiedRequest,
@@ -12,8 +11,9 @@ use crate::{
use super::{ProviderResponse, ProviderStreamChunk}; use super::{ProviderResponse, ProviderStreamChunk};
pub struct DeepSeekProvider { pub struct DeepSeekProvider {
client: Client<OpenAIConfig>, // DeepSeek uses OpenAI-compatible API client: reqwest::Client,
_config: crate::config::DeepSeekConfig, config: crate::config::DeepSeekConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>, pricing: Vec<crate::config::ModelPricing>,
} }
@@ -21,16 +21,10 @@ impl DeepSeekProvider {
pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result<Self> { pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("deepseek")?; let api_key = app_config.get_api_key("deepseek")?;
// Create OpenAIConfig with api key and base url
let openai_config = OpenAIConfig::default()
.with_api_key(api_key)
.with_api_base(&config.base_url);
let client = Client::with_config(openai_config);
Ok(Self { Ok(Self {
client, client: reqwest::Client::new(),
_config: config.clone(), config: config.clone(),
api_key,
pricing: app_config.pricing.deepseek.clone(), pricing: app_config.pricing.deepseek.clone(),
}) })
} }
@@ -47,114 +41,72 @@ impl super::Provider for DeepSeekProvider {
} }
fn supports_multimodal(&self) -> bool { fn supports_multimodal(&self) -> bool {
false // DeepSeek doesn't support general vision (only OCR) false
} }
async fn chat_completion( async fn chat_completion(
&self, &self,
request: UnifiedRequest, request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> { ) -> Result<ProviderResponse, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; // Build the OpenAI-compatible body
let mut body = serde_json::json!({
// Convert UnifiedRequest messages to OpenAI-compatible messages "model": request.model,
let mut messages = Vec::with_capacity(request.messages.len()); "messages": request.messages.iter().map(|m| {
serde_json::json!({
for msg in request.messages { "role": m.role,
let mut parts = Vec::with_capacity(msg.content.len()); "content": m.content.iter().map(|p| {
match p {
for part in msg.content { crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
match part { crate::models::ContentPart::Image(image_input) => {
crate::models::ContentPart::Text { text } => { // DeepSeek currently doesn't support images in the same way, but we'll try to be standard
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
text, serde_json::json!({
})); "type": "image_url",
} "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
crate::models::ContentPart::Image(image_input) => { })
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
} }
})); }
} }).collect::<Vec<_>>()
} })
} }).collect::<Vec<_>>(),
"stream": false,
let message = match msg.role.as_str() { });
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(request.model.clone());
builder.messages(messages);
// Add optional parameters
if let Some(temp) = request.temperature { if let Some(temp) = request.temperature {
builder.temperature(temp as f32); body["temperature"] = serde_json::json!(temp);
} }
if let Some(max_tokens) = request.max_tokens { if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16); body["max_tokens"] = serde_json::json!(max_tokens);
} }
// Execute API call let response = self.client.post(format!("{}/chat/completions", self.config.base_url))
let response = self.client .header("Authorization", format!("Bearer {}", self.api_key))
.chat() .json(&body)
.create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?) .send()
.await .await
.map_err(|e| AppError::ProviderError(e.to_string()))?; .map_err(|e| AppError::ProviderError(e.to_string()))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("DeepSeek API error: {}", error_text)));
}
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Extract content from response let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
let content = response let message = &choice["message"];
.choices
.first()
.and_then(|choice| choice.message.content.clone())
.unwrap_or_default();
// Extract token usage let content = message["content"].as_str().unwrap_or_default().to_string();
let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32; let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32;
let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32;
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
Ok(ProviderResponse { Ok(ProviderResponse {
content, content,
reasoning_content,
prompt_tokens, prompt_tokens,
completion_tokens, completion_tokens,
total_tokens, total_tokens,
@@ -177,7 +129,7 @@ impl super::Provider for DeepSeekProvider {
let (prompt_rate, completion_rate) = self.pricing.iter() let (prompt_rate, completion_rate) = self.pricing.iter()
.find(|p| model.contains(&p.model)) .find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((0.14, 0.28)); // Default to DeepSeek V3 price if not found .unwrap_or((0.14, 0.28));
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) (prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
} }
@@ -186,118 +138,72 @@ impl super::Provider for DeepSeekProvider {
&self, &self,
request: UnifiedRequest, request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> { ) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; let mut body = serde_json::json!({
"model": request.model,
// Convert UnifiedRequest messages to OpenAI-compatible messages "messages": request.messages.iter().map(|m| {
let mut messages = Vec::with_capacity(request.messages.len()); serde_json::json!({
"role": m.role,
for msg in request.messages { "content": m.content.iter().map(|p| {
let mut parts = Vec::with_capacity(msg.content.len()); match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
for part in msg.content { crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }),
match part { }
crate::models::ContentPart::Text { text } => { }).collect::<Vec<_>>()
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { })
text, }).collect::<Vec<_>>(),
})); "stream": true,
}
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
}
}));
}
}
}
let message = match msg.role.as_str() {
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(request.model.clone());
builder.messages(messages);
builder.stream(true); // Enable streaming
// Add optional parameters
if let Some(temp) = request.temperature {
builder.temperature(temp as f32);
}
if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16);
}
// Execute streaming API call
let stream = self.client
.chat()
.create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Convert OpenAI stream to our stream format
let model = request.model.clone();
let stream = stream.map(move |chunk_result| {
match chunk_result {
Ok(chunk) => {
// Extract content from chunk
let content = chunk.choices.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
let finish_reason = chunk.choices.first()
.and_then(|choice| choice.finish_reason.clone())
.map(|reason| format!("{:?}", reason));
Ok(ProviderStreamChunk {
content,
finish_reason,
model: model.clone(),
})
}
Err(e) => Err(AppError::ProviderError(e.to_string())),
}
}); });
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body))
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let model = request.model.clone();
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
model: model.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream)) Ok(Box::pin(stream))
} }
} }

View File

@@ -206,6 +206,7 @@ impl super::Provider for GeminiProvider {
Ok(ProviderResponse { Ok(ProviderResponse {
content, content,
reasoning_content: None, // Gemini doesn't use this field name
prompt_tokens, prompt_tokens,
completion_tokens, completion_tokens,
total_tokens, total_tokens,
@@ -324,6 +325,7 @@ impl super::Provider for GeminiProvider {
yield ProviderStreamChunk { yield ProviderStreamChunk {
content, content,
reasoning_content: None,
finish_reason: None, // Will be set in the last chunk finish_reason: None, // Will be set in the last chunk
model: model.clone(), model: model.clone(),
}; };

View File

@@ -1,8 +1,7 @@
use async_trait::async_trait; use async_trait::async_trait;
use anyhow::Result; use anyhow::Result;
use async_openai::{Client, config::OpenAIConfig};
use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent};
use futures::stream::{BoxStream, StreamExt}; use futures::stream::{BoxStream, StreamExt};
use serde_json::Value;
use crate::{ use crate::{
models::UnifiedRequest, models::UnifiedRequest,
@@ -12,8 +11,9 @@ use crate::{
use super::{ProviderResponse, ProviderStreamChunk}; use super::{ProviderResponse, ProviderStreamChunk};
pub struct GrokProvider { pub struct GrokProvider {
client: Client<OpenAIConfig>, client: reqwest::Client,
_config: crate::config::GrokConfig, _config: crate::config::GrokConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>, pricing: Vec<crate::config::ModelPricing>,
} }
@@ -21,16 +21,10 @@ impl GrokProvider {
pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result<Self> { pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("grok")?; let api_key = app_config.get_api_key("grok")?;
// Grok is OpenAI-compatible
let openai_config = OpenAIConfig::default()
.with_api_key(api_key)
.with_api_base(&config.base_url);
let client = Client::with_config(openai_config);
Ok(Self { Ok(Self {
client, client: reqwest::Client::new(),
_config: config.clone(), _config: config.clone(),
api_key,
pricing: app_config.pricing.grok.clone(), pricing: app_config.pricing.grok.clone(),
}) })
} }
@@ -47,114 +41,70 @@ impl super::Provider for GrokProvider {
} }
fn supports_multimodal(&self) -> bool { fn supports_multimodal(&self) -> bool {
true // Grok supports vision models true
} }
async fn chat_completion( async fn chat_completion(
&self, &self,
request: UnifiedRequest, request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> { ) -> Result<ProviderResponse, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; let mut body = serde_json::json!({
"model": request.model,
// Convert UnifiedRequest messages to OpenAI messages "messages": request.messages.iter().map(|m| {
let mut messages = Vec::with_capacity(request.messages.len()); serde_json::json!({
"role": m.role,
for msg in request.messages { "content": m.content.iter().map(|p| {
let mut parts = Vec::with_capacity(msg.content.len()); match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
for part in msg.content { crate::models::ContentPart::Image(image_input) => {
match part { let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
crate::models::ContentPart::Text { text } => { serde_json::json!({
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { "type": "image_url",
text, "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})); })
}
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
} }
})); }
} }).collect::<Vec<_>>()
} })
} }).collect::<Vec<_>>(),
"stream": false,
let message = match msg.role.as_str() { });
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(request.model.clone());
builder.messages(messages);
// Add optional parameters
if let Some(temp) = request.temperature { if let Some(temp) = request.temperature {
builder.temperature(temp as f32); body["temperature"] = serde_json::json!(temp);
} }
if let Some(max_tokens) = request.max_tokens { if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16); body["max_tokens"] = serde_json::json!(max_tokens);
} }
// Execute API call let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
let response = self.client .header("Authorization", format!("Bearer {}", self.api_key))
.chat() .json(&body)
.create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?) .send()
.await .await
.map_err(|e| AppError::ProviderError(e.to_string()))?; .map_err(|e| AppError::ProviderError(e.to_string()))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("Grok API error: {}", error_text)));
}
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Extract content from response let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
let content = response let message = &choice["message"];
.choices
.first()
.and_then(|choice| choice.message.content.clone())
.unwrap_or_default();
// Extract token usage let content = message["content"].as_str().unwrap_or_default().to_string();
let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32; let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32;
let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32;
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
Ok(ProviderResponse { Ok(ProviderResponse {
content, content,
reasoning_content,
prompt_tokens, prompt_tokens,
completion_tokens, completion_tokens,
total_tokens, total_tokens,
@@ -174,11 +124,10 @@ impl super::Provider for GrokProvider {
} }
} }
// Fallback to static pricing if not in registry
let (prompt_rate, completion_rate) = self.pricing.iter() let (prompt_rate, completion_rate) = self.pricing.iter()
.find(|p| model.contains(&p.model)) .find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((5.0, 15.0)); // Grok-2 pricing is roughly this .unwrap_or((5.0, 15.0));
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) (prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
} }
@@ -187,117 +136,77 @@ impl super::Provider for GrokProvider {
&self, &self,
request: UnifiedRequest, request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> { ) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; let mut body = serde_json::json!({
"model": request.model,
// Convert UnifiedRequest messages to OpenAI messages "messages": request.messages.iter().map(|m| {
let mut messages = Vec::with_capacity(request.messages.len()); serde_json::json!({
"role": m.role,
for msg in request.messages { "content": m.content.iter().map(|p| {
let mut parts = Vec::with_capacity(msg.content.len()); match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
for part in msg.content { crate::models::ContentPart::Image(image_input) => {
match part { let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
crate::models::ContentPart::Text { text } => { serde_json::json!({
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { "type": "image_url",
text, "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})); })
}
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
} }
})); }
} }).collect::<Vec<_>>()
} })
} }).collect::<Vec<_>>(),
"stream": true,
let message = match msg.role.as_str() {
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(request.model.clone());
builder.messages(messages);
builder.stream(true); // Enable streaming
// Add optional parameters
if let Some(temp) = request.temperature {
builder.temperature(temp as f32);
}
if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16);
}
// Execute streaming API call
let stream = self.client
.chat()
.create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Convert OpenAI stream to our stream format
let model = request.model.clone();
let stream = stream.map(move |chunk_result| {
match chunk_result {
Ok(chunk) => {
// Extract content from chunk
let content = chunk.choices.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
let finish_reason = chunk.choices.first()
.and_then(|choice| choice.finish_reason.clone())
.map(|reason| format!("{:?}", reason));
Ok(ProviderStreamChunk {
content,
finish_reason,
model: model.clone(),
})
}
Err(e) => Err(AppError::ProviderError(e.to_string())),
}
}); });
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body))
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let model = request.model.clone();
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
model: model.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream)) Ok(Box::pin(stream))
} }

View File

@@ -44,6 +44,7 @@ pub trait Provider: Send + Sync {
pub struct ProviderResponse { pub struct ProviderResponse {
pub content: String, pub content: String,
pub reasoning_content: Option<String>,
pub prompt_tokens: u32, pub prompt_tokens: u32,
pub completion_tokens: u32, pub completion_tokens: u32,
pub total_tokens: u32, pub total_tokens: u32,
@@ -53,6 +54,7 @@ pub struct ProviderResponse {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ProviderStreamChunk { pub struct ProviderStreamChunk {
pub content: String, pub content: String,
pub reasoning_content: Option<String>,
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
pub model: String, pub model: String,
} }

View File

@@ -1,8 +1,7 @@
use async_trait::async_trait; use async_trait::async_trait;
use anyhow::Result; use anyhow::Result;
use async_openai::{Client, config::OpenAIConfig};
use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent};
use futures::stream::{BoxStream, StreamExt}; use futures::stream::{BoxStream, StreamExt};
use serde_json::Value;
use crate::{ use crate::{
models::UnifiedRequest, models::UnifiedRequest,
@@ -12,22 +11,15 @@ use crate::{
use super::{ProviderResponse, ProviderStreamChunk}; use super::{ProviderResponse, ProviderStreamChunk};
pub struct OllamaProvider { pub struct OllamaProvider {
client: Client<OpenAIConfig>, client: reqwest::Client,
_config: crate::config::OllamaConfig, _config: crate::config::OllamaConfig,
pricing: Vec<crate::config::ModelPricing>, pricing: Vec<crate::config::ModelPricing>,
} }
impl OllamaProvider { impl OllamaProvider {
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> { pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
// Ollama usually doesn't need an API key, use a dummy one
let openai_config = OpenAIConfig::default()
.with_api_key("ollama")
.with_api_base(&config.base_url);
let client = Client::with_config(openai_config);
Ok(Self { Ok(Self {
client, client: reqwest::Client::new(),
_config: config.clone(), _config: config.clone(),
pricing: app_config.pricing.ollama.clone(), pricing: app_config.pricing.ollama.clone(),
}) })
@@ -41,124 +33,75 @@ impl super::Provider for OllamaProvider {
} }
fn supports_model(&self, model: &str) -> bool { fn supports_model(&self, model: &str) -> bool {
// Check if model is in the list of configured Ollama models
self._config.models.iter().any(|m| m == model) || model.starts_with("ollama/") self._config.models.iter().any(|m| m == model) || model.starts_with("ollama/")
} }
fn supports_multimodal(&self) -> bool { fn supports_multimodal(&self) -> bool {
true // Many Ollama models support vision (e.g. llava, moondream) true
} }
async fn chat_completion( async fn chat_completion(
&self, &self,
request: UnifiedRequest, request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> { ) -> Result<ProviderResponse, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
// Strip "ollama/" prefix if present
let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string(); let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string();
// Convert UnifiedRequest messages to OpenAI messages let mut body = serde_json::json!({
let mut messages = Vec::with_capacity(request.messages.len()); "model": model,
"messages": request.messages.iter().map(|m| {
for msg in request.messages { serde_json::json!({
let mut parts = Vec::with_capacity(msg.content.len()); "role": m.role,
"content": m.content.iter().map(|p| {
for part in msg.content { match p {
match part { crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Text { text } => { crate::models::ContentPart::Image(image_input) => {
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
text, serde_json::json!({
})); "type": "image_url",
} "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
crate::models::ContentPart::Image(image_input) => { })
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
} }
})); }
} }).collect::<Vec<_>>()
} })
} }).collect::<Vec<_>>(),
"stream": false,
let message = match msg.role.as_str() { });
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("
")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("
")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(model);
builder.messages(messages);
// Add optional parameters
if let Some(temp) = request.temperature { if let Some(temp) = request.temperature {
builder.temperature(temp as f32); body["temperature"] = serde_json::json!(temp);
} }
if let Some(max_tokens) = request.max_tokens { if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16); body["max_tokens"] = serde_json::json!(max_tokens);
} }
// Execute API call let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
let response = self.client .json(&body)
.chat() .send()
.create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
.await .await
.map_err(|e| AppError::ProviderError(e.to_string()))?; .map_err(|e| AppError::ProviderError(e.to_string()))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("Ollama API error: {}", error_text)));
}
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Extract content from response let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
let content = response let message = &choice["message"];
.choices
.first()
.and_then(|choice| choice.message.content.clone())
.unwrap_or_default();
// Extract token usage let content = message["content"].as_str().unwrap_or_default().to_string();
let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32; let reasoning_content = message["reasoning_content"].as_str().or_else(|| message["thought"].as_str()).map(|s| s.to_string());
let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32;
let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32;
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
Ok(ProviderResponse { Ok(ProviderResponse {
content, content,
reasoning_content,
prompt_tokens, prompt_tokens,
completion_tokens, completion_tokens,
total_tokens, total_tokens,
@@ -178,7 +121,6 @@ impl super::Provider for OllamaProvider {
} }
} }
// Ollama is free by default
let (prompt_rate, completion_rate) = self.pricing.iter() let (prompt_rate, completion_rate) = self.pricing.iter()
.find(|p| model.contains(&p.model)) .find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
@@ -191,122 +133,72 @@ impl super::Provider for OllamaProvider {
&self, &self,
request: UnifiedRequest, request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> { ) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
// Strip "ollama/" prefix if present
let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string(); let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string();
// Convert UnifiedRequest messages to OpenAI messages let mut body = serde_json::json!({
let mut messages = Vec::with_capacity(request.messages.len()); "model": model,
"messages": request.messages.iter().map(|m| {
for msg in request.messages { serde_json::json!({
let mut parts = Vec::with_capacity(msg.content.len()); "role": m.role,
"content": m.content.iter().map(|p| {
for part in msg.content { match p {
match part { crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Text { text } => { crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }),
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { }
text, }).collect::<Vec<_>>()
})); })
} }).collect::<Vec<_>>(),
crate::models::ContentPart::Image(image_input) => { "stream": true,
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
}
}));
}
}
}
let message = match msg.role.as_str() {
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("
")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("
")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(model);
builder.messages(messages);
builder.stream(true); // Enable streaming
// Add optional parameters
if let Some(temp) = request.temperature {
builder.temperature(temp as f32);
}
if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16);
}
// Execute streaming API call
let stream = self.client
.chat()
.create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Convert OpenAI stream to our stream format
let model_name = request.model.clone();
let stream = stream.map(move |chunk_result| {
match chunk_result {
Ok(chunk) => {
// Extract content from chunk
let content = chunk.choices.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
let finish_reason = chunk.choices.first()
.and_then(|choice| choice.finish_reason.clone())
.map(|reason| format!("{:?}", reason));
Ok(ProviderStreamChunk {
content,
finish_reason,
model: model_name.clone(),
})
}
Err(e) => Err(AppError::ProviderError(e.to_string())),
}
}); });
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
.json(&body))
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let model_name = request.model.clone();
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"].as_str().or_else(|| delta["thought"].as_str()).map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
model: model_name.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream)) Ok(Box::pin(stream))
} }

View File

@@ -1,8 +1,7 @@
use async_trait::async_trait; use async_trait::async_trait;
use anyhow::Result; use anyhow::Result;
use async_openai::{Client, config::OpenAIConfig};
use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent};
use futures::stream::{BoxStream, StreamExt}; use futures::stream::{BoxStream, StreamExt};
use serde_json::Value;
use crate::{ use crate::{
models::UnifiedRequest, models::UnifiedRequest,
@@ -12,8 +11,9 @@ use crate::{
use super::{ProviderResponse, ProviderStreamChunk}; use super::{ProviderResponse, ProviderStreamChunk};
pub struct OpenAIProvider { pub struct OpenAIProvider {
client: Client<OpenAIConfig>, client: reqwest::Client,
_config: crate::config::OpenAIConfig, _config: crate::config::OpenAIConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>, pricing: Vec<crate::config::ModelPricing>,
} }
@@ -21,16 +21,10 @@ impl OpenAIProvider {
pub fn new(config: &crate::config::OpenAIConfig, app_config: &AppConfig) -> Result<Self> { pub fn new(config: &crate::config::OpenAIConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("openai")?; let api_key = app_config.get_api_key("openai")?;
// Create OpenAIConfig with api key and base url
let openai_config = OpenAIConfig::default()
.with_api_key(api_key)
.with_api_base(&config.base_url);
let client = Client::with_config(openai_config);
Ok(Self { Ok(Self {
client, client: reqwest::Client::new(),
_config: config.clone(), _config: config.clone(),
api_key,
pricing: app_config.pricing.openai.clone(), pricing: app_config.pricing.openai.clone(),
}) })
} }
@@ -47,114 +41,70 @@ impl super::Provider for OpenAIProvider {
} }
fn supports_multimodal(&self) -> bool { fn supports_multimodal(&self) -> bool {
true // OpenAI supports vision models true
} }
async fn chat_completion( async fn chat_completion(
&self, &self,
request: UnifiedRequest, request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> { ) -> Result<ProviderResponse, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; let mut body = serde_json::json!({
"model": request.model,
// Convert UnifiedRequest messages to OpenAI messages "messages": request.messages.iter().map(|m| {
let mut messages = Vec::with_capacity(request.messages.len()); serde_json::json!({
"role": m.role,
for msg in request.messages { "content": m.content.iter().map(|p| {
let mut parts = Vec::with_capacity(msg.content.len()); match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
for part in msg.content { crate::models::ContentPart::Image(image_input) => {
match part { let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
crate::models::ContentPart::Text { text } => { serde_json::json!({
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { "type": "image_url",
text, "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})); })
}
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
} }
})); }
} }).collect::<Vec<_>>()
} })
} }).collect::<Vec<_>>(),
"stream": false,
let message = match msg.role.as_str() { });
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(request.model.clone());
builder.messages(messages);
// Add optional parameters
if let Some(temp) = request.temperature { if let Some(temp) = request.temperature {
builder.temperature(temp as f32); body["temperature"] = serde_json::json!(temp);
} }
if let Some(max_tokens) = request.max_tokens { if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16); body["max_tokens"] = serde_json::json!(max_tokens);
} }
// Execute API call let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
let response = self.client .header("Authorization", format!("Bearer {}", self.api_key))
.chat() .json(&body)
.create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?) .send()
.await .await
.map_err(|e| AppError::ProviderError(e.to_string()))?; .map_err(|e| AppError::ProviderError(e.to_string()))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("OpenAI API error: {}", error_text)));
}
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Extract content from response let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
let content = response let message = &choice["message"];
.choices
.first()
.and_then(|choice| choice.message.content.clone())
.unwrap_or_default();
// Extract token usage let content = message["content"].as_str().unwrap_or_default().to_string();
let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32; let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32;
let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32;
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
Ok(ProviderResponse { Ok(ProviderResponse {
content, content,
reasoning_content,
prompt_tokens, prompt_tokens,
completion_tokens, completion_tokens,
total_tokens, total_tokens,
@@ -174,7 +124,6 @@ impl super::Provider for OpenAIProvider {
} }
} }
// Fallback to static pricing if not in registry
let (prompt_rate, completion_rate) = self.pricing.iter() let (prompt_rate, completion_rate) = self.pricing.iter()
.find(|p| model.contains(&p.model)) .find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
@@ -187,118 +136,78 @@ impl super::Provider for OpenAIProvider {
&self, &self,
request: UnifiedRequest, request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> { ) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; let mut body = serde_json::json!({
"model": request.model,
// Convert UnifiedRequest messages to OpenAI messages "messages": request.messages.iter().map(|m| {
let mut messages = Vec::with_capacity(request.messages.len()); serde_json::json!({
"role": m.role,
for msg in request.messages { "content": m.content.iter().map(|p| {
let mut parts = Vec::with_capacity(msg.content.len()); match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
for part in msg.content { crate::models::ContentPart::Image(image_input) => {
match part { let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
crate::models::ContentPart::Text { text } => { serde_json::json!({
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { "type": "image_url",
text, "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})); })
}
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
} }
})); }
} }).collect::<Vec<_>>()
} })
} }).collect::<Vec<_>>(),
"stream": true,
let message = match msg.role.as_str() {
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(request.model.clone());
builder.messages(messages);
builder.stream(true); // Enable streaming
// Add optional parameters
if let Some(temp) = request.temperature {
builder.temperature(temp as f32);
}
if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16);
}
// Execute streaming API call
let stream = self.client
.chat()
.create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Convert OpenAI stream to our stream format
let model = request.model.clone();
let stream = stream.map(move |chunk_result| {
match chunk_result {
Ok(chunk) => {
// Extract content from chunk
let content = chunk.choices.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
let finish_reason = chunk.choices.first()
.and_then(|choice| choice.finish_reason.clone())
.map(|reason| format!("{:?}", reason));
Ok(ProviderStreamChunk {
content,
finish_reason,
model: model.clone(),
})
}
Err(e) => Err(AppError::ProviderError(e.to_string())),
}
}); });
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body))
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let model = request.model.clone();
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
model: model.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream)) Ok(Box::pin(stream))
} }
} }

View File

@@ -102,6 +102,7 @@ async fn chat_completions(
delta: ChatStreamDelta { delta: ChatStreamDelta {
role: None, role: None,
content: Some(chunk.content), content: Some(chunk.content),
reasoning_content: chunk.reasoning_content,
}, },
finish_reason: chunk.finish_reason, finish_reason: chunk.finish_reason,
}], }],
@@ -177,6 +178,7 @@ async fn chat_completions(
content: crate::models::MessageContent::Text { content: crate::models::MessageContent::Text {
content: response.content content: response.content
}, },
reasoning_content: response.reasoning_content,
}, },
finish_reason: Some("stop".to_string()), finish_reason: Some("stop".to_string()),
}], }],

View File

@@ -16,6 +16,7 @@ pub struct AggregatingStream<S> {
prompt_tokens: u32, prompt_tokens: u32,
has_images: bool, has_images: bool,
accumulated_content: String, accumulated_content: String,
accumulated_reasoning: String,
logger: Arc<RequestLogger>, logger: Arc<RequestLogger>,
client_manager: Arc<ClientManager>, client_manager: Arc<ClientManager>,
model_registry: Arc<crate::models::registry::ModelRegistry>, model_registry: Arc<crate::models::registry::ModelRegistry>,
@@ -46,6 +47,7 @@ where
prompt_tokens, prompt_tokens,
has_images, has_images,
accumulated_content: String::new(), accumulated_content: String::new(),
accumulated_reasoning: String::new(),
logger, logger,
client_manager, client_manager,
model_registry, model_registry,
@@ -71,8 +73,15 @@ where
let has_images = self.has_images; let has_images = self.has_images;
let registry = self.model_registry.clone(); let registry = self.model_registry.clone();
// Estimate completion tokens // Estimate completion tokens (including reasoning if present)
let completion_tokens = estimate_completion_tokens(&self.accumulated_content, &model); let content_tokens = estimate_completion_tokens(&self.accumulated_content, &model);
let reasoning_tokens = if !self.accumulated_reasoning.is_empty() {
estimate_completion_tokens(&self.accumulated_reasoning, &model)
} else {
0
};
let completion_tokens = content_tokens + reasoning_tokens;
let total_tokens = prompt_tokens + completion_tokens; let total_tokens = prompt_tokens + completion_tokens;
let cost = provider.calculate_cost(&model, prompt_tokens, completion_tokens, &registry); let cost = provider.calculate_cost(&model, prompt_tokens, completion_tokens, &registry);
@@ -116,6 +125,9 @@ where
match &result { match &result {
Poll::Ready(Some(Ok(chunk))) => { Poll::Ready(Some(Ok(chunk))) => {
self.accumulated_content.push_str(&chunk.content); self.accumulated_content.push_str(&chunk.content);
if let Some(reasoning) = &chunk.reasoning_content {
self.accumulated_reasoning.push_str(reasoning);
}
} }
Poll::Ready(Some(Err(_))) => { Poll::Ready(Some(Err(_))) => {
// If there's an error, we might still want to log what we got so far? // If there's an error, we might still want to log what we got so far?