feat: add support for reasoning (thinking) models across all providers

- Refactored OpenAI, DeepSeek, Grok, and Ollama to manual JSON parsing to capture 'reasoning_content' and 'thought' fields.
- Implemented real-time streaming of reasoning blocks.
- Added token aggregation and cost tracking for reasoning tokens.
- Updated unified models to include 'reasoning_content' in API responses.
This commit is contained in:
2026-02-26 13:50:22 -05:00
parent 6143b88eac
commit 70fef80051
9 changed files with 502 additions and 860 deletions

View File

@@ -22,6 +22,8 @@ pub struct ChatMessage {
pub role: String, // "system", "user", "assistant"
#[serde(flatten)]
pub content: MessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -29,6 +31,7 @@ pub struct ChatMessage {
pub enum MessageContent {
Text { content: String },
Parts { content: Vec<ContentPartValue> },
None, // Handle cases where content might be null but reasoning is present
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -91,6 +94,8 @@ pub struct ChatStreamChoice {
pub struct ChatStreamDelta {
pub role: Option<String>,
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
}
// ========== Unified Request Format (for internal use) ==========
@@ -217,6 +222,9 @@ impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
(unified_content, has_images_in_msg)
}
MessageContent::None => {
(vec![], false)
}
};
UnifiedMessage {

View File

@@ -1,8 +1,7 @@
use async_trait::async_trait;
use anyhow::Result;
use async_openai::{Client, config::OpenAIConfig};
use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent};
use futures::stream::{BoxStream, StreamExt};
use serde_json::Value;
use crate::{
models::UnifiedRequest,
@@ -12,8 +11,9 @@ use crate::{
use super::{ProviderResponse, ProviderStreamChunk};
pub struct DeepSeekProvider {
client: Client<OpenAIConfig>, // DeepSeek uses OpenAI-compatible API
_config: crate::config::DeepSeekConfig,
client: reqwest::Client,
config: crate::config::DeepSeekConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>,
}
@@ -21,16 +21,10 @@ impl DeepSeekProvider {
pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("deepseek")?;
// Create OpenAIConfig with api key and base url
let openai_config = OpenAIConfig::default()
.with_api_key(api_key)
.with_api_base(&config.base_url);
let client = Client::with_config(openai_config);
Ok(Self {
client,
_config: config.clone(),
client: reqwest::Client::new(),
config: config.clone(),
api_key,
pricing: app_config.pricing.deepseek.clone(),
})
}
@@ -47,114 +41,72 @@ impl super::Provider for DeepSeekProvider {
}
fn supports_multimodal(&self) -> bool {
false // DeepSeek doesn't support general vision (only OCR)
false
}
async fn chat_completion(
&self,
request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
// Convert UnifiedRequest messages to OpenAI-compatible messages
let mut messages = Vec::with_capacity(request.messages.len());
for msg in request.messages {
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
crate::models::ContentPart::Text { text } => {
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
text,
}));
}
// Build the OpenAI-compatible body
let mut body = serde_json::json!({
"model": request.model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
// DeepSeek currently doesn't support images in the same way, but we'll try to be standard
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})
}
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": false,
});
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
}
}));
}
}
}
let message = match msg.role.as_str() {
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(request.model.clone());
builder.messages(messages);
// Add optional parameters
if let Some(temp) = request.temperature {
builder.temperature(temp as f32);
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16);
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Execute API call
let response = self.client
.chat()
.create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
let response = self.client.post(format!("{}/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Extract content from response
let content = response
.choices
.first()
.and_then(|choice| choice.message.content.clone())
.unwrap_or_default();
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("DeepSeek API error: {}", error_text)));
}
// Extract token usage
let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32;
let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32;
let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32;
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
let message = &choice["message"];
let content = message["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
Ok(ProviderResponse {
content,
reasoning_content,
prompt_tokens,
completion_tokens,
total_tokens,
@@ -177,7 +129,7 @@ impl super::Provider for DeepSeekProvider {
let (prompt_rate, completion_rate) = self.pricing.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((0.14, 0.28)); // Default to DeepSeek V3 price if not found
.unwrap_or((0.14, 0.28));
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
}
@@ -186,117 +138,71 @@ impl super::Provider for DeepSeekProvider {
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
let mut body = serde_json::json!({
"model": request.model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }),
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": true,
});
// Convert UnifiedRequest messages to OpenAI-compatible messages
let mut messages = Vec::with_capacity(request.messages.len());
for msg in request.messages {
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
crate::models::ContentPart::Text { text } => {
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
text,
}));
}
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
}
}));
}
}
}
let message = match msg.role.as_str() {
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(request.model.clone());
builder.messages(messages);
builder.stream(true); // Enable streaming
// Add optional parameters
if let Some(temp) = request.temperature {
builder.temperature(temp as f32);
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16);
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Execute streaming API call
let stream = self.client
.chat()
.create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body))
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
// Convert OpenAI stream to our stream format
let model = request.model.clone();
let stream = stream.map(move |chunk_result| {
match chunk_result {
Ok(chunk) => {
// Extract content from chunk
let content = chunk.choices.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
let finish_reason = chunk.choices.first()
.and_then(|choice| choice.finish_reason.clone())
.map(|reason| format!("{:?}", reason));
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
Ok(ProviderStreamChunk {
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
model: model.clone(),
})
};
}
Err(e) => Err(AppError::ProviderError(e.to_string())),
}
});
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream))
}

