269 lines
11 KiB
Rust
269 lines
11 KiB
Rust
use anyhow::Result;
|
|
use async_trait::async_trait;
|
|
use futures::stream::BoxStream;
|
|
use futures::StreamExt;
|
|
|
|
use super::helpers;
|
|
use super::{ProviderResponse, ProviderStreamChunk};
|
|
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
|
|
|
pub struct DeepSeekProvider {
|
|
client: reqwest::Client,
|
|
config: crate::config::DeepSeekConfig,
|
|
api_key: String,
|
|
pricing: Vec<crate::config::ModelPricing>,
|
|
}
|
|
|
|
impl DeepSeekProvider {
|
|
pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result<Self> {
|
|
let api_key = app_config.get_api_key("deepseek")?;
|
|
Self::new_with_key(config, app_config, api_key)
|
|
}
|
|
|
|
pub fn new_with_key(
|
|
config: &crate::config::DeepSeekConfig,
|
|
app_config: &AppConfig,
|
|
api_key: String,
|
|
) -> Result<Self> {
|
|
let client = reqwest::Client::builder()
|
|
.connect_timeout(std::time::Duration::from_secs(5))
|
|
.timeout(std::time::Duration::from_secs(300))
|
|
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
|
.pool_max_idle_per_host(4)
|
|
.tcp_keepalive(std::time::Duration::from_secs(30))
|
|
.build()?;
|
|
|
|
Ok(Self {
|
|
client,
|
|
config: config.clone(),
|
|
api_key,
|
|
pricing: app_config.pricing.deepseek.clone(),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl super::Provider for DeepSeekProvider {
|
|
fn name(&self) -> &str {
|
|
"deepseek"
|
|
}
|
|
|
|
fn supports_model(&self, model: &str) -> bool {
|
|
model.starts_with("deepseek-") || model.contains("deepseek")
|
|
}
|
|
|
|
fn supports_multimodal(&self) -> bool {
|
|
false
|
|
}
|
|
|
|
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
|
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
|
let mut body = helpers::build_openai_body(&request, messages_json, false);
|
|
|
|
// Sanitize and fix for deepseek-reasoner (R1)
|
|
if request.model == "deepseek-reasoner" {
|
|
if let Some(obj) = body.as_object_mut() {
|
|
// Remove unsupported parameters
|
|
obj.remove("temperature");
|
|
obj.remove("top_p");
|
|
obj.remove("presence_penalty");
|
|
obj.remove("frequency_penalty");
|
|
obj.remove("logit_bias");
|
|
obj.remove("logprobs");
|
|
obj.remove("top_logprobs");
|
|
|
|
// ENSURE: EVERY assistant message must have reasoning_content and valid content
|
|
if let Some(messages) = obj.get_mut("messages").and_then(|m| m.as_array_mut()) {
|
|
for m in messages {
|
|
if m["role"].as_str() == Some("assistant") {
|
|
// DeepSeek R1 requires reasoning_content for consistency in history.
|
|
if m.get("reasoning_content").is_none() || m["reasoning_content"].is_null() {
|
|
m["reasoning_content"] = serde_json::json!(" ");
|
|
}
|
|
// DeepSeek R1 often requires content to be a string, not null/array
|
|
if m.get("content").is_none() || m["content"].is_null() || m["content"].is_array() {
|
|
m["content"] = serde_json::json!("");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let response = self
|
|
.client
|
|
.post(format!("{}/chat/completions", self.config.base_url))
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.json(&body)
|
|
.send()
|
|
.await
|
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let error_text = response.text().await.unwrap_or_default();
|
|
tracing::error!("DeepSeek API error ({}): {}", status, error_text);
|
|
tracing::error!("Offending DeepSeek Request Body: {}", serde_json::to_string(&body).unwrap_or_default());
|
|
return Err(AppError::ProviderError(format!("DeepSeek API error ({}): {}", status, error_text)));
|
|
}
|
|
|
|
let resp_json: serde_json::Value = response
|
|
.json()
|
|
.await
|
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
|
|
helpers::parse_openai_response(&resp_json, request.model)
|
|
}
|
|
|
|
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
|
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
|
}
|
|
|
|
fn calculate_cost(
|
|
&self,
|
|
model: &str,
|
|
prompt_tokens: u32,
|
|
completion_tokens: u32,
|
|
cache_read_tokens: u32,
|
|
cache_write_tokens: u32,
|
|
registry: &crate::models::registry::ModelRegistry,
|
|
) -> f64 {
|
|
if let Some(metadata) = registry.find_model(model) {
|
|
if metadata.cost.is_some() {
|
|
return helpers::calculate_cost_with_registry(
|
|
model,
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
cache_read_tokens,
|
|
cache_write_tokens,
|
|
registry,
|
|
&self.pricing,
|
|
0.28,
|
|
0.42,
|
|
);
|
|
}
|
|
}
|
|
|
|
// Custom DeepSeek fallback that correctly handles cache hits
|
|
let (prompt_rate, completion_rate) = self
|
|
.pricing
|
|
.iter()
|
|
.find(|p| model.contains(&p.model))
|
|
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
|
.unwrap_or((0.28, 0.42)); // Default to DeepSeek's current API pricing
|
|
|
|
let cache_hit_rate = prompt_rate / 10.0;
|
|
let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens);
|
|
|
|
(non_cached_prompt as f64 * prompt_rate / 1_000_000.0)
|
|
+ (cache_read_tokens as f64 * cache_hit_rate / 1_000_000.0)
|
|
+ (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
|
}
|
|
|
|
async fn chat_completion_stream(
|
|
&self,
|
|
request: UnifiedRequest,
|
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
|
// DeepSeek doesn't support images in streaming, use text-only
|
|
let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?;
|
|
let mut body = helpers::build_openai_body(&request, messages_json, true);
|
|
|
|
// Sanitize and fix for deepseek-reasoner (R1)
|
|
if request.model == "deepseek-reasoner" {
|
|
if let Some(obj) = body.as_object_mut() {
|
|
// Keep stream_options if present (DeepSeek supports include_usage)
|
|
|
|
// Remove unsupported parameters
|
|
obj.remove("temperature");
|
|
|
|
obj.remove("top_p");
|
|
obj.remove("presence_penalty");
|
|
obj.remove("frequency_penalty");
|
|
obj.remove("logit_bias");
|
|
obj.remove("logprobs");
|
|
obj.remove("top_logprobs");
|
|
|
|
// ENSURE: EVERY assistant message must have reasoning_content and valid content
|
|
if let Some(messages) = obj.get_mut("messages").and_then(|m| m.as_array_mut()) {
|
|
for m in messages {
|
|
if m["role"].as_str() == Some("assistant") {
|
|
// DeepSeek R1 requires reasoning_content for consistency in history.
|
|
if m.get("reasoning_content").is_none() || m["reasoning_content"].is_null() {
|
|
m["reasoning_content"] = serde_json::json!(" ");
|
|
}
|
|
// DeepSeek R1 often requires content to be a string, not null/array
|
|
if m.get("content").is_none() || m["content"].is_null() || m["content"].is_array() {
|
|
m["content"] = serde_json::json!("");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let url = format!("{}/chat/completions", self.config.base_url);
|
|
let api_key = self.api_key.clone();
|
|
let probe_client = self.client.clone();
|
|
let probe_body = body.clone();
|
|
let model = request.model.clone();
|
|
|
|
let es = reqwest_eventsource::EventSource::new(
|
|
self.client
|
|
.post(&url)
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.json(&body),
|
|
)
|
|
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
|
|
|
let stream = async_stream::try_stream! {
|
|
let mut es = es;
|
|
while let Some(event) = es.next().await {
|
|
match event {
|
|
Ok(reqwest_eventsource::Event::Message(msg)) => {
|
|
if msg.data == "[DONE]" {
|
|
break;
|
|
}
|
|
|
|
let chunk: serde_json::Value = serde_json::from_str(&msg.data)
|
|
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
|
|
|
if let Some(p_chunk) = helpers::parse_openai_stream_chunk(&chunk, &model, None) {
|
|
yield p_chunk?;
|
|
}
|
|
}
|
|
Ok(_) => continue,
|
|
Err(e) => {
|
|
// Attempt to probe for the actual error body
|
|
let probe_resp = probe_client
|
|
.post(&url)
|
|
.header("Authorization", format!("Bearer {}", api_key))
|
|
.json(&probe_body)
|
|
.send()
|
|
.await;
|
|
|
|
match probe_resp {
|
|
Ok(r) if !r.status().is_success() => {
|
|
let status = r.status();
|
|
let error_body = r.text().await.unwrap_or_default();
|
|
tracing::error!("DeepSeek Stream Error Probe ({}): {}", status, error_body);
|
|
// Log the offending request body at ERROR level so it shows up in standard logs
|
|
tracing::error!("Offending DeepSeek Request Body: {}", serde_json::to_string(&probe_body).unwrap_or_default());
|
|
Err(AppError::ProviderError(format!("DeepSeek API error ({}): {}", status, error_body)))?;
|
|
}
|
|
Ok(_) => {
|
|
Err(AppError::ProviderError(format!("Stream error (probe returned 200): {}", e)))?;
|
|
}
|
|
Err(probe_err) => {
|
|
tracing::error!("DeepSeek Stream Error Probe failed: {}", probe_err);
|
|
Err(AppError::ProviderError(format!("Stream error (probe failed: {}): {}", probe_err, e)))?;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
Ok(Box::pin(stream))
|
|
}
|
|
}
|