Files
GopherGate/src/providers/gemini.rs
hobokenchicken 6440e8cc13
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled
fix(gemini): ensure final finish_reason is 'tool_calls' if any tools were seen
Gemini often sends tool calls in one chunk and then 'STOP' in a final chunk.
If we pass the raw 'stop' at the end, clients stop and ignore the previously
received tool calls. We now track if any tools were seen and override the
final 'stop' to 'tool_calls'.
2026-03-05 17:50:25 +00:00

1018 lines
41 KiB
Rust

use anyhow::Result;
use async_trait::async_trait;
use futures::stream::{BoxStream, StreamExt};
use reqwest_eventsource::Event;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use uuid::Uuid;
use super::{ProviderResponse, ProviderStreamChunk};
use crate::{
config::AppConfig,
errors::AppError,
models::{ContentPart, FunctionCall, FunctionCallDelta, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest},
};
// ========== Gemini Request Structs ==========
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiRequest {
contents: Vec<GeminiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
system_instruction: Option<GeminiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
generation_config: Option<GeminiGenerationConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<GeminiTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_config: Option<GeminiToolConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
safety_settings: Option<Vec<GeminiSafetySetting>>,
}
#[derive(Debug, Clone, Serialize)]
struct GeminiSafetySetting {
category: String,
threshold: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct GeminiContent {
parts: Vec<GeminiPart>,
#[serde(skip_serializing_if = "Option::is_none")]
role: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiPart {
#[serde(skip_serializing_if = "Option::is_none")]
text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
inline_data: Option<GeminiInlineData>,
#[serde(skip_serializing_if = "Option::is_none")]
function_call: Option<GeminiFunctionCall>,
#[serde(skip_serializing_if = "Option::is_none")]
function_response: Option<GeminiFunctionResponse>,
#[serde(skip_serializing_if = "Option::is_none")]
thought: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename = "thought_signature")]
thought_signature_snake: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename = "thoughtSignature")]
thought_signature_camel: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct GeminiInlineData {
mime_type: String,
data: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiFunctionCall {
name: String,
args: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct GeminiFunctionResponse {
name: String,
response: Value,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiGenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_output_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
candidate_count: Option<u32>,
}
// ========== Gemini Tool Structs ==========
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiTool {
function_declarations: Vec<GeminiFunctionDeclaration>,
}
#[derive(Debug, Clone, Serialize)]
struct GeminiFunctionDeclaration {
name: String,
#[serde(skip_serializing_if = "Option::is_none")]
description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
parameters: Option<Value>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiToolConfig {
function_calling_config: GeminiFunctionCallingConfig,
}
#[derive(Debug, Clone, Serialize)]
struct GeminiFunctionCallingConfig {
mode: String,
#[serde(skip_serializing_if = "Option::is_none", rename = "allowedFunctionNames")]
allowed_function_names: Option<Vec<String>>,
}
// ========== Gemini Response Structs ==========
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiCandidate {
content: GeminiContent,
#[serde(default)]
#[allow(dead_code)]
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiUsageMetadata {
#[serde(default)]
prompt_token_count: u32,
#[serde(default)]
candidates_token_count: u32,
#[serde(default)]
total_token_count: u32,
#[serde(default)]
cached_content_token_count: u32,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiResponse {
candidates: Vec<GeminiCandidate>,
usage_metadata: Option<GeminiUsageMetadata>,
}
// Streaming responses from Gemini may include messages without `candidates` (e.g. promptFeedback).
// Use a more permissive struct for streaming to avoid aborting the SSE prematurely.
#[derive(Debug, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
struct GeminiStreamResponse {
#[serde(default)]
candidates: Vec<GeminiStreamCandidate>,
#[serde(default)]
usage_metadata: Option<GeminiUsageMetadata>,
}
#[derive(Debug, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
struct GeminiStreamCandidate {
#[serde(default)]
content: Option<GeminiContent>,
#[serde(default)]
finish_reason: Option<String>,
}
// ========== Provider Implementation ==========
pub struct GeminiProvider {
client: reqwest::Client,
config: crate::config::GeminiConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>,
}
impl GeminiProvider {
pub fn new(config: &crate::config::GeminiConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("gemini")?;
Self::new_with_key(config, app_config, api_key)
}
pub fn new_with_key(config: &crate::config::GeminiConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(300))
.pool_idle_timeout(std::time::Duration::from_secs(90))
.pool_max_idle_per_host(4)
.tcp_keepalive(std::time::Duration::from_secs(30))
.build()?;
Ok(Self {
client,
config: config.clone(),
api_key,
pricing: app_config.pricing.gemini.clone(),
})
}
/// Convert unified messages to Gemini content format.
/// Handles text, images, tool calls (assistant), and tool results.
/// Returns (contents, system_instruction)
async fn convert_messages(
messages: Vec<UnifiedMessage>,
) -> Result<(Vec<GeminiContent>, Option<GeminiContent>), AppError> {
let mut contents: Vec<GeminiContent> = Vec::new();
let mut system_parts = Vec::new();
for msg in messages {
if msg.role == "system" {
for part in msg.content {
if let ContentPart::Text { text } = part {
if !text.trim().is_empty() {
system_parts.push(GeminiPart {
text: Some(text),
inline_data: None,
function_call: None,
function_response: None,
thought: None,
thought_signature_snake: None,
thought_signature_camel: None,
});
}
}
}
continue;
}
let role = match msg.role.as_str() {
"assistant" => "model".to_string(),
"tool" => "user".to_string(), // Tool results are user-side in Gemini
_ => "user".to_string(),
};
let mut parts = Vec::new();
// Handle tool results (role "tool")
if msg.role == "tool" {
let text_content = msg
.content
.first()
.map(|p| match p {
ContentPart::Text { text } => text.clone(),
ContentPart::Image(_) => "[Image]".to_string(),
})
.unwrap_or_default();
let name = msg.name.clone().or_else(|| msg.tool_call_id.clone()).unwrap_or_else(|| "unknown_function".to_string());
// Gemini API requires 'response' to be a JSON object (google.protobuf.Struct).
// If it is an array or primitive, wrap it in an object.
let mut response_value = serde_json::from_str::<Value>(&text_content)
.unwrap_or_else(|_| serde_json::json!({ "result": text_content }));
if !response_value.is_object() {
response_value = serde_json::json!({ "result": response_value });
}
parts.push(GeminiPart {
text: None,
inline_data: None,
function_call: None,
function_response: Some(GeminiFunctionResponse {
name,
response: response_value,
}),
thought: None,
thought_signature_snake: None,
thought_signature_camel: None,
});
} else if msg.role == "assistant" {
// Assistant messages: handle text, thought (reasoning), and tool_calls
for p in &msg.content {
if let ContentPart::Text { text } = p {
if !text.trim().is_empty() {
parts.push(GeminiPart {
text: Some(text.clone()),
inline_data: None,
function_call: None,
function_response: None,
thought: None,
thought_signature_snake: None,
thought_signature_camel: None,
});
}
}
}
// If reasoning_content is present, include it as a 'thought' part
if let Some(reasoning) = &msg.reasoning_content {
if !reasoning.trim().is_empty() {
parts.push(GeminiPart {
text: None,
inline_data: None,
function_call: None,
function_response: None,
thought: Some(reasoning.clone()),
thought_signature_snake: None,
thought_signature_camel: None,
});
}
}
if let Some(tool_calls) = &msg.tool_calls {
for tc in tool_calls {
let args = serde_json::from_str::<Value>(&tc.function.arguments)
.unwrap_or_else(|_| serde_json::json!({}));
// RESTORE: Use tc.id as thought_signature.
// Gemini 3 models require this field for any function call in the history.
// We include it regardless of format to ensure the model has context.
let thought_signature = Some(tc.id.clone());
parts.push(GeminiPart {
text: None,
inline_data: None,
function_call: Some(GeminiFunctionCall {
name: tc.function.name.clone(),
args,
}),
function_response: None,
thought: None,
thought_signature_snake: thought_signature.clone(),
thought_signature_camel: thought_signature,
});
}
}
} else {
// Regular text/image messages (mostly user)
for part in msg.content {
match part {
ContentPart::Text { text } => {
if !text.trim().is_empty() {
parts.push(GeminiPart {
text: Some(text),
inline_data: None,
function_call: None,
function_response: None,
thought: None,
thought_signature_snake: None,
thought_signature_camel: None,
});
}
}
ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input
.to_base64()
.await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
parts.push(GeminiPart {
text: None,
inline_data: Some(GeminiInlineData {
mime_type,
data: base64_data,
}),
function_call: None,
function_response: None,
thought: None,
thought_signature_snake: None,
thought_signature_camel: None,
});
}
}
}
}
if parts.is_empty() {
continue;
}
// STRATEGY: Strictly enforce alternating roles.
if let Some(last_content) = contents.last_mut() {
if last_content.role.as_ref() == Some(&role) {
last_content.parts.extend(parts);
continue;
}
}
contents.push(GeminiContent {
parts,
role: Some(role),
});
}
// Gemini requires the first message to be from "user".
if let Some(first) = contents.first() {
if first.role.as_deref() == Some("model") {
contents.insert(0, GeminiContent {
role: Some("user".to_string()),
parts: vec![GeminiPart {
text: Some("Continue conversation.".to_string()),
inline_data: None,
function_call: None,
function_response: None,
thought: None,
thought_signature_snake: None,
thought_signature_camel: None,
}],
});
}
}
// Final check: ensure we don't have empty contents after filtering.
if contents.is_empty() && system_parts.is_empty() {
return Err(AppError::ProviderError("No valid content parts after filtering".to_string()));
}
let system_instruction = if !system_parts.is_empty() {
Some(GeminiContent {
parts: system_parts,
role: None,
})
} else {
None
};
Ok((contents, system_instruction))
}
/// Convert OpenAI tools to Gemini function declarations.
fn convert_tools(request: &UnifiedRequest) -> Option<Vec<GeminiTool>> {
request.tools.as_ref().map(|tools| {
let declarations: Vec<GeminiFunctionDeclaration> = tools
.iter()
.map(|t| {
let mut parameters = t.function.parameters.clone().unwrap_or(serde_json::json!({
"type": "object",
"properties": {}
}));
Self::sanitize_schema(&mut parameters);
GeminiFunctionDeclaration {
name: t.function.name.clone(),
description: t.function.description.clone(),
parameters: Some(parameters),
}
})
.collect();
vec![GeminiTool {
function_declarations: declarations,
}]
})
}
/// Recursively remove unsupported JSON Schema fields that Gemini's API rejects.
fn sanitize_schema(value: &mut Value) {
if let Value::Object(map) = value {
// Remove unsupported fields at this level
map.remove("$schema");
map.remove("additionalProperties");
map.remove("exclusiveMaximum");
map.remove("exclusiveMinimum");
// Recursively sanitize all object properties
if let Some(properties) = map.get_mut("properties") {
if let Value::Object(props_map) = properties {
for prop_value in props_map.values_mut() {
Self::sanitize_schema(prop_value);
}
}
}
// Recursively sanitize array items
if let Some(items) = map.get_mut("items") {
Self::sanitize_schema(items);
}
// Gemini 1.5/2.0+ supports anyOf in some contexts, but it's often
// the source of additionalProperties errors when nested.
if let Some(any_of) = map.get_mut("anyOf") {
if let Value::Array(arr) = any_of {
for item in arr {
Self::sanitize_schema(item);
}
}
}
if let Some(one_of) = map.get_mut("oneOf") {
if let Value::Array(arr) = one_of {
for item in arr {
Self::sanitize_schema(item);
}
}
}
if let Some(all_of) = map.get_mut("allOf") {
if let Value::Array(arr) = all_of {
for item in arr {
Self::sanitize_schema(item);
}
}
}
}
}
/// Convert OpenAI tool_choice to Gemini tool_config.
fn convert_tool_config(request: &UnifiedRequest) -> Option<GeminiToolConfig> {
request.tool_choice.as_ref().map(|tc| {
let (mode, allowed_names) = match tc {
crate::models::ToolChoice::Mode(mode) => {
let gemini_mode = match mode.as_str() {
"auto" => "AUTO",
"none" => "NONE",
"required" => "ANY",
_ => "AUTO",
};
(gemini_mode.to_string(), None)
}
crate::models::ToolChoice::Specific(specific) => {
("ANY".to_string(), Some(vec![specific.function.name.clone()]))
}
};
GeminiToolConfig {
function_calling_config: GeminiFunctionCallingConfig {
mode,
allowed_function_names: allowed_names,
},
}
})
}
/// Extract tool calls from Gemini response parts into OpenAI-format ToolCalls.
fn extract_tool_calls(parts: &[GeminiPart]) -> Option<Vec<ToolCall>> {
let calls: Vec<ToolCall> = parts
.iter()
.filter(|p| p.function_call.is_some())
.map(|p| {
let fc = p.function_call.as_ref().unwrap();
// CAPTURE: Try extracting thought_signature from sibling fields
let id = p.thought_signature_camel.clone()
.or_else(|| p.thought_signature_snake.clone())
.unwrap_or_else(|| format!("call_{}", Uuid::new_v4().simple()));
ToolCall {
id,
call_type: "function".to_string(),
function: FunctionCall {
name: fc.name.clone(),
arguments: serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string()),
},
}
})
.collect();
if calls.is_empty() { None } else { Some(calls) }
}
/// Determine the appropriate base URL for the model.
/// "preview" models often require the v1beta endpoint, but newer promoted ones may be on v1.
fn get_base_url(&self, model: &str) -> String {
let base = &self.config.base_url;
// If the model requires v1beta but the base is currently v1
if (model.contains("preview") || model.contains("thinking") || model.contains("gemini-3")) && base.ends_with("/v1") {
return base.replace("/v1", "/v1beta");
}
// If the model is a standard model but the base is v1beta, we could downgrade it,
// but typically v1beta is a superset, so we just return the base as configured.
base.clone()
}
/// Default safety settings to avoid blocking responses.
fn get_safety_settings(&self, base_url: &str) -> Vec<GeminiSafetySetting> {
let mut categories = vec![
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
];
// Civic integrity is only available in v1beta
if base_url.contains("v1beta") {
categories.push("HARM_CATEGORY_CIVIC_INTEGRITY");
}
categories
.into_iter()
.map(|c| GeminiSafetySetting {
category: c.to_string(),
threshold: "BLOCK_NONE".to_string(),
})
.collect()
}
}
#[async_trait]
impl super::Provider for GeminiProvider {
fn name(&self) -> &str {
"gemini"
}
fn supports_model(&self, model: &str) -> bool {
model.starts_with("gemini-")
}
fn supports_multimodal(&self) -> bool {
true // Gemini supports vision
}
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
let mut model = request.model.clone();
// Normalize model name: If it's a known Gemini model version, use it;
// otherwise, if it starts with gemini- but is an unknown legacy version,
// fallback to the default model to avoid 400 errors.
// We now allow gemini-3+ as valid versions.
let is_known_version = model.starts_with("gemini-1.5") ||
model.starts_with("gemini-2.0") ||
model.starts_with("gemini-2.5") ||
model.starts_with("gemini-3");
if !is_known_version && model.starts_with("gemini-") {
tracing::info!("Mapping unknown Gemini model {} to default {}", model, self.config.default_model);
model = self.config.default_model.clone();
}
let tools = Self::convert_tools(&request);
let tool_config = Self::convert_tool_config(&request);
let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?;
if contents.is_empty() && system_instruction.is_none() {
return Err(AppError::ProviderError("No valid messages to send".to_string()));
}
let base_url = self.get_base_url(&model);
// Sanitize stop sequences: Gemini rejects empty strings
let stop_sequences = request.stop.map(|s| {
s.into_iter()
.filter(|seq| !seq.is_empty())
.collect::<Vec<_>>()
}).filter(|s| !s.is_empty());
let generation_config = Some(GeminiGenerationConfig {
temperature: request.temperature,
top_p: request.top_p,
top_k: request.top_k,
max_output_tokens: request.max_tokens.map(|t| t.min(65536)),
stop_sequences,
candidate_count: request.n,
});
let gemini_request = GeminiRequest {
contents,
system_instruction,
generation_config,
tools,
tool_config,
safety_settings: Some(self.get_safety_settings(&base_url)),
};
let url = format!("{}/models/{}:generateContent", base_url, model);
tracing::info!("Calling Gemini API: {}", url);
let response = self
.client
.post(&url)
.header("x-goog-api-key", &self.api_key)
.json(&gemini_request)
.send()
.await
.map_err(|e| AppError::ProviderError(format!("HTTP request failed: {}", e)))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!(
"Gemini API error ({}): {}",
status, error_text
)));
}
let gemini_response: GeminiResponse = response
.json()
.await
.map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?;
let candidate = gemini_response.candidates.first();
// Extract text content
let content = candidate
.and_then(|c| c.content.parts.iter().find_map(|p| p.text.clone()))
.unwrap_or_default();
// Extract reasoning (Gemini 3 'thought' parts)
let reasoning_content = candidate
.and_then(|c| c.content.parts.iter().find_map(|p| p.thought.clone()));
// Extract function calls → OpenAI tool_calls
let tool_calls = candidate.and_then(|c| Self::extract_tool_calls(&c.content.parts));
let prompt_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.prompt_token_count)
.unwrap_or(0);
let completion_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.candidates_token_count)
.unwrap_or(0);
let total_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.total_token_count)
.unwrap_or(0);
let cache_read_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.cached_content_token_count)
.unwrap_or(0);
Ok(ProviderResponse {
content,
reasoning_content,
tool_calls,
prompt_tokens,
completion_tokens,
total_tokens,
cache_read_tokens,
cache_write_tokens: 0, // Gemini doesn't report cache writes separately
model,
})
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(
&self,
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
cache_read_tokens: u32,
cache_write_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
) -> f64 {
super::helpers::calculate_cost_with_registry(
model,
prompt_tokens,
completion_tokens,
cache_read_tokens,
cache_write_tokens,
registry,
&self.pricing,
0.075,
0.30,
)
}
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
let mut model = request.model.clone();
// Normalize model name: fallback to default if unknown Gemini model is requested
let is_known_version = model.starts_with("gemini-1.5") ||
model.starts_with("gemini-2.0") ||
model.starts_with("gemini-2.5") ||
model.starts_with("gemini-3");
if !is_known_version && model.starts_with("gemini-") {
tracing::info!("Mapping unknown Gemini model {} to default {}", model, self.config.default_model);
model = self.config.default_model.clone();
}
let tools = Self::convert_tools(&request);
let tool_config = Self::convert_tool_config(&request);
let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?;
if contents.is_empty() && system_instruction.is_none() {
return Err(AppError::ProviderError("No valid messages to send".to_string()));
}
let base_url = self.get_base_url(&model);
// Sanitize stop sequences: Gemini rejects empty strings
let stop_sequences = request.stop.map(|s| {
s.into_iter()
.filter(|seq| !seq.is_empty())
.collect::<Vec<_>>()
}).filter(|s| !s.is_empty());
let generation_config = Some(GeminiGenerationConfig {
temperature: request.temperature,
top_p: request.top_p,
top_k: request.top_k,
max_output_tokens: request.max_tokens.map(|t| t.min(65536)),
stop_sequences,
candidate_count: request.n,
});
let gemini_request = GeminiRequest {
contents,
system_instruction,
generation_config,
tools,
tool_config,
safety_settings: Some(self.get_safety_settings(&base_url)),
};
let url = format!(
"{}/models/{}:streamGenerateContent?alt=sse",
base_url, model,
);
tracing::info!("Calling Gemini Stream API: {}", url);
// Capture a clone of the request to probe for errors (Gemini 400s are common)
let probe_request = gemini_request.clone();
let probe_client = self.client.clone();
// Use non-streaming URL for probing to get a valid JSON error body
let probe_url = format!("{}/models/{}:generateContent", base_url, model);
let probe_api_key = self.api_key.clone();
// Create the EventSource first (it doesn't send until polled)
let es = reqwest_eventsource::EventSource::new(
self.client
.post(&url)
.header("x-goog-api-key", &self.api_key)
.header("Accept", "text/event-stream")
.json(&gemini_request),
).map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let stream = async_stream::try_stream! {
let mut es = es;
// Track tool call IDs by their part index to ensure stability during streaming.
// Gemini doesn't always include the thoughtSignature in every chunk for the same part.
let mut tool_call_ids: std::collections::HashMap<u32, String> = std::collections::HashMap::new();
let mut seen_tool_calls = false;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
let gemini_response: GeminiStreamResponse = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
tracing::info!("Received Gemini stream chunk (candidates: {}, has_usage: {}, finish_reason: {:?})",
gemini_response.candidates.len(),
gemini_response.usage_metadata.is_some(),
gemini_response.candidates.first().and_then(|c| c.finish_reason.as_deref())
);
// Extract usage from usageMetadata if present (reported on every/last chunk)
let stream_usage = gemini_response.usage_metadata.as_ref().map(|u| {
super::StreamUsage {
prompt_tokens: u.prompt_token_count,
completion_tokens: u.candidates_token_count,
total_tokens: u.total_token_count,
cache_read_tokens: u.cached_content_token_count,
cache_write_tokens: 0,
}
});
// Some streaming events may not contain candidates (e.g. promptFeedback).
// Only emit chunks when we have candidate content or tool calls.
if let Some(candidate) = gemini_response.candidates.first() {
if let Some(content_obj) = &candidate.content {
let content = content_obj
.parts
.iter()
.find_map(|p| p.text.clone())
.unwrap_or_default();
let reasoning_content = content_obj
.parts
.iter()
.find_map(|p| p.thought.clone());
// Extract tool calls with index and ID stability
let mut deltas = Vec::new();
for (p_idx, p) in content_obj.parts.iter().enumerate() {
if let Some(fc) = &p.function_call {
seen_tool_calls = true;
let tool_call_idx = p_idx as u32;
// Attempt to find a signature in sibling fields
let signature = p.thought_signature_camel.clone()
.or_else(|| p.thought_signature_snake.clone());
// Ensure the ID remains stable for this tool call index.
// If we found a real signature now, we update it; otherwise use the existing or new random ID.
let entry = tool_call_ids.entry(tool_call_idx);
let current_id = match entry {
std::collections::hash_map::Entry::Occupied(mut e) => {
if let Some(sig) = signature {
// If we previously had a 'call_' ID but now found a real signature, upgrade it.
if e.get().starts_with("call_") {
e.insert(sig);
}
}
e.get().clone()
}
std::collections::hash_map::Entry::Vacant(e) => {
let id = signature.unwrap_or_else(|| format!("call_{}", Uuid::new_v4().simple()));
e.insert(id.clone());
id
}
};
deltas.push(ToolCallDelta {
index: tool_call_idx,
id: Some(current_id),
call_type: Some("function".to_string()),
function: Some(FunctionCallDelta {
name: Some(fc.name.clone()),
arguments: Some(serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string())),
}),
});
}
}
let tool_calls = if deltas.is_empty() { None } else { Some(deltas) };
// Determine finish_reason
// STRATEGY: If we have tool calls in this chunk, OR if we have seen them
// previously in the stream, the finish_reason MUST be "tool_calls"
// if the provider signals a stop. This ensures the client executes tools.
let mut finish_reason = candidate.finish_reason.as_ref().map(|fr| {
match fr.as_str() {
"STOP" => "stop".to_string(),
_ => fr.to_lowercase(),
}
});
if seen_tool_calls && finish_reason.as_deref() == Some("stop") {
finish_reason = Some("tool_calls".to_string());
} else if tool_calls.is_some() && finish_reason.is_none() {
// Optional: Could signal tool_calls here too, but OpenAI often waits until EOF.
// For now we only override it at the actual stop signal.
}
// Avoid emitting completely empty chunks unless they carry usage.
if !content.is_empty() || reasoning_content.is_some() || tool_calls.is_some() || stream_usage.is_some() {
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
tool_calls,
model: model.clone(),
usage: stream_usage,
};
}
} else if stream_usage.is_some() {
// Usage-only update
yield ProviderStreamChunk {
content: String::new(),
reasoning_content: None,
finish_reason: None,
tool_calls: None,
model: model.clone(),
usage: stream_usage,
};
}
} else if stream_usage.is_some() {
// No candidates but usage present
yield ProviderStreamChunk {
content: String::new(),
reasoning_content: None,
finish_reason: None,
tool_calls: None,
model: model.clone(),
usage: stream_usage,
};
}
}
Ok(_) => continue,
Err(e) => {
// "Stream ended" is usually a normal EOF signal in reqwest-eventsource.
// We check the string representation to avoid returning it as an error.
if e.to_string().contains("Stream ended") {
break;
}
// On stream error, attempt to probe for the actual error body from the provider
let probe_resp = probe_client
.post(&probe_url)
.header("x-goog-api-key", &probe_api_key)
.json(&probe_request)
.send()
.await;
match probe_resp {
Ok(r) if !r.status().is_success() => {
let status = r.status();
let body = r.text().await.unwrap_or_default();
tracing::error!("Gemini Stream Error Probe ({}): {}", status, body);
Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, body)))?;
}
_ => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
}
}
};
Ok(Box::pin(stream))
}
}