View File

@@ -206,6 +206,7 @@ impl super::Provider for GeminiProvider {
Ok(ProviderResponse {
content,
reasoning_content: None, // Gemini doesn't use this field name
prompt_tokens,
completion_tokens,
total_tokens,
@@ -324,6 +325,7 @@ impl super::Provider for GeminiProvider {
yield ProviderStreamChunk {
content,
reasoning_content: None,
finish_reason: None, // Will be set in the last chunk
model: model.clone(),
};

View File

@@ -1,8 +1,7 @@
use async_trait::async_trait;
use anyhow::Result;
use async_openai::{Client, config::OpenAIConfig};
use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent};
use futures::stream::{BoxStream, StreamExt};
use serde_json::Value;
use crate::{
models::UnifiedRequest,
@@ -12,8 +11,9 @@ use crate::{
use super::{ProviderResponse, ProviderStreamChunk};
pub struct GrokProvider {
client: Client<OpenAIConfig>,
client: reqwest::Client,
_config: crate::config::GrokConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>,
}
@@ -21,16 +21,10 @@ impl GrokProvider {
pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("grok")?;
// Grok is OpenAI-compatible
let openai_config = OpenAIConfig::default()
.with_api_key(api_key)
.with_api_base(&config.base_url);
let client = Client::with_config(openai_config);
Ok(Self {
client,
client: reqwest::Client::new(),
_config: config.clone(),
api_key,
pricing: app_config.pricing.grok.clone(),
})
}
@@ -47,114 +41,70 @@ impl super::Provider for GrokProvider {
}
fn supports_multimodal(&self) -> bool {
true // Grok supports vision models
true
}
async fn chat_completion(
&self,
request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
// Convert UnifiedRequest messages to OpenAI messages
let mut messages = Vec::with_capacity(request.messages.len());
for msg in request.messages {
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
crate::models::ContentPart::Text { text } => {
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
text,
}));
}
let mut body = serde_json::json!({
"model": request.model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})
}
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": false,
});
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
}
}));
}
}
}
let message = match msg.role.as_str() {
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(request.model.clone());
builder.messages(messages);
// Add optional parameters
if let Some(temp) = request.temperature {
builder.temperature(temp as f32);
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16);
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Execute API call
let response = self.client
.chat()
.create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Extract content from response
let content = response
.choices
.first()
.and_then(|choice| choice.message.content.clone())
.unwrap_or_default();
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("Grok API error: {}", error_text)));
}
// Extract token usage
let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32;
let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32;
let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32;
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
let message = &choice["message"];
let content = message["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
Ok(ProviderResponse {
content,
reasoning_content,
prompt_tokens,
completion_tokens,
total_tokens,
@@ -174,11 +124,10 @@ impl super::Provider for GrokProvider {
}
}
// Fallback to static pricing if not in registry
let (prompt_rate, completion_rate) = self.pricing.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((5.0, 15.0)); // Grok-2 pricing is roughly this
.unwrap_or((5.0, 15.0));
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
}
@@ -187,118 +136,78 @@ impl super::Provider for GrokProvider {
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
// Convert UnifiedRequest messages to OpenAI messages
let mut messages = Vec::with_capacity(request.messages.len());
for msg in request.messages {
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
crate::models::ContentPart::Text { text } => {
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
text,
}));
}
let mut body = serde_json::json!({
"model": request.model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
}
}));
}
}
}
let message = match msg.role.as_str() {
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(request.model.clone());
builder.messages(messages);
builder.stream(true); // Enable streaming
// Add optional parameters
if let Some(temp) = request.temperature {
builder.temperature(temp as f32);
}
if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16);
}
// Execute streaming API call
let stream = self.client
.chat()
.create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Convert OpenAI stream to our stream format
let model = request.model.clone();
let stream = stream.map(move |chunk_result| {
match chunk_result {
Ok(chunk) => {
// Extract content from chunk
let content = chunk.choices.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
let finish_reason = chunk.choices.first()
.and_then(|choice| choice.finish_reason.clone())
.map(|reason| format!("{:?}", reason));
Ok(ProviderStreamChunk {
content,
finish_reason,
model: model.clone(),
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})
}
Err(e) => Err(AppError::ProviderError(e.to_string())),
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": true,
});
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body))
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let model = request.model.clone();
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
model: model.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream))
}
}

