Files
GopherGate/src/providers/helpers.rs
hobokenchicken 9c01b97f82
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled
fix(providers): handle tool messages in text_only message converter
messages_to_openai_json_text_only() was missing tool-calling support,
causing DeepSeek 400 errors when conversations included tool turns.
Now mirrors messages_to_openai_json() logic for tool-role messages
(tool_call_id, name) and assistant tool_calls, with images replaced
by "[Image]" text.
2026-03-02 11:30:38 -05:00

301 lines
11 KiB
Rust

use super::{ProviderResponse, ProviderStreamChunk};
use crate::errors::AppError;
use crate::models::{ContentPart, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest};
use futures::stream::{BoxStream, StreamExt};
use serde_json::Value;
/// Convert messages to OpenAI-compatible JSON, resolving images asynchronously.
///
/// This avoids the deadlock caused by `futures::executor::block_on` inside a
/// Tokio async context. All image base64 conversions are awaited properly.
/// Handles tool-calling messages: assistant messages with tool_calls, and
/// tool-role messages with tool_call_id/name.
pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result<Vec<serde_json::Value>, AppError> {
let mut result = Vec::new();
for m in messages {
// Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." }
if m.role == "tool" {
let text_content = m
.content
.first()
.map(|p| match p {
ContentPart::Text { text } => text.clone(),
ContentPart::Image(_) => "[Image]".to_string(),
})
.unwrap_or_default();
let mut msg = serde_json::json!({
"role": "tool",
"content": text_content
});
if let Some(tool_call_id) = &m.tool_call_id {
msg["tool_call_id"] = serde_json::json!(tool_call_id);
}
if let Some(name) = &m.name {
msg["name"] = serde_json::json!(name);
}
result.push(msg);
continue;
}
// Build content parts for non-tool messages
let mut parts = Vec::new();
for p in &m.content {
match p {
ContentPart::Text { text } => {
parts.push(serde_json::json!({ "type": "text", "text": text }));
}
ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input
.to_base64()
.await
.map_err(|e| AppError::MultimodalError(e.to_string()))?;
parts.push(serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
}));
}
}
}
let mut msg = serde_json::json!({ "role": m.role });
// For assistant messages with tool_calls, content can be null
if let Some(tool_calls) = &m.tool_calls {
if parts.is_empty() {
msg["content"] = serde_json::Value::Null;
} else {
msg["content"] = serde_json::json!(parts);
}
msg["tool_calls"] = serde_json::json!(tool_calls);
} else {
msg["content"] = serde_json::json!(parts);
}
if let Some(name) = &m.name {
msg["name"] = serde_json::json!(name);
}
result.push(msg);
}
Ok(result)
}
/// Convert messages to OpenAI-compatible JSON, but replace images with a
/// text placeholder "[Image]". Useful for providers that don't support
/// multimodal in streaming mode or at all.
///
/// Handles tool-calling messages identically to `messages_to_openai_json`:
/// assistant messages with `tool_calls`, and tool-role messages with
/// `tool_call_id`/`name`.
pub async fn messages_to_openai_json_text_only(
messages: &[UnifiedMessage],
) -> Result<Vec<serde_json::Value>, AppError> {
let mut result = Vec::new();
for m in messages {
// Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." }
if m.role == "tool" {
let text_content = m
.content
.first()
.map(|p| match p {
ContentPart::Text { text } => text.clone(),
ContentPart::Image(_) => "[Image]".to_string(),
})
.unwrap_or_default();
let mut msg = serde_json::json!({
"role": "tool",
"content": text_content
});
if let Some(tool_call_id) = &m.tool_call_id {
msg["tool_call_id"] = serde_json::json!(tool_call_id);
}
if let Some(name) = &m.name {
msg["name"] = serde_json::json!(name);
}
result.push(msg);
continue;
}
// Build content parts for non-tool messages (images become "[Image]" text)
let mut parts = Vec::new();
for p in &m.content {
match p {
ContentPart::Text { text } => {
parts.push(serde_json::json!({ "type": "text", "text": text }));
}
ContentPart::Image(_) => {
parts.push(serde_json::json!({ "type": "text", "text": "[Image]" }));
}
}
}
let mut msg = serde_json::json!({ "role": m.role });
// For assistant messages with tool_calls, content can be null
if let Some(tool_calls) = &m.tool_calls {
if parts.is_empty() {
msg["content"] = serde_json::Value::Null;
} else {
msg["content"] = serde_json::json!(parts);
}
msg["tool_calls"] = serde_json::json!(tool_calls);
} else {
msg["content"] = serde_json::json!(parts);
}
if let Some(name) = &m.name {
msg["name"] = serde_json::json!(name);
}
result.push(msg);
}
Ok(result)
}
/// Build an OpenAI-compatible request body from a UnifiedRequest and pre-converted messages.
/// Includes tools and tool_choice when present.
pub fn build_openai_body(
request: &UnifiedRequest,
messages_json: Vec<serde_json::Value>,
stream: bool,
) -> serde_json::Value {
let mut body = serde_json::json!({
"model": request.model,
"messages": messages_json,
"stream": stream,
});
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
if let Some(tools) = &request.tools {
body["tools"] = serde_json::json!(tools);
}
if let Some(tool_choice) = &request.tool_choice {
body["tool_choice"] = serde_json::json!(tool_choice);
}
body
}
/// Parse an OpenAI-compatible chat completion response JSON into a ProviderResponse.
/// Extracts tool_calls from the message when present.
pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<ProviderResponse, AppError> {
let choice = resp_json["choices"]
.get(0)
.ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
let message = &choice["message"];
let content = message["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
// Parse tool_calls from the response message
let tool_calls: Option<Vec<ToolCall>> = message
.get("tool_calls")
.and_then(|tc| serde_json::from_value(tc.clone()).ok());
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
Ok(ProviderResponse {
content,
reasoning_content,
tool_calls,
prompt_tokens,
completion_tokens,
total_tokens,
model,
})
}
/// Create an SSE stream that parses OpenAI-compatible streaming chunks.
///
/// The optional `reasoning_field` allows overriding the field name for
/// reasoning content (e.g., "thought" for Ollama).
/// Parses tool_calls deltas from streaming chunks when present.
pub fn create_openai_stream(
es: reqwest_eventsource::EventSource,
model: String,
reasoning_field: Option<&'static str>,
) -> BoxStream<'static, Result<ProviderStreamChunk, AppError>> {
use reqwest_eventsource::Event;
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"]
.as_str()
.or_else(|| reasoning_field.and_then(|f| delta[f].as_str()))
.map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
// Parse tool_calls deltas from the stream chunk
let tool_calls: Option<Vec<ToolCallDelta>> = delta
.get("tool_calls")
.and_then(|tc| serde_json::from_value(tc.clone()).ok());
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
tool_calls,
model: model.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Box::pin(stream)
}
/// Calculate cost using the model registry first, then falling back to provider pricing config.
pub fn calculate_cost_with_registry(
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
pricing: &[crate::config::ModelPricing],
default_prompt_rate: f64,
default_completion_rate: f64,
) -> f64 {
if let Some(metadata) = registry.find_model(model)
&& let Some(cost) = &metadata.cost
{
return (prompt_tokens as f64 * cost.input / 1_000_000.0)
+ (completion_tokens as f64 * cost.output / 1_000_000.0);
}
let (prompt_rate, completion_rate) = pricing
.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((default_prompt_rate, default_completion_rate));
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
}