From 70fef80051a5d6368d234bc7165fdbde9af1db25 Mon Sep 17 00:00:00 2001 From: hobokenchicken Date: Thu, 26 Feb 2026 13:50:22 -0500 Subject: [PATCH] feat: add support for reasoning (thinking) models across all providers - Refactored OpenAI, DeepSeek, Grok, and Ollama to manual JSON parsing to capture 'reasoning_content' and 'thought' fields. - Implemented real-time streaming of reasoning blocks. - Added token aggregation and cost tracking for reasoning tokens. - Updated unified models to include 'reasoning_content' in API responses. --- src/models/mod.rs | 8 + src/providers/deepseek.rs | 336 ++++++++++++++------------------------ src/providers/gemini.rs | 2 + src/providers/grok.rs | 333 ++++++++++++++----------------------- src/providers/mod.rs | 2 + src/providers/ollama.rs | 330 +++++++++++++------------------------ src/providers/openai.rs | 333 ++++++++++++++----------------------- src/server/mod.rs | 2 + src/utils/streaming.rs | 16 +- 9 files changed, 502 insertions(+), 860 deletions(-) diff --git a/src/models/mod.rs b/src/models/mod.rs index 40adb3e7..01d97fcf 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -22,6 +22,8 @@ pub struct ChatMessage { pub role: String, // "system", "user", "assistant" #[serde(flatten)] pub content: MessageContent, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -29,6 +31,7 @@ pub struct ChatMessage { pub enum MessageContent { Text { content: String }, Parts { content: Vec }, + None, // Handle cases where content might be null but reasoning is present } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -91,6 +94,8 @@ pub struct ChatStreamChoice { pub struct ChatStreamDelta { pub role: Option, pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, } // ========== Unified Request Format (for internal use) ========== @@ -217,6 +222,9 @@ impl TryFrom for UnifiedRequest { (unified_content, has_images_in_msg) } + MessageContent::None => { + (vec![], false) + } }; UnifiedMessage { diff --git a/src/providers/deepseek.rs b/src/providers/deepseek.rs index 92135e93..63d6b91f 100644 --- a/src/providers/deepseek.rs +++ b/src/providers/deepseek.rs @@ -1,8 +1,7 @@ use async_trait::async_trait; use anyhow::Result; -use async_openai::{Client, config::OpenAIConfig}; -use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent}; use futures::stream::{BoxStream, StreamExt}; +use serde_json::Value; use crate::{ models::UnifiedRequest, @@ -12,8 +11,9 @@ use crate::{ use super::{ProviderResponse, ProviderStreamChunk}; pub struct DeepSeekProvider { - client: Client, // DeepSeek uses OpenAI-compatible API - _config: crate::config::DeepSeekConfig, + client: reqwest::Client, + config: crate::config::DeepSeekConfig, + api_key: String, pricing: Vec, } @@ -21,16 +21,10 @@ impl DeepSeekProvider { pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result { let api_key = app_config.get_api_key("deepseek")?; - // Create OpenAIConfig with api key and base url - let openai_config = OpenAIConfig::default() - .with_api_key(api_key) - .with_api_base(&config.base_url); - - let client = Client::with_config(openai_config); - Ok(Self { - client, - _config: config.clone(), + client: reqwest::Client::new(), + config: config.clone(), + api_key, pricing: app_config.pricing.deepseek.clone(), }) } @@ -47,114 +41,72 @@ impl super::Provider for DeepSeekProvider { } fn supports_multimodal(&self) -> bool { - false // DeepSeek doesn't support general vision (only OCR) + false } async fn chat_completion( &self, request: UnifiedRequest, ) -> Result { - use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; - - // Convert UnifiedRequest messages to OpenAI-compatible messages - let mut messages = Vec::with_capacity(request.messages.len()); - - for msg in request.messages { - let mut parts = Vec::with_capacity(msg.content.len()); - - for part in msg.content { - match part { - crate::models::ContentPart::Text { text } => { - parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { - text, - })); - } - crate::models::ContentPart::Image(image_input) => { - let (base64_data, mime_type) = image_input.to_base64().await - .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; - let data_url = format!("data:{};base64,{}", mime_type, base64_data); - - parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage { - image_url: ImageUrl { - url: data_url, - detail: Some(ImageDetail::Auto), + // Build the OpenAI-compatible body + let mut body = serde_json::json!({ + "model": request.model, + "messages": request.messages.iter().map(|m| { + serde_json::json!({ + "role": m.role, + "content": m.content.iter().map(|p| { + match p { + crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }), + crate::models::ContentPart::Image(image_input) => { + // DeepSeek currently doesn't support images in the same way, but we'll try to be standard + let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default(); + serde_json::json!({ + "type": "image_url", + "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) } + }) } - })); - } - } - } - - let message = match msg.role.as_str() { - "system" => ChatCompletionRequestMessage::System( - ChatCompletionRequestSystemMessage { - content: ChatCompletionRequestSystemMessageContent::Text( - parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") - ), - name: None, - } - ), - "assistant" => ChatCompletionRequestMessage::Assistant( - ChatCompletionRequestAssistantMessage { - content: Some(ChatCompletionRequestAssistantMessageContent::Text( - parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") - )), - name: None, - tool_calls: None, - refusal: None, - audio: None, - #[allow(deprecated)] - function_call: None, - } - ), - _ => ChatCompletionRequestMessage::User( - ChatCompletionRequestUserMessage { - content: ChatCompletionRequestUserMessageContent::Array(parts), - name: None, - } - ), - }; - messages.push(message); - } - - if messages.is_empty() { - return Err(AppError::ProviderError("No valid text messages to send".to_string())); - } - - // Build request using builder pattern - let mut builder = CreateChatCompletionRequestArgs::default(); - builder.model(request.model.clone()); - builder.messages(messages); - - // Add optional parameters + } + }).collect::>() + }) + }).collect::>(), + "stream": false, + }); + if let Some(temp) = request.temperature { - builder.temperature(temp as f32); + body["temperature"] = serde_json::json!(temp); } - if let Some(max_tokens) = request.max_tokens { - builder.max_tokens(max_tokens as u16); + body["max_tokens"] = serde_json::json!(max_tokens); } - - // Execute API call - let response = self.client - .chat() - .create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?) + + let response = self.client.post(format!("{}/chat/completions", self.config.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&body) + .send() .await .map_err(|e| AppError::ProviderError(e.to_string()))?; + + if !response.status().is_success() { + let error_text = response.text().await.unwrap_or_default(); + return Err(AppError::ProviderError(format!("DeepSeek API error: {}", error_text))); + } + + let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?; - // Extract content from response - let content = response - .choices - .first() - .and_then(|choice| choice.message.content.clone()) - .unwrap_or_default(); + let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?; + let message = &choice["message"]; - // Extract token usage - let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32; - let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32; - let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32; + let content = message["content"].as_str().unwrap_or_default().to_string(); + let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string()); + let usage = &resp_json["usage"]; + let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32; + let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32; + let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32; + Ok(ProviderResponse { content, + reasoning_content, prompt_tokens, completion_tokens, total_tokens, @@ -177,7 +129,7 @@ impl super::Provider for DeepSeekProvider { let (prompt_rate, completion_rate) = self.pricing.iter() .find(|p| model.contains(&p.model)) .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) - .unwrap_or((0.14, 0.28)); // Default to DeepSeek V3 price if not found + .unwrap_or((0.14, 0.28)); (prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) } @@ -186,118 +138,72 @@ impl super::Provider for DeepSeekProvider { &self, request: UnifiedRequest, ) -> Result>, AppError> { - use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; - - // Convert UnifiedRequest messages to OpenAI-compatible messages - let mut messages = Vec::with_capacity(request.messages.len()); - - for msg in request.messages { - let mut parts = Vec::with_capacity(msg.content.len()); - - for part in msg.content { - match part { - crate::models::ContentPart::Text { text } => { - parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { - text, - })); - } - crate::models::ContentPart::Image(image_input) => { - let (base64_data, mime_type) = image_input.to_base64().await - .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; - let data_url = format!("data:{};base64,{}", mime_type, base64_data); - - parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage { - image_url: ImageUrl { - url: data_url, - detail: Some(ImageDetail::Auto), - } - })); - } - } - } - - let message = match msg.role.as_str() { - "system" => ChatCompletionRequestMessage::System( - ChatCompletionRequestSystemMessage { - content: ChatCompletionRequestSystemMessageContent::Text( - parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") - ), - name: None, - } - ), - "assistant" => ChatCompletionRequestMessage::Assistant( - ChatCompletionRequestAssistantMessage { - content: Some(ChatCompletionRequestAssistantMessageContent::Text( - parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") - )), - name: None, - tool_calls: None, - refusal: None, - audio: None, - #[allow(deprecated)] - function_call: None, - } - ), - _ => ChatCompletionRequestMessage::User( - ChatCompletionRequestUserMessage { - content: ChatCompletionRequestUserMessageContent::Array(parts), - name: None, - } - ), - }; - messages.push(message); - } - - if messages.is_empty() { - return Err(AppError::ProviderError("No valid text messages to send".to_string())); - } - - // Build request using builder pattern - let mut builder = CreateChatCompletionRequestArgs::default(); - builder.model(request.model.clone()); - builder.messages(messages); - builder.stream(true); // Enable streaming - - // Add optional parameters - if let Some(temp) = request.temperature { - builder.temperature(temp as f32); - } - - if let Some(max_tokens) = request.max_tokens { - builder.max_tokens(max_tokens as u16); - } - - // Execute streaming API call - let stream = self.client - .chat() - .create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?) - .await - .map_err(|e| AppError::ProviderError(e.to_string()))?; - - // Convert OpenAI stream to our stream format - let model = request.model.clone(); - let stream = stream.map(move |chunk_result| { - match chunk_result { - Ok(chunk) => { - // Extract content from chunk - let content = chunk.choices.first() - .and_then(|choice| choice.delta.content.clone()) - .unwrap_or_default(); - - let finish_reason = chunk.choices.first() - .and_then(|choice| choice.finish_reason.clone()) - .map(|reason| format!("{:?}", reason)); - - Ok(ProviderStreamChunk { - content, - finish_reason, - model: model.clone(), - }) - } - Err(e) => Err(AppError::ProviderError(e.to_string())), - } + let mut body = serde_json::json!({ + "model": request.model, + "messages": request.messages.iter().map(|m| { + serde_json::json!({ + "role": m.role, + "content": m.content.iter().map(|p| { + match p { + crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }), + crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }), + } + }).collect::>() + }) + }).collect::>(), + "stream": true, }); + + if let Some(temp) = request.temperature { + body["temperature"] = serde_json::json!(temp); + } + if let Some(max_tokens) = request.max_tokens { + body["max_tokens"] = serde_json::json!(max_tokens); + } + + // Create eventsource stream + use reqwest_eventsource::{EventSource, Event}; + let es = EventSource::new(self.client.post(format!("{}/chat/completions", self.config.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&body)) + .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; + + let model = request.model.clone(); + + let stream = async_stream::try_stream! { + let mut es = es; + while let Some(event) = es.next().await { + match event { + Ok(Event::Message(msg)) => { + if msg.data == "[DONE]" { + break; + } + + let chunk: Value = serde_json::from_str(&msg.data) + .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; + + if let Some(choice) = chunk["choices"].get(0) { + let delta = &choice["delta"]; + let content = delta["content"].as_str().unwrap_or_default().to_string(); + let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string()); + let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string()); + + yield ProviderStreamChunk { + content, + reasoning_content, + finish_reason, + model: model.clone(), + }; + } + } + Ok(_) => continue, + Err(e) => { + Err(AppError::ProviderError(format!("Stream error: {}", e)))?; + } + } + } + }; Ok(Box::pin(stream)) } -} \ No newline at end of file +} diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index 91098112..ddf77fa9 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -206,6 +206,7 @@ impl super::Provider for GeminiProvider { Ok(ProviderResponse { content, + reasoning_content: None, // Gemini doesn't use this field name prompt_tokens, completion_tokens, total_tokens, @@ -324,6 +325,7 @@ impl super::Provider for GeminiProvider { yield ProviderStreamChunk { content, + reasoning_content: None, finish_reason: None, // Will be set in the last chunk model: model.clone(), }; diff --git a/src/providers/grok.rs b/src/providers/grok.rs index 0e546e00..ab827024 100644 --- a/src/providers/grok.rs +++ b/src/providers/grok.rs @@ -1,8 +1,7 @@ use async_trait::async_trait; use anyhow::Result; -use async_openai::{Client, config::OpenAIConfig}; -use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent}; use futures::stream::{BoxStream, StreamExt}; +use serde_json::Value; use crate::{ models::UnifiedRequest, @@ -12,8 +11,9 @@ use crate::{ use super::{ProviderResponse, ProviderStreamChunk}; pub struct GrokProvider { - client: Client, + client: reqwest::Client, _config: crate::config::GrokConfig, + api_key: String, pricing: Vec, } @@ -21,16 +21,10 @@ impl GrokProvider { pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result { let api_key = app_config.get_api_key("grok")?; - // Grok is OpenAI-compatible - let openai_config = OpenAIConfig::default() - .with_api_key(api_key) - .with_api_base(&config.base_url); - - let client = Client::with_config(openai_config); - Ok(Self { - client, + client: reqwest::Client::new(), _config: config.clone(), + api_key, pricing: app_config.pricing.grok.clone(), }) } @@ -47,114 +41,70 @@ impl super::Provider for GrokProvider { } fn supports_multimodal(&self) -> bool { - true // Grok supports vision models + true } async fn chat_completion( &self, request: UnifiedRequest, ) -> Result { - use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; - - // Convert UnifiedRequest messages to OpenAI messages - let mut messages = Vec::with_capacity(request.messages.len()); - - for msg in request.messages { - let mut parts = Vec::with_capacity(msg.content.len()); - - for part in msg.content { - match part { - crate::models::ContentPart::Text { text } => { - parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { - text, - })); - } - crate::models::ContentPart::Image(image_input) => { - let (base64_data, mime_type) = image_input.to_base64().await - .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; - let data_url = format!("data:{};base64,{}", mime_type, base64_data); - - parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage { - image_url: ImageUrl { - url: data_url, - detail: Some(ImageDetail::Auto), + let mut body = serde_json::json!({ + "model": request.model, + "messages": request.messages.iter().map(|m| { + serde_json::json!({ + "role": m.role, + "content": m.content.iter().map(|p| { + match p { + crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }), + crate::models::ContentPart::Image(image_input) => { + let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default(); + serde_json::json!({ + "type": "image_url", + "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) } + }) } - })); - } - } - } - - let message = match msg.role.as_str() { - "system" => ChatCompletionRequestMessage::System( - ChatCompletionRequestSystemMessage { - content: ChatCompletionRequestSystemMessageContent::Text( - parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") - ), - name: None, - } - ), - "assistant" => ChatCompletionRequestMessage::Assistant( - ChatCompletionRequestAssistantMessage { - content: Some(ChatCompletionRequestAssistantMessageContent::Text( - parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") - )), - name: None, - tool_calls: None, - refusal: None, - audio: None, - #[allow(deprecated)] - function_call: None, - } - ), - _ => ChatCompletionRequestMessage::User( - ChatCompletionRequestUserMessage { - content: ChatCompletionRequestUserMessageContent::Array(parts), - name: None, - } - ), - }; - messages.push(message); - } - - if messages.is_empty() { - return Err(AppError::ProviderError("No valid text messages to send".to_string())); - } - - // Build request using builder pattern - let mut builder = CreateChatCompletionRequestArgs::default(); - builder.model(request.model.clone()); - builder.messages(messages); - - // Add optional parameters + } + }).collect::>() + }) + }).collect::>(), + "stream": false, + }); + if let Some(temp) = request.temperature { - builder.temperature(temp as f32); + body["temperature"] = serde_json::json!(temp); } - if let Some(max_tokens) = request.max_tokens { - builder.max_tokens(max_tokens as u16); + body["max_tokens"] = serde_json::json!(max_tokens); } - - // Execute API call - let response = self.client - .chat() - .create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?) + + let response = self.client.post(format!("{}/chat/completions", self._config.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&body) + .send() .await .map_err(|e| AppError::ProviderError(e.to_string()))?; + + if !response.status().is_success() { + let error_text = response.text().await.unwrap_or_default(); + return Err(AppError::ProviderError(format!("Grok API error: {}", error_text))); + } + + let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?; - // Extract content from response - let content = response - .choices - .first() - .and_then(|choice| choice.message.content.clone()) - .unwrap_or_default(); + let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?; + let message = &choice["message"]; - // Extract token usage - let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32; - let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32; - let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32; + let content = message["content"].as_str().unwrap_or_default().to_string(); + let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string()); + let usage = &resp_json["usage"]; + let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32; + let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32; + let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32; + Ok(ProviderResponse { content, + reasoning_content, prompt_tokens, completion_tokens, total_tokens, @@ -174,11 +124,10 @@ impl super::Provider for GrokProvider { } } - // Fallback to static pricing if not in registry let (prompt_rate, completion_rate) = self.pricing.iter() .find(|p| model.contains(&p.model)) .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) - .unwrap_or((5.0, 15.0)); // Grok-2 pricing is roughly this + .unwrap_or((5.0, 15.0)); (prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) } @@ -187,117 +136,77 @@ impl super::Provider for GrokProvider { &self, request: UnifiedRequest, ) -> Result>, AppError> { - use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; - - // Convert UnifiedRequest messages to OpenAI messages - let mut messages = Vec::with_capacity(request.messages.len()); - - for msg in request.messages { - let mut parts = Vec::with_capacity(msg.content.len()); - - for part in msg.content { - match part { - crate::models::ContentPart::Text { text } => { - parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { - text, - })); - } - crate::models::ContentPart::Image(image_input) => { - let (base64_data, mime_type) = image_input.to_base64().await - .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; - let data_url = format!("data:{};base64,{}", mime_type, base64_data); - - parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage { - image_url: ImageUrl { - url: data_url, - detail: Some(ImageDetail::Auto), + let mut body = serde_json::json!({ + "model": request.model, + "messages": request.messages.iter().map(|m| { + serde_json::json!({ + "role": m.role, + "content": m.content.iter().map(|p| { + match p { + crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }), + crate::models::ContentPart::Image(image_input) => { + let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default(); + serde_json::json!({ + "type": "image_url", + "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) } + }) } - })); - } - } - } - - let message = match msg.role.as_str() { - "system" => ChatCompletionRequestMessage::System( - ChatCompletionRequestSystemMessage { - content: ChatCompletionRequestSystemMessageContent::Text( - parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") - ), - name: None, - } - ), - "assistant" => ChatCompletionRequestMessage::Assistant( - ChatCompletionRequestAssistantMessage { - content: Some(ChatCompletionRequestAssistantMessageContent::Text( - parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") - )), - name: None, - tool_calls: None, - refusal: None, - audio: None, - #[allow(deprecated)] - function_call: None, - } - ), - _ => ChatCompletionRequestMessage::User( - ChatCompletionRequestUserMessage { - content: ChatCompletionRequestUserMessageContent::Array(parts), - name: None, - } - ), - }; - messages.push(message); - } - - if messages.is_empty() { - return Err(AppError::ProviderError("No valid text messages to send".to_string())); - } - - // Build request using builder pattern - let mut builder = CreateChatCompletionRequestArgs::default(); - builder.model(request.model.clone()); - builder.messages(messages); - builder.stream(true); // Enable streaming - - // Add optional parameters - if let Some(temp) = request.temperature { - builder.temperature(temp as f32); - } - - if let Some(max_tokens) = request.max_tokens { - builder.max_tokens(max_tokens as u16); - } - - // Execute streaming API call - let stream = self.client - .chat() - .create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?) - .await - .map_err(|e| AppError::ProviderError(e.to_string()))?; - - // Convert OpenAI stream to our stream format - let model = request.model.clone(); - let stream = stream.map(move |chunk_result| { - match chunk_result { - Ok(chunk) => { - // Extract content from chunk - let content = chunk.choices.first() - .and_then(|choice| choice.delta.content.clone()) - .unwrap_or_default(); - - let finish_reason = chunk.choices.first() - .and_then(|choice| choice.finish_reason.clone()) - .map(|reason| format!("{:?}", reason)); - - Ok(ProviderStreamChunk { - content, - finish_reason, - model: model.clone(), - }) - } - Err(e) => Err(AppError::ProviderError(e.to_string())), - } + } + }).collect::>() + }) + }).collect::>(), + "stream": true, }); + + if let Some(temp) = request.temperature { + body["temperature"] = serde_json::json!(temp); + } + if let Some(max_tokens) = request.max_tokens { + body["max_tokens"] = serde_json::json!(max_tokens); + } + + // Create eventsource stream + use reqwest_eventsource::{EventSource, Event}; + let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&body)) + .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; + + let model = request.model.clone(); + + let stream = async_stream::try_stream! { + let mut es = es; + while let Some(event) = es.next().await { + match event { + Ok(Event::Message(msg)) => { + if msg.data == "[DONE]" { + break; + } + + let chunk: Value = serde_json::from_str(&msg.data) + .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; + + if let Some(choice) = chunk["choices"].get(0) { + let delta = &choice["delta"]; + let content = delta["content"].as_str().unwrap_or_default().to_string(); + let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string()); + let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string()); + + yield ProviderStreamChunk { + content, + reasoning_content, + finish_reason, + model: model.clone(), + }; + } + } + Ok(_) => continue, + Err(e) => { + Err(AppError::ProviderError(format!("Stream error: {}", e)))?; + } + } + } + }; Ok(Box::pin(stream)) } diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 6de57d1c..92f0fbb3 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -44,6 +44,7 @@ pub trait Provider: Send + Sync { pub struct ProviderResponse { pub content: String, + pub reasoning_content: Option, pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, @@ -53,6 +54,7 @@ pub struct ProviderResponse { #[derive(Debug, Clone)] pub struct ProviderStreamChunk { pub content: String, + pub reasoning_content: Option, pub finish_reason: Option, pub model: String, } diff --git a/src/providers/ollama.rs b/src/providers/ollama.rs index e6adc894..41343ce3 100644 --- a/src/providers/ollama.rs +++ b/src/providers/ollama.rs @@ -1,8 +1,7 @@ use async_trait::async_trait; use anyhow::Result; -use async_openai::{Client, config::OpenAIConfig}; -use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent}; use futures::stream::{BoxStream, StreamExt}; +use serde_json::Value; use crate::{ models::UnifiedRequest, @@ -12,22 +11,15 @@ use crate::{ use super::{ProviderResponse, ProviderStreamChunk}; pub struct OllamaProvider { - client: Client, + client: reqwest::Client, _config: crate::config::OllamaConfig, pricing: Vec, } impl OllamaProvider { pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result { - // Ollama usually doesn't need an API key, use a dummy one - let openai_config = OpenAIConfig::default() - .with_api_key("ollama") - .with_api_base(&config.base_url); - - let client = Client::with_config(openai_config); - Ok(Self { - client, + client: reqwest::Client::new(), _config: config.clone(), pricing: app_config.pricing.ollama.clone(), }) @@ -41,124 +33,75 @@ impl super::Provider for OllamaProvider { } fn supports_model(&self, model: &str) -> bool { - // Check if model is in the list of configured Ollama models self._config.models.iter().any(|m| m == model) || model.starts_with("ollama/") } fn supports_multimodal(&self) -> bool { - true // Many Ollama models support vision (e.g. llava, moondream) + true } async fn chat_completion( &self, request: UnifiedRequest, ) -> Result { - use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; - - // Strip "ollama/" prefix if present let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string(); - // Convert UnifiedRequest messages to OpenAI messages - let mut messages = Vec::with_capacity(request.messages.len()); - - for msg in request.messages { - let mut parts = Vec::with_capacity(msg.content.len()); - - for part in msg.content { - match part { - crate::models::ContentPart::Text { text } => { - parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { - text, - })); - } - crate::models::ContentPart::Image(image_input) => { - let (base64_data, mime_type) = image_input.to_base64().await - .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; - let data_url = format!("data:{};base64,{}", mime_type, base64_data); - - parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage { - image_url: ImageUrl { - url: data_url, - detail: Some(ImageDetail::Auto), + let mut body = serde_json::json!({ + "model": model, + "messages": request.messages.iter().map(|m| { + serde_json::json!({ + "role": m.role, + "content": m.content.iter().map(|p| { + match p { + crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }), + crate::models::ContentPart::Image(image_input) => { + let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default(); + serde_json::json!({ + "type": "image_url", + "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) } + }) } - })); - } - } - } - - let message = match msg.role.as_str() { - "system" => ChatCompletionRequestMessage::System( - ChatCompletionRequestSystemMessage { - content: ChatCompletionRequestSystemMessageContent::Text( - parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join(" -") - ), - name: None, - } - ), - "assistant" => ChatCompletionRequestMessage::Assistant( - ChatCompletionRequestAssistantMessage { - content: Some(ChatCompletionRequestAssistantMessageContent::Text( - parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join(" -") - )), - name: None, - tool_calls: None, - refusal: None, - audio: None, - #[allow(deprecated)] - function_call: None, - } - ), - _ => ChatCompletionRequestMessage::User( - ChatCompletionRequestUserMessage { - content: ChatCompletionRequestUserMessageContent::Array(parts), - name: None, - } - ), - }; - messages.push(message); - } - - if messages.is_empty() { - return Err(AppError::ProviderError("No valid text messages to send".to_string())); - } - - // Build request using builder pattern - let mut builder = CreateChatCompletionRequestArgs::default(); - builder.model(model); - builder.messages(messages); - - // Add optional parameters + } + }).collect::>() + }) + }).collect::>(), + "stream": false, + }); + if let Some(temp) = request.temperature { - builder.temperature(temp as f32); + body["temperature"] = serde_json::json!(temp); } - if let Some(max_tokens) = request.max_tokens { - builder.max_tokens(max_tokens as u16); + body["max_tokens"] = serde_json::json!(max_tokens); } - - // Execute API call - let response = self.client - .chat() - .create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?) + + let response = self.client.post(format!("{}/chat/completions", self._config.base_url)) + .json(&body) + .send() .await .map_err(|e| AppError::ProviderError(e.to_string()))?; + + if !response.status().is_success() { + let error_text = response.text().await.unwrap_or_default(); + return Err(AppError::ProviderError(format!("Ollama API error: {}", error_text))); + } + + let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?; - // Extract content from response - let content = response - .choices - .first() - .and_then(|choice| choice.message.content.clone()) - .unwrap_or_default(); + let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?; + let message = &choice["message"]; - // Extract token usage - let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32; - let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32; - let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32; + let content = message["content"].as_str().unwrap_or_default().to_string(); + let reasoning_content = message["reasoning_content"].as_str().or_else(|| message["thought"].as_str()).map(|s| s.to_string()); + let usage = &resp_json["usage"]; + let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32; + let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32; + let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32; + Ok(ProviderResponse { content, + reasoning_content, prompt_tokens, completion_tokens, total_tokens, @@ -178,7 +121,6 @@ impl super::Provider for OllamaProvider { } } - // Ollama is free by default let (prompt_rate, completion_rate) = self.pricing.iter() .find(|p| model.contains(&p.model)) .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) @@ -191,122 +133,72 @@ impl super::Provider for OllamaProvider { &self, request: UnifiedRequest, ) -> Result>, AppError> { - use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; - - // Strip "ollama/" prefix if present let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string(); - // Convert UnifiedRequest messages to OpenAI messages - let mut messages = Vec::with_capacity(request.messages.len()); - - for msg in request.messages { - let mut parts = Vec::with_capacity(msg.content.len()); - - for part in msg.content { - match part { - crate::models::ContentPart::Text { text } => { - parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { - text, - })); - } - crate::models::ContentPart::Image(image_input) => { - let (base64_data, mime_type) = image_input.to_base64().await - .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; - let data_url = format!("data:{};base64,{}", mime_type, base64_data); - - parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage { - image_url: ImageUrl { - url: data_url, - detail: Some(ImageDetail::Auto), - } - })); - } - } - } - - let message = match msg.role.as_str() { - "system" => ChatCompletionRequestMessage::System( - ChatCompletionRequestSystemMessage { - content: ChatCompletionRequestSystemMessageContent::Text( - parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join(" -") - ), - name: None, - } - ), - "assistant" => ChatCompletionRequestMessage::Assistant( - ChatCompletionRequestAssistantMessage { - content: Some(ChatCompletionRequestAssistantMessageContent::Text( - parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join(" -") - )), - name: None, - tool_calls: None, - refusal: None, - audio: None, - #[allow(deprecated)] - function_call: None, - } - ), - _ => ChatCompletionRequestMessage::User( - ChatCompletionRequestUserMessage { - content: ChatCompletionRequestUserMessageContent::Array(parts), - name: None, - } - ), - }; - messages.push(message); - } - - if messages.is_empty() { - return Err(AppError::ProviderError("No valid text messages to send".to_string())); - } - - // Build request using builder pattern - let mut builder = CreateChatCompletionRequestArgs::default(); - builder.model(model); - builder.messages(messages); - builder.stream(true); // Enable streaming - - // Add optional parameters - if let Some(temp) = request.temperature { - builder.temperature(temp as f32); - } - - if let Some(max_tokens) = request.max_tokens { - builder.max_tokens(max_tokens as u16); - } - - // Execute streaming API call - let stream = self.client - .chat() - .create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?) - .await - .map_err(|e| AppError::ProviderError(e.to_string()))?; - - // Convert OpenAI stream to our stream format - let model_name = request.model.clone(); - let stream = stream.map(move |chunk_result| { - match chunk_result { - Ok(chunk) => { - // Extract content from chunk - let content = chunk.choices.first() - .and_then(|choice| choice.delta.content.clone()) - .unwrap_or_default(); - - let finish_reason = chunk.choices.first() - .and_then(|choice| choice.finish_reason.clone()) - .map(|reason| format!("{:?}", reason)); - - Ok(ProviderStreamChunk { - content, - finish_reason, - model: model_name.clone(), - }) - } - Err(e) => Err(AppError::ProviderError(e.to_string())), - } + let mut body = serde_json::json!({ + "model": model, + "messages": request.messages.iter().map(|m| { + serde_json::json!({ + "role": m.role, + "content": m.content.iter().map(|p| { + match p { + crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }), + crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }), + } + }).collect::>() + }) + }).collect::>(), + "stream": true, }); + + if let Some(temp) = request.temperature { + body["temperature"] = serde_json::json!(temp); + } + if let Some(max_tokens) = request.max_tokens { + body["max_tokens"] = serde_json::json!(max_tokens); + } + + // Create eventsource stream + use reqwest_eventsource::{EventSource, Event}; + let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url)) + .json(&body)) + .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; + + let model_name = request.model.clone(); + + let stream = async_stream::try_stream! { + let mut es = es; + while let Some(event) = es.next().await { + match event { + Ok(Event::Message(msg)) => { + if msg.data == "[DONE]" { + break; + } + + let chunk: Value = serde_json::from_str(&msg.data) + .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; + + if let Some(choice) = chunk["choices"].get(0) { + let delta = &choice["delta"]; + let content = delta["content"].as_str().unwrap_or_default().to_string(); + let reasoning_content = delta["reasoning_content"].as_str().or_else(|| delta["thought"].as_str()).map(|s| s.to_string()); + let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string()); + + yield ProviderStreamChunk { + content, + reasoning_content, + finish_reason, + model: model_name.clone(), + }; + } + } + Ok(_) => continue, + Err(e) => { + Err(AppError::ProviderError(format!("Stream error: {}", e)))?; + } + } + } + }; Ok(Box::pin(stream)) } diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 0569f8de..e1f53a59 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -1,8 +1,7 @@ use async_trait::async_trait; use anyhow::Result; -use async_openai::{Client, config::OpenAIConfig}; -use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent}; use futures::stream::{BoxStream, StreamExt}; +use serde_json::Value; use crate::{ models::UnifiedRequest, @@ -12,8 +11,9 @@ use crate::{ use super::{ProviderResponse, ProviderStreamChunk}; pub struct OpenAIProvider { - client: Client, + client: reqwest::Client, _config: crate::config::OpenAIConfig, + api_key: String, pricing: Vec, } @@ -21,16 +21,10 @@ impl OpenAIProvider { pub fn new(config: &crate::config::OpenAIConfig, app_config: &AppConfig) -> Result { let api_key = app_config.get_api_key("openai")?; - // Create OpenAIConfig with api key and base url - let openai_config = OpenAIConfig::default() - .with_api_key(api_key) - .with_api_base(&config.base_url); - - let client = Client::with_config(openai_config); - Ok(Self { - client, + client: reqwest::Client::new(), _config: config.clone(), + api_key, pricing: app_config.pricing.openai.clone(), }) } @@ -47,114 +41,70 @@ impl super::Provider for OpenAIProvider { } fn supports_multimodal(&self) -> bool { - true // OpenAI supports vision models + true } async fn chat_completion( &self, request: UnifiedRequest, ) -> Result { - use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; - - // Convert UnifiedRequest messages to OpenAI messages - let mut messages = Vec::with_capacity(request.messages.len()); - - for msg in request.messages { - let mut parts = Vec::with_capacity(msg.content.len()); - - for part in msg.content { - match part { - crate::models::ContentPart::Text { text } => { - parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { - text, - })); - } - crate::models::ContentPart::Image(image_input) => { - let (base64_data, mime_type) = image_input.to_base64().await - .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; - let data_url = format!("data:{};base64,{}", mime_type, base64_data); - - parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage { - image_url: ImageUrl { - url: data_url, - detail: Some(ImageDetail::Auto), + let mut body = serde_json::json!({ + "model": request.model, + "messages": request.messages.iter().map(|m| { + serde_json::json!({ + "role": m.role, + "content": m.content.iter().map(|p| { + match p { + crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }), + crate::models::ContentPart::Image(image_input) => { + let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default(); + serde_json::json!({ + "type": "image_url", + "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) } + }) } - })); - } - } - } - - let message = match msg.role.as_str() { - "system" => ChatCompletionRequestMessage::System( - ChatCompletionRequestSystemMessage { - content: ChatCompletionRequestSystemMessageContent::Text( - parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") - ), - name: None, - } - ), - "assistant" => ChatCompletionRequestMessage::Assistant( - ChatCompletionRequestAssistantMessage { - content: Some(ChatCompletionRequestAssistantMessageContent::Text( - parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") - )), - name: None, - tool_calls: None, - refusal: None, - audio: None, - #[allow(deprecated)] - function_call: None, - } - ), - _ => ChatCompletionRequestMessage::User( - ChatCompletionRequestUserMessage { - content: ChatCompletionRequestUserMessageContent::Array(parts), - name: None, - } - ), - }; - messages.push(message); - } - - if messages.is_empty() { - return Err(AppError::ProviderError("No valid text messages to send".to_string())); - } - - // Build request using builder pattern - let mut builder = CreateChatCompletionRequestArgs::default(); - builder.model(request.model.clone()); - builder.messages(messages); - - // Add optional parameters + } + }).collect::>() + }) + }).collect::>(), + "stream": false, + }); + if let Some(temp) = request.temperature { - builder.temperature(temp as f32); + body["temperature"] = serde_json::json!(temp); } - if let Some(max_tokens) = request.max_tokens { - builder.max_tokens(max_tokens as u16); + body["max_tokens"] = serde_json::json!(max_tokens); } - - // Execute API call - let response = self.client - .chat() - .create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?) + + let response = self.client.post(format!("{}/chat/completions", self._config.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&body) + .send() .await .map_err(|e| AppError::ProviderError(e.to_string()))?; + + if !response.status().is_success() { + let error_text = response.text().await.unwrap_or_default(); + return Err(AppError::ProviderError(format!("OpenAI API error: {}", error_text))); + } + + let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?; - // Extract content from response - let content = response - .choices - .first() - .and_then(|choice| choice.message.content.clone()) - .unwrap_or_default(); + let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?; + let message = &choice["message"]; - // Extract token usage - let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32; - let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32; - let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32; + let content = message["content"].as_str().unwrap_or_default().to_string(); + let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string()); + let usage = &resp_json["usage"]; + let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32; + let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32; + let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32; + Ok(ProviderResponse { content, + reasoning_content, prompt_tokens, completion_tokens, total_tokens, @@ -174,7 +124,6 @@ impl super::Provider for OpenAIProvider { } } - // Fallback to static pricing if not in registry let (prompt_rate, completion_rate) = self.pricing.iter() .find(|p| model.contains(&p.model)) .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) @@ -187,118 +136,78 @@ impl super::Provider for OpenAIProvider { &self, request: UnifiedRequest, ) -> Result>, AppError> { - use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; - - // Convert UnifiedRequest messages to OpenAI messages - let mut messages = Vec::with_capacity(request.messages.len()); - - for msg in request.messages { - let mut parts = Vec::with_capacity(msg.content.len()); - - for part in msg.content { - match part { - crate::models::ContentPart::Text { text } => { - parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { - text, - })); - } - crate::models::ContentPart::Image(image_input) => { - let (base64_data, mime_type) = image_input.to_base64().await - .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; - let data_url = format!("data:{};base64,{}", mime_type, base64_data); - - parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage { - image_url: ImageUrl { - url: data_url, - detail: Some(ImageDetail::Auto), + let mut body = serde_json::json!({ + "model": request.model, + "messages": request.messages.iter().map(|m| { + serde_json::json!({ + "role": m.role, + "content": m.content.iter().map(|p| { + match p { + crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }), + crate::models::ContentPart::Image(image_input) => { + let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default(); + serde_json::json!({ + "type": "image_url", + "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) } + }) } - })); - } - } - } - - let message = match msg.role.as_str() { - "system" => ChatCompletionRequestMessage::System( - ChatCompletionRequestSystemMessage { - content: ChatCompletionRequestSystemMessageContent::Text( - parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") - ), - name: None, - } - ), - "assistant" => ChatCompletionRequestMessage::Assistant( - ChatCompletionRequestAssistantMessage { - content: Some(ChatCompletionRequestAssistantMessageContent::Text( - parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") - )), - name: None, - tool_calls: None, - refusal: None, - audio: None, - #[allow(deprecated)] - function_call: None, - } - ), - _ => ChatCompletionRequestMessage::User( - ChatCompletionRequestUserMessage { - content: ChatCompletionRequestUserMessageContent::Array(parts), - name: None, - } - ), - }; - messages.push(message); - } - - if messages.is_empty() { - return Err(AppError::ProviderError("No valid text messages to send".to_string())); - } - - // Build request using builder pattern - let mut builder = CreateChatCompletionRequestArgs::default(); - builder.model(request.model.clone()); - builder.messages(messages); - builder.stream(true); // Enable streaming - - // Add optional parameters - if let Some(temp) = request.temperature { - builder.temperature(temp as f32); - } - - if let Some(max_tokens) = request.max_tokens { - builder.max_tokens(max_tokens as u16); - } - - // Execute streaming API call - let stream = self.client - .chat() - .create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?) - .await - .map_err(|e| AppError::ProviderError(e.to_string()))?; - - // Convert OpenAI stream to our stream format - let model = request.model.clone(); - let stream = stream.map(move |chunk_result| { - match chunk_result { - Ok(chunk) => { - // Extract content from chunk - let content = chunk.choices.first() - .and_then(|choice| choice.delta.content.clone()) - .unwrap_or_default(); - - let finish_reason = chunk.choices.first() - .and_then(|choice| choice.finish_reason.clone()) - .map(|reason| format!("{:?}", reason)); - - Ok(ProviderStreamChunk { - content, - finish_reason, - model: model.clone(), - }) - } - Err(e) => Err(AppError::ProviderError(e.to_string())), - } + } + }).collect::>() + }) + }).collect::>(), + "stream": true, }); + + if let Some(temp) = request.temperature { + body["temperature"] = serde_json::json!(temp); + } + if let Some(max_tokens) = request.max_tokens { + body["max_tokens"] = serde_json::json!(max_tokens); + } + + // Create eventsource stream + use reqwest_eventsource::{EventSource, Event}; + let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&body)) + .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; + + let model = request.model.clone(); + + let stream = async_stream::try_stream! { + let mut es = es; + while let Some(event) = es.next().await { + match event { + Ok(Event::Message(msg)) => { + if msg.data == "[DONE]" { + break; + } + + let chunk: Value = serde_json::from_str(&msg.data) + .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; + + if let Some(choice) = chunk["choices"].get(0) { + let delta = &choice["delta"]; + let content = delta["content"].as_str().unwrap_or_default().to_string(); + let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string()); + let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string()); + + yield ProviderStreamChunk { + content, + reasoning_content, + finish_reason, + model: model.clone(), + }; + } + } + Ok(_) => continue, + Err(e) => { + Err(AppError::ProviderError(format!("Stream error: {}", e)))?; + } + } + } + }; Ok(Box::pin(stream)) } -} \ No newline at end of file +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 93b53275..7257dcf4 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -102,6 +102,7 @@ async fn chat_completions( delta: ChatStreamDelta { role: None, content: Some(chunk.content), + reasoning_content: chunk.reasoning_content, }, finish_reason: chunk.finish_reason, }], @@ -177,6 +178,7 @@ async fn chat_completions( content: crate::models::MessageContent::Text { content: response.content }, + reasoning_content: response.reasoning_content, }, finish_reason: Some("stop".to_string()), }], diff --git a/src/utils/streaming.rs b/src/utils/streaming.rs index 3e021f55..ad5d4bcb 100644 --- a/src/utils/streaming.rs +++ b/src/utils/streaming.rs @@ -16,6 +16,7 @@ pub struct AggregatingStream { prompt_tokens: u32, has_images: bool, accumulated_content: String, + accumulated_reasoning: String, logger: Arc, client_manager: Arc, model_registry: Arc, @@ -46,6 +47,7 @@ where prompt_tokens, has_images, accumulated_content: String::new(), + accumulated_reasoning: String::new(), logger, client_manager, model_registry, @@ -71,8 +73,15 @@ where let has_images = self.has_images; let registry = self.model_registry.clone(); - // Estimate completion tokens - let completion_tokens = estimate_completion_tokens(&self.accumulated_content, &model); + // Estimate completion tokens (including reasoning if present) + let content_tokens = estimate_completion_tokens(&self.accumulated_content, &model); + let reasoning_tokens = if !self.accumulated_reasoning.is_empty() { + estimate_completion_tokens(&self.accumulated_reasoning, &model) + } else { + 0 + }; + + let completion_tokens = content_tokens + reasoning_tokens; let total_tokens = prompt_tokens + completion_tokens; let cost = provider.calculate_cost(&model, prompt_tokens, completion_tokens, ®istry); @@ -116,6 +125,9 @@ where match &result { Poll::Ready(Some(Ok(chunk))) => { self.accumulated_content.push_str(&chunk.content); + if let Some(reasoning) = &chunk.reasoning_content { + self.accumulated_reasoning.push_str(reasoning); + } } Poll::Ready(Some(Err(_))) => { // If there's an error, we might still want to log what we got so far?