View File

@@ -44,6 +44,7 @@ pub trait Provider: Send + Sync {
pub struct ProviderResponse {
pub content: String,
pub reasoning_content: Option<String>,
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
@@ -53,6 +54,7 @@ pub struct ProviderResponse {
#[derive(Debug, Clone)]
pub struct ProviderStreamChunk {
pub content: String,
pub reasoning_content: Option<String>,
pub finish_reason: Option<String>,
pub model: String,
}

View File

@@ -1,8 +1,7 @@
use async_trait::async_trait;
use anyhow::Result;
use async_openai::{Client, config::OpenAIConfig};
use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent};
use futures::stream::{BoxStream, StreamExt};
use serde_json::Value;
use crate::{
models::UnifiedRequest,
@@ -12,22 +11,15 @@ use crate::{
use super::{ProviderResponse, ProviderStreamChunk};
pub struct OllamaProvider {
client: Client<OpenAIConfig>,
client: reqwest::Client,
_config: crate::config::OllamaConfig,
pricing: Vec<crate::config::ModelPricing>,
}
impl OllamaProvider {
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
// Ollama usually doesn't need an API key, use a dummy one
let openai_config = OpenAIConfig::default()
.with_api_key("ollama")
.with_api_base(&config.base_url);
let client = Client::with_config(openai_config);
Ok(Self {
client,
client: reqwest::Client::new(),
_config: config.clone(),
pricing: app_config.pricing.ollama.clone(),
})
@@ -41,124 +33,75 @@ impl super::Provider for OllamaProvider {
}
fn supports_model(&self, model: &str) -> bool {
// Check if model is in the list of configured Ollama models
self._config.models.iter().any(|m| m == model) || model.starts_with("ollama/")
}
fn supports_multimodal(&self) -> bool {
true // Many Ollama models support vision (e.g. llava, moondream)
true
}
async fn chat_completion(
&self,
request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
// Strip "ollama/" prefix if present
let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string();
// Convert UnifiedRequest messages to OpenAI messages
let mut messages = Vec::with_capacity(request.messages.len());
for msg in request.messages {
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
crate::models::ContentPart::Text { text } => {
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
text,
}));
}
let mut body = serde_json::json!({
"model": model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})
}
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": false,
});
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
}
}));
}
}
}
let message = match msg.role.as_str() {
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("
")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("
")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(model);
builder.messages(messages);
// Add optional parameters
if let Some(temp) = request.temperature {
builder.temperature(temp as f32);
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16);
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Execute API call
let response = self.client
.chat()
.create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
.json(&body)
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Extract content from response
let content = response
.choices
.first()
.and_then(|choice| choice.message.content.clone())
.unwrap_or_default();
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("Ollama API error: {}", error_text)));
}
// Extract token usage
let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32;
let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32;
let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32;
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
let message = &choice["message"];
let content = message["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = message["reasoning_content"].as_str().or_else(|| message["thought"].as_str()).map(|s| s.to_string());
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
Ok(ProviderResponse {
content,
reasoning_content,
prompt_tokens,
completion_tokens,
total_tokens,
@@ -178,7 +121,6 @@ impl super::Provider for OllamaProvider {
}
}
// Ollama is free by default
let (prompt_rate, completion_rate) = self.pricing.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
@@ -191,122 +133,72 @@ impl super::Provider for OllamaProvider {
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
// Strip "ollama/" prefix if present
let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string();
// Convert UnifiedRequest messages to OpenAI messages
let mut messages = Vec::with_capacity(request.messages.len());
let mut body = serde_json::json!({
"model": model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }),
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": true,
});
for msg in request.messages {
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
crate::models::ContentPart::Text { text } => {
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
text,
}));
}
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
}
}));
}
}
}
let message = match msg.role.as_str() {
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("
")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("
")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(model);
builder.messages(messages);
builder.stream(true); // Enable streaming
// Add optional parameters
if let Some(temp) = request.temperature {
builder.temperature(temp as f32);
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16);
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Execute streaming API call
let stream = self.client
.chat()
.create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
.json(&body))
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
// Convert OpenAI stream to our stream format
let model_name = request.model.clone();
let stream = stream.map(move |chunk_result| {
match chunk_result {
Ok(chunk) => {
// Extract content from chunk
let content = chunk.choices.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
let finish_reason = chunk.choices.first()
.and_then(|choice| choice.finish_reason.clone())
.map(|reason| format!("{:?}", reason));
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
Ok(ProviderStreamChunk {
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"].as_str().or_else(|| delta["thought"].as_str()).map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
model: model_name.clone(),
})
};
}
Err(e) => Err(AppError::ProviderError(e.to_string())),
}
});
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream))
}

