feat: add support for reasoning (thinking) models across all providers
- Refactored OpenAI, DeepSeek, Grok, and Ollama to manual JSON parsing to capture 'reasoning_content' and 'thought' fields. - Implemented real-time streaming of reasoning blocks. - Added token aggregation and cost tracking for reasoning tokens. - Updated unified models to include 'reasoning_content' in API responses.
This commit is contained in:
@@ -22,6 +22,8 @@ pub struct ChatMessage {
|
|||||||
pub role: String, // "system", "user", "assistant"
|
pub role: String, // "system", "user", "assistant"
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
pub content: MessageContent,
|
pub content: MessageContent,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub reasoning_content: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -29,6 +31,7 @@ pub struct ChatMessage {
|
|||||||
pub enum MessageContent {
|
pub enum MessageContent {
|
||||||
Text { content: String },
|
Text { content: String },
|
||||||
Parts { content: Vec<ContentPartValue> },
|
Parts { content: Vec<ContentPartValue> },
|
||||||
|
None, // Handle cases where content might be null but reasoning is present
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -91,6 +94,8 @@ pub struct ChatStreamChoice {
|
|||||||
pub struct ChatStreamDelta {
|
pub struct ChatStreamDelta {
|
||||||
pub role: Option<String>,
|
pub role: Option<String>,
|
||||||
pub content: Option<String>,
|
pub content: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub reasoning_content: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ========== Unified Request Format (for internal use) ==========
|
// ========== Unified Request Format (for internal use) ==========
|
||||||
@@ -217,6 +222,9 @@ impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
|
|||||||
|
|
||||||
(unified_content, has_images_in_msg)
|
(unified_content, has_images_in_msg)
|
||||||
}
|
}
|
||||||
|
MessageContent::None => {
|
||||||
|
(vec![], false)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
UnifiedMessage {
|
UnifiedMessage {
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_openai::{Client, config::OpenAIConfig};
|
|
||||||
use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent};
|
|
||||||
use futures::stream::{BoxStream, StreamExt};
|
use futures::stream::{BoxStream, StreamExt};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
models::UnifiedRequest,
|
models::UnifiedRequest,
|
||||||
@@ -12,8 +11,9 @@ use crate::{
|
|||||||
use super::{ProviderResponse, ProviderStreamChunk};
|
use super::{ProviderResponse, ProviderStreamChunk};
|
||||||
|
|
||||||
pub struct DeepSeekProvider {
|
pub struct DeepSeekProvider {
|
||||||
client: Client<OpenAIConfig>, // DeepSeek uses OpenAI-compatible API
|
client: reqwest::Client,
|
||||||
_config: crate::config::DeepSeekConfig,
|
config: crate::config::DeepSeekConfig,
|
||||||
|
api_key: String,
|
||||||
pricing: Vec<crate::config::ModelPricing>,
|
pricing: Vec<crate::config::ModelPricing>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -21,16 +21,10 @@ impl DeepSeekProvider {
|
|||||||
pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result<Self> {
|
pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result<Self> {
|
||||||
let api_key = app_config.get_api_key("deepseek")?;
|
let api_key = app_config.get_api_key("deepseek")?;
|
||||||
|
|
||||||
// Create OpenAIConfig with api key and base url
|
|
||||||
let openai_config = OpenAIConfig::default()
|
|
||||||
.with_api_key(api_key)
|
|
||||||
.with_api_base(&config.base_url);
|
|
||||||
|
|
||||||
let client = Client::with_config(openai_config);
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
client,
|
client: reqwest::Client::new(),
|
||||||
_config: config.clone(),
|
config: config.clone(),
|
||||||
|
api_key,
|
||||||
pricing: app_config.pricing.deepseek.clone(),
|
pricing: app_config.pricing.deepseek.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -47,114 +41,72 @@ impl super::Provider for DeepSeekProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn supports_multimodal(&self) -> bool {
|
fn supports_multimodal(&self) -> bool {
|
||||||
false // DeepSeek doesn't support general vision (only OCR)
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_completion(
|
async fn chat_completion(
|
||||||
&self,
|
&self,
|
||||||
request: UnifiedRequest,
|
request: UnifiedRequest,
|
||||||
) -> Result<ProviderResponse, AppError> {
|
) -> Result<ProviderResponse, AppError> {
|
||||||
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
|
// Build the OpenAI-compatible body
|
||||||
|
let mut body = serde_json::json!({
|
||||||
// Convert UnifiedRequest messages to OpenAI-compatible messages
|
"model": request.model,
|
||||||
let mut messages = Vec::with_capacity(request.messages.len());
|
"messages": request.messages.iter().map(|m| {
|
||||||
|
serde_json::json!({
|
||||||
for msg in request.messages {
|
"role": m.role,
|
||||||
let mut parts = Vec::with_capacity(msg.content.len());
|
"content": m.content.iter().map(|p| {
|
||||||
|
match p {
|
||||||
for part in msg.content {
|
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
||||||
match part {
|
crate::models::ContentPart::Image(image_input) => {
|
||||||
crate::models::ContentPart::Text { text } => {
|
// DeepSeek currently doesn't support images in the same way, but we'll try to be standard
|
||||||
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
|
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
||||||
text,
|
serde_json::json!({
|
||||||
}));
|
"type": "image_url",
|
||||||
}
|
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||||
crate::models::ContentPart::Image(image_input) => {
|
})
|
||||||
let (base64_data, mime_type) = image_input.to_base64().await
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
|
||||||
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
|
|
||||||
|
|
||||||
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
|
|
||||||
image_url: ImageUrl {
|
|
||||||
url: data_url,
|
|
||||||
detail: Some(ImageDetail::Auto),
|
|
||||||
}
|
}
|
||||||
}));
|
}
|
||||||
}
|
}).collect::<Vec<_>>()
|
||||||
}
|
})
|
||||||
}
|
}).collect::<Vec<_>>(),
|
||||||
|
"stream": false,
|
||||||
|
});
|
||||||
|
|
||||||
let message = match msg.role.as_str() {
|
|
||||||
"system" => ChatCompletionRequestMessage::System(
|
|
||||||
ChatCompletionRequestSystemMessage {
|
|
||||||
content: ChatCompletionRequestSystemMessageContent::Text(
|
|
||||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
|
||||||
),
|
|
||||||
name: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
"assistant" => ChatCompletionRequestMessage::Assistant(
|
|
||||||
ChatCompletionRequestAssistantMessage {
|
|
||||||
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
|
|
||||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
|
||||||
)),
|
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
refusal: None,
|
|
||||||
audio: None,
|
|
||||||
#[allow(deprecated)]
|
|
||||||
function_call: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
_ => ChatCompletionRequestMessage::User(
|
|
||||||
ChatCompletionRequestUserMessage {
|
|
||||||
content: ChatCompletionRequestUserMessageContent::Array(parts),
|
|
||||||
name: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
};
|
|
||||||
messages.push(message);
|
|
||||||
}
|
|
||||||
|
|
||||||
if messages.is_empty() {
|
|
||||||
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build request using builder pattern
|
|
||||||
let mut builder = CreateChatCompletionRequestArgs::default();
|
|
||||||
builder.model(request.model.clone());
|
|
||||||
builder.messages(messages);
|
|
||||||
|
|
||||||
// Add optional parameters
|
|
||||||
if let Some(temp) = request.temperature {
|
if let Some(temp) = request.temperature {
|
||||||
builder.temperature(temp as f32);
|
body["temperature"] = serde_json::json!(temp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(max_tokens) = request.max_tokens {
|
if let Some(max_tokens) = request.max_tokens {
|
||||||
builder.max_tokens(max_tokens as u16);
|
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute API call
|
let response = self.client.post(format!("{}/chat/completions", self.config.base_url))
|
||||||
let response = self.client
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
.chat()
|
.json(&body)
|
||||||
.create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
|
.send()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||||
|
|
||||||
// Extract content from response
|
if !response.status().is_success() {
|
||||||
let content = response
|
let error_text = response.text().await.unwrap_or_default();
|
||||||
.choices
|
return Err(AppError::ProviderError(format!("DeepSeek API error: {}", error_text)));
|
||||||
.first()
|
}
|
||||||
.and_then(|choice| choice.message.content.clone())
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
// Extract token usage
|
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||||
let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32;
|
|
||||||
let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32;
|
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
||||||
let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32;
|
let message = &choice["message"];
|
||||||
|
|
||||||
|
let content = message["content"].as_str().unwrap_or_default().to_string();
|
||||||
|
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
|
||||||
|
|
||||||
|
let usage = &resp_json["usage"];
|
||||||
|
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
||||||
|
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||||
|
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
||||||
|
|
||||||
Ok(ProviderResponse {
|
Ok(ProviderResponse {
|
||||||
content,
|
content,
|
||||||
|
reasoning_content,
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
completion_tokens,
|
completion_tokens,
|
||||||
total_tokens,
|
total_tokens,
|
||||||
@@ -177,7 +129,7 @@ impl super::Provider for DeepSeekProvider {
|
|||||||
let (prompt_rate, completion_rate) = self.pricing.iter()
|
let (prompt_rate, completion_rate) = self.pricing.iter()
|
||||||
.find(|p| model.contains(&p.model))
|
.find(|p| model.contains(&p.model))
|
||||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||||
.unwrap_or((0.14, 0.28)); // Default to DeepSeek V3 price if not found
|
.unwrap_or((0.14, 0.28));
|
||||||
|
|
||||||
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
||||||
}
|
}
|
||||||
@@ -186,118 +138,72 @@ impl super::Provider for DeepSeekProvider {
|
|||||||
&self,
|
&self,
|
||||||
request: UnifiedRequest,
|
request: UnifiedRequest,
|
||||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||||
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
|
let mut body = serde_json::json!({
|
||||||
|
"model": request.model,
|
||||||
// Convert UnifiedRequest messages to OpenAI-compatible messages
|
"messages": request.messages.iter().map(|m| {
|
||||||
let mut messages = Vec::with_capacity(request.messages.len());
|
serde_json::json!({
|
||||||
|
"role": m.role,
|
||||||
for msg in request.messages {
|
"content": m.content.iter().map(|p| {
|
||||||
let mut parts = Vec::with_capacity(msg.content.len());
|
match p {
|
||||||
|
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
||||||
for part in msg.content {
|
crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }),
|
||||||
match part {
|
}
|
||||||
crate::models::ContentPart::Text { text } => {
|
}).collect::<Vec<_>>()
|
||||||
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
|
})
|
||||||
text,
|
}).collect::<Vec<_>>(),
|
||||||
}));
|
"stream": true,
|
||||||
}
|
|
||||||
crate::models::ContentPart::Image(image_input) => {
|
|
||||||
let (base64_data, mime_type) = image_input.to_base64().await
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
|
||||||
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
|
|
||||||
|
|
||||||
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
|
|
||||||
image_url: ImageUrl {
|
|
||||||
url: data_url,
|
|
||||||
detail: Some(ImageDetail::Auto),
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let message = match msg.role.as_str() {
|
|
||||||
"system" => ChatCompletionRequestMessage::System(
|
|
||||||
ChatCompletionRequestSystemMessage {
|
|
||||||
content: ChatCompletionRequestSystemMessageContent::Text(
|
|
||||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
|
||||||
),
|
|
||||||
name: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
"assistant" => ChatCompletionRequestMessage::Assistant(
|
|
||||||
ChatCompletionRequestAssistantMessage {
|
|
||||||
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
|
|
||||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
|
||||||
)),
|
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
refusal: None,
|
|
||||||
audio: None,
|
|
||||||
#[allow(deprecated)]
|
|
||||||
function_call: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
_ => ChatCompletionRequestMessage::User(
|
|
||||||
ChatCompletionRequestUserMessage {
|
|
||||||
content: ChatCompletionRequestUserMessageContent::Array(parts),
|
|
||||||
name: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
};
|
|
||||||
messages.push(message);
|
|
||||||
}
|
|
||||||
|
|
||||||
if messages.is_empty() {
|
|
||||||
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build request using builder pattern
|
|
||||||
let mut builder = CreateChatCompletionRequestArgs::default();
|
|
||||||
builder.model(request.model.clone());
|
|
||||||
builder.messages(messages);
|
|
||||||
builder.stream(true); // Enable streaming
|
|
||||||
|
|
||||||
// Add optional parameters
|
|
||||||
if let Some(temp) = request.temperature {
|
|
||||||
builder.temperature(temp as f32);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(max_tokens) = request.max_tokens {
|
|
||||||
builder.max_tokens(max_tokens as u16);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute streaming API call
|
|
||||||
let stream = self.client
|
|
||||||
.chat()
|
|
||||||
.create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
|
|
||||||
.await
|
|
||||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
||||||
|
|
||||||
// Convert OpenAI stream to our stream format
|
|
||||||
let model = request.model.clone();
|
|
||||||
let stream = stream.map(move |chunk_result| {
|
|
||||||
match chunk_result {
|
|
||||||
Ok(chunk) => {
|
|
||||||
// Extract content from chunk
|
|
||||||
let content = chunk.choices.first()
|
|
||||||
.and_then(|choice| choice.delta.content.clone())
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
let finish_reason = chunk.choices.first()
|
|
||||||
.and_then(|choice| choice.finish_reason.clone())
|
|
||||||
.map(|reason| format!("{:?}", reason));
|
|
||||||
|
|
||||||
Ok(ProviderStreamChunk {
|
|
||||||
content,
|
|
||||||
finish_reason,
|
|
||||||
model: model.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
Err(e) => Err(AppError::ProviderError(e.to_string())),
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if let Some(temp) = request.temperature {
|
||||||
|
body["temperature"] = serde_json::json!(temp);
|
||||||
|
}
|
||||||
|
if let Some(max_tokens) = request.max_tokens {
|
||||||
|
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create eventsource stream
|
||||||
|
use reqwest_eventsource::{EventSource, Event};
|
||||||
|
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self.config.base_url))
|
||||||
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
|
.json(&body))
|
||||||
|
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||||
|
|
||||||
|
let model = request.model.clone();
|
||||||
|
|
||||||
|
let stream = async_stream::try_stream! {
|
||||||
|
let mut es = es;
|
||||||
|
while let Some(event) = es.next().await {
|
||||||
|
match event {
|
||||||
|
Ok(Event::Message(msg)) => {
|
||||||
|
if msg.data == "[DONE]" {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk: Value = serde_json::from_str(&msg.data)
|
||||||
|
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||||
|
|
||||||
|
if let Some(choice) = chunk["choices"].get(0) {
|
||||||
|
let delta = &choice["delta"];
|
||||||
|
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
||||||
|
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
|
||||||
|
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
||||||
|
|
||||||
|
yield ProviderStreamChunk {
|
||||||
|
content,
|
||||||
|
reasoning_content,
|
||||||
|
finish_reason,
|
||||||
|
model: model.clone(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(_) => continue,
|
||||||
|
Err(e) => {
|
||||||
|
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Box::pin(stream))
|
Ok(Box::pin(stream))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -206,6 +206,7 @@ impl super::Provider for GeminiProvider {
|
|||||||
|
|
||||||
Ok(ProviderResponse {
|
Ok(ProviderResponse {
|
||||||
content,
|
content,
|
||||||
|
reasoning_content: None, // Gemini doesn't use this field name
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
completion_tokens,
|
completion_tokens,
|
||||||
total_tokens,
|
total_tokens,
|
||||||
@@ -324,6 +325,7 @@ impl super::Provider for GeminiProvider {
|
|||||||
|
|
||||||
yield ProviderStreamChunk {
|
yield ProviderStreamChunk {
|
||||||
content,
|
content,
|
||||||
|
reasoning_content: None,
|
||||||
finish_reason: None, // Will be set in the last chunk
|
finish_reason: None, // Will be set in the last chunk
|
||||||
model: model.clone(),
|
model: model.clone(),
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_openai::{Client, config::OpenAIConfig};
|
|
||||||
use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent};
|
|
||||||
use futures::stream::{BoxStream, StreamExt};
|
use futures::stream::{BoxStream, StreamExt};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
models::UnifiedRequest,
|
models::UnifiedRequest,
|
||||||
@@ -12,8 +11,9 @@ use crate::{
|
|||||||
use super::{ProviderResponse, ProviderStreamChunk};
|
use super::{ProviderResponse, ProviderStreamChunk};
|
||||||
|
|
||||||
pub struct GrokProvider {
|
pub struct GrokProvider {
|
||||||
client: Client<OpenAIConfig>,
|
client: reqwest::Client,
|
||||||
_config: crate::config::GrokConfig,
|
_config: crate::config::GrokConfig,
|
||||||
|
api_key: String,
|
||||||
pricing: Vec<crate::config::ModelPricing>,
|
pricing: Vec<crate::config::ModelPricing>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -21,16 +21,10 @@ impl GrokProvider {
|
|||||||
pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result<Self> {
|
pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result<Self> {
|
||||||
let api_key = app_config.get_api_key("grok")?;
|
let api_key = app_config.get_api_key("grok")?;
|
||||||
|
|
||||||
// Grok is OpenAI-compatible
|
|
||||||
let openai_config = OpenAIConfig::default()
|
|
||||||
.with_api_key(api_key)
|
|
||||||
.with_api_base(&config.base_url);
|
|
||||||
|
|
||||||
let client = Client::with_config(openai_config);
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
client,
|
client: reqwest::Client::new(),
|
||||||
_config: config.clone(),
|
_config: config.clone(),
|
||||||
|
api_key,
|
||||||
pricing: app_config.pricing.grok.clone(),
|
pricing: app_config.pricing.grok.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -47,114 +41,70 @@ impl super::Provider for GrokProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn supports_multimodal(&self) -> bool {
|
fn supports_multimodal(&self) -> bool {
|
||||||
true // Grok supports vision models
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_completion(
|
async fn chat_completion(
|
||||||
&self,
|
&self,
|
||||||
request: UnifiedRequest,
|
request: UnifiedRequest,
|
||||||
) -> Result<ProviderResponse, AppError> {
|
) -> Result<ProviderResponse, AppError> {
|
||||||
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
|
let mut body = serde_json::json!({
|
||||||
|
"model": request.model,
|
||||||
// Convert UnifiedRequest messages to OpenAI messages
|
"messages": request.messages.iter().map(|m| {
|
||||||
let mut messages = Vec::with_capacity(request.messages.len());
|
serde_json::json!({
|
||||||
|
"role": m.role,
|
||||||
for msg in request.messages {
|
"content": m.content.iter().map(|p| {
|
||||||
let mut parts = Vec::with_capacity(msg.content.len());
|
match p {
|
||||||
|
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
||||||
for part in msg.content {
|
crate::models::ContentPart::Image(image_input) => {
|
||||||
match part {
|
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
||||||
crate::models::ContentPart::Text { text } => {
|
serde_json::json!({
|
||||||
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
|
"type": "image_url",
|
||||||
text,
|
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||||
}));
|
})
|
||||||
}
|
|
||||||
crate::models::ContentPart::Image(image_input) => {
|
|
||||||
let (base64_data, mime_type) = image_input.to_base64().await
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
|
||||||
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
|
|
||||||
|
|
||||||
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
|
|
||||||
image_url: ImageUrl {
|
|
||||||
url: data_url,
|
|
||||||
detail: Some(ImageDetail::Auto),
|
|
||||||
}
|
}
|
||||||
}));
|
}
|
||||||
}
|
}).collect::<Vec<_>>()
|
||||||
}
|
})
|
||||||
}
|
}).collect::<Vec<_>>(),
|
||||||
|
"stream": false,
|
||||||
|
});
|
||||||
|
|
||||||
let message = match msg.role.as_str() {
|
|
||||||
"system" => ChatCompletionRequestMessage::System(
|
|
||||||
ChatCompletionRequestSystemMessage {
|
|
||||||
content: ChatCompletionRequestSystemMessageContent::Text(
|
|
||||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
|
||||||
),
|
|
||||||
name: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
"assistant" => ChatCompletionRequestMessage::Assistant(
|
|
||||||
ChatCompletionRequestAssistantMessage {
|
|
||||||
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
|
|
||||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
|
||||||
)),
|
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
refusal: None,
|
|
||||||
audio: None,
|
|
||||||
#[allow(deprecated)]
|
|
||||||
function_call: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
_ => ChatCompletionRequestMessage::User(
|
|
||||||
ChatCompletionRequestUserMessage {
|
|
||||||
content: ChatCompletionRequestUserMessageContent::Array(parts),
|
|
||||||
name: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
};
|
|
||||||
messages.push(message);
|
|
||||||
}
|
|
||||||
|
|
||||||
if messages.is_empty() {
|
|
||||||
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build request using builder pattern
|
|
||||||
let mut builder = CreateChatCompletionRequestArgs::default();
|
|
||||||
builder.model(request.model.clone());
|
|
||||||
builder.messages(messages);
|
|
||||||
|
|
||||||
// Add optional parameters
|
|
||||||
if let Some(temp) = request.temperature {
|
if let Some(temp) = request.temperature {
|
||||||
builder.temperature(temp as f32);
|
body["temperature"] = serde_json::json!(temp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(max_tokens) = request.max_tokens {
|
if let Some(max_tokens) = request.max_tokens {
|
||||||
builder.max_tokens(max_tokens as u16);
|
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute API call
|
let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
|
||||||
let response = self.client
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
.chat()
|
.json(&body)
|
||||||
.create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
|
.send()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||||
|
|
||||||
// Extract content from response
|
if !response.status().is_success() {
|
||||||
let content = response
|
let error_text = response.text().await.unwrap_or_default();
|
||||||
.choices
|
return Err(AppError::ProviderError(format!("Grok API error: {}", error_text)));
|
||||||
.first()
|
}
|
||||||
.and_then(|choice| choice.message.content.clone())
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
// Extract token usage
|
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||||
let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32;
|
|
||||||
let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32;
|
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
||||||
let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32;
|
let message = &choice["message"];
|
||||||
|
|
||||||
|
let content = message["content"].as_str().unwrap_or_default().to_string();
|
||||||
|
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
|
||||||
|
|
||||||
|
let usage = &resp_json["usage"];
|
||||||
|
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
||||||
|
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||||
|
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
||||||
|
|
||||||
Ok(ProviderResponse {
|
Ok(ProviderResponse {
|
||||||
content,
|
content,
|
||||||
|
reasoning_content,
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
completion_tokens,
|
completion_tokens,
|
||||||
total_tokens,
|
total_tokens,
|
||||||
@@ -174,11 +124,10 @@ impl super::Provider for GrokProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to static pricing if not in registry
|
|
||||||
let (prompt_rate, completion_rate) = self.pricing.iter()
|
let (prompt_rate, completion_rate) = self.pricing.iter()
|
||||||
.find(|p| model.contains(&p.model))
|
.find(|p| model.contains(&p.model))
|
||||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||||
.unwrap_or((5.0, 15.0)); // Grok-2 pricing is roughly this
|
.unwrap_or((5.0, 15.0));
|
||||||
|
|
||||||
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
||||||
}
|
}
|
||||||
@@ -187,118 +136,78 @@ impl super::Provider for GrokProvider {
|
|||||||
&self,
|
&self,
|
||||||
request: UnifiedRequest,
|
request: UnifiedRequest,
|
||||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||||
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
|
let mut body = serde_json::json!({
|
||||||
|
"model": request.model,
|
||||||
// Convert UnifiedRequest messages to OpenAI messages
|
"messages": request.messages.iter().map(|m| {
|
||||||
let mut messages = Vec::with_capacity(request.messages.len());
|
serde_json::json!({
|
||||||
|
"role": m.role,
|
||||||
for msg in request.messages {
|
"content": m.content.iter().map(|p| {
|
||||||
let mut parts = Vec::with_capacity(msg.content.len());
|
match p {
|
||||||
|
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
||||||
for part in msg.content {
|
crate::models::ContentPart::Image(image_input) => {
|
||||||
match part {
|
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
||||||
crate::models::ContentPart::Text { text } => {
|
serde_json::json!({
|
||||||
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
|
"type": "image_url",
|
||||||
text,
|
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||||
}));
|
})
|
||||||
}
|
|
||||||
crate::models::ContentPart::Image(image_input) => {
|
|
||||||
let (base64_data, mime_type) = image_input.to_base64().await
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
|
||||||
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
|
|
||||||
|
|
||||||
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
|
|
||||||
image_url: ImageUrl {
|
|
||||||
url: data_url,
|
|
||||||
detail: Some(ImageDetail::Auto),
|
|
||||||
}
|
}
|
||||||
}));
|
}
|
||||||
}
|
}).collect::<Vec<_>>()
|
||||||
}
|
})
|
||||||
}
|
}).collect::<Vec<_>>(),
|
||||||
|
"stream": true,
|
||||||
let message = match msg.role.as_str() {
|
|
||||||
"system" => ChatCompletionRequestMessage::System(
|
|
||||||
ChatCompletionRequestSystemMessage {
|
|
||||||
content: ChatCompletionRequestSystemMessageContent::Text(
|
|
||||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
|
||||||
),
|
|
||||||
name: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
"assistant" => ChatCompletionRequestMessage::Assistant(
|
|
||||||
ChatCompletionRequestAssistantMessage {
|
|
||||||
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
|
|
||||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
|
||||||
)),
|
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
refusal: None,
|
|
||||||
audio: None,
|
|
||||||
#[allow(deprecated)]
|
|
||||||
function_call: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
_ => ChatCompletionRequestMessage::User(
|
|
||||||
ChatCompletionRequestUserMessage {
|
|
||||||
content: ChatCompletionRequestUserMessageContent::Array(parts),
|
|
||||||
name: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
};
|
|
||||||
messages.push(message);
|
|
||||||
}
|
|
||||||
|
|
||||||
if messages.is_empty() {
|
|
||||||
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build request using builder pattern
|
|
||||||
let mut builder = CreateChatCompletionRequestArgs::default();
|
|
||||||
builder.model(request.model.clone());
|
|
||||||
builder.messages(messages);
|
|
||||||
builder.stream(true); // Enable streaming
|
|
||||||
|
|
||||||
// Add optional parameters
|
|
||||||
if let Some(temp) = request.temperature {
|
|
||||||
builder.temperature(temp as f32);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(max_tokens) = request.max_tokens {
|
|
||||||
builder.max_tokens(max_tokens as u16);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute streaming API call
|
|
||||||
let stream = self.client
|
|
||||||
.chat()
|
|
||||||
.create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
|
|
||||||
.await
|
|
||||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
||||||
|
|
||||||
// Convert OpenAI stream to our stream format
|
|
||||||
let model = request.model.clone();
|
|
||||||
let stream = stream.map(move |chunk_result| {
|
|
||||||
match chunk_result {
|
|
||||||
Ok(chunk) => {
|
|
||||||
// Extract content from chunk
|
|
||||||
let content = chunk.choices.first()
|
|
||||||
.and_then(|choice| choice.delta.content.clone())
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
let finish_reason = chunk.choices.first()
|
|
||||||
.and_then(|choice| choice.finish_reason.clone())
|
|
||||||
.map(|reason| format!("{:?}", reason));
|
|
||||||
|
|
||||||
Ok(ProviderStreamChunk {
|
|
||||||
content,
|
|
||||||
finish_reason,
|
|
||||||
model: model.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
Err(e) => Err(AppError::ProviderError(e.to_string())),
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if let Some(temp) = request.temperature {
|
||||||
|
body["temperature"] = serde_json::json!(temp);
|
||||||
|
}
|
||||||
|
if let Some(max_tokens) = request.max_tokens {
|
||||||
|
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create eventsource stream
|
||||||
|
use reqwest_eventsource::{EventSource, Event};
|
||||||
|
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
|
||||||
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
|
.json(&body))
|
||||||
|
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||||
|
|
||||||
|
let model = request.model.clone();
|
||||||
|
|
||||||
|
let stream = async_stream::try_stream! {
|
||||||
|
let mut es = es;
|
||||||
|
while let Some(event) = es.next().await {
|
||||||
|
match event {
|
||||||
|
Ok(Event::Message(msg)) => {
|
||||||
|
if msg.data == "[DONE]" {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk: Value = serde_json::from_str(&msg.data)
|
||||||
|
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||||
|
|
||||||
|
if let Some(choice) = chunk["choices"].get(0) {
|
||||||
|
let delta = &choice["delta"];
|
||||||
|
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
||||||
|
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
|
||||||
|
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
||||||
|
|
||||||
|
yield ProviderStreamChunk {
|
||||||
|
content,
|
||||||
|
reasoning_content,
|
||||||
|
finish_reason,
|
||||||
|
model: model.clone(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(_) => continue,
|
||||||
|
Err(e) => {
|
||||||
|
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Box::pin(stream))
|
Ok(Box::pin(stream))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ pub trait Provider: Send + Sync {
|
|||||||
|
|
||||||
pub struct ProviderResponse {
|
pub struct ProviderResponse {
|
||||||
pub content: String,
|
pub content: String,
|
||||||
|
pub reasoning_content: Option<String>,
|
||||||
pub prompt_tokens: u32,
|
pub prompt_tokens: u32,
|
||||||
pub completion_tokens: u32,
|
pub completion_tokens: u32,
|
||||||
pub total_tokens: u32,
|
pub total_tokens: u32,
|
||||||
@@ -53,6 +54,7 @@ pub struct ProviderResponse {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ProviderStreamChunk {
|
pub struct ProviderStreamChunk {
|
||||||
pub content: String,
|
pub content: String,
|
||||||
|
pub reasoning_content: Option<String>,
|
||||||
pub finish_reason: Option<String>,
|
pub finish_reason: Option<String>,
|
||||||
pub model: String,
|
pub model: String,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_openai::{Client, config::OpenAIConfig};
|
|
||||||
use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent};
|
|
||||||
use futures::stream::{BoxStream, StreamExt};
|
use futures::stream::{BoxStream, StreamExt};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
models::UnifiedRequest,
|
models::UnifiedRequest,
|
||||||
@@ -12,22 +11,15 @@ use crate::{
|
|||||||
use super::{ProviderResponse, ProviderStreamChunk};
|
use super::{ProviderResponse, ProviderStreamChunk};
|
||||||
|
|
||||||
pub struct OllamaProvider {
|
pub struct OllamaProvider {
|
||||||
client: Client<OpenAIConfig>,
|
client: reqwest::Client,
|
||||||
_config: crate::config::OllamaConfig,
|
_config: crate::config::OllamaConfig,
|
||||||
pricing: Vec<crate::config::ModelPricing>,
|
pricing: Vec<crate::config::ModelPricing>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OllamaProvider {
|
impl OllamaProvider {
|
||||||
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
|
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
|
||||||
// Ollama usually doesn't need an API key, use a dummy one
|
|
||||||
let openai_config = OpenAIConfig::default()
|
|
||||||
.with_api_key("ollama")
|
|
||||||
.with_api_base(&config.base_url);
|
|
||||||
|
|
||||||
let client = Client::with_config(openai_config);
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
client,
|
client: reqwest::Client::new(),
|
||||||
_config: config.clone(),
|
_config: config.clone(),
|
||||||
pricing: app_config.pricing.ollama.clone(),
|
pricing: app_config.pricing.ollama.clone(),
|
||||||
})
|
})
|
||||||
@@ -41,124 +33,75 @@ impl super::Provider for OllamaProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn supports_model(&self, model: &str) -> bool {
|
fn supports_model(&self, model: &str) -> bool {
|
||||||
// Check if model is in the list of configured Ollama models
|
|
||||||
self._config.models.iter().any(|m| m == model) || model.starts_with("ollama/")
|
self._config.models.iter().any(|m| m == model) || model.starts_with("ollama/")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_multimodal(&self) -> bool {
|
fn supports_multimodal(&self) -> bool {
|
||||||
true // Many Ollama models support vision (e.g. llava, moondream)
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_completion(
|
async fn chat_completion(
|
||||||
&self,
|
&self,
|
||||||
request: UnifiedRequest,
|
request: UnifiedRequest,
|
||||||
) -> Result<ProviderResponse, AppError> {
|
) -> Result<ProviderResponse, AppError> {
|
||||||
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
|
|
||||||
|
|
||||||
// Strip "ollama/" prefix if present
|
|
||||||
let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string();
|
let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string();
|
||||||
|
|
||||||
// Convert UnifiedRequest messages to OpenAI messages
|
let mut body = serde_json::json!({
|
||||||
let mut messages = Vec::with_capacity(request.messages.len());
|
"model": model,
|
||||||
|
"messages": request.messages.iter().map(|m| {
|
||||||
for msg in request.messages {
|
serde_json::json!({
|
||||||
let mut parts = Vec::with_capacity(msg.content.len());
|
"role": m.role,
|
||||||
|
"content": m.content.iter().map(|p| {
|
||||||
for part in msg.content {
|
match p {
|
||||||
match part {
|
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
||||||
crate::models::ContentPart::Text { text } => {
|
crate::models::ContentPart::Image(image_input) => {
|
||||||
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
|
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
||||||
text,
|
serde_json::json!({
|
||||||
}));
|
"type": "image_url",
|
||||||
}
|
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||||
crate::models::ContentPart::Image(image_input) => {
|
})
|
||||||
let (base64_data, mime_type) = image_input.to_base64().await
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
|
||||||
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
|
|
||||||
|
|
||||||
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
|
|
||||||
image_url: ImageUrl {
|
|
||||||
url: data_url,
|
|
||||||
detail: Some(ImageDetail::Auto),
|
|
||||||
}
|
}
|
||||||
}));
|
}
|
||||||
}
|
}).collect::<Vec<_>>()
|
||||||
}
|
})
|
||||||
}
|
}).collect::<Vec<_>>(),
|
||||||
|
"stream": false,
|
||||||
|
});
|
||||||
|
|
||||||
let message = match msg.role.as_str() {
|
|
||||||
"system" => ChatCompletionRequestMessage::System(
|
|
||||||
ChatCompletionRequestSystemMessage {
|
|
||||||
content: ChatCompletionRequestSystemMessageContent::Text(
|
|
||||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("
|
|
||||||
")
|
|
||||||
),
|
|
||||||
name: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
"assistant" => ChatCompletionRequestMessage::Assistant(
|
|
||||||
ChatCompletionRequestAssistantMessage {
|
|
||||||
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
|
|
||||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("
|
|
||||||
")
|
|
||||||
)),
|
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
refusal: None,
|
|
||||||
audio: None,
|
|
||||||
#[allow(deprecated)]
|
|
||||||
function_call: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
_ => ChatCompletionRequestMessage::User(
|
|
||||||
ChatCompletionRequestUserMessage {
|
|
||||||
content: ChatCompletionRequestUserMessageContent::Array(parts),
|
|
||||||
name: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
};
|
|
||||||
messages.push(message);
|
|
||||||
}
|
|
||||||
|
|
||||||
if messages.is_empty() {
|
|
||||||
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build request using builder pattern
|
|
||||||
let mut builder = CreateChatCompletionRequestArgs::default();
|
|
||||||
builder.model(model);
|
|
||||||
builder.messages(messages);
|
|
||||||
|
|
||||||
// Add optional parameters
|
|
||||||
if let Some(temp) = request.temperature {
|
if let Some(temp) = request.temperature {
|
||||||
builder.temperature(temp as f32);
|
body["temperature"] = serde_json::json!(temp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(max_tokens) = request.max_tokens {
|
if let Some(max_tokens) = request.max_tokens {
|
||||||
builder.max_tokens(max_tokens as u16);
|
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute API call
|
let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
|
||||||
let response = self.client
|
.json(&body)
|
||||||
.chat()
|
.send()
|
||||||
.create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
|
|
||||||
.await
|
.await
|
||||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||||
|
|
||||||
// Extract content from response
|
if !response.status().is_success() {
|
||||||
let content = response
|
let error_text = response.text().await.unwrap_or_default();
|
||||||
.choices
|
return Err(AppError::ProviderError(format!("Ollama API error: {}", error_text)));
|
||||||
.first()
|
}
|
||||||
.and_then(|choice| choice.message.content.clone())
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
// Extract token usage
|
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||||
let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32;
|
|
||||||
let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32;
|
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
||||||
let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32;
|
let message = &choice["message"];
|
||||||
|
|
||||||
|
let content = message["content"].as_str().unwrap_or_default().to_string();
|
||||||
|
let reasoning_content = message["reasoning_content"].as_str().or_else(|| message["thought"].as_str()).map(|s| s.to_string());
|
||||||
|
|
||||||
|
let usage = &resp_json["usage"];
|
||||||
|
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
||||||
|
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||||
|
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
||||||
|
|
||||||
Ok(ProviderResponse {
|
Ok(ProviderResponse {
|
||||||
content,
|
content,
|
||||||
|
reasoning_content,
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
completion_tokens,
|
completion_tokens,
|
||||||
total_tokens,
|
total_tokens,
|
||||||
@@ -178,7 +121,6 @@ impl super::Provider for OllamaProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ollama is free by default
|
|
||||||
let (prompt_rate, completion_rate) = self.pricing.iter()
|
let (prompt_rate, completion_rate) = self.pricing.iter()
|
||||||
.find(|p| model.contains(&p.model))
|
.find(|p| model.contains(&p.model))
|
||||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||||
@@ -191,123 +133,73 @@ impl super::Provider for OllamaProvider {
|
|||||||
&self,
|
&self,
|
||||||
request: UnifiedRequest,
|
request: UnifiedRequest,
|
||||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||||
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
|
|
||||||
|
|
||||||
// Strip "ollama/" prefix if present
|
|
||||||
let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string();
|
let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string();
|
||||||
|
|
||||||
// Convert UnifiedRequest messages to OpenAI messages
|
let mut body = serde_json::json!({
|
||||||
let mut messages = Vec::with_capacity(request.messages.len());
|
"model": model,
|
||||||
|
"messages": request.messages.iter().map(|m| {
|
||||||
for msg in request.messages {
|
serde_json::json!({
|
||||||
let mut parts = Vec::with_capacity(msg.content.len());
|
"role": m.role,
|
||||||
|
"content": m.content.iter().map(|p| {
|
||||||
for part in msg.content {
|
match p {
|
||||||
match part {
|
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
||||||
crate::models::ContentPart::Text { text } => {
|
crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }),
|
||||||
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
|
}
|
||||||
text,
|
}).collect::<Vec<_>>()
|
||||||
}));
|
})
|
||||||
}
|
}).collect::<Vec<_>>(),
|
||||||
crate::models::ContentPart::Image(image_input) => {
|
"stream": true,
|
||||||
let (base64_data, mime_type) = image_input.to_base64().await
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
|
||||||
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
|
|
||||||
|
|
||||||
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
|
|
||||||
image_url: ImageUrl {
|
|
||||||
url: data_url,
|
|
||||||
detail: Some(ImageDetail::Auto),
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let message = match msg.role.as_str() {
|
|
||||||
"system" => ChatCompletionRequestMessage::System(
|
|
||||||
ChatCompletionRequestSystemMessage {
|
|
||||||
content: ChatCompletionRequestSystemMessageContent::Text(
|
|
||||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("
|
|
||||||
")
|
|
||||||
),
|
|
||||||
name: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
"assistant" => ChatCompletionRequestMessage::Assistant(
|
|
||||||
ChatCompletionRequestAssistantMessage {
|
|
||||||
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
|
|
||||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("
|
|
||||||
")
|
|
||||||
)),
|
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
refusal: None,
|
|
||||||
audio: None,
|
|
||||||
#[allow(deprecated)]
|
|
||||||
function_call: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
_ => ChatCompletionRequestMessage::User(
|
|
||||||
ChatCompletionRequestUserMessage {
|
|
||||||
content: ChatCompletionRequestUserMessageContent::Array(parts),
|
|
||||||
name: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
};
|
|
||||||
messages.push(message);
|
|
||||||
}
|
|
||||||
|
|
||||||
if messages.is_empty() {
|
|
||||||
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build request using builder pattern
|
|
||||||
let mut builder = CreateChatCompletionRequestArgs::default();
|
|
||||||
builder.model(model);
|
|
||||||
builder.messages(messages);
|
|
||||||
builder.stream(true); // Enable streaming
|
|
||||||
|
|
||||||
// Add optional parameters
|
|
||||||
if let Some(temp) = request.temperature {
|
|
||||||
builder.temperature(temp as f32);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(max_tokens) = request.max_tokens {
|
|
||||||
builder.max_tokens(max_tokens as u16);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute streaming API call
|
|
||||||
let stream = self.client
|
|
||||||
.chat()
|
|
||||||
.create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
|
|
||||||
.await
|
|
||||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
||||||
|
|
||||||
// Convert OpenAI stream to our stream format
|
|
||||||
let model_name = request.model.clone();
|
|
||||||
let stream = stream.map(move |chunk_result| {
|
|
||||||
match chunk_result {
|
|
||||||
Ok(chunk) => {
|
|
||||||
// Extract content from chunk
|
|
||||||
let content = chunk.choices.first()
|
|
||||||
.and_then(|choice| choice.delta.content.clone())
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
let finish_reason = chunk.choices.first()
|
|
||||||
.and_then(|choice| choice.finish_reason.clone())
|
|
||||||
.map(|reason| format!("{:?}", reason));
|
|
||||||
|
|
||||||
Ok(ProviderStreamChunk {
|
|
||||||
content,
|
|
||||||
finish_reason,
|
|
||||||
model: model_name.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
Err(e) => Err(AppError::ProviderError(e.to_string())),
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if let Some(temp) = request.temperature {
|
||||||
|
body["temperature"] = serde_json::json!(temp);
|
||||||
|
}
|
||||||
|
if let Some(max_tokens) = request.max_tokens {
|
||||||
|
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create eventsource stream
|
||||||
|
use reqwest_eventsource::{EventSource, Event};
|
||||||
|
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
|
||||||
|
.json(&body))
|
||||||
|
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||||
|
|
||||||
|
let model_name = request.model.clone();
|
||||||
|
|
||||||
|
let stream = async_stream::try_stream! {
|
||||||
|
let mut es = es;
|
||||||
|
while let Some(event) = es.next().await {
|
||||||
|
match event {
|
||||||
|
Ok(Event::Message(msg)) => {
|
||||||
|
if msg.data == "[DONE]" {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk: Value = serde_json::from_str(&msg.data)
|
||||||
|
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||||
|
|
||||||
|
if let Some(choice) = chunk["choices"].get(0) {
|
||||||
|
let delta = &choice["delta"];
|
||||||
|
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
||||||
|
let reasoning_content = delta["reasoning_content"].as_str().or_else(|| delta["thought"].as_str()).map(|s| s.to_string());
|
||||||
|
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
||||||
|
|
||||||
|
yield ProviderStreamChunk {
|
||||||
|
content,
|
||||||
|
reasoning_content,
|
||||||
|
finish_reason,
|
||||||
|
model: model_name.clone(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(_) => continue,
|
||||||
|
Err(e) => {
|
||||||
|
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Box::pin(stream))
|
Ok(Box::pin(stream))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_openai::{Client, config::OpenAIConfig};
|
|
||||||
use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent};
|
|
||||||
use futures::stream::{BoxStream, StreamExt};
|
use futures::stream::{BoxStream, StreamExt};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
models::UnifiedRequest,
|
models::UnifiedRequest,
|
||||||
@@ -12,8 +11,9 @@ use crate::{
|
|||||||
use super::{ProviderResponse, ProviderStreamChunk};
|
use super::{ProviderResponse, ProviderStreamChunk};
|
||||||
|
|
||||||
pub struct OpenAIProvider {
|
pub struct OpenAIProvider {
|
||||||
client: Client<OpenAIConfig>,
|
client: reqwest::Client,
|
||||||
_config: crate::config::OpenAIConfig,
|
_config: crate::config::OpenAIConfig,
|
||||||
|
api_key: String,
|
||||||
pricing: Vec<crate::config::ModelPricing>,
|
pricing: Vec<crate::config::ModelPricing>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -21,16 +21,10 @@ impl OpenAIProvider {
|
|||||||
pub fn new(config: &crate::config::OpenAIConfig, app_config: &AppConfig) -> Result<Self> {
|
pub fn new(config: &crate::config::OpenAIConfig, app_config: &AppConfig) -> Result<Self> {
|
||||||
let api_key = app_config.get_api_key("openai")?;
|
let api_key = app_config.get_api_key("openai")?;
|
||||||
|
|
||||||
// Create OpenAIConfig with api key and base url
|
|
||||||
let openai_config = OpenAIConfig::default()
|
|
||||||
.with_api_key(api_key)
|
|
||||||
.with_api_base(&config.base_url);
|
|
||||||
|
|
||||||
let client = Client::with_config(openai_config);
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
client,
|
client: reqwest::Client::new(),
|
||||||
_config: config.clone(),
|
_config: config.clone(),
|
||||||
|
api_key,
|
||||||
pricing: app_config.pricing.openai.clone(),
|
pricing: app_config.pricing.openai.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -47,114 +41,70 @@ impl super::Provider for OpenAIProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn supports_multimodal(&self) -> bool {
|
fn supports_multimodal(&self) -> bool {
|
||||||
true // OpenAI supports vision models
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_completion(
|
async fn chat_completion(
|
||||||
&self,
|
&self,
|
||||||
request: UnifiedRequest,
|
request: UnifiedRequest,
|
||||||
) -> Result<ProviderResponse, AppError> {
|
) -> Result<ProviderResponse, AppError> {
|
||||||
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
|
let mut body = serde_json::json!({
|
||||||
|
"model": request.model,
|
||||||
// Convert UnifiedRequest messages to OpenAI messages
|
"messages": request.messages.iter().map(|m| {
|
||||||
let mut messages = Vec::with_capacity(request.messages.len());
|
serde_json::json!({
|
||||||
|
"role": m.role,
|
||||||
for msg in request.messages {
|
"content": m.content.iter().map(|p| {
|
||||||
let mut parts = Vec::with_capacity(msg.content.len());
|
match p {
|
||||||
|
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
||||||
for part in msg.content {
|
crate::models::ContentPart::Image(image_input) => {
|
||||||
match part {
|
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
||||||
crate::models::ContentPart::Text { text } => {
|
serde_json::json!({
|
||||||
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
|
"type": "image_url",
|
||||||
text,
|
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||||
}));
|
})
|
||||||
}
|
|
||||||
crate::models::ContentPart::Image(image_input) => {
|
|
||||||
let (base64_data, mime_type) = image_input.to_base64().await
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
|
||||||
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
|
|
||||||
|
|
||||||
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
|
|
||||||
image_url: ImageUrl {
|
|
||||||
url: data_url,
|
|
||||||
detail: Some(ImageDetail::Auto),
|
|
||||||
}
|
}
|
||||||
}));
|
}
|
||||||
}
|
}).collect::<Vec<_>>()
|
||||||
}
|
})
|
||||||
}
|
}).collect::<Vec<_>>(),
|
||||||
|
"stream": false,
|
||||||
|
});
|
||||||
|
|
||||||
let message = match msg.role.as_str() {
|
|
||||||
"system" => ChatCompletionRequestMessage::System(
|
|
||||||
ChatCompletionRequestSystemMessage {
|
|
||||||
content: ChatCompletionRequestSystemMessageContent::Text(
|
|
||||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
|
||||||
),
|
|
||||||
name: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
"assistant" => ChatCompletionRequestMessage::Assistant(
|
|
||||||
ChatCompletionRequestAssistantMessage {
|
|
||||||
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
|
|
||||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
|
||||||
)),
|
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
refusal: None,
|
|
||||||
audio: None,
|
|
||||||
#[allow(deprecated)]
|
|
||||||
function_call: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
_ => ChatCompletionRequestMessage::User(
|
|
||||||
ChatCompletionRequestUserMessage {
|
|
||||||
content: ChatCompletionRequestUserMessageContent::Array(parts),
|
|
||||||
name: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
};
|
|
||||||
messages.push(message);
|
|
||||||
}
|
|
||||||
|
|
||||||
if messages.is_empty() {
|
|
||||||
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build request using builder pattern
|
|
||||||
let mut builder = CreateChatCompletionRequestArgs::default();
|
|
||||||
builder.model(request.model.clone());
|
|
||||||
builder.messages(messages);
|
|
||||||
|
|
||||||
// Add optional parameters
|
|
||||||
if let Some(temp) = request.temperature {
|
if let Some(temp) = request.temperature {
|
||||||
builder.temperature(temp as f32);
|
body["temperature"] = serde_json::json!(temp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(max_tokens) = request.max_tokens {
|
if let Some(max_tokens) = request.max_tokens {
|
||||||
builder.max_tokens(max_tokens as u16);
|
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute API call
|
let response = self.client.post(format!("{}/chat/completions", self._config.base_url))
|
||||||
let response = self.client
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
.chat()
|
.json(&body)
|
||||||
.create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
|
.send()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||||
|
|
||||||
// Extract content from response
|
if !response.status().is_success() {
|
||||||
let content = response
|
let error_text = response.text().await.unwrap_or_default();
|
||||||
.choices
|
return Err(AppError::ProviderError(format!("OpenAI API error: {}", error_text)));
|
||||||
.first()
|
}
|
||||||
.and_then(|choice| choice.message.content.clone())
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
// Extract token usage
|
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||||
let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32;
|
|
||||||
let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32;
|
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
||||||
let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32;
|
let message = &choice["message"];
|
||||||
|
|
||||||
|
let content = message["content"].as_str().unwrap_or_default().to_string();
|
||||||
|
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
|
||||||
|
|
||||||
|
let usage = &resp_json["usage"];
|
||||||
|
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
||||||
|
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||||
|
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
||||||
|
|
||||||
Ok(ProviderResponse {
|
Ok(ProviderResponse {
|
||||||
content,
|
content,
|
||||||
|
reasoning_content,
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
completion_tokens,
|
completion_tokens,
|
||||||
total_tokens,
|
total_tokens,
|
||||||
@@ -174,7 +124,6 @@ impl super::Provider for OpenAIProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback to static pricing if not in registry
|
|
||||||
let (prompt_rate, completion_rate) = self.pricing.iter()
|
let (prompt_rate, completion_rate) = self.pricing.iter()
|
||||||
.find(|p| model.contains(&p.model))
|
.find(|p| model.contains(&p.model))
|
||||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||||
@@ -187,118 +136,78 @@ impl super::Provider for OpenAIProvider {
|
|||||||
&self,
|
&self,
|
||||||
request: UnifiedRequest,
|
request: UnifiedRequest,
|
||||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||||
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
|
let mut body = serde_json::json!({
|
||||||
|
"model": request.model,
|
||||||
// Convert UnifiedRequest messages to OpenAI messages
|
"messages": request.messages.iter().map(|m| {
|
||||||
let mut messages = Vec::with_capacity(request.messages.len());
|
serde_json::json!({
|
||||||
|
"role": m.role,
|
||||||
for msg in request.messages {
|
"content": m.content.iter().map(|p| {
|
||||||
let mut parts = Vec::with_capacity(msg.content.len());
|
match p {
|
||||||
|
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
||||||
for part in msg.content {
|
crate::models::ContentPart::Image(image_input) => {
|
||||||
match part {
|
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
||||||
crate::models::ContentPart::Text { text } => {
|
serde_json::json!({
|
||||||
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
|
"type": "image_url",
|
||||||
text,
|
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||||
}));
|
})
|
||||||
}
|
|
||||||
crate::models::ContentPart::Image(image_input) => {
|
|
||||||
let (base64_data, mime_type) = image_input.to_base64().await
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
|
||||||
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
|
|
||||||
|
|
||||||
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
|
|
||||||
image_url: ImageUrl {
|
|
||||||
url: data_url,
|
|
||||||
detail: Some(ImageDetail::Auto),
|
|
||||||
}
|
}
|
||||||
}));
|
}
|
||||||
}
|
}).collect::<Vec<_>>()
|
||||||
}
|
})
|
||||||
}
|
}).collect::<Vec<_>>(),
|
||||||
|
"stream": true,
|
||||||
let message = match msg.role.as_str() {
|
|
||||||
"system" => ChatCompletionRequestMessage::System(
|
|
||||||
ChatCompletionRequestSystemMessage {
|
|
||||||
content: ChatCompletionRequestSystemMessageContent::Text(
|
|
||||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
|
||||||
),
|
|
||||||
name: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
"assistant" => ChatCompletionRequestMessage::Assistant(
|
|
||||||
ChatCompletionRequestAssistantMessage {
|
|
||||||
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
|
|
||||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
|
||||||
)),
|
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
refusal: None,
|
|
||||||
audio: None,
|
|
||||||
#[allow(deprecated)]
|
|
||||||
function_call: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
_ => ChatCompletionRequestMessage::User(
|
|
||||||
ChatCompletionRequestUserMessage {
|
|
||||||
content: ChatCompletionRequestUserMessageContent::Array(parts),
|
|
||||||
name: None,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
};
|
|
||||||
messages.push(message);
|
|
||||||
}
|
|
||||||
|
|
||||||
if messages.is_empty() {
|
|
||||||
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build request using builder pattern
|
|
||||||
let mut builder = CreateChatCompletionRequestArgs::default();
|
|
||||||
builder.model(request.model.clone());
|
|
||||||
builder.messages(messages);
|
|
||||||
builder.stream(true); // Enable streaming
|
|
||||||
|
|
||||||
// Add optional parameters
|
|
||||||
if let Some(temp) = request.temperature {
|
|
||||||
builder.temperature(temp as f32);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(max_tokens) = request.max_tokens {
|
|
||||||
builder.max_tokens(max_tokens as u16);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute streaming API call
|
|
||||||
let stream = self.client
|
|
||||||
.chat()
|
|
||||||
.create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
|
|
||||||
.await
|
|
||||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
||||||
|
|
||||||
// Convert OpenAI stream to our stream format
|
|
||||||
let model = request.model.clone();
|
|
||||||
let stream = stream.map(move |chunk_result| {
|
|
||||||
match chunk_result {
|
|
||||||
Ok(chunk) => {
|
|
||||||
// Extract content from chunk
|
|
||||||
let content = chunk.choices.first()
|
|
||||||
.and_then(|choice| choice.delta.content.clone())
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
let finish_reason = chunk.choices.first()
|
|
||||||
.and_then(|choice| choice.finish_reason.clone())
|
|
||||||
.map(|reason| format!("{:?}", reason));
|
|
||||||
|
|
||||||
Ok(ProviderStreamChunk {
|
|
||||||
content,
|
|
||||||
finish_reason,
|
|
||||||
model: model.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
Err(e) => Err(AppError::ProviderError(e.to_string())),
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if let Some(temp) = request.temperature {
|
||||||
|
body["temperature"] = serde_json::json!(temp);
|
||||||
|
}
|
||||||
|
if let Some(max_tokens) = request.max_tokens {
|
||||||
|
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create eventsource stream
|
||||||
|
use reqwest_eventsource::{EventSource, Event};
|
||||||
|
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self._config.base_url))
|
||||||
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
|
.json(&body))
|
||||||
|
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||||
|
|
||||||
|
let model = request.model.clone();
|
||||||
|
|
||||||
|
let stream = async_stream::try_stream! {
|
||||||
|
let mut es = es;
|
||||||
|
while let Some(event) = es.next().await {
|
||||||
|
match event {
|
||||||
|
Ok(Event::Message(msg)) => {
|
||||||
|
if msg.data == "[DONE]" {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk: Value = serde_json::from_str(&msg.data)
|
||||||
|
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||||
|
|
||||||
|
if let Some(choice) = chunk["choices"].get(0) {
|
||||||
|
let delta = &choice["delta"];
|
||||||
|
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
||||||
|
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
|
||||||
|
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
||||||
|
|
||||||
|
yield ProviderStreamChunk {
|
||||||
|
content,
|
||||||
|
reasoning_content,
|
||||||
|
finish_reason,
|
||||||
|
model: model.clone(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(_) => continue,
|
||||||
|
Err(e) => {
|
||||||
|
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Box::pin(stream))
|
Ok(Box::pin(stream))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -102,6 +102,7 @@ async fn chat_completions(
|
|||||||
delta: ChatStreamDelta {
|
delta: ChatStreamDelta {
|
||||||
role: None,
|
role: None,
|
||||||
content: Some(chunk.content),
|
content: Some(chunk.content),
|
||||||
|
reasoning_content: chunk.reasoning_content,
|
||||||
},
|
},
|
||||||
finish_reason: chunk.finish_reason,
|
finish_reason: chunk.finish_reason,
|
||||||
}],
|
}],
|
||||||
@@ -177,6 +178,7 @@ async fn chat_completions(
|
|||||||
content: crate::models::MessageContent::Text {
|
content: crate::models::MessageContent::Text {
|
||||||
content: response.content
|
content: response.content
|
||||||
},
|
},
|
||||||
|
reasoning_content: response.reasoning_content,
|
||||||
},
|
},
|
||||||
finish_reason: Some("stop".to_string()),
|
finish_reason: Some("stop".to_string()),
|
||||||
}],
|
}],
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ pub struct AggregatingStream<S> {
|
|||||||
prompt_tokens: u32,
|
prompt_tokens: u32,
|
||||||
has_images: bool,
|
has_images: bool,
|
||||||
accumulated_content: String,
|
accumulated_content: String,
|
||||||
|
accumulated_reasoning: String,
|
||||||
logger: Arc<RequestLogger>,
|
logger: Arc<RequestLogger>,
|
||||||
client_manager: Arc<ClientManager>,
|
client_manager: Arc<ClientManager>,
|
||||||
model_registry: Arc<crate::models::registry::ModelRegistry>,
|
model_registry: Arc<crate::models::registry::ModelRegistry>,
|
||||||
@@ -46,6 +47,7 @@ where
|
|||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
has_images,
|
has_images,
|
||||||
accumulated_content: String::new(),
|
accumulated_content: String::new(),
|
||||||
|
accumulated_reasoning: String::new(),
|
||||||
logger,
|
logger,
|
||||||
client_manager,
|
client_manager,
|
||||||
model_registry,
|
model_registry,
|
||||||
@@ -71,8 +73,15 @@ where
|
|||||||
let has_images = self.has_images;
|
let has_images = self.has_images;
|
||||||
let registry = self.model_registry.clone();
|
let registry = self.model_registry.clone();
|
||||||
|
|
||||||
// Estimate completion tokens
|
// Estimate completion tokens (including reasoning if present)
|
||||||
let completion_tokens = estimate_completion_tokens(&self.accumulated_content, &model);
|
let content_tokens = estimate_completion_tokens(&self.accumulated_content, &model);
|
||||||
|
let reasoning_tokens = if !self.accumulated_reasoning.is_empty() {
|
||||||
|
estimate_completion_tokens(&self.accumulated_reasoning, &model)
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
|
||||||
|
let completion_tokens = content_tokens + reasoning_tokens;
|
||||||
let total_tokens = prompt_tokens + completion_tokens;
|
let total_tokens = prompt_tokens + completion_tokens;
|
||||||
let cost = provider.calculate_cost(&model, prompt_tokens, completion_tokens, ®istry);
|
let cost = provider.calculate_cost(&model, prompt_tokens, completion_tokens, ®istry);
|
||||||
|
|
||||||
@@ -116,6 +125,9 @@ where
|
|||||||
match &result {
|
match &result {
|
||||||
Poll::Ready(Some(Ok(chunk))) => {
|
Poll::Ready(Some(Ok(chunk))) => {
|
||||||
self.accumulated_content.push_str(&chunk.content);
|
self.accumulated_content.push_str(&chunk.content);
|
||||||
|
if let Some(reasoning) = &chunk.reasoning_content {
|
||||||
|
self.accumulated_reasoning.push_str(reasoning);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Poll::Ready(Some(Err(_))) => {
|
Poll::Ready(Some(Err(_))) => {
|
||||||
// If there's an error, we might still want to log what we got so far?
|
// If there's an error, we might still want to log what we got so far?
|
||||||
|
|||||||
Reference in New Issue
Block a user