diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 20b2ae7f..ff2de314 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -102,12 +102,12 @@ impl OpenAIProvider { for p in patterns { if let Some(start) = result.find(p) { - // Remove the pattern and any whitespace around it + // Remove the pattern result.replace_range(start..start + p.len(), ""); } } - result.trim().to_string() + result } } @@ -730,6 +730,7 @@ impl super::Provider for OpenAIProvider { let stream = async_stream::try_stream! { let mut es = es; let mut content_buffer = String::new(); + let mut has_tool_calls = false; while let Some(event) = es.next().await { match event { @@ -760,6 +761,7 @@ impl super::Provider for OpenAIProvider { "response.output_item.added" => { if let Some(item) = chunk.get("item") { if item.get("type").and_then(|v| v.as_str()) == Some("function_call") { + has_tool_calls = true; let call_id = item.get("call_id").and_then(|v| v.as_str()); let name = item.get("name").and_then(|v| v.as_str()); @@ -777,6 +779,7 @@ impl super::Provider for OpenAIProvider { } "response.function_call_arguments.delta" => { if let Some(delta) = chunk.get("delta").and_then(|v| v.as_str()) { + has_tool_calls = true; tool_calls = Some(vec![crate::models::ToolCallDelta { index: chunk.get("output_index").and_then(|v| v.as_u64()).unwrap_or(0) as u32, id: None, @@ -788,8 +791,8 @@ impl super::Provider for OpenAIProvider { }]); } } - "response.output_text.done" | "response.item.done" | "response.completed" => { - finish_reason = Some("stop".to_string()); + "response.completed" => { + finish_reason = Some(if has_tool_calls { "tool_calls".to_string() } else { "stop".to_string() }); } _ => {} } @@ -800,6 +803,7 @@ impl super::Provider for OpenAIProvider { if content_buffer.contains("{\"tool_uses\":") { let embedded_calls = Self::parse_tool_uses_json(&content_buffer); if !embedded_calls.is_empty() { + has_tool_calls = true; if let Some(start) = content_buffer.find("{\"tool_uses\":") { // Yield text before the JSON block let preamble = content_buffer[..start].to_string();