View File

@@ -1,8 +1,7 @@
use async_trait::async_trait;
use anyhow::Result;
use async_openai::{Client, config::OpenAIConfig};
use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent};
use futures::stream::{BoxStream, StreamExt};
use serde_json::Value;
use crate::{
models::UnifiedRequest,
@@ -12,8 +11,9 @@ use crate::{
use super::{ProviderResponse, ProviderStreamChunk};
pub struct OpenAIProvider {
client: Client<OpenAIConfig>,
client: reqwest::Client,
_config: crate::config::OpenAIConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>,
}
@@ -21,16 +21,10 @@ impl OpenAIProvider {
pub fn new(config: &crate::config::OpenAIConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("openai")?;
// Create OpenAIConfig with api key and base url
let openai_config = OpenAIConfig::default()
.with_api_key(api_key)
.with_api_base(&config.base_url);
let client = Client::with_config(openai_config);
Ok(Self {
client,
client: reqwest::Client::new(),
_config: config.clone(),
api_key,
pricing: app_config.pricing.openai.clone(),
})
}
@@ -47,114 +41,70 @@ impl super::Provider for OpenAIProvider {
}
fn supports_multimodal(&self) -> bool {
true // OpenAI supports vision models
true
}
async fn chat_completion(
&self,
request: UnifiedRequest,
) -> Result<ProviderResponse, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
// Convert UnifiedRequest messages to OpenAI messages
let mut messages = Vec::with_capacity(request.messages.len());
for msg in request.messages {
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
crate::models::ContentPart::Text { text } => {
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
text,
}));
}
let mut body = serde_json::json!({
"model": request.model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})
}
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": false,
});
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
}
}));
}
}
}
let message = match msg.role.as_str() {
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(request.model.clone());
builder.messages(messages);
// Add optional parameters
if let Some(temp) = request.temperature {
builder.temperature(temp as f32);
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16);
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Execute API call
let response = self.client
.chat()
.create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Extract content from response
let content = response
.choices
.first()
.and_then(|choice| choice.message.content.clone())
.unwrap_or_default();
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("OpenAI API error: {}", error_text)));
}
// Extract token usage
let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32;
let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32;
let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32;
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
let message = &choice["message"];
let content = message["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
Ok(ProviderResponse {
content,
reasoning_content,
prompt_tokens,
completion_tokens,
total_tokens,
@@ -174,7 +124,6 @@ impl super::Provider for OpenAIProvider {
}
}
// Fallback to static pricing if not in registry
let (prompt_rate, completion_rate) = self.pricing.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
@@ -187,118 +136,78 @@ impl super::Provider for OpenAIProvider {
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
// Convert UnifiedRequest messages to OpenAI messages
let mut messages = Vec::with_capacity(request.messages.len());
for msg in request.messages {
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
crate::models::ContentPart::Text { text } => {
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
text,
}));
}
let mut body = serde_json::json!({
"model": request.model,
"messages": request.messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content.iter().map(|p| {
match p {
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input.to_base64().await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
image_url: ImageUrl {
url: data_url,
detail: Some(ImageDetail::Auto),
}
}));
}
}
}
let message = match msg.role.as_str() {
"system" => ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
),
name: None,
}
),
"assistant" => ChatCompletionRequestMessage::Assistant(
ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
)),
name: None,
tool_calls: None,
refusal: None,
audio: None,
#[allow(deprecated)]
function_call: None,
}
),
_ => ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Array(parts),
name: None,
}
),
};
messages.push(message);
}
if messages.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
}
// Build request using builder pattern
let mut builder = CreateChatCompletionRequestArgs::default();
builder.model(request.model.clone());
builder.messages(messages);
builder.stream(true); // Enable streaming
// Add optional parameters
if let Some(temp) = request.temperature {
builder.temperature(temp as f32);
}
if let Some(max_tokens) = request.max_tokens {
builder.max_tokens(max_tokens as u16);
}
// Execute streaming API call
let stream = self.client
.chat()
.create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Convert OpenAI stream to our stream format
let model = request.model.clone();
let stream = stream.map(move |chunk_result| {
match chunk_result {
Ok(chunk) => {
// Extract content from chunk
let content = chunk.choices.first()
.and_then(|choice| choice.delta.content.clone())
.unwrap_or_default();
let finish_reason = chunk.choices.first()
.and_then(|choice| choice.finish_reason.clone())
.map(|reason| format!("{:?}", reason));
Ok(ProviderStreamChunk {
content,
finish_reason,
model: model.clone(),
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
})
}
Err(e) => Err(AppError::ProviderError(e.to_string())),
}
}).collect::<Vec<_>>()
})
}).collect::<Vec<_>>(),
"stream": true,
});
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
// Create eventsource stream
use reqwest_eventsource::{EventSource, Event};
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body))
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let model = request.model.clone();
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
model: model.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream))
}
}

