feat: add tool-calling passthrough for all providers
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled

Implement full OpenAI-compatible tool-calling support across the proxy,
enabling OpenCode to use llm-proxy as its sole LLM backend.

- Add 9 tool-calling types (Tool, FunctionDef, ToolChoice, ToolCall, etc.)
- Update ChatCompletionRequest/ChatMessage/ChatStreamDelta with tool fields
- Update UnifiedRequest/UnifiedMessage to carry tool data through the pipeline
- Shared helpers: messages_to_openai_json handles tool messages, build_openai_body
  includes tools/tool_choice, parse/stream extract tool_calls from responses
- Gemini: full OpenAI<->Gemini format translation (functionDeclarations,
  functionCall/functionResponse, synthetic call IDs, tool_config mapping)
- Gemini: extract duplicated message-conversion into shared convert_messages()
- Server: SSE streams include tool_calls deltas, finish_reason='tool_calls'
- AggregatingStream: accumulate tool call deltas across stream chunks
- OpenAI provider: add o4- prefix to supports_model()
This commit is contained in:
2026-03-02 09:40:57 -05:00
parent 942aa23f88
commit 9318336f62
8 changed files with 543 additions and 113 deletions

View File

@@ -325,11 +325,16 @@ pub(super) async fn handle_test_provider(
messages: vec![crate::models::UnifiedMessage { messages: vec![crate::models::UnifiedMessage {
role: "user".to_string(), role: "user".to_string(),
content: vec![crate::models::ContentPart::Text { text: "Hi".to_string() }], content: vec![crate::models::ContentPart::Text { text: "Hi".to_string() }],
tool_calls: None,
name: None,
tool_call_id: None,
}], }],
temperature: None, temperature: None,
max_tokens: Some(5), max_tokens: Some(5),
stream: false, stream: false,
has_images: false, has_images: false,
tools: None,
tool_choice: None,
}; };
match provider.chat_completion(test_request).await { match provider.chat_completion(test_request).await {

View File

@@ -1,4 +1,5 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value;
pub mod registry; pub mod registry;
@@ -14,16 +15,25 @@ pub struct ChatCompletionRequest {
pub max_tokens: Option<u32>, pub max_tokens: Option<u32>,
#[serde(default)] #[serde(default)]
pub stream: Option<bool>, pub stream: Option<bool>,
// Add other OpenAI-compatible fields as needed #[serde(default, skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage { pub struct ChatMessage {
pub role: String, // "system", "user", "assistant" pub role: String, // "system", "user", "assistant", "tool"
#[serde(flatten)] #[serde(flatten)]
pub content: MessageContent, pub content: MessageContent,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>, pub reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -48,6 +58,78 @@ pub struct ImageUrl {
pub detail: Option<String>, pub detail: Option<String>,
} }
// ========== Tool-Calling Types ==========
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionDef,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDef {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
Mode(String), // "auto", "none", "required"
Specific(ToolChoiceSpecific),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolChoiceSpecific {
#[serde(rename = "type")]
pub choice_type: String,
pub function: ToolChoiceFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolChoiceFunction {
pub name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallDelta {
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
pub call_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<FunctionCallDelta>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCallDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
// ========== OpenAI-compatible Response Structs ==========
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse { pub struct ChatCompletionResponse {
pub id: String, pub id: String,
@@ -96,6 +178,8 @@ pub struct ChatStreamDelta {
pub content: Option<String>, pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>, pub reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallDelta>>,
} }
// ========== Unified Request Format (for internal use) ========== // ========== Unified Request Format (for internal use) ==========
@@ -109,12 +193,17 @@ pub struct UnifiedRequest {
pub max_tokens: Option<u32>, pub max_tokens: Option<u32>,
pub stream: bool, pub stream: bool,
pub has_images: bool, pub has_images: bool,
pub tools: Option<Vec<Tool>>,
pub tool_choice: Option<ToolChoice>,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct UnifiedMessage { pub struct UnifiedMessage {
pub role: String, pub role: String,
pub content: Vec<ContentPart>, pub content: Vec<ContentPart>,
pub tool_calls: Option<Vec<ToolCall>>,
pub name: Option<String>,
pub tool_call_id: Option<String>,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@@ -226,6 +315,9 @@ impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
UnifiedMessage { UnifiedMessage {
role: msg.role, role: msg.role,
content, content,
tool_calls: msg.tool_calls,
name: msg.name,
tool_call_id: msg.tool_call_id,
} }
}) })
.collect(); .collect();
@@ -238,6 +330,8 @@ impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
max_tokens: req.max_tokens, max_tokens: req.max_tokens,
stream: req.stream.unwrap_or(false), stream: req.stream.unwrap_or(false),
has_images, has_images,
tools: req.tools,
tool_choice: req.tool_choice,
}) })
} }
} }

