use async_trait::async_trait; use anyhow::Result; use futures::stream::{BoxStream, StreamExt}; use serde_json::Value; use crate::{ models::UnifiedRequest, errors::AppError, config::AppConfig, }; use super::{ProviderResponse, ProviderStreamChunk}; pub struct DeepSeekProvider { client: reqwest::Client, config: crate::config::DeepSeekConfig, api_key: String, pricing: Vec, } impl DeepSeekProvider { pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result { let api_key = app_config.get_api_key("deepseek")?; Ok(Self { client: reqwest::Client::new(), config: config.clone(), api_key, pricing: app_config.pricing.deepseek.clone(), }) } } #[async_trait] impl super::Provider for DeepSeekProvider { fn name(&self) -> &str { "deepseek" } fn supports_model(&self, model: &str) -> bool { model.starts_with("deepseek-") || model.contains("deepseek") } fn supports_multimodal(&self) -> bool { false } async fn chat_completion( &self, request: UnifiedRequest, ) -> Result { // Build the OpenAI-compatible body let mut body = serde_json::json!({ "model": request.model, "messages": request.messages.iter().map(|m| { serde_json::json!({ "role": m.role, "content": m.content.iter().map(|p| { match p { crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }), crate::models::ContentPart::Image(image_input) => { // DeepSeek currently doesn't support images in the same way, but we'll try to be standard let (base64_data, mime_type) = futures::executor::block_on(image_input.to_base64()).unwrap_or_default(); serde_json::json!({ "type": "image_url", "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) } }) } } }).collect::>() }) }).collect::>(), "stream": false, }); if let Some(temp) = request.temperature { body["temperature"] = serde_json::json!(temp); } if let Some(max_tokens) = request.max_tokens { body["max_tokens"] = serde_json::json!(max_tokens); } let response = self.client.post(format!("{}/chat/completions", self.config.base_url)) .header("Authorization", format!("Bearer {}", self.api_key)) .json(&body) .send() .await .map_err(|e| AppError::ProviderError(e.to_string()))?; if !response.status().is_success() { let error_text = response.text().await.unwrap_or_default(); return Err(AppError::ProviderError(format!("DeepSeek API error: {}", error_text))); } let resp_json: Value = response.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?; let choice = resp_json["choices"].get(0).ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?; let message = &choice["message"]; let content = message["content"].as_str().unwrap_or_default().to_string(); let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string()); let usage = &resp_json["usage"]; let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32; let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32; let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32; Ok(ProviderResponse { content, reasoning_content, prompt_tokens, completion_tokens, total_tokens, model: request.model, }) } fn estimate_tokens(&self, request: &UnifiedRequest) -> Result { Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) } fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 { if let Some(metadata) = registry.find_model(model) { if let Some(cost) = &metadata.cost { return (prompt_tokens as f64 * cost.input / 1_000_000.0) + (completion_tokens as f64 * cost.output / 1_000_000.0); } } let (prompt_rate, completion_rate) = self.pricing.iter() .find(|p| model.contains(&p.model)) .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) .unwrap_or((0.14, 0.28)); (prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) } async fn chat_completion_stream( &self, request: UnifiedRequest, ) -> Result>, AppError> { let mut body = serde_json::json!({ "model": request.model, "messages": request.messages.iter().map(|m| { serde_json::json!({ "role": m.role, "content": m.content.iter().map(|p| { match p { crate::models::ContentPart::Text { text } => serde_json::json!({ "type": "text", "text": text }), crate::models::ContentPart::Image(_) => serde_json::json!({ "type": "text", "text": "[Image]" }), } }).collect::>() }) }).collect::>(), "stream": true, }); if let Some(temp) = request.temperature { body["temperature"] = serde_json::json!(temp); } if let Some(max_tokens) = request.max_tokens { body["max_tokens"] = serde_json::json!(max_tokens); } // Create eventsource stream use reqwest_eventsource::{EventSource, Event}; let es = EventSource::new(self.client.post(format!("{}/chat/completions", self.config.base_url)) .header("Authorization", format!("Bearer {}", self.api_key)) .json(&body)) .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; let model = request.model.clone(); let stream = async_stream::try_stream! { let mut es = es; while let Some(event) = es.next().await { match event { Ok(Event::Message(msg)) => { if msg.data == "[DONE]" { break; } let chunk: Value = serde_json::from_str(&msg.data) .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; if let Some(choice) = chunk["choices"].get(0) { let delta = &choice["delta"]; let content = delta["content"].as_str().unwrap_or_default().to_string(); let reasoning_content = delta["reasoning_content"].as_str().map(|s| s.to_string()); let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string()); yield ProviderStreamChunk { content, reasoning_content, finish_reason, model: model.clone(), }; } } Ok(_) => continue, Err(e) => { Err(AppError::ProviderError(format!("Stream error: {}", e)))?; } } } }; Ok(Box::pin(stream)) } }