View File

@@ -102,6 +102,7 @@ async fn chat_completions(
delta: ChatStreamDelta {
role: None,
content: Some(chunk.content),
reasoning_content: chunk.reasoning_content,
},
finish_reason: chunk.finish_reason,
}],
@@ -177,6 +178,7 @@ async fn chat_completions(
content: crate::models::MessageContent::Text {
content: response.content
},
reasoning_content: response.reasoning_content,
},
finish_reason: Some("stop".to_string()),
}],

View File

@@ -16,6 +16,7 @@ pub struct AggregatingStream<S> {
prompt_tokens: u32,
has_images: bool,
accumulated_content: String,
accumulated_reasoning: String,
logger: Arc<RequestLogger>,
client_manager: Arc<ClientManager>,
model_registry: Arc<crate::models::registry::ModelRegistry>,
@@ -46,6 +47,7 @@ where
prompt_tokens,
has_images,
accumulated_content: String::new(),
accumulated_reasoning: String::new(),
logger,
client_manager,
model_registry,
@@ -71,8 +73,15 @@ where
let has_images = self.has_images;
let registry = self.model_registry.clone();
// Estimate completion tokens
let completion_tokens = estimate_completion_tokens(&self.accumulated_content, &model);
// Estimate completion tokens (including reasoning if present)
let content_tokens = estimate_completion_tokens(&self.accumulated_content, &model);
let reasoning_tokens = if !self.accumulated_reasoning.is_empty() {
estimate_completion_tokens(&self.accumulated_reasoning, &model)
} else {
0
};
let completion_tokens = content_tokens + reasoning_tokens;
let total_tokens = prompt_tokens + completion_tokens;
let cost = provider.calculate_cost(&model, prompt_tokens, completion_tokens, &registry);
@@ -116,6 +125,9 @@ where
match &result {
Poll::Ready(Some(Ok(chunk))) => {
self.accumulated_content.push_str(&chunk.content);
if let Some(reasoning) = &chunk.reasoning_content {
self.accumulated_reasoning.push_str(reasoning);
}
}
Poll::Ready(Some(Err(_))) => {
// If there's an error, we might still want to log what we got so far?