feat: add tool-calling passthrough for all providers
Implement full OpenAI-compatible tool-calling support across the proxy, enabling OpenCode to use llm-proxy as its sole LLM backend. - Add 9 tool-calling types (Tool, FunctionDef, ToolChoice, ToolCall, etc.) - Update ChatCompletionRequest/ChatMessage/ChatStreamDelta with tool fields - Update UnifiedRequest/UnifiedMessage to carry tool data through the pipeline - Shared helpers: messages_to_openai_json handles tool messages, build_openai_body includes tools/tool_choice, parse/stream extract tool_calls from responses - Gemini: full OpenAI<->Gemini format translation (functionDeclarations, functionCall/functionResponse, synthetic call IDs, tool_config mapping) - Gemini: extract duplicated message-conversion into shared convert_messages() - Server: SSE streams include tool_calls deltas, finish_reason='tool_calls' - AggregatingStream: accumulate tool call deltas across stream chunks - OpenAI provider: add o4- prefix to supports_model()
This commit is contained in:
@@ -2,14 +2,28 @@ use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{ProviderResponse, ProviderStreamChunk};
|
||||
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
||||
use crate::{
|
||||
config::AppConfig,
|
||||
errors::AppError,
|
||||
models::{ContentPart, FunctionCall, FunctionCallDelta, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest},
|
||||
};
|
||||
|
||||
// ========== Gemini Request Structs ==========
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiRequest {
|
||||
contents: Vec<GeminiContent>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
generation_config: Option<GeminiGenerationConfig>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<GeminiTool>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_config: Option<GeminiToolConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -19,11 +33,16 @@ struct GeminiContent {
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiPart {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
text: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
inline_data: Option<GeminiInlineData>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
function_call: Option<GeminiFunctionCall>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
function_response: Option<GeminiFunctionResponse>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -32,31 +51,85 @@ struct GeminiInlineData {
|
||||
data: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct GeminiFunctionCall {
|
||||
name: String,
|
||||
args: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct GeminiFunctionResponse {
|
||||
name: String,
|
||||
response: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiGenerationConfig {
|
||||
temperature: Option<f64>,
|
||||
max_output_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
// ========== Gemini Tool Structs ==========
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiTool {
|
||||
function_declarations: Vec<GeminiFunctionDeclaration>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct GeminiFunctionDeclaration {
|
||||
name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
parameters: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiToolConfig {
|
||||
function_calling_config: GeminiFunctionCallingConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct GeminiFunctionCallingConfig {
|
||||
mode: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename = "allowedFunctionNames")]
|
||||
allowed_function_names: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
// ========== Gemini Response Structs ==========
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiCandidate {
|
||||
content: GeminiContent,
|
||||
_finish_reason: Option<String>,
|
||||
#[serde(default)]
|
||||
finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiUsageMetadata {
|
||||
#[serde(default)]
|
||||
prompt_token_count: u32,
|
||||
#[serde(default)]
|
||||
candidates_token_count: u32,
|
||||
#[serde(default)]
|
||||
total_token_count: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiResponse {
|
||||
candidates: Vec<GeminiCandidate>,
|
||||
usage_metadata: Option<GeminiUsageMetadata>,
|
||||
}
|
||||
|
||||
// ========== Provider Implementation ==========
|
||||
|
||||
pub struct GeminiProvider {
|
||||
client: reqwest::Client,
|
||||
config: crate::config::GeminiConfig,
|
||||
@@ -82,6 +155,209 @@ impl GeminiProvider {
|
||||
pricing: app_config.pricing.gemini.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert unified messages to Gemini content format.
|
||||
/// Handles text, images, tool calls (assistant), and tool results.
|
||||
async fn convert_messages(messages: Vec<UnifiedMessage>) -> Result<Vec<GeminiContent>, AppError> {
|
||||
let mut contents = Vec::with_capacity(messages.len());
|
||||
|
||||
for msg in messages {
|
||||
// Tool-result messages → functionResponse parts under role "user"
|
||||
if msg.role == "tool" {
|
||||
let text_content = msg
|
||||
.content
|
||||
.first()
|
||||
.map(|p| match p {
|
||||
ContentPart::Text { text } => text.clone(),
|
||||
ContentPart::Image(_) => "[Image]".to_string(),
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let name = msg.name.unwrap_or_default();
|
||||
|
||||
// Parse the content as JSON if possible, otherwise wrap as string
|
||||
let response_value = serde_json::from_str::<Value>(&text_content)
|
||||
.unwrap_or_else(|_| serde_json::json!({ "result": text_content }));
|
||||
|
||||
contents.push(GeminiContent {
|
||||
parts: vec![GeminiPart {
|
||||
text: None,
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: Some(GeminiFunctionResponse {
|
||||
name,
|
||||
response: response_value,
|
||||
}),
|
||||
}],
|
||||
role: "user".to_string(),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
// Assistant messages with tool_calls → functionCall parts
|
||||
if msg.role == "assistant" {
|
||||
if let Some(tool_calls) = &msg.tool_calls {
|
||||
let mut parts = Vec::new();
|
||||
|
||||
// Include text content if present
|
||||
for p in &msg.content {
|
||||
if let ContentPart::Text { text } = p {
|
||||
if !text.is_empty() {
|
||||
parts.push(GeminiPart {
|
||||
text: Some(text.clone()),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert each tool call to a functionCall part
|
||||
for tc in tool_calls {
|
||||
let args = serde_json::from_str::<Value>(&tc.function.arguments)
|
||||
.unwrap_or_else(|_| serde_json::json!({}));
|
||||
parts.push(GeminiPart {
|
||||
text: None,
|
||||
inline_data: None,
|
||||
function_call: Some(GeminiFunctionCall {
|
||||
name: tc.function.name.clone(),
|
||||
args,
|
||||
}),
|
||||
function_response: None,
|
||||
});
|
||||
}
|
||||
|
||||
contents.push(GeminiContent {
|
||||
parts,
|
||||
role: "model".to_string(),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Regular text/image messages
|
||||
let mut parts = Vec::with_capacity(msg.content.len());
|
||||
for part in msg.content {
|
||||
match part {
|
||||
ContentPart::Text { text } => {
|
||||
parts.push(GeminiPart {
|
||||
text: Some(text),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
});
|
||||
}
|
||||
ContentPart::Image(image_input) => {
|
||||
let (base64_data, mime_type) = image_input
|
||||
.to_base64()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
||||
|
||||
parts.push(GeminiPart {
|
||||
text: None,
|
||||
inline_data: Some(GeminiInlineData {
|
||||
mime_type,
|
||||
data: base64_data,
|
||||
}),
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let role = match msg.role.as_str() {
|
||||
"assistant" => "model".to_string(),
|
||||
_ => "user".to_string(),
|
||||
};
|
||||
|
||||
contents.push(GeminiContent { parts, role });
|
||||
}
|
||||
|
||||
Ok(contents)
|
||||
}
|
||||
|
||||
/// Convert OpenAI tools to Gemini function declarations.
|
||||
fn convert_tools(request: &UnifiedRequest) -> Option<Vec<GeminiTool>> {
|
||||
request.tools.as_ref().map(|tools| {
|
||||
let declarations: Vec<GeminiFunctionDeclaration> = tools
|
||||
.iter()
|
||||
.map(|t| GeminiFunctionDeclaration {
|
||||
name: t.function.name.clone(),
|
||||
description: t.function.description.clone(),
|
||||
parameters: t.function.parameters.clone(),
|
||||
})
|
||||
.collect();
|
||||
vec![GeminiTool {
|
||||
function_declarations: declarations,
|
||||
}]
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert OpenAI tool_choice to Gemini tool_config.
|
||||
fn convert_tool_config(request: &UnifiedRequest) -> Option<GeminiToolConfig> {
|
||||
request.tool_choice.as_ref().map(|tc| {
|
||||
let (mode, allowed_names) = match tc {
|
||||
crate::models::ToolChoice::Mode(mode) => {
|
||||
let gemini_mode = match mode.as_str() {
|
||||
"auto" => "AUTO",
|
||||
"none" => "NONE",
|
||||
"required" => "ANY",
|
||||
_ => "AUTO",
|
||||
};
|
||||
(gemini_mode.to_string(), None)
|
||||
}
|
||||
crate::models::ToolChoice::Specific(specific) => {
|
||||
("ANY".to_string(), Some(vec![specific.function.name.clone()]))
|
||||
}
|
||||
};
|
||||
GeminiToolConfig {
|
||||
function_calling_config: GeminiFunctionCallingConfig {
|
||||
mode,
|
||||
allowed_function_names: allowed_names,
|
||||
},
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract tool calls from Gemini response parts into OpenAI-format ToolCalls.
|
||||
fn extract_tool_calls(parts: &[GeminiPart]) -> Option<Vec<ToolCall>> {
|
||||
let calls: Vec<ToolCall> = parts
|
||||
.iter()
|
||||
.filter_map(|p| p.function_call.as_ref())
|
||||
.map(|fc| ToolCall {
|
||||
id: format!("call_{}", Uuid::new_v4().simple()),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: fc.name.clone(),
|
||||
arguments: serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string()),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
||||
if calls.is_empty() { None } else { Some(calls) }
|
||||
}
|
||||
|
||||
/// Extract tool call deltas from Gemini response parts for streaming.
|
||||
fn extract_tool_call_deltas(parts: &[GeminiPart]) -> Option<Vec<ToolCallDelta>> {
|
||||
let deltas: Vec<ToolCallDelta> = parts
|
||||
.iter()
|
||||
.filter_map(|p| p.function_call.as_ref())
|
||||
.enumerate()
|
||||
.map(|(i, fc)| ToolCallDelta {
|
||||
index: i as u32,
|
||||
id: Some(format!("call_{}", Uuid::new_v4().simple())),
|
||||
call_type: Some("function".to_string()),
|
||||
function: Some(FunctionCallDelta {
|
||||
name: Some(fc.name.clone()),
|
||||
arguments: Some(serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string())),
|
||||
}),
|
||||
})
|
||||
.collect();
|
||||
|
||||
if deltas.is_empty() { None } else { Some(deltas) }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -99,51 +375,15 @@ impl super::Provider for GeminiProvider {
|
||||
}
|
||||
|
||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
// Convert UnifiedRequest to Gemini request
|
||||
let mut contents = Vec::with_capacity(request.messages.len());
|
||||
|
||||
for msg in request.messages {
|
||||
let mut parts = Vec::with_capacity(msg.content.len());
|
||||
|
||||
for part in msg.content {
|
||||
match part {
|
||||
crate::models::ContentPart::Text { text } => {
|
||||
parts.push(GeminiPart {
|
||||
text: Some(text),
|
||||
inline_data: None,
|
||||
});
|
||||
}
|
||||
crate::models::ContentPart::Image(image_input) => {
|
||||
let (base64_data, mime_type) = image_input
|
||||
.to_base64()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
||||
|
||||
parts.push(GeminiPart {
|
||||
text: None,
|
||||
inline_data: Some(GeminiInlineData {
|
||||
mime_type,
|
||||
data: base64_data,
|
||||
}),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Map role: "user" -> "user", "assistant" -> "model", "system" -> "user"
|
||||
let role = match msg.role.as_str() {
|
||||
"assistant" => "model".to_string(),
|
||||
_ => "user".to_string(),
|
||||
};
|
||||
|
||||
contents.push(GeminiContent { parts, role });
|
||||
}
|
||||
let model = request.model.clone();
|
||||
let tools = Self::convert_tools(&request);
|
||||
let tool_config = Self::convert_tool_config(&request);
|
||||
let contents = Self::convert_messages(request.messages).await?;
|
||||
|
||||
if contents.is_empty() {
|
||||
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
|
||||
return Err(AppError::ProviderError("No valid messages to send".to_string()));
|
||||
}
|
||||
|
||||
// Build generation config
|
||||
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
|
||||
Some(GeminiGenerationConfig {
|
||||
temperature: request.temperature,
|
||||
@@ -156,12 +396,12 @@ impl super::Provider for GeminiProvider {
|
||||
let gemini_request = GeminiRequest {
|
||||
contents,
|
||||
generation_config,
|
||||
tools,
|
||||
tool_config,
|
||||
};
|
||||
|
||||
// Build URL
|
||||
let url = format!("{}/models/{}:generateContent", self.config.base_url, request.model,);
|
||||
let url = format!("{}/models/{}:generateContent", self.config.base_url, model);
|
||||
|
||||
// Send request
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
@@ -171,7 +411,6 @@ impl super::Provider for GeminiProvider {
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(format!("HTTP request failed: {}", e)))?;
|
||||
|
||||
// Check status
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
@@ -186,15 +425,16 @@ impl super::Provider for GeminiProvider {
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?;
|
||||
|
||||
// Extract content from first candidate
|
||||
let content = gemini_response
|
||||
.candidates
|
||||
.first()
|
||||
.and_then(|c| c.content.parts.first())
|
||||
.and_then(|p| p.text.clone())
|
||||
let candidate = gemini_response.candidates.first();
|
||||
|
||||
// Extract text content (may be absent if only function calls)
|
||||
let content = candidate
|
||||
.and_then(|c| c.content.parts.iter().find_map(|p| p.text.clone()))
|
||||
.unwrap_or_default();
|
||||
|
||||
// Extract token usage
|
||||
// Extract function calls → OpenAI tool_calls
|
||||
let tool_calls = candidate.and_then(|c| Self::extract_tool_calls(&c.content.parts));
|
||||
|
||||
let prompt_tokens = gemini_response
|
||||
.usage_metadata
|
||||
.as_ref()
|
||||
@@ -213,11 +453,12 @@ impl super::Provider for GeminiProvider {
|
||||
|
||||
Ok(ProviderResponse {
|
||||
content,
|
||||
reasoning_content: None, // Gemini doesn't use this field name
|
||||
reasoning_content: None,
|
||||
tool_calls,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
model: request.model,
|
||||
model,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -247,47 +488,11 @@ impl super::Provider for GeminiProvider {
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
// Convert UnifiedRequest to Gemini request
|
||||
let mut contents = Vec::with_capacity(request.messages.len());
|
||||
let model = request.model.clone();
|
||||
let tools = Self::convert_tools(&request);
|
||||
let tool_config = Self::convert_tool_config(&request);
|
||||
let contents = Self::convert_messages(request.messages).await?;
|
||||
|
||||
for msg in request.messages {
|
||||
let mut parts = Vec::with_capacity(msg.content.len());
|
||||
|
||||
for part in msg.content {
|
||||
match part {
|
||||
crate::models::ContentPart::Text { text } => {
|
||||
parts.push(GeminiPart {
|
||||
text: Some(text),
|
||||
inline_data: None,
|
||||
});
|
||||
}
|
||||
crate::models::ContentPart::Image(image_input) => {
|
||||
let (base64_data, mime_type) = image_input
|
||||
.to_base64()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
||||
|
||||
parts.push(GeminiPart {
|
||||
text: None,
|
||||
inline_data: Some(GeminiInlineData {
|
||||
mime_type,
|
||||
data: base64_data,
|
||||
}),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Map role
|
||||
let role = match msg.role.as_str() {
|
||||
"assistant" => "model".to_string(),
|
||||
_ => "user".to_string(),
|
||||
};
|
||||
|
||||
contents.push(GeminiContent { parts, role });
|
||||
}
|
||||
|
||||
// Build generation config
|
||||
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
|
||||
Some(GeminiGenerationConfig {
|
||||
temperature: request.temperature,
|
||||
@@ -300,15 +505,15 @@ impl super::Provider for GeminiProvider {
|
||||
let gemini_request = GeminiRequest {
|
||||
contents,
|
||||
generation_config,
|
||||
tools,
|
||||
tool_config,
|
||||
};
|
||||
|
||||
// Build URL for streaming
|
||||
let url = format!(
|
||||
"{}/models/{}:streamGenerateContent?alt=sse",
|
||||
self.config.base_url, request.model,
|
||||
self.config.base_url, model,
|
||||
);
|
||||
|
||||
// Create eventsource stream
|
||||
use futures::StreamExt;
|
||||
use reqwest_eventsource::{Event, EventSource};
|
||||
|
||||
@@ -320,8 +525,6 @@ impl super::Provider for GeminiProvider {
|
||||
)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
let model = request.model.clone();
|
||||
|
||||
let stream = async_stream::try_stream! {
|
||||
let mut es = es;
|
||||
while let Some(event) = es.next().await {
|
||||
@@ -331,14 +534,28 @@ impl super::Provider for GeminiProvider {
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||
|
||||
if let Some(candidate) = gemini_response.candidates.first() {
|
||||
let content = candidate.content.parts.first()
|
||||
.and_then(|p| p.text.clone())
|
||||
let content = candidate
|
||||
.content
|
||||
.parts
|
||||
.iter()
|
||||
.find_map(|p| p.text.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
let tool_calls = Self::extract_tool_call_deltas(&candidate.content.parts);
|
||||
|
||||
// Determine finish_reason
|
||||
let finish_reason = candidate.finish_reason.as_ref().map(|fr| {
|
||||
match fr.as_str() {
|
||||
"STOP" => "stop".to_string(),
|
||||
_ => fr.to_lowercase(),
|
||||
}
|
||||
});
|
||||
|
||||
yield ProviderStreamChunk {
|
||||
content,
|
||||
reasoning_content: None,
|
||||
finish_reason: None, // Will be set in the last chunk
|
||||
finish_reason,
|
||||
tool_calls,
|
||||
model: model.clone(),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::{ProviderResponse, ProviderStreamChunk};
|
||||
use crate::errors::AppError;
|
||||
use crate::models::{ContentPart, UnifiedMessage, UnifiedRequest};
|
||||
use crate::models::{ContentPart, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest};
|
||||
use futures::stream::{BoxStream, StreamExt};
|
||||
use serde_json::Value;
|
||||
|
||||
@@ -8,9 +8,37 @@ use serde_json::Value;
|
||||
///
|
||||
/// This avoids the deadlock caused by `futures::executor::block_on` inside a
|
||||
/// Tokio async context. All image base64 conversions are awaited properly.
|
||||
/// Handles tool-calling messages: assistant messages with tool_calls, and
|
||||
/// tool-role messages with tool_call_id/name.
|
||||
pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result<Vec<serde_json::Value>, AppError> {
|
||||
let mut result = Vec::new();
|
||||
for m in messages {
|
||||
// Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." }
|
||||
if m.role == "tool" {
|
||||
let text_content = m
|
||||
.content
|
||||
.first()
|
||||
.map(|p| match p {
|
||||
ContentPart::Text { text } => text.clone(),
|
||||
ContentPart::Image(_) => "[Image]".to_string(),
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut msg = serde_json::json!({
|
||||
"role": "tool",
|
||||
"content": text_content
|
||||
});
|
||||
if let Some(tool_call_id) = &m.tool_call_id {
|
||||
msg["tool_call_id"] = serde_json::json!(tool_call_id);
|
||||
}
|
||||
if let Some(name) = &m.name {
|
||||
msg["name"] = serde_json::json!(name);
|
||||
}
|
||||
result.push(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Build content parts for non-tool messages
|
||||
let mut parts = Vec::new();
|
||||
for p in &m.content {
|
||||
match p {
|
||||
@@ -29,10 +57,26 @@ pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result<Vec<
|
||||
}
|
||||
}
|
||||
}
|
||||
result.push(serde_json::json!({
|
||||
"role": m.role,
|
||||
"content": parts
|
||||
}));
|
||||
|
||||
let mut msg = serde_json::json!({ "role": m.role });
|
||||
|
||||
// For assistant messages with tool_calls, content can be null
|
||||
if let Some(tool_calls) = &m.tool_calls {
|
||||
if parts.is_empty() {
|
||||
msg["content"] = serde_json::Value::Null;
|
||||
} else {
|
||||
msg["content"] = serde_json::json!(parts);
|
||||
}
|
||||
msg["tool_calls"] = serde_json::json!(tool_calls);
|
||||
} else {
|
||||
msg["content"] = serde_json::json!(parts);
|
||||
}
|
||||
|
||||
if let Some(name) = &m.name {
|
||||
msg["name"] = serde_json::json!(name);
|
||||
}
|
||||
|
||||
result.push(msg);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
@@ -65,6 +109,7 @@ pub async fn messages_to_openai_json_text_only(
|
||||
}
|
||||
|
||||
/// Build an OpenAI-compatible request body from a UnifiedRequest and pre-converted messages.
|
||||
/// Includes tools and tool_choice when present.
|
||||
pub fn build_openai_body(
|
||||
request: &UnifiedRequest,
|
||||
messages_json: Vec<serde_json::Value>,
|
||||
@@ -82,11 +127,18 @@ pub fn build_openai_body(
|
||||
if let Some(max_tokens) = request.max_tokens {
|
||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||
}
|
||||
if let Some(tools) = &request.tools {
|
||||
body["tools"] = serde_json::json!(tools);
|
||||
}
|
||||
if let Some(tool_choice) = &request.tool_choice {
|
||||
body["tool_choice"] = serde_json::json!(tool_choice);
|
||||
}
|
||||
|
||||
body
|
||||
}
|
||||
|
||||
/// Parse an OpenAI-compatible chat completion response JSON into a ProviderResponse.
|
||||
/// Extracts tool_calls from the message when present.
|
||||
pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<ProviderResponse, AppError> {
|
||||
let choice = resp_json["choices"]
|
||||
.get(0)
|
||||
@@ -96,6 +148,11 @@ pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<Provide
|
||||
let content = message["content"].as_str().unwrap_or_default().to_string();
|
||||
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
|
||||
|
||||
// Parse tool_calls from the response message
|
||||
let tool_calls: Option<Vec<ToolCall>> = message
|
||||
.get("tool_calls")
|
||||
.and_then(|tc| serde_json::from_value(tc.clone()).ok());
|
||||
|
||||
let usage = &resp_json["usage"];
|
||||
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
@@ -104,6 +161,7 @@ pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<Provide
|
||||
Ok(ProviderResponse {
|
||||
content,
|
||||
reasoning_content,
|
||||
tool_calls,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
@@ -115,6 +173,7 @@ pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<Provide
|
||||
///
|
||||
/// The optional `reasoning_field` allows overriding the field name for
|
||||
/// reasoning content (e.g., "thought" for Ollama).
|
||||
/// Parses tool_calls deltas from streaming chunks when present.
|
||||
pub fn create_openai_stream(
|
||||
es: reqwest_eventsource::EventSource,
|
||||
model: String,
|
||||
@@ -143,10 +202,16 @@ pub fn create_openai_stream(
|
||||
.map(|s| s.to_string());
|
||||
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
||||
|
||||
// Parse tool_calls deltas from the stream chunk
|
||||
let tool_calls: Option<Vec<ToolCallDelta>> = delta
|
||||
.get("tool_calls")
|
||||
.and_then(|tc| serde_json::from_value(tc.clone()).ok());
|
||||
|
||||
yield ProviderStreamChunk {
|
||||
content,
|
||||
reasoning_content,
|
||||
finish_reason,
|
||||
tool_calls,
|
||||
model: model.clone(),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -50,6 +50,7 @@ pub trait Provider: Send + Sync {
|
||||
pub struct ProviderResponse {
|
||||
pub content: String,
|
||||
pub reasoning_content: Option<String>,
|
||||
pub tool_calls: Option<Vec<crate::models::ToolCall>>,
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
@@ -61,6 +62,7 @@ pub struct ProviderStreamChunk {
|
||||
pub content: String,
|
||||
pub reasoning_content: Option<String>,
|
||||
pub finish_reason: Option<String>,
|
||||
pub tool_calls: Option<Vec<crate::models::ToolCallDelta>>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ impl super::Provider for OpenAIProvider {
|
||||
}
|
||||
|
||||
fn supports_model(&self, model: &str) -> bool {
|
||||
model.starts_with("gpt-") || model.starts_with("o1-") || model.starts_with("o3-")
|
||||
model.starts_with("gpt-") || model.starts_with("o1-") || model.starts_with("o3-") || model.starts_with("o4-")
|
||||
}
|
||||
|
||||
fn supports_multimodal(&self) -> bool {
|
||||
|
||||
Reference in New Issue
Block a user