- Fix DeepSeek R1 (reasoner) 400 errors by ensuring assistant messages with tool_calls in history always have non-null 'content' and 'reasoning_content'. - Implement deterministic tool call ID truncation (max 40 chars) for OpenAI compatibility (fixes errors when history contains long Gemini signatures). - Automatic transition from 'max_tokens' to 'max_completion_tokens' for newer OpenAI models (o1, o3, gpt-5-nano). - Added 'reasoning' and 'thought' aliases to reasoning_content for robust deserialization from various clients.
446 lines
17 KiB
Rust
446 lines
17 KiB
Rust
use super::{ProviderResponse, ProviderStreamChunk, StreamUsage};
|
|
use crate::errors::AppError;
|
|
use crate::models::{ContentPart, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest};
|
|
use futures::stream::{BoxStream, StreamExt};
|
|
use serde_json::Value;
|
|
|
|
/// Convert messages to OpenAI-compatible JSON, resolving images asynchronously.
|
|
///
|
|
/// This avoids the deadlock caused by `futures::executor::block_on` inside a
|
|
/// Tokio async context. All image base64 conversions are awaited properly.
|
|
/// Handles tool-calling messages: assistant messages with tool_calls, and
|
|
/// tool-role messages with tool_call_id/name.
|
|
pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result<Vec<serde_json::Value>, AppError> {
|
|
let mut result = Vec::new();
|
|
for m in messages {
|
|
// Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." }
|
|
if m.role == "tool" {
|
|
let text_content = m
|
|
.content
|
|
.first()
|
|
.map(|p| match p {
|
|
ContentPart::Text { text } => text.clone(),
|
|
ContentPart::Image(_) => "[Image]".to_string(),
|
|
})
|
|
.unwrap_or_default();
|
|
|
|
let mut msg = serde_json::json!({
|
|
"role": "tool",
|
|
"content": text_content
|
|
});
|
|
if let Some(tool_call_id) = &m.tool_call_id {
|
|
// OpenAI and others have a 40-char limit for tool_call_id.
|
|
// Gemini signatures (56 chars) must be shortened for compatibility.
|
|
let id = if tool_call_id.len() > 40 {
|
|
&tool_call_id[..40]
|
|
} else {
|
|
tool_call_id
|
|
};
|
|
msg["tool_call_id"] = serde_json::json!(id);
|
|
}
|
|
if let Some(name) = &m.name {
|
|
msg["name"] = serde_json::json!(name);
|
|
}
|
|
result.push(msg);
|
|
continue;
|
|
}
|
|
|
|
// Build content parts for non-tool messages
|
|
let mut parts = Vec::new();
|
|
for p in &m.content {
|
|
match p {
|
|
ContentPart::Text { text } => {
|
|
parts.push(serde_json::json!({ "type": "text", "text": text }));
|
|
}
|
|
ContentPart::Image(image_input) => {
|
|
let (base64_data, mime_type) = image_input
|
|
.to_base64()
|
|
.await
|
|
.map_err(|e| AppError::MultimodalError(e.to_string()))?;
|
|
parts.push(serde_json::json!({
|
|
"type": "image_url",
|
|
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
|
}));
|
|
}
|
|
}
|
|
}
|
|
|
|
let mut msg = serde_json::json!({ "role": m.role });
|
|
|
|
// Include reasoning_content if present (DeepSeek R1/reasoner requires this in history)
|
|
if let Some(reasoning) = &m.reasoning_content {
|
|
msg["reasoning_content"] = serde_json::json!(reasoning);
|
|
}
|
|
|
|
// For assistant messages with tool_calls, content can be empty string
|
|
if let Some(tool_calls) = &m.tool_calls {
|
|
// Sanitize tool call IDs for OpenAI compatibility (max 40 chars)
|
|
let sanitized_calls: Vec<_> = tool_calls.iter().map(|tc| {
|
|
let mut sanitized = tc.clone();
|
|
if sanitized.id.len() > 40 {
|
|
sanitized.id = sanitized.id[..40].to_string();
|
|
}
|
|
sanitized
|
|
}).collect();
|
|
|
|
if parts.is_empty() {
|
|
msg["content"] = serde_json::json!("");
|
|
} else {
|
|
msg["content"] = serde_json::json!(parts);
|
|
}
|
|
msg["tool_calls"] = serde_json::json!(sanitized_calls);
|
|
} else {
|
|
msg["content"] = serde_json::json!(parts);
|
|
}
|
|
|
|
if let Some(name) = &m.name {
|
|
msg["name"] = serde_json::json!(name);
|
|
}
|
|
|
|
result.push(msg);
|
|
}
|
|
Ok(result)
|
|
}
|
|
|
|
/// Convert messages to OpenAI-compatible JSON, but replace images with a
|
|
/// text placeholder "[Image]". Useful for providers that don't support
|
|
/// multimodal in streaming mode or at all.
|
|
///
|
|
/// Handles tool-calling messages identically to `messages_to_openai_json`:
|
|
/// assistant messages with `tool_calls`, and tool-role messages with
|
|
/// `tool_call_id`/`name`.
|
|
pub async fn messages_to_openai_json_text_only(
|
|
messages: &[UnifiedMessage],
|
|
) -> Result<Vec<serde_json::Value>, AppError> {
|
|
let mut result = Vec::new();
|
|
for m in messages {
|
|
// Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." }
|
|
if m.role == "tool" {
|
|
let text_content = m
|
|
.content
|
|
.first()
|
|
.map(|p| match p {
|
|
ContentPart::Text { text } => text.clone(),
|
|
ContentPart::Image(_) => "[Image]".to_string(),
|
|
})
|
|
.unwrap_or_default();
|
|
|
|
let mut msg = serde_json::json!({
|
|
"role": "tool",
|
|
"content": text_content
|
|
});
|
|
if let Some(tool_call_id) = &m.tool_call_id {
|
|
// OpenAI and others have a 40-char limit for tool_call_id.
|
|
let id = if tool_call_id.len() > 40 {
|
|
&tool_call_id[..40]
|
|
} else {
|
|
tool_call_id
|
|
};
|
|
msg["tool_call_id"] = serde_json::json!(id);
|
|
}
|
|
if let Some(name) = &m.name {
|
|
msg["name"] = serde_json::json!(name);
|
|
}
|
|
result.push(msg);
|
|
continue;
|
|
}
|
|
|
|
// Build content parts for non-tool messages (images become "[Image]" text)
|
|
let mut parts = Vec::new();
|
|
for p in &m.content {
|
|
match p {
|
|
ContentPart::Text { text } => {
|
|
parts.push(serde_json::json!({ "type": "text", "text": text }));
|
|
}
|
|
ContentPart::Image(_) => {
|
|
parts.push(serde_json::json!({ "type": "text", "text": "[Image]" }));
|
|
}
|
|
}
|
|
}
|
|
|
|
let mut msg = serde_json::json!({ "role": m.role });
|
|
|
|
// Include reasoning_content if present (DeepSeek R1/reasoner requires this in history)
|
|
if let Some(reasoning) = &m.reasoning_content {
|
|
msg["reasoning_content"] = serde_json::json!(reasoning);
|
|
}
|
|
|
|
// For assistant messages with tool_calls, content can be empty string
|
|
if let Some(tool_calls) = &m.tool_calls {
|
|
// Sanitize tool call IDs for OpenAI compatibility (max 40 chars)
|
|
let sanitized_calls: Vec<_> = tool_calls.iter().map(|tc| {
|
|
let mut sanitized = tc.clone();
|
|
if sanitized.id.len() > 40 {
|
|
sanitized.id = sanitized.id[..40].to_string();
|
|
}
|
|
sanitized
|
|
}).collect();
|
|
|
|
if parts.is_empty() {
|
|
msg["content"] = serde_json::json!("");
|
|
} else {
|
|
msg["content"] = serde_json::json!(parts);
|
|
}
|
|
msg["tool_calls"] = serde_json::json!(sanitized_calls);
|
|
} else {
|
|
msg["content"] = serde_json::json!(parts);
|
|
}
|
|
|
|
if let Some(name) = &m.name {
|
|
msg["name"] = serde_json::json!(name);
|
|
}
|
|
|
|
result.push(msg);
|
|
}
|
|
Ok(result)
|
|
}
|
|
|
|
/// Build an OpenAI-compatible request body from a UnifiedRequest and pre-converted messages.
|
|
/// Includes tools and tool_choice when present.
|
|
/// When streaming, adds `stream_options.include_usage: true` so providers report
|
|
/// token counts in the final SSE chunk.
|
|
pub fn build_openai_body(
|
|
request: &UnifiedRequest,
|
|
messages_json: Vec<serde_json::Value>,
|
|
stream: bool,
|
|
) -> serde_json::Value {
|
|
let mut body = serde_json::json!({
|
|
"model": request.model,
|
|
"messages": messages_json,
|
|
"stream": stream,
|
|
});
|
|
|
|
if stream {
|
|
body["stream_options"] = serde_json::json!({ "include_usage": true });
|
|
}
|
|
|
|
if let Some(temp) = request.temperature {
|
|
body["temperature"] = serde_json::json!(temp);
|
|
}
|
|
if let Some(max_tokens) = request.max_tokens {
|
|
body["max_tokens"] = serde_json::json!(max_tokens);
|
|
}
|
|
if let Some(tools) = &request.tools {
|
|
body["tools"] = serde_json::json!(tools);
|
|
}
|
|
if let Some(tool_choice) = &request.tool_choice {
|
|
body["tool_choice"] = serde_json::json!(tool_choice);
|
|
}
|
|
|
|
body
|
|
}
|
|
|
|
/// Parse an OpenAI-compatible chat completion response JSON into a ProviderResponse.
|
|
/// Extracts tool_calls from the message when present.
|
|
/// Extracts cache token counts from:
|
|
/// - OpenAI/Grok: `usage.prompt_tokens_details.cached_tokens`
|
|
/// - DeepSeek: `usage.prompt_cache_hit_tokens` / `usage.prompt_cache_miss_tokens`
|
|
pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<ProviderResponse, AppError> {
|
|
let choice = resp_json["choices"]
|
|
.get(0)
|
|
.ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
|
let message = &choice["message"];
|
|
|
|
let content = message["content"].as_str().unwrap_or_default().to_string();
|
|
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
|
|
|
|
// Parse tool_calls from the response message
|
|
let tool_calls: Option<Vec<ToolCall>> = message
|
|
.get("tool_calls")
|
|
.and_then(|tc| serde_json::from_value(tc.clone()).ok());
|
|
|
|
let usage = &resp_json["usage"];
|
|
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
|
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
|
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
|
|
|
// Extract cache tokens — try OpenAI/Grok format first, then DeepSeek format
|
|
let cache_read_tokens = usage["prompt_tokens_details"]["cached_tokens"]
|
|
.as_u64()
|
|
// DeepSeek uses a different field name
|
|
.or_else(|| usage["prompt_cache_hit_tokens"].as_u64())
|
|
.unwrap_or(0) as u32;
|
|
|
|
// DeepSeek reports cache_write as prompt_cache_miss_tokens (tokens written to cache for future use).
|
|
// OpenAI doesn't report cache_write in this location, but may in the future.
|
|
let cache_write_tokens = usage["prompt_cache_miss_tokens"].as_u64().unwrap_or(0) as u32;
|
|
|
|
Ok(ProviderResponse {
|
|
content,
|
|
reasoning_content,
|
|
tool_calls,
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
total_tokens,
|
|
cache_read_tokens,
|
|
cache_write_tokens,
|
|
model,
|
|
})
|
|
}
|
|
|
|
/// Parse a single OpenAI-compatible stream chunk into a ProviderStreamChunk.
|
|
/// Returns None if the chunk should be skipped (e.g. promptFeedback).
|
|
pub fn parse_openai_stream_chunk(
|
|
chunk: &Value,
|
|
model: &str,
|
|
reasoning_field: Option<&'static str>,
|
|
) -> Option<Result<ProviderStreamChunk, AppError>> {
|
|
// Parse usage from the final chunk (sent when stream_options.include_usage is true).
|
|
// This chunk may have an empty `choices` array.
|
|
let stream_usage = chunk.get("usage").and_then(|u| {
|
|
if u.is_null() {
|
|
return None;
|
|
}
|
|
let prompt_tokens = u["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
|
let completion_tokens = u["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
|
let total_tokens = u["total_tokens"].as_u64().unwrap_or(0) as u32;
|
|
|
|
let cache_read_tokens = u["prompt_tokens_details"]["cached_tokens"]
|
|
.as_u64()
|
|
.or_else(|| u["prompt_cache_hit_tokens"].as_u64())
|
|
.unwrap_or(0) as u32;
|
|
|
|
let cache_write_tokens = u["prompt_cache_miss_tokens"]
|
|
.as_u64()
|
|
.unwrap_or(0) as u32;
|
|
|
|
Some(StreamUsage {
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
total_tokens,
|
|
cache_read_tokens,
|
|
cache_write_tokens,
|
|
})
|
|
});
|
|
|
|
if let Some(choice) = chunk["choices"].get(0) {
|
|
let delta = &choice["delta"];
|
|
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
|
let reasoning_content = delta["reasoning_content"]
|
|
.as_str()
|
|
.or_else(|| reasoning_field.and_then(|f| delta[f].as_str()))
|
|
.map(|s| s.to_string());
|
|
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
|
|
|
// Parse tool_calls deltas from the stream chunk
|
|
let tool_calls: Option<Vec<ToolCallDelta>> = delta
|
|
.get("tool_calls")
|
|
.and_then(|tc| serde_json::from_value(tc.clone()).ok());
|
|
|
|
Some(Ok(ProviderStreamChunk {
|
|
content,
|
|
reasoning_content,
|
|
finish_reason,
|
|
tool_calls,
|
|
model: model.to_string(),
|
|
usage: stream_usage,
|
|
}))
|
|
} else if stream_usage.is_some() {
|
|
// Final usage-only chunk (empty choices array) — yield it so
|
|
// AggregatingStream can capture the real token counts.
|
|
Some(Ok(ProviderStreamChunk {
|
|
content: String::new(),
|
|
reasoning_content: None,
|
|
finish_reason: None,
|
|
tool_calls: None,
|
|
model: model.to_string(),
|
|
usage: stream_usage,
|
|
}))
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
/// Create an SSE stream that parses OpenAI-compatible streaming chunks.
|
|
///
|
|
/// The optional `reasoning_field` allows overriding the field name for
|
|
/// reasoning content (e.g., "thought" for Ollama).
|
|
/// Parses tool_calls deltas from streaming chunks when present.
|
|
/// When `stream_options.include_usage: true` was sent, the provider sends a
|
|
/// final chunk with `usage` data — this is parsed into `StreamUsage` and
|
|
/// attached to the yielded `ProviderStreamChunk`.
|
|
pub fn create_openai_stream(
|
|
es: reqwest_eventsource::EventSource,
|
|
model: String,
|
|
reasoning_field: Option<&'static str>,
|
|
) -> BoxStream<'static, Result<ProviderStreamChunk, AppError>> {
|
|
use reqwest_eventsource::Event;
|
|
|
|
let stream = async_stream::try_stream! {
|
|
let mut es = es;
|
|
while let Some(event) = es.next().await {
|
|
match event {
|
|
Ok(Event::Message(msg)) => {
|
|
if msg.data == "[DONE]" {
|
|
break;
|
|
}
|
|
|
|
let chunk: Value = serde_json::from_str(&msg.data)
|
|
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
|
|
|
if let Some(p_chunk) = parse_openai_stream_chunk(&chunk, &model, reasoning_field) {
|
|
yield p_chunk?;
|
|
}
|
|
}
|
|
Ok(_) => continue,
|
|
Err(e) => {
|
|
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
Box::pin(stream)
|
|
}
|
|
|
|
/// Calculate cost using the model registry first, then falling back to provider pricing config.
|
|
///
|
|
/// When the registry provides `cache_read` / `cache_write` rates, the formula is:
|
|
/// (prompt_tokens - cache_read_tokens) * input_rate
|
|
/// + cache_read_tokens * cache_read_rate
|
|
/// + cache_write_tokens * cache_write_rate (if applicable)
|
|
/// + completion_tokens * output_rate
|
|
///
|
|
/// All rates are per-token (the registry stores per-million-token rates).
|
|
pub fn calculate_cost_with_registry(
|
|
model: &str,
|
|
prompt_tokens: u32,
|
|
completion_tokens: u32,
|
|
cache_read_tokens: u32,
|
|
cache_write_tokens: u32,
|
|
registry: &crate::models::registry::ModelRegistry,
|
|
pricing: &[crate::config::ModelPricing],
|
|
default_prompt_rate: f64,
|
|
default_completion_rate: f64,
|
|
) -> f64 {
|
|
if let Some(metadata) = registry.find_model(model)
|
|
&& let Some(cost) = &metadata.cost
|
|
{
|
|
let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens);
|
|
let mut total = (non_cached_prompt as f64 * cost.input / 1_000_000.0)
|
|
+ (completion_tokens as f64 * cost.output / 1_000_000.0);
|
|
|
|
if let Some(cache_read_rate) = cost.cache_read {
|
|
total += cache_read_tokens as f64 * cache_read_rate / 1_000_000.0;
|
|
} else {
|
|
// No cache_read rate — charge cached tokens at full input rate
|
|
total += cache_read_tokens as f64 * cost.input / 1_000_000.0;
|
|
}
|
|
|
|
if let Some(cache_write_rate) = cost.cache_write {
|
|
total += cache_write_tokens as f64 * cache_write_rate / 1_000_000.0;
|
|
}
|
|
|
|
return total;
|
|
}
|
|
|
|
// Fallback: no registry entry — use provider pricing config (no cache awareness)
|
|
let (prompt_rate, completion_rate) = pricing
|
|
.iter()
|
|
.find(|p| model.contains(&p.model))
|
|
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
|
.unwrap_or((default_prompt_rate, default_completion_rate));
|
|
|
|
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
|
}
|