- Refactored OpenAI, DeepSeek, Grok, and Ollama to manual JSON parsing to capture 'reasoning_content' and 'thought' fields. - Implemented real-time streaming of reasoning blocks. - Added token aggregation and cost tracking for reasoning tokens. - Updated unified models to include 'reasoning_content' in API responses.
210 lines
8.3 KiB
Rust
210 lines
8.3 KiB
Rust
use async_trait::async_trait;
|
|
use anyhow::Result;
|
|
use futures::stream::{BoxStream, StreamExt};
|
|
use serde_json::Value;
|
|
|
|
use crate::{
|
|
models::UnifiedRequest,
|
|
errors::AppError,
|
|
config::AppConfig,
|
|
};
|
|
use super::{ProviderResponse, ProviderStreamChunk};
|
|
|
|
pub struct DeepSeekProvider {
|
|
client: reqwest::Client,
|
|
config: crate::config::DeepSeekConfig,
|
|
api_key: String,
|
|
pricing: Vec<crate::config::ModelPricing>,
|
|
}
|
|
|
|
impl DeepSeekProvider {
|
|
pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result<Self> {
|
|
let api_key = app_config.get_api_key("deepseek")?;
|
|
|
|
Ok(Self {
|
|
client: reqwest::Client::new(),
|
|
config: config.clone(),
|
|
api_key,
|
|
pricing: app_config.pricing.deepseek.clone(),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl super::Provider for DeepSeekProvider {
|
|
fn name(&self) -> &str {
|
|
"deepseek"
|
|
}
|
|
|
|
fn supports_model(&self, model: &str) -> bool {
|
|
model.starts_with("deepseek-") || model.contains("deepseek")
|
|
}
|
|
|
|
fn supports_multimodal(&self) -> bool {
|
|
false
|
|
}
|
|
|
|
async fn chat_completion(
|
|
&self,
|
|
request: UnifiedRequest,
|
|
) -> Result<ProviderResponse, AppError> {
|
|
// Build the OpenAI-compatible body
|
|
let mut body = serde_json::json!({
|
|
"model": request.model,
|
|
"messages": request.messages.iter().map(|m| {
|
|
serde_json::json!({
|
|
"role": m.role,
|
|
"content": m.content.iter().map(|p| {
|
|
match p {
|
|
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
|
crate::models::ContentPart::Image(image_input) => {
|
|
// DeepSeek currently doesn't support images in the same way, but we'll try to be standard
|
|
let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default();
|
|
serde_json::json!({
|
|
"type": "image_url",
|
|
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
|
})
|
|
}
|
|
}
|
|
}).collect::<Vec<_>>()
|
|
})
|
|
}).collect::<Vec<_>>(),
|
|
"stream": false,
|
|
});
|
|
|
|
if let Some(temp) = request.temperature {
|
|
body["temperature"] = serde_json::json!(temp);
|
|
}
|
|
if let Some(max_tokens) = request.max_tokens {
|
|
body["max_tokens"] = serde_json::json!(max_tokens);
|
|
}
|
|
|
|
let response = self.client.post(format!("{}/chat/completions", self.config.base_url))
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.json(&body)
|
|
.send()
|
|
.await
|
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
|
|
if !response.status().is_success() {
|
|
let error_text = response.text().await.unwrap_or_default();
|
|
return Err(AppError::ProviderError(format!("DeepSeek API error: {}", error_text)));
|
|
}
|
|
|
|
let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
|
|
let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
|
let message = &choice["message"];
|
|
|
|
let content = message["content"].as_str().unwrap_or_default().to_string();
|
|
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
|
|
|
|
let usage = &resp_json["usage"];
|
|
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
|
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
|
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
|
|
|
Ok(ProviderResponse {
|
|
content,
|
|
reasoning_content,
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
total_tokens,
|
|
model: request.model,
|
|
})
|
|
}
|
|
|
|
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
|
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
|
}
|
|
|
|
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
|
|
if let Some(metadata) = registry.find_model(model) {
|
|
if let Some(cost) = &metadata.cost {
|
|
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
|
|
(completion_tokens as f64 * cost.output / 1_000_000.0);
|
|
}
|
|
}
|
|
|
|
let (prompt_rate, completion_rate) = self.pricing.iter()
|
|
.find(|p| model.contains(&p.model))
|
|
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
|
.unwrap_or((0.14, 0.28));
|
|
|
|
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
|
}
|
|
|
|
async fn chat_completion_stream(
|
|
&self,
|
|
request: UnifiedRequest,
|
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
|
let mut body = serde_json::json!({
|
|
"model": request.model,
|
|
"messages": request.messages.iter().map(|m| {
|
|
serde_json::json!({
|
|
"role": m.role,
|
|
"content": m.content.iter().map(|p| {
|
|
match p {
|
|
crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }),
|
|
crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }),
|
|
}
|
|
}).collect::<Vec<_>>()
|
|
})
|
|
}).collect::<Vec<_>>(),
|
|
"stream": true,
|
|
});
|
|
|
|
if let Some(temp) = request.temperature {
|
|
body["temperature"] = serde_json::json!(temp);
|
|
}
|
|
if let Some(max_tokens) = request.max_tokens {
|
|
body["max_tokens"] = serde_json::json!(max_tokens);
|
|
}
|
|
|
|
// Create eventsource stream
|
|
use reqwest_eventsource::{EventSource, Event};
|
|
let es = EventSource::new(self.client.post(format!("{}/chat/completions", self.config.base_url))
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.json(&body))
|
|
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
|
|
|
let model = request.model.clone();
|
|
|
|
let stream = async_stream::try_stream! {
|
|
let mut es = es;
|
|
while let Some(event) = es.next().await {
|
|
match event {
|
|
Ok(Event::Message(msg)) => {
|
|
if msg.data == "[DONE]" {
|
|
break;
|
|
}
|
|
|
|
let chunk: Value = serde_json::from_str(&msg.data)
|
|
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
|
|
|
if let Some(choice) = chunk["choices"].get(0) {
|
|
let delta = &choice["delta"];
|
|
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
|
let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string());
|
|
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
|
|
|
yield ProviderStreamChunk {
|
|
content,
|
|
reasoning_content,
|
|
finish_reason,
|
|
model: model.clone(),
|
|
};
|
|
}
|
|
}
|
|
Ok(_) => continue,
|
|
Err(e) => {
|
|
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
Ok(Box::pin(stream))
|
|
}
|
|
}
|