diff --git a/src/dashboard/providers.rs b/src/dashboard/providers.rs index 7bee91ee..5d9eb86c 100644 --- a/src/dashboard/providers.rs +++ b/src/dashboard/providers.rs @@ -325,11 +325,16 @@ pub(super) async fn handle_test_provider( messages: vec![crate::models::UnifiedMessage { role: "user".to_string(), content: vec![crate::models::ContentPart::Text { text: "Hi".to_string() }], + tool_calls: None, + name: None, + tool_call_id: None, }], temperature: None, max_tokens: Some(5), stream: false, has_images: false, + tools: None, + tool_choice: None, }; match provider.chat_completion(test_request).await { diff --git a/src/models/mod.rs b/src/models/mod.rs index 29d56c2f..d632d9d7 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,4 +1,5 @@ use serde::{Deserialize, Serialize}; +use serde_json::Value; pub mod registry; @@ -14,16 +15,25 @@ pub struct ChatCompletionRequest { pub max_tokens: Option, #[serde(default)] pub stream: Option, - // Add other OpenAI-compatible fields as needed + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatMessage { - pub role: String, // "system", "user", "assistant" + pub role: String, // "system", "user", "assistant", "tool" #[serde(flatten)] pub content: MessageContent, #[serde(skip_serializing_if = "Option::is_none")] pub reasoning_content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -48,6 +58,78 @@ pub struct ImageUrl { pub detail: Option, } +// ========== Tool-Calling Types ========== + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Tool { + #[serde(rename = "type")] + pub tool_type: String, + pub function: FunctionDef, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionDef { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToolChoice { + Mode(String), // "auto", "none", "required" + Specific(ToolChoiceSpecific), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolChoiceSpecific { + #[serde(rename = "type")] + pub choice_type: String, + pub function: ToolChoiceFunction, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolChoiceFunction { + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + pub id: String, + #[serde(rename = "type")] + pub call_type: String, + pub function: FunctionCall, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionCall { + pub name: String, + pub arguments: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallDelta { + pub index: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub call_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub function: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionCallDelta { + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, +} + +// ========== OpenAI-compatible Response Structs ========== + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionResponse { pub id: String, @@ -96,6 +178,8 @@ pub struct ChatStreamDelta { pub content: Option, #[serde(skip_serializing_if = "Option::is_none")] pub reasoning_content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, } // ========== Unified Request Format (for internal use) ========== @@ -109,12 +193,17 @@ pub struct UnifiedRequest { pub max_tokens: Option, pub stream: bool, pub has_images: bool, + pub tools: Option>, + pub tool_choice: Option, } #[derive(Debug, Clone)] pub struct UnifiedMessage { pub role: String, pub content: Vec, + pub tool_calls: Option>, + pub name: Option, + pub tool_call_id: Option, } #[derive(Debug, Clone)] @@ -226,6 +315,9 @@ impl TryFrom for UnifiedRequest { UnifiedMessage { role: msg.role, content, + tool_calls: msg.tool_calls, + name: msg.name, + tool_call_id: msg.tool_call_id, } }) .collect(); @@ -238,6 +330,8 @@ impl TryFrom for UnifiedRequest { max_tokens: req.max_tokens, stream: req.stream.unwrap_or(false), has_images, + tools: req.tools, + tool_choice: req.tool_choice, }) } } diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index 3fc43084..7939d078 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -2,14 +2,28 @@ use anyhow::Result; use async_trait::async_trait; use futures::stream::BoxStream; use serde::{Deserialize, Serialize}; +use serde_json::Value; +use uuid::Uuid; use super::{ProviderResponse, ProviderStreamChunk}; -use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest}; +use crate::{ + config::AppConfig, + errors::AppError, + models::{ContentPart, FunctionCall, FunctionCallDelta, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest}, +}; + +// ========== Gemini Request Structs ========== #[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] struct GeminiRequest { contents: Vec, + #[serde(skip_serializing_if = "Option::is_none")] generation_config: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_config: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -19,11 +33,16 @@ struct GeminiContent { } #[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] struct GeminiPart { #[serde(skip_serializing_if = "Option::is_none")] text: Option, #[serde(skip_serializing_if = "Option::is_none")] inline_data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + function_call: Option, + #[serde(skip_serializing_if = "Option::is_none")] + function_response: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -32,31 +51,85 @@ struct GeminiInlineData { data: String, } +#[derive(Debug, Serialize, Deserialize)] +struct GeminiFunctionCall { + name: String, + args: Value, +} + +#[derive(Debug, Serialize, Deserialize)] +struct GeminiFunctionResponse { + name: String, + response: Value, +} + #[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] struct GeminiGenerationConfig { temperature: Option, max_output_tokens: Option, } +// ========== Gemini Tool Structs ========== + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct GeminiTool { + function_declarations: Vec, +} + +#[derive(Debug, Serialize)] +struct GeminiFunctionDeclaration { + name: String, + #[serde(skip_serializing_if = "Option::is_none")] + description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + parameters: Option, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct GeminiToolConfig { + function_calling_config: GeminiFunctionCallingConfig, +} + +#[derive(Debug, Serialize)] +struct GeminiFunctionCallingConfig { + mode: String, + #[serde(skip_serializing_if = "Option::is_none", rename = "allowedFunctionNames")] + allowed_function_names: Option>, +} + +// ========== Gemini Response Structs ========== + #[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] struct GeminiCandidate { content: GeminiContent, - _finish_reason: Option, + #[serde(default)] + finish_reason: Option, } #[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] struct GeminiUsageMetadata { + #[serde(default)] prompt_token_count: u32, + #[serde(default)] candidates_token_count: u32, + #[serde(default)] total_token_count: u32, } #[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] struct GeminiResponse { candidates: Vec, usage_metadata: Option, } +// ========== Provider Implementation ========== + pub struct GeminiProvider { client: reqwest::Client, config: crate::config::GeminiConfig, @@ -82,6 +155,209 @@ impl GeminiProvider { pricing: app_config.pricing.gemini.clone(), }) } + + /// Convert unified messages to Gemini content format. + /// Handles text, images, tool calls (assistant), and tool results. + async fn convert_messages(messages: Vec) -> Result, AppError> { + let mut contents = Vec::with_capacity(messages.len()); + + for msg in messages { + // Tool-result messages → functionResponse parts under role "user" + if msg.role == "tool" { + let text_content = msg + .content + .first() + .map(|p| match p { + ContentPart::Text { text } => text.clone(), + ContentPart::Image(_) => "[Image]".to_string(), + }) + .unwrap_or_default(); + + let name = msg.name.unwrap_or_default(); + + // Parse the content as JSON if possible, otherwise wrap as string + let response_value = serde_json::from_str::(&text_content) + .unwrap_or_else(|_| serde_json::json!({ "result": text_content })); + + contents.push(GeminiContent { + parts: vec![GeminiPart { + text: None, + inline_data: None, + function_call: None, + function_response: Some(GeminiFunctionResponse { + name, + response: response_value, + }), + }], + role: "user".to_string(), + }); + continue; + } + + // Assistant messages with tool_calls → functionCall parts + if msg.role == "assistant" { + if let Some(tool_calls) = &msg.tool_calls { + let mut parts = Vec::new(); + + // Include text content if present + for p in &msg.content { + if let ContentPart::Text { text } = p { + if !text.is_empty() { + parts.push(GeminiPart { + text: Some(text.clone()), + inline_data: None, + function_call: None, + function_response: None, + }); + } + } + } + + // Convert each tool call to a functionCall part + for tc in tool_calls { + let args = serde_json::from_str::(&tc.function.arguments) + .unwrap_or_else(|_| serde_json::json!({})); + parts.push(GeminiPart { + text: None, + inline_data: None, + function_call: Some(GeminiFunctionCall { + name: tc.function.name.clone(), + args, + }), + function_response: None, + }); + } + + contents.push(GeminiContent { + parts, + role: "model".to_string(), + }); + continue; + } + } + + // Regular text/image messages + let mut parts = Vec::with_capacity(msg.content.len()); + for part in msg.content { + match part { + ContentPart::Text { text } => { + parts.push(GeminiPart { + text: Some(text), + inline_data: None, + function_call: None, + function_response: None, + }); + } + ContentPart::Image(image_input) => { + let (base64_data, mime_type) = image_input + .to_base64() + .await + .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; + + parts.push(GeminiPart { + text: None, + inline_data: Some(GeminiInlineData { + mime_type, + data: base64_data, + }), + function_call: None, + function_response: None, + }); + } + } + } + + let role = match msg.role.as_str() { + "assistant" => "model".to_string(), + _ => "user".to_string(), + }; + + contents.push(GeminiContent { parts, role }); + } + + Ok(contents) + } + + /// Convert OpenAI tools to Gemini function declarations. + fn convert_tools(request: &UnifiedRequest) -> Option> { + request.tools.as_ref().map(|tools| { + let declarations: Vec = tools + .iter() + .map(|t| GeminiFunctionDeclaration { + name: t.function.name.clone(), + description: t.function.description.clone(), + parameters: t.function.parameters.clone(), + }) + .collect(); + vec![GeminiTool { + function_declarations: declarations, + }] + }) + } + + /// Convert OpenAI tool_choice to Gemini tool_config. + fn convert_tool_config(request: &UnifiedRequest) -> Option { + request.tool_choice.as_ref().map(|tc| { + let (mode, allowed_names) = match tc { + crate::models::ToolChoice::Mode(mode) => { + let gemini_mode = match mode.as_str() { + "auto" => "AUTO", + "none" => "NONE", + "required" => "ANY", + _ => "AUTO", + }; + (gemini_mode.to_string(), None) + } + crate::models::ToolChoice::Specific(specific) => { + ("ANY".to_string(), Some(vec![specific.function.name.clone()])) + } + }; + GeminiToolConfig { + function_calling_config: GeminiFunctionCallingConfig { + mode, + allowed_function_names: allowed_names, + }, + } + }) + } + + /// Extract tool calls from Gemini response parts into OpenAI-format ToolCalls. + fn extract_tool_calls(parts: &[GeminiPart]) -> Option> { + let calls: Vec = parts + .iter() + .filter_map(|p| p.function_call.as_ref()) + .map(|fc| ToolCall { + id: format!("call_{}", Uuid::new_v4().simple()), + call_type: "function".to_string(), + function: FunctionCall { + name: fc.name.clone(), + arguments: serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string()), + }, + }) + .collect(); + + if calls.is_empty() { None } else { Some(calls) } + } + + /// Extract tool call deltas from Gemini response parts for streaming. + fn extract_tool_call_deltas(parts: &[GeminiPart]) -> Option> { + let deltas: Vec = parts + .iter() + .filter_map(|p| p.function_call.as_ref()) + .enumerate() + .map(|(i, fc)| ToolCallDelta { + index: i as u32, + id: Some(format!("call_{}", Uuid::new_v4().simple())), + call_type: Some("function".to_string()), + function: Some(FunctionCallDelta { + name: Some(fc.name.clone()), + arguments: Some(serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string())), + }), + }) + .collect(); + + if deltas.is_empty() { None } else { Some(deltas) } + } } #[async_trait] @@ -99,51 +375,15 @@ impl super::Provider for GeminiProvider { } async fn chat_completion(&self, request: UnifiedRequest) -> Result { - // Convert UnifiedRequest to Gemini request - let mut contents = Vec::with_capacity(request.messages.len()); - - for msg in request.messages { - let mut parts = Vec::with_capacity(msg.content.len()); - - for part in msg.content { - match part { - crate::models::ContentPart::Text { text } => { - parts.push(GeminiPart { - text: Some(text), - inline_data: None, - }); - } - crate::models::ContentPart::Image(image_input) => { - let (base64_data, mime_type) = image_input - .to_base64() - .await - .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; - - parts.push(GeminiPart { - text: None, - inline_data: Some(GeminiInlineData { - mime_type, - data: base64_data, - }), - }); - } - } - } - - // Map role: "user" -> "user", "assistant" -> "model", "system" -> "user" - let role = match msg.role.as_str() { - "assistant" => "model".to_string(), - _ => "user".to_string(), - }; - - contents.push(GeminiContent { parts, role }); - } + let model = request.model.clone(); + let tools = Self::convert_tools(&request); + let tool_config = Self::convert_tool_config(&request); + let contents = Self::convert_messages(request.messages).await?; if contents.is_empty() { - return Err(AppError::ProviderError("No valid text messages to send".to_string())); + return Err(AppError::ProviderError("No valid messages to send".to_string())); } - // Build generation config let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() { Some(GeminiGenerationConfig { temperature: request.temperature, @@ -156,12 +396,12 @@ impl super::Provider for GeminiProvider { let gemini_request = GeminiRequest { contents, generation_config, + tools, + tool_config, }; - // Build URL - let url = format!("{}/models/{}:generateContent", self.config.base_url, request.model,); + let url = format!("{}/models/{}:generateContent", self.config.base_url, model); - // Send request let response = self .client .post(&url) @@ -171,7 +411,6 @@ impl super::Provider for GeminiProvider { .await .map_err(|e| AppError::ProviderError(format!("HTTP request failed: {}", e)))?; - // Check status let status = response.status(); if !status.is_success() { let error_text = response.text().await.unwrap_or_default(); @@ -186,15 +425,16 @@ impl super::Provider for GeminiProvider { .await .map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?; - // Extract content from first candidate - let content = gemini_response - .candidates - .first() - .and_then(|c| c.content.parts.first()) - .and_then(|p| p.text.clone()) + let candidate = gemini_response.candidates.first(); + + // Extract text content (may be absent if only function calls) + let content = candidate + .and_then(|c| c.content.parts.iter().find_map(|p| p.text.clone())) .unwrap_or_default(); - // Extract token usage + // Extract function calls → OpenAI tool_calls + let tool_calls = candidate.and_then(|c| Self::extract_tool_calls(&c.content.parts)); + let prompt_tokens = gemini_response .usage_metadata .as_ref() @@ -213,11 +453,12 @@ impl super::Provider for GeminiProvider { Ok(ProviderResponse { content, - reasoning_content: None, // Gemini doesn't use this field name + reasoning_content: None, + tool_calls, prompt_tokens, completion_tokens, total_tokens, - model: request.model, + model, }) } @@ -247,47 +488,11 @@ impl super::Provider for GeminiProvider { &self, request: UnifiedRequest, ) -> Result>, AppError> { - // Convert UnifiedRequest to Gemini request - let mut contents = Vec::with_capacity(request.messages.len()); + let model = request.model.clone(); + let tools = Self::convert_tools(&request); + let tool_config = Self::convert_tool_config(&request); + let contents = Self::convert_messages(request.messages).await?; - for msg in request.messages { - let mut parts = Vec::with_capacity(msg.content.len()); - - for part in msg.content { - match part { - crate::models::ContentPart::Text { text } => { - parts.push(GeminiPart { - text: Some(text), - inline_data: None, - }); - } - crate::models::ContentPart::Image(image_input) => { - let (base64_data, mime_type) = image_input - .to_base64() - .await - .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; - - parts.push(GeminiPart { - text: None, - inline_data: Some(GeminiInlineData { - mime_type, - data: base64_data, - }), - }); - } - } - } - - // Map role - let role = match msg.role.as_str() { - "assistant" => "model".to_string(), - _ => "user".to_string(), - }; - - contents.push(GeminiContent { parts, role }); - } - - // Build generation config let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() { Some(GeminiGenerationConfig { temperature: request.temperature, @@ -300,15 +505,15 @@ impl super::Provider for GeminiProvider { let gemini_request = GeminiRequest { contents, generation_config, + tools, + tool_config, }; - // Build URL for streaming let url = format!( "{}/models/{}:streamGenerateContent?alt=sse", - self.config.base_url, request.model, + self.config.base_url, model, ); - // Create eventsource stream use futures::StreamExt; use reqwest_eventsource::{Event, EventSource}; @@ -320,8 +525,6 @@ impl super::Provider for GeminiProvider { ) .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; - let model = request.model.clone(); - let stream = async_stream::try_stream! { let mut es = es; while let Some(event) = es.next().await { @@ -331,14 +534,28 @@ impl super::Provider for GeminiProvider { .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; if let Some(candidate) = gemini_response.candidates.first() { - let content = candidate.content.parts.first() - .and_then(|p| p.text.clone()) + let content = candidate + .content + .parts + .iter() + .find_map(|p| p.text.clone()) .unwrap_or_default(); + let tool_calls = Self::extract_tool_call_deltas(&candidate.content.parts); + + // Determine finish_reason + let finish_reason = candidate.finish_reason.as_ref().map(|fr| { + match fr.as_str() { + "STOP" => "stop".to_string(), + _ => fr.to_lowercase(), + } + }); + yield ProviderStreamChunk { content, reasoning_content: None, - finish_reason: None, // Will be set in the last chunk + finish_reason, + tool_calls, model: model.clone(), }; } diff --git a/src/providers/helpers.rs b/src/providers/helpers.rs index c0542f53..11f686aa 100644 --- a/src/providers/helpers.rs +++ b/src/providers/helpers.rs @@ -1,6 +1,6 @@ use super::{ProviderResponse, ProviderStreamChunk}; use crate::errors::AppError; -use crate::models::{ContentPart, UnifiedMessage, UnifiedRequest}; +use crate::models::{ContentPart, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest}; use futures::stream::{BoxStream, StreamExt}; use serde_json::Value; @@ -8,9 +8,37 @@ use serde_json::Value; /// /// This avoids the deadlock caused by `futures::executor::block_on` inside a /// Tokio async context. All image base64 conversions are awaited properly. +/// Handles tool-calling messages: assistant messages with tool_calls, and +/// tool-role messages with tool_call_id/name. pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result, AppError> { let mut result = Vec::new(); for m in messages { + // Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." } + if m.role == "tool" { + let text_content = m + .content + .first() + .map(|p| match p { + ContentPart::Text { text } => text.clone(), + ContentPart::Image(_) => "[Image]".to_string(), + }) + .unwrap_or_default(); + + let mut msg = serde_json::json!({ + "role": "tool", + "content": text_content + }); + if let Some(tool_call_id) = &m.tool_call_id { + msg["tool_call_id"] = serde_json::json!(tool_call_id); + } + if let Some(name) = &m.name { + msg["name"] = serde_json::json!(name); + } + result.push(msg); + continue; + } + + // Build content parts for non-tool messages let mut parts = Vec::new(); for p in &m.content { match p { @@ -29,10 +57,26 @@ pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result, @@ -82,11 +127,18 @@ pub fn build_openai_body( if let Some(max_tokens) = request.max_tokens { body["max_tokens"] = serde_json::json!(max_tokens); } + if let Some(tools) = &request.tools { + body["tools"] = serde_json::json!(tools); + } + if let Some(tool_choice) = &request.tool_choice { + body["tool_choice"] = serde_json::json!(tool_choice); + } body } /// Parse an OpenAI-compatible chat completion response JSON into a ProviderResponse. +/// Extracts tool_calls from the message when present. pub fn parse_openai_response(resp_json: &Value, model: String) -> Result { let choice = resp_json["choices"] .get(0) @@ -96,6 +148,11 @@ pub fn parse_openai_response(resp_json: &Value, model: String) -> Result> = message + .get("tool_calls") + .and_then(|tc| serde_json::from_value(tc.clone()).ok()); + let usage = &resp_json["usage"]; let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32; let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32; @@ -104,6 +161,7 @@ pub fn parse_openai_response(resp_json: &Value, model: String) -> Result Result> = delta + .get("tool_calls") + .and_then(|tc| serde_json::from_value(tc.clone()).ok()); + yield ProviderStreamChunk { content, reasoning_content, finish_reason, + tool_calls, model: model.clone(), }; } diff --git a/src/providers/mod.rs b/src/providers/mod.rs index d64b3d79..446ea643 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -50,6 +50,7 @@ pub trait Provider: Send + Sync { pub struct ProviderResponse { pub content: String, pub reasoning_content: Option, + pub tool_calls: Option>, pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, @@ -61,6 +62,7 @@ pub struct ProviderStreamChunk { pub content: String, pub reasoning_content: Option, pub finish_reason: Option, + pub tool_calls: Option>, pub model: String, } diff --git a/src/providers/openai.rs b/src/providers/openai.rs index ca14e8f2..08458d3f 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -36,7 +36,7 @@ impl super::Provider for OpenAIProvider { } fn supports_model(&self, model: &str) -> bool { - model.starts_with("gpt-") || model.starts_with("o1-") || model.starts_with("o3-") + model.starts_with("gpt-") || model.starts_with("o1-") || model.starts_with("o3-") || model.starts_with("o4-") } fn supports_multimodal(&self) -> bool { diff --git a/src/server/mod.rs b/src/server/mod.rs index fe3a3888..08aee338 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -174,6 +174,7 @@ async fn chat_completions( role: None, content: Some(chunk.content), reasoning_content: chunk.reasoning_content, + tool_calls: chunk.tool_calls, }, finish_reason: chunk.finish_reason, }], @@ -248,6 +249,12 @@ async fn chat_completions( .await; // Convert ProviderResponse to ChatCompletionResponse + let finish_reason = if response.tool_calls.is_some() { + "tool_calls".to_string() + } else { + "stop".to_string() + }; + let chat_response = ChatCompletionResponse { id: format!("chatcmpl-{}", Uuid::new_v4()), object: "chat.completion".to_string(), @@ -261,8 +268,11 @@ async fn chat_completions( content: response.content, }, reasoning_content: response.reasoning_content, + tool_calls: response.tool_calls, + name: None, + tool_call_id: None, }, - finish_reason: Some("stop".to_string()), + finish_reason: Some(finish_reason), }], usage: Some(Usage { prompt_tokens: response.prompt_tokens, diff --git a/src/utils/streaming.rs b/src/utils/streaming.rs index 4f557a76..eaa20467 100644 --- a/src/utils/streaming.rs +++ b/src/utils/streaming.rs @@ -1,6 +1,7 @@ use crate::client::ClientManager; use crate::errors::AppError; use crate::logging::{RequestLog, RequestLogger}; +use crate::models::ToolCall; use crate::providers::{Provider, ProviderStreamChunk}; use crate::utils::tokens::estimate_completion_tokens; use futures::stream::Stream; @@ -31,6 +32,7 @@ pub struct AggregatingStream { has_images: bool, accumulated_content: String, accumulated_reasoning: String, + accumulated_tool_calls: Vec, logger: Arc, client_manager: Arc, model_registry: Arc, @@ -53,6 +55,7 @@ where has_images: config.has_images, accumulated_content: String::new(), accumulated_reasoning: String::new(), + accumulated_tool_calls: Vec::new(), logger: config.logger, client_manager: config.client_manager, model_registry: config.model_registry, @@ -153,6 +156,38 @@ where if let Some(reasoning) = &chunk.reasoning_content { self.accumulated_reasoning.push_str(reasoning); } + // Accumulate tool call deltas into complete tool calls + if let Some(deltas) = &chunk.tool_calls { + for delta in deltas { + let idx = delta.index as usize; + // Grow the accumulated_tool_calls vec if needed + while self.accumulated_tool_calls.len() <= idx { + self.accumulated_tool_calls.push(ToolCall { + id: String::new(), + call_type: "function".to_string(), + function: crate::models::FunctionCall { + name: String::new(), + arguments: String::new(), + }, + }); + } + let tc = &mut self.accumulated_tool_calls[idx]; + if let Some(id) = &delta.id { + tc.id.clone_from(id); + } + if let Some(ct) = &delta.call_type { + tc.call_type.clone_from(ct); + } + if let Some(f) = &delta.function { + if let Some(name) = &f.name { + tc.function.name.push_str(name); + } + if let Some(args) = &f.arguments { + tc.function.arguments.push_str(args); + } + } + } + } } Poll::Ready(Some(Err(_))) => { // If there's an error, we might still want to log what we got so far? @@ -217,12 +252,14 @@ mod tests { content: "Hello".to_string(), reasoning_content: None, finish_reason: None, + tool_calls: None, model: "test".to_string(), }), Ok(ProviderStreamChunk { content: " World".to_string(), reasoning_content: None, finish_reason: Some("stop".to_string()), + tool_calls: None, model: "test".to_string(), }), ];