feat: add support for reasoning (thinking) models across all providers
- Refactored OpenAI, DeepSeek, Grok, and Ollama to manual JSON parsing to capture 'reasoning_content' and 'thought' fields. - Implemented real-time streaming of reasoning blocks. - Added token aggregation and cost tracking for reasoning tokens. - Updated unified models to include 'reasoning_content' in API responses.
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
use async_trait::async_trait;
|
||||
use anyhow::Result;
|
||||
use async_openai::{Client, config::OpenAIConfig};
|
||||
use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent};
|
||||
use futures::stream::{BoxStream, StreamExt};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::{
|
||||
models::UnifiedRequest,
|
||||
@@ -12,8 +11,9 @@ use crate::{
|
||||
use super::{ProviderResponse, ProviderStreamChunk};
|
||||
|
||||
pub struct DeepSeekProvider {
|
||||
client: Client<OpenAIConfig>, // DeepSeek uses OpenAI-compatible API
|
||||
_config: crate::config::DeepSeekConfig,
|
||||
client: reqwest::Client,
|
||||
config: crate::config::DeepSeekConfig,
|
||||
api_key: String,
|
||||
pricing: Vec<crate::config::ModelPricing>,
|
||||
}
|
||||
|
||||
@@ -21,16 +21,10 @@ impl DeepSeekProvider {
|
||||
pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result<Self> {
|
||||
let api_key = app_config.get_api_key("deepseek")?;
|
||||
|
||||
// Create OpenAIConfig with api key and base url
|
||||
let openai_config = OpenAIConfig::default()
|
||||
.with_api_key(api_key)
|
||||
.with_api_base(&config.base_url);
|
||||
|
||||
let client = Client::with_config(openai_config);
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
_config: config.clone(),
|
||||
client: reqwest::Client::new(),
|
||||
config: config.clone(),
|
||||
api_key,
|
||||
pricing: app_config.pricing.deepseek.clone(),
|
||||
})
|
||||
}
|
||||
@@ -47,114 +41,72 @@ impl super::Provider for DeepSeekProvider {
|
||||
}
|
||||
|
||||
fn supports_multimodal(&self) -> bool {
|
||||
false // DeepSeek doesn't support general vision (only OCR)
|
||||
false
|
||||
}
|
||||
|
||||
async fn chat_completion(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<ProviderResponse, AppError> {
|
||||
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
|
||||
|
||||
// Convert UnifiedRequest messages to OpenAI-compatible messages
|
||||
let mut messages = Vec::with_capacity(request.messages.len());
|
||||
|
||||
for msg in request.messages {
|
||||
let mut parts = Vec::with_capacity(msg.content.len());
|
||||
|
||||
for part in msg.content {
|
||||
match part {
|
||||
crate::models::ContentPart::Text { text } => {
|
||||
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
|
||||
text,
|
||||
}));
|
||||
}
|
||||
crate::models::ContentPart::Image(image_input) => {
|
||||
let (base64_data, mime_type) = image_input.to_base64().await
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
||||
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
|
||||
|
||||
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
|
||||
image_url: ImageUrl {
|
||||
url: data_url,
|
||||
detail: Some(ImageDetail::Auto),
|
||||
// Build the OpenAI-compatible body
|
||||
let mut body = serde_json::json!({
|
||||
"model": request.model,
|
||||
"messages": request.messages.iter().map(|m| {
|
||||
serde_json::json!({
|
||||
"role": m.role,
|
||||
"content": m.content.iter().map(|p| {
|
||||
match p {
|
||||
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
||||
crate::models::ContentPart::Image(image_input) => {
|
||||
// DeepSeek currently doesn't support images in the same way, but we'll try to be standard
|
||||
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
||||
serde_json::json!({
|
||||
"type": "image_url",
|
||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||
})
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let message = match msg.role.as_str() {
|
||||
"system" => ChatCompletionRequestMessage::System(
|
||||
ChatCompletionRequestSystemMessage {
|
||||
content: ChatCompletionRequestSystemMessageContent::Text(
|
||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
||||
),
|
||||
name: None,
|
||||
}
|
||||
),
|
||||
"assistant" => ChatCompletionRequestMessage::Assistant(
|
||||
ChatCompletionRequestAssistantMessage {
|
||||
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
|
||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
refusal: None,
|
||||
audio: None,
|
||||
#[allow(deprecated)]
|
||||
function_call: None,
|
||||
}
|
||||
),
|
||||
_ => ChatCompletionRequestMessage::User(
|
||||
ChatCompletionRequestUserMessage {
|
||||
content: ChatCompletionRequestUserMessageContent::Array(parts),
|
||||
name: None,
|
||||
}
|
||||
),
|
||||
};
|
||||
messages.push(message);
|
||||
}
|
||||
|
||||
if messages.is_empty() {
|
||||
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
|
||||
}
|
||||
|
||||
// Build request using builder pattern
|
||||
let mut builder = CreateChatCompletionRequestArgs::default();
|
||||
builder.model(request.model.clone());
|
||||
builder.messages(messages);
|
||||
|
||||
// Add optional parameters
|
||||
}
|
||||
}).collect::<Vec<_>>()
|
||||
})
|
||||
}).collect::<Vec<_>>(),
|
||||
"stream": false,
|
||||
});
|
||||
|
||||
if let Some(temp) = request.temperature {
|
||||
builder.temperature(temp as f32);
|
||||
body["temperature"] = serde_json::json!(temp);
|
||||
}
|
||||
|
||||
if let Some(max_tokens) = request.max_tokens {
|
||||
builder.max_tokens(max_tokens as u16);
|
||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||
}
|
||||
|
||||
// Execute API call
|
||||
let response = self.client
|
||||
.chat()
|
||||
.create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
|
||||
|
||||
let response = self.client.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(AppError::ProviderError(format!("DeepSeek API error: {}", error_text)));
|
||||
}
|
||||
|
||||
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
// Extract content from response
|
||||
let content = response
|
||||
.choices
|
||||
.first()
|
||||
.and_then(|choice| choice.message.content.clone())
|
||||
.unwrap_or_default();
|
||||
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
||||
let message = &choice["message"];
|
||||
|
||||
// Extract token usage
|
||||
let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32;
|
||||
let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32;
|
||||
let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32;
|
||||
let content = message["content"].as_str().unwrap_or_default().to_string();
|
||||
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
|
||||
|
||||
let usage = &resp_json["usage"];
|
||||
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
|
||||
Ok(ProviderResponse {
|
||||
content,
|
||||
reasoning_content,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
@@ -177,7 +129,7 @@ impl super::Provider for DeepSeekProvider {
|
||||
let (prompt_rate, completion_rate) = self.pricing.iter()
|
||||
.find(|p| model.contains(&p.model))
|
||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||
.unwrap_or((0.14, 0.28)); // Default to DeepSeek V3 price if not found
|
||||
.unwrap_or((0.14, 0.28));
|
||||
|
||||
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
||||
}
|
||||
@@ -186,118 +138,72 @@ impl super::Provider for DeepSeekProvider {
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
|
||||
|
||||
// Convert UnifiedRequest messages to OpenAI-compatible messages
|
||||
let mut messages = Vec::with_capacity(request.messages.len());
|
||||
|
||||
for msg in request.messages {
|
||||
let mut parts = Vec::with_capacity(msg.content.len());
|
||||
|
||||
for part in msg.content {
|
||||
match part {
|
||||
crate::models::ContentPart::Text { text } => {
|
||||
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
|
||||
text,
|
||||
}));
|
||||
}
|
||||
crate::models::ContentPart::Image(image_input) => {
|
||||
let (base64_data, mime_type) = image_input.to_base64().await
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
||||
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
|
||||
|
||||
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
|
||||
image_url: ImageUrl {
|
||||
url: data_url,
|
||||
detail: Some(ImageDetail::Auto),
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let message = match msg.role.as_str() {
|
||||
"system" => ChatCompletionRequestMessage::System(
|
||||
ChatCompletionRequestSystemMessage {
|
||||
content: ChatCompletionRequestSystemMessageContent::Text(
|
||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
||||
),
|
||||
name: None,
|
||||
}
|
||||
),
|
||||
"assistant" => ChatCompletionRequestMessage::Assistant(
|
||||
ChatCompletionRequestAssistantMessage {
|
||||
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
|
||||
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
||||
)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
refusal: None,
|
||||
audio: None,
|
||||
#[allow(deprecated)]
|
||||
function_call: None,
|
||||
}
|
||||
),
|
||||
_ => ChatCompletionRequestMessage::User(
|
||||
ChatCompletionRequestUserMessage {
|
||||
content: ChatCompletionRequestUserMessageContent::Array(parts),
|
||||
name: None,
|
||||
}
|
||||
),
|
||||
};
|
||||
messages.push(message);
|
||||
}
|
||||
|
||||
if messages.is_empty() {
|
||||
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
|
||||
}
|
||||
|
||||
// Build request using builder pattern
|
||||
let mut builder = CreateChatCompletionRequestArgs::default();
|
||||
builder.model(request.model.clone());
|
||||
builder.messages(messages);
|
||||
builder.stream(true); // Enable streaming
|
||||
|
||||
// Add optional parameters
|
||||
if let Some(temp) = request.temperature {
|
||||
builder.temperature(temp as f32);
|
||||
}
|
||||
|
||||
if let Some(max_tokens) = request.max_tokens {
|
||||
builder.max_tokens(max_tokens as u16);
|
||||
}
|
||||
|
||||
// Execute streaming API call
|
||||
let stream = self.client
|
||||
.chat()
|
||||
.create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
// Convert OpenAI stream to our stream format
|
||||
let model = request.model.clone();
|
||||
let stream = stream.map(move |chunk_result| {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
// Extract content from chunk
|
||||
let content = chunk.choices.first()
|
||||
.and_then(|choice| choice.delta.content.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
let finish_reason = chunk.choices.first()
|
||||
.and_then(|choice| choice.finish_reason.clone())
|
||||
.map(|reason| format!("{:?}", reason));
|
||||
|
||||
Ok(ProviderStreamChunk {
|
||||
content,
|
||||
finish_reason,
|
||||
model: model.clone(),
|
||||
})
|
||||
}
|
||||
Err(e) => Err(AppError::ProviderError(e.to_string())),
|
||||
}
|
||||
let mut body = serde_json::json!({
|
||||
"model": request.model,
|
||||
"messages": request.messages.iter().map(|m| {
|
||||
serde_json::json!({
|
||||
"role": m.role,
|
||||
"content": m.content.iter().map(|p| {
|
||||
match p {
|
||||
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
||||
crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }),
|
||||
}
|
||||
}).collect::<Vec<_>>()
|
||||
})
|
||||
}).collect::<Vec<_>>(),
|
||||
"stream": true,
|
||||
});
|
||||
|
||||
if let Some(temp) = request.temperature {
|
||||
body["temperature"] = serde_json::json!(temp);
|
||||
}
|
||||
if let Some(max_tokens) = request.max_tokens {
|
||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||
}
|
||||
|
||||
// Create eventsource stream
|
||||
use reqwest_eventsource::{EventSource, Event};
|
||||
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body))
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
let model = request.model.clone();
|
||||
|
||||
let stream = async_stream::try_stream! {
|
||||
let mut es = es;
|
||||
while let Some(event) = es.next().await {
|
||||
match event {
|
||||
Ok(Event::Message(msg)) => {
|
||||
if msg.data == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
|
||||
let chunk: Value = serde_json::from_str(&msg.data)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||
|
||||
if let Some(choice) = chunk["choices"].get(0) {
|
||||
let delta = &choice["delta"];
|
||||
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
||||
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
|
||||
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
||||
|
||||
yield ProviderStreamChunk {
|
||||
content,
|
||||
reasoning_content,
|
||||
finish_reason,
|
||||
model: model.clone(),
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok(_) => continue,
|
||||
Err(e) => {
|
||||
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user