View File

@@ -2,14 +2,28 @@ use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use futures::stream::BoxStream; use futures::stream::BoxStream;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value;
use uuid::Uuid;
use super::{ProviderResponse, ProviderStreamChunk}; use super::{ProviderResponse, ProviderStreamChunk};
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest}; use crate::{
config::AppConfig,
errors::AppError,
models::{ContentPart, FunctionCall, FunctionCallDelta, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest},
};
// ========== Gemini Request Structs ==========
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiRequest { struct GeminiRequest {
contents: Vec<GeminiContent>, contents: Vec<GeminiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
generation_config: Option<GeminiGenerationConfig>, generation_config: Option<GeminiGenerationConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<GeminiTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_config: Option<GeminiToolConfig>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@@ -19,11 +33,16 @@ struct GeminiContent {
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiPart { struct GeminiPart {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
text: Option<String>, text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
inline_data: Option<GeminiInlineData>, inline_data: Option<GeminiInlineData>,
#[serde(skip_serializing_if = "Option::is_none")]
function_call: Option<GeminiFunctionCall>,
#[serde(skip_serializing_if = "Option::is_none")]
function_response: Option<GeminiFunctionResponse>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
@@ -32,31 +51,85 @@ struct GeminiInlineData {
data: String, data: String,
} }
#[derive(Debug, Serialize, Deserialize)]
struct GeminiFunctionCall {
name: String,
args: Value,
}
#[derive(Debug, Serialize, Deserialize)]
struct GeminiFunctionResponse {
name: String,
response: Value,
}
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiGenerationConfig { struct GeminiGenerationConfig {
temperature: Option<f64>, temperature: Option<f64>,
max_output_tokens: Option<u32>, max_output_tokens: Option<u32>,
} }
// ========== Gemini Tool Structs ==========
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiTool {
function_declarations: Vec<GeminiFunctionDeclaration>,
}
#[derive(Debug, Serialize)]
struct GeminiFunctionDeclaration {
name: String,
#[serde(skip_serializing_if = "Option::is_none")]
description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
parameters: Option<Value>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiToolConfig {
function_calling_config: GeminiFunctionCallingConfig,
}
#[derive(Debug, Serialize)]
struct GeminiFunctionCallingConfig {
mode: String,
#[serde(skip_serializing_if = "Option::is_none", rename = "allowedFunctionNames")]
allowed_function_names: Option<Vec<String>>,
}
// ========== Gemini Response Structs ==========
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiCandidate { struct GeminiCandidate {
content: GeminiContent, content: GeminiContent,
_finish_reason: Option<String>, #[serde(default)]
finish_reason: Option<String>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiUsageMetadata { struct GeminiUsageMetadata {
#[serde(default)]
prompt_token_count: u32, prompt_token_count: u32,
#[serde(default)]
candidates_token_count: u32, candidates_token_count: u32,
#[serde(default)]
total_token_count: u32, total_token_count: u32,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiResponse { struct GeminiResponse {
candidates: Vec<GeminiCandidate>, candidates: Vec<GeminiCandidate>,
usage_metadata: Option<GeminiUsageMetadata>, usage_metadata: Option<GeminiUsageMetadata>,
} }
// ========== Provider Implementation ==========
pub struct GeminiProvider { pub struct GeminiProvider {
client: reqwest::Client, client: reqwest::Client,
config: crate::config::GeminiConfig, config: crate::config::GeminiConfig,
@@ -82,6 +155,209 @@ impl GeminiProvider {
pricing: app_config.pricing.gemini.clone(), pricing: app_config.pricing.gemini.clone(),
}) })
} }
/// Convert unified messages to Gemini content format.
/// Handles text, images, tool calls (assistant), and tool results.
async fn convert_messages(messages: Vec<UnifiedMessage>) -> Result<Vec<GeminiContent>, AppError> {
let mut contents = Vec::with_capacity(messages.len());
for msg in messages {
// Tool-result messages → functionResponse parts under role "user"
if msg.role == "tool" {
let text_content = msg
.content
.first()
.map(|p| match p {
ContentPart::Text { text } => text.clone(),
ContentPart::Image(_) => "[Image]".to_string(),
})
.unwrap_or_default();
let name = msg.name.unwrap_or_default();
// Parse the content as JSON if possible, otherwise wrap as string
let response_value = serde_json::from_str::<Value>(&text_content)
.unwrap_or_else(|_| serde_json::json!({ "result": text_content }));
contents.push(GeminiContent {
parts: vec![GeminiPart {
text: None,
inline_data: None,
function_call: None,
function_response: Some(GeminiFunctionResponse {
name,
response: response_value,
}),
}],
role: "user".to_string(),
});
continue;
}
// Assistant messages with tool_calls → functionCall parts
if msg.role == "assistant" {
if let Some(tool_calls) = &msg.tool_calls {
let mut parts = Vec::new();
// Include text content if present
for p in &msg.content {
if let ContentPart::Text { text } = p {
if !text.is_empty() {
parts.push(GeminiPart {
text: Some(text.clone()),
inline_data: None,
function_call: None,
function_response: None,
});
}
}
}
// Convert each tool call to a functionCall part
for tc in tool_calls {
let args = serde_json::from_str::<Value>(&tc.function.arguments)
.unwrap_or_else(|_| serde_json::json!({}));
parts.push(GeminiPart {
text: None,
inline_data: None,
function_call: Some(GeminiFunctionCall {
name: tc.function.name.clone(),
args,
}),
function_response: None,
});
}
contents.push(GeminiContent {
parts,
role: "model".to_string(),
});
continue;
}
}
// Regular text/image messages
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
ContentPart::Text { text } => {
parts.push(GeminiPart {
text: Some(text),
inline_data: None,
function_call: None,
function_response: None,
});
}
ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input
.to_base64()
.await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
parts.push(GeminiPart {
text: None,
inline_data: Some(GeminiInlineData {
mime_type,
data: base64_data,
}),
function_call: None,
function_response: None,
});
}
}
}
let role = match msg.role.as_str() {
"assistant" => "model".to_string(),
_ => "user".to_string(),
};
contents.push(GeminiContent { parts, role });
}
Ok(contents)
}
/// Convert OpenAI tools to Gemini function declarations.
fn convert_tools(request: &UnifiedRequest) -> Option<Vec<GeminiTool>> {
request.tools.as_ref().map(|tools| {
let declarations: Vec<GeminiFunctionDeclaration> = tools
.iter()
.map(|t| GeminiFunctionDeclaration {
name: t.function.name.clone(),
description: t.function.description.clone(),
parameters: t.function.parameters.clone(),
})
.collect();
vec![GeminiTool {
function_declarations: declarations,
}]
})
}
/// Convert OpenAI tool_choice to Gemini tool_config.
fn convert_tool_config(request: &UnifiedRequest) -> Option<GeminiToolConfig> {
request.tool_choice.as_ref().map(|tc| {
let (mode, allowed_names) = match tc {
crate::models::ToolChoice::Mode(mode) => {
let gemini_mode = match mode.as_str() {
"auto" => "AUTO",
"none" => "NONE",
"required" => "ANY",
_ => "AUTO",
};
(gemini_mode.to_string(), None)
}
crate::models::ToolChoice::Specific(specific) => {
("ANY".to_string(), Some(vec![specific.function.name.clone()]))
}
};
GeminiToolConfig {
function_calling_config: GeminiFunctionCallingConfig {
mode,
allowed_function_names: allowed_names,
},
}
})
}
/// Extract tool calls from Gemini response parts into OpenAI-format ToolCalls.
fn extract_tool_calls(parts: &[GeminiPart]) -> Option<Vec<ToolCall>> {
let calls: Vec<ToolCall> = parts
.iter()
.filter_map(|p| p.function_call.as_ref())
.map(|fc| ToolCall {
id: format!("call_{}", Uuid::new_v4().simple()),
call_type: "function".to_string(),
function: FunctionCall {
name: fc.name.clone(),
arguments: serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string()),
},
})
.collect();
if calls.is_empty() { None } else { Some(calls) }
}
/// Extract tool call deltas from Gemini response parts for streaming.
fn extract_tool_call_deltas(parts: &[GeminiPart]) -> Option<Vec<ToolCallDelta>> {
let deltas: Vec<ToolCallDelta> = parts
.iter()
.filter_map(|p| p.function_call.as_ref())
.enumerate()
.map(|(i, fc)| ToolCallDelta {
index: i as u32,
id: Some(format!("call_{}", Uuid::new_v4().simple())),
call_type: Some("function".to_string()),
function: Some(FunctionCallDelta {
name: Some(fc.name.clone()),
arguments: Some(serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string())),
}),
})
.collect();
if deltas.is_empty() { None } else { Some(deltas) }
}
} }
#[async_trait] #[async_trait]
@@ -99,51 +375,15 @@ impl super::Provider for GeminiProvider {
} }
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> { async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
// Convert UnifiedRequest to Gemini request let model = request.model.clone();
let mut contents = Vec::with_capacity(request.messages.len()); let tools = Self::convert_tools(&request);
let tool_config = Self::convert_tool_config(&request);
for msg in request.messages { let contents = Self::convert_messages(request.messages).await?;
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
crate::models::ContentPart::Text { text } => {
parts.push(GeminiPart {
text: Some(text),
inline_data: None,
});
}
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input
.to_base64()
.await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
parts.push(GeminiPart {
text: None,
inline_data: Some(GeminiInlineData {
mime_type,
data: base64_data,
}),
});
}
}
}
// Map role: "user" -> "user", "assistant" -> "model", "system" -> "user"
let role = match msg.role.as_str() {
"assistant" => "model".to_string(),
_ => "user".to_string(),
};
contents.push(GeminiContent { parts, role });
}
if contents.is_empty() { if contents.is_empty() {
return Err(AppError::ProviderError("No valid text messages to send".to_string())); return Err(AppError::ProviderError("No valid messages to send".to_string()));
} }
// Build generation config
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() { let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
Some(GeminiGenerationConfig { Some(GeminiGenerationConfig {
temperature: request.temperature, temperature: request.temperature,
@@ -156,12 +396,12 @@ impl super::Provider for GeminiProvider {
let gemini_request = GeminiRequest { let gemini_request = GeminiRequest {
contents, contents,
generation_config, generation_config,
tools,
tool_config,
}; };
// Build URL let url = format!("{}/models/{}:generateContent", self.config.base_url, model);
let url = format!("{}/models/{}:generateContent", self.config.base_url, request.model,);
// Send request
let response = self let response = self
.client .client
.post(&url) .post(&url)
@@ -171,7 +411,6 @@ impl super::Provider for GeminiProvider {
.await .await
.map_err(|e| AppError::ProviderError(format!("HTTP request failed: {}", e)))?; .map_err(|e| AppError::ProviderError(format!("HTTP request failed: {}", e)))?;
// Check status
let status = response.status(); let status = response.status();
if !status.is_success() { if !status.is_success() {
let error_text = response.text().await.unwrap_or_default(); let error_text = response.text().await.unwrap_or_default();
@@ -186,15 +425,16 @@ impl super::Provider for GeminiProvider {
.await .await
.map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?; .map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?;
// Extract content from first candidate let candidate = gemini_response.candidates.first();
let content = gemini_response
.candidates // Extract text content (may be absent if only function calls)
.first() let content = candidate
.and_then(|c| c.content.parts.first()) .and_then(|c| c.content.parts.iter().find_map(|p| p.text.clone()))
.and_then(|p| p.text.clone())
.unwrap_or_default(); .unwrap_or_default();
// Extract token usage // Extract function calls → OpenAI tool_calls
let tool_calls = candidate.and_then(|c| Self::extract_tool_calls(&c.content.parts));
let prompt_tokens = gemini_response let prompt_tokens = gemini_response
.usage_metadata .usage_metadata
.as_ref() .as_ref()
@@ -213,11 +453,12 @@ impl super::Provider for GeminiProvider {
Ok(ProviderResponse { Ok(ProviderResponse {
content, content,
reasoning_content: None, // Gemini doesn't use this field name reasoning_content: None,
tool_calls,
prompt_tokens, prompt_tokens,
completion_tokens, completion_tokens,
total_tokens, total_tokens,
model: request.model, model,
}) })
} }
@@ -247,47 +488,11 @@ impl super::Provider for GeminiProvider {
&self, &self,
request: UnifiedRequest, request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> { ) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
// Convert UnifiedRequest to Gemini request let model = request.model.clone();
let mut contents = Vec::with_capacity(request.messages.len()); let tools = Self::convert_tools(&request);
let tool_config = Self::convert_tool_config(&request);
let contents = Self::convert_messages(request.messages).await?;
for msg in request.messages {
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
crate::models::ContentPart::Text { text } => {
parts.push(GeminiPart {
text: Some(text),
inline_data: None,
});
}
crate::models::ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input
.to_base64()
.await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
parts.push(GeminiPart {
text: None,
inline_data: Some(GeminiInlineData {
mime_type,
data: base64_data,
}),
});
}
}
}
// Map role
let role = match msg.role.as_str() {
"assistant" => "model".to_string(),
_ => "user".to_string(),
};
contents.push(GeminiContent { parts, role });
}
// Build generation config
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() { let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
Some(GeminiGenerationConfig { Some(GeminiGenerationConfig {
temperature: request.temperature, temperature: request.temperature,
@@ -300,15 +505,15 @@ impl super::Provider for GeminiProvider {
let gemini_request = GeminiRequest { let gemini_request = GeminiRequest {
contents, contents,
generation_config, generation_config,
tools,
tool_config,
}; };
// Build URL for streaming
let url = format!( let url = format!(
"{}/models/{}:streamGenerateContent?alt=sse", "{}/models/{}:streamGenerateContent?alt=sse",
self.config.base_url, request.model, self.config.base_url, model,
); );
// Create eventsource stream
use futures::StreamExt; use futures::StreamExt;
use reqwest_eventsource::{Event, EventSource}; use reqwest_eventsource::{Event, EventSource};
@@ -320,8 +525,6 @@ impl super::Provider for GeminiProvider {
) )
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let model = request.model.clone();
let stream = async_stream::try_stream! { let stream = async_stream::try_stream! {
let mut es = es; let mut es = es;
while let Some(event) = es.next().await { while let Some(event) = es.next().await {
@@ -331,14 +534,28 @@ impl super::Provider for GeminiProvider {
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(candidate) = gemini_response.candidates.first() { if let Some(candidate) = gemini_response.candidates.first() {
let content = candidate.content.parts.first() let content = candidate
.and_then(|p| p.text.clone()) .content
.parts
.iter()
.find_map(|p| p.text.clone())
.unwrap_or_default(); .unwrap_or_default();
let tool_calls = Self::extract_tool_call_deltas(&candidate.content.parts);
// Determine finish_reason
let finish_reason = candidate.finish_reason.as_ref().map(|fr| {
match fr.as_str() {
"STOP" => "stop".to_string(),
_ => fr.to_lowercase(),
}
});
yield ProviderStreamChunk { yield ProviderStreamChunk {
content, content,
reasoning_content: None, reasoning_content: None,
finish_reason: None, // Will be set in the last chunk finish_reason,
tool_calls,
model: model.clone(), model: model.clone(),
}; };
} }

View File

@@ -1,6 +1,6 @@
use super::{ProviderResponse, ProviderStreamChunk}; use super::{ProviderResponse, ProviderStreamChunk};
use crate::errors::AppError; use crate::errors::AppError;
use crate::models::{ContentPart, UnifiedMessage, UnifiedRequest}; use crate::models::{ContentPart, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest};
use futures::stream::{BoxStream, StreamExt}; use futures::stream::{BoxStream, StreamExt};
use serde_json::Value; use serde_json::Value;
@@ -8,9 +8,37 @@ use serde_json::Value;
/// ///
/// This avoids the deadlock caused by `futures::executor::block_on` inside a /// This avoids the deadlock caused by `futures::executor::block_on` inside a
/// Tokio async context. All image base64 conversions are awaited properly. /// Tokio async context. All image base64 conversions are awaited properly.
/// Handles tool-calling messages: assistant messages with tool_calls, and
/// tool-role messages with tool_call_id/name.
pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result<Vec<serde_json::Value>, AppError> { pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result<Vec<serde_json::Value>, AppError> {
let mut result = Vec::new(); let mut result = Vec::new();
for m in messages { for m in messages {
// Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." }
if m.role == "tool" {
let text_content = m
.content
.first()
.map(|p| match p {
ContentPart::Text { text } => text.clone(),
ContentPart::Image(_) => "[Image]".to_string(),
})
.unwrap_or_default();
let mut msg = serde_json::json!({
"role": "tool",
"content": text_content
});
if let Some(tool_call_id) = &m.tool_call_id {
msg["tool_call_id"] = serde_json::json!(tool_call_id);
}
if let Some(name) = &m.name {
msg["name"] = serde_json::json!(name);
}
result.push(msg);
continue;
}
// Build content parts for non-tool messages
let mut parts = Vec::new(); let mut parts = Vec::new();
for p in &m.content { for p in &m.content {
match p { match p {
@@ -29,10 +57,26 @@ pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result<Vec<
} }
} }
} }
result.push(serde_json::json!({
"role": m.role, let mut msg = serde_json::json!({ "role": m.role });
"content": parts
})); // For assistant messages with tool_calls, content can be null
if let Some(tool_calls) = &m.tool_calls {
if parts.is_empty() {
msg["content"] = serde_json::Value::Null;
} else {
msg["content"] = serde_json::json!(parts);
}
msg["tool_calls"] = serde_json::json!(tool_calls);
} else {
msg["content"] = serde_json::json!(parts);
}
if let Some(name) = &m.name {
msg["name"] = serde_json::json!(name);
}
result.push(msg);
} }
Ok(result) Ok(result)
} }
@@ -65,6 +109,7 @@ pub async fn messages_to_openai_json_text_only(
} }
/// Build an OpenAI-compatible request body from a UnifiedRequest and pre-converted messages. /// Build an OpenAI-compatible request body from a UnifiedRequest and pre-converted messages.
/// Includes tools and tool_choice when present.
pub fn build_openai_body( pub fn build_openai_body(
request: &UnifiedRequest, request: &UnifiedRequest,
messages_json: Vec<serde_json::Value>, messages_json: Vec<serde_json::Value>,
@@ -82,11 +127,18 @@ pub fn build_openai_body(
if let Some(max_tokens) = request.max_tokens { if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens); body["max_tokens"] = serde_json::json!(max_tokens);
} }
if let Some(tools) = &request.tools {
body["tools"] = serde_json::json!(tools);
}
if let Some(tool_choice) = &request.tool_choice {
body["tool_choice"] = serde_json::json!(tool_choice);
}
body body
} }
/// Parse an OpenAI-compatible chat completion response JSON into a ProviderResponse. /// Parse an OpenAI-compatible chat completion response JSON into a ProviderResponse.
/// Extracts tool_calls from the message when present.
pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<ProviderResponse, AppError> { pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<ProviderResponse, AppError> {
let choice = resp_json["choices"] let choice = resp_json["choices"]
.get(0) .get(0)
@@ -96,6 +148,11 @@ pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<Provide
let content = message["content"].as_str().unwrap_or_default().to_string(); let content = message["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string()); let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
// Parse tool_calls from the response message
let tool_calls: Option<Vec<ToolCall>> = message
.get("tool_calls")
.and_then(|tc| serde_json::from_value(tc.clone()).ok());
let usage = &resp_json["usage"]; let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32; let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32; let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
@@ -104,6 +161,7 @@ pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<Provide
Ok(ProviderResponse { Ok(ProviderResponse {
content, content,
reasoning_content, reasoning_content,
tool_calls,
prompt_tokens, prompt_tokens,
completion_tokens, completion_tokens,
total_tokens, total_tokens,
@@ -115,6 +173,7 @@ pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<Provide
/// ///
/// The optional `reasoning_field` allows overriding the field name for /// The optional `reasoning_field` allows overriding the field name for
/// reasoning content (e.g., "thought" for Ollama). /// reasoning content (e.g., "thought" for Ollama).
/// Parses tool_calls deltas from streaming chunks when present.
pub fn create_openai_stream( pub fn create_openai_stream(
es: reqwest_eventsource::EventSource, es: reqwest_eventsource::EventSource,
model: String, model: String,
@@ -143,10 +202,16 @@ pub fn create_openai_stream(
.map(|s| s.to_string()); .map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string()); let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
// Parse tool_calls deltas from the stream chunk
let tool_calls: Option<Vec<ToolCallDelta>> = delta
.get("tool_calls")
.and_then(|tc| serde_json::from_value(tc.clone()).ok());
yield ProviderStreamChunk { yield ProviderStreamChunk {
content, content,
reasoning_content, reasoning_content,
finish_reason, finish_reason,
tool_calls,
model: model.clone(), model: model.clone(),
}; };
} }

View File

@@ -50,6 +50,7 @@ pub trait Provider: Send + Sync {
pub struct ProviderResponse { pub struct ProviderResponse {
pub content: String, pub content: String,
pub reasoning_content: Option<String>, pub reasoning_content: Option<String>,
pub tool_calls: Option<Vec<crate::models::ToolCall>>,
pub prompt_tokens: u32, pub prompt_tokens: u32,
pub completion_tokens: u32, pub completion_tokens: u32,
pub total_tokens: u32, pub total_tokens: u32,
@@ -61,6 +62,7 @@ pub struct ProviderStreamChunk {
pub content: String, pub content: String,
pub reasoning_content: Option<String>, pub reasoning_content: Option<String>,
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
pub tool_calls: Option<Vec<crate::models::ToolCallDelta>>,
pub model: String, pub model: String,
} }

View File

@@ -36,7 +36,7 @@ impl super::Provider for OpenAIProvider {
} }
fn supports_model(&self, model: &str) -> bool { fn supports_model(&self, model: &str) -> bool {
model.starts_with("gpt-") || model.starts_with("o1-") || model.starts_with("o3-") model.starts_with("gpt-") || model.starts_with("o1-") || model.starts_with("o3-") || model.starts_with("o4-")
} }
fn supports_multimodal(&self) -> bool { fn supports_multimodal(&self) -> bool {

View File

@@ -174,6 +174,7 @@ async fn chat_completions(
role: None, role: None,
content: Some(chunk.content), content: Some(chunk.content),
reasoning_content: chunk.reasoning_content, reasoning_content: chunk.reasoning_content,
tool_calls: chunk.tool_calls,
}, },
finish_reason: chunk.finish_reason, finish_reason: chunk.finish_reason,
}], }],
@@ -248,6 +249,12 @@ async fn chat_completions(
.await; .await;
// Convert ProviderResponse to ChatCompletionResponse // Convert ProviderResponse to ChatCompletionResponse
let finish_reason = if response.tool_calls.is_some() {
"tool_calls".to_string()
} else {
"stop".to_string()
};
let chat_response = ChatCompletionResponse { let chat_response = ChatCompletionResponse {
id: format!("chatcmpl-{}", Uuid::new_v4()), id: format!("chatcmpl-{}", Uuid::new_v4()),
object: "chat.completion".to_string(), object: "chat.completion".to_string(),
@@ -261,8 +268,11 @@ async fn chat_completions(
content: response.content, content: response.content,
}, },
reasoning_content: response.reasoning_content, reasoning_content: response.reasoning_content,
tool_calls: response.tool_calls,
name: None,
tool_call_id: None,
}, },
finish_reason: Some("stop".to_string()), finish_reason: Some(finish_reason),
}], }],
usage: Some(Usage { usage: Some(Usage {
prompt_tokens: response.prompt_tokens, prompt_tokens: response.prompt_tokens,

View File

@@ -1,6 +1,7 @@
use crate::client::ClientManager; use crate::client::ClientManager;
use crate::errors::AppError; use crate::errors::AppError;
use crate::logging::{RequestLog, RequestLogger}; use crate::logging::{RequestLog, RequestLogger};
use crate::models::ToolCall;
use crate::providers::{Provider, ProviderStreamChunk}; use crate::providers::{Provider, ProviderStreamChunk};
use crate::utils::tokens::estimate_completion_tokens; use crate::utils::tokens::estimate_completion_tokens;
use futures::stream::Stream; use futures::stream::Stream;
@@ -31,6 +32,7 @@ pub struct AggregatingStream<S> {
has_images: bool, has_images: bool,
accumulated_content: String, accumulated_content: String,
accumulated_reasoning: String, accumulated_reasoning: String,
accumulated_tool_calls: Vec<ToolCall>,
logger: Arc<RequestLogger>, logger: Arc<RequestLogger>,
client_manager: Arc<ClientManager>, client_manager: Arc<ClientManager>,
model_registry: Arc<crate::models::registry::ModelRegistry>, model_registry: Arc<crate::models::registry::ModelRegistry>,
@@ -53,6 +55,7 @@ where
has_images: config.has_images, has_images: config.has_images,
accumulated_content: String::new(), accumulated_content: String::new(),
accumulated_reasoning: String::new(), accumulated_reasoning: String::new(),
accumulated_tool_calls: Vec::new(),
logger: config.logger, logger: config.logger,
client_manager: config.client_manager, client_manager: config.client_manager,
model_registry: config.model_registry, model_registry: config.model_registry,
@@ -153,6 +156,38 @@ where
if let Some(reasoning) = &chunk.reasoning_content { if let Some(reasoning) = &chunk.reasoning_content {
self.accumulated_reasoning.push_str(reasoning); self.accumulated_reasoning.push_str(reasoning);
} }
// Accumulate tool call deltas into complete tool calls
if let Some(deltas) = &chunk.tool_calls {
for delta in deltas {
let idx = delta.index as usize;
// Grow the accumulated_tool_calls vec if needed
while self.accumulated_tool_calls.len() <= idx {
self.accumulated_tool_calls.push(ToolCall {
id: String::new(),
call_type: "function".to_string(),
function: crate::models::FunctionCall {
name: String::new(),
arguments: String::new(),
},
});
}
let tc = &mut self.accumulated_tool_calls[idx];
if let Some(id) = &delta.id {
tc.id.clone_from(id);
}
if let Some(ct) = &delta.call_type {
tc.call_type.clone_from(ct);
}
if let Some(f) = &delta.function {
if let Some(name) = &f.name {
tc.function.name.push_str(name);
}
if let Some(args) = &f.arguments {
tc.function.arguments.push_str(args);
}
}
}
}
} }
Poll::Ready(Some(Err(_))) => { Poll::Ready(Some(Err(_))) => {
// If there's an error, we might still want to log what we got so far? // If there's an error, we might still want to log what we got so far?
@@ -217,12 +252,14 @@ mod tests {
content: "Hello".to_string(), content: "Hello".to_string(),
reasoning_content: None, reasoning_content: None,
finish_reason: None, finish_reason: None,
tool_calls: None,
model: "test".to_string(), model: "test".to_string(),
}), }),
Ok(ProviderStreamChunk { Ok(ProviderStreamChunk {
content: " World".to_string(), content: " World".to_string(),
reasoning_content: None, reasoning_content: None,
finish_reason: Some("stop".to_string()), finish_reason: Some("stop".to_string()),
tool_calls: None,
model: "test".to_string(), model: "test".to_string(),
}), }),
]; ];