Files
GopherGate/src/providers/gemini.rs
hobokenchicken d32386df3f
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled
fix(gemini): resolve 400 Bad Request by sanitizing thought_signature and improving tool name resolution
This commit fixes the Gemini API 'Invalid value at thought_signature' error by ensuring synthetic 'call_' IDs are not passed into the TYPE_BYTES field. It also adds a pre-pass to correctly resolve function names from tool call IDs for tool responses.
2026-03-06 14:59:04 +00:00

1038 lines
42 KiB
Rust

use anyhow::Result;
use async_trait::async_trait;
use futures::stream::{BoxStream, StreamExt};
use reqwest_eventsource::Event;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use uuid::Uuid;
use super::{ProviderResponse, ProviderStreamChunk};
use crate::{
config::AppConfig,
errors::AppError,
models::{ContentPart, FunctionCall, FunctionCallDelta, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest},
};
// ========== Gemini Request Structs ==========
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiRequest {
contents: Vec<GeminiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
system_instruction: Option<GeminiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
generation_config: Option<GeminiGenerationConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<GeminiTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_config: Option<GeminiToolConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
safety_settings: Option<Vec<GeminiSafetySetting>>,
}
#[derive(Debug, Clone, Serialize)]
struct GeminiSafetySetting {
category: String,
threshold: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct GeminiContent {
parts: Vec<GeminiPart>,
#[serde(skip_serializing_if = "Option::is_none")]
role: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiPart {
#[serde(skip_serializing_if = "Option::is_none")]
text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
inline_data: Option<GeminiInlineData>,
#[serde(skip_serializing_if = "Option::is_none")]
function_call: Option<GeminiFunctionCall>,
#[serde(skip_serializing_if = "Option::is_none")]
function_response: Option<GeminiFunctionResponse>,
#[serde(skip_serializing_if = "Option::is_none")]
thought: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename = "thought_signature")]
thought_signature_snake: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename = "thoughtSignature")]
thought_signature_camel: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct GeminiInlineData {
mime_type: String,
data: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiFunctionCall {
name: String,
args: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct GeminiFunctionResponse {
name: String,
response: Value,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiGenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_output_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
candidate_count: Option<u32>,
}
// ========== Gemini Tool Structs ==========
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiTool {
function_declarations: Vec<GeminiFunctionDeclaration>,
}
#[derive(Debug, Clone, Serialize)]
struct GeminiFunctionDeclaration {
name: String,
#[serde(skip_serializing_if = "Option::is_none")]
description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
parameters: Option<Value>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiToolConfig {
function_calling_config: GeminiFunctionCallingConfig,
}
#[derive(Debug, Clone, Serialize)]
struct GeminiFunctionCallingConfig {
mode: String,
#[serde(skip_serializing_if = "Option::is_none", rename = "allowedFunctionNames")]
allowed_function_names: Option<Vec<String>>,
}
// ========== Gemini Response Structs ==========
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiCandidate {
content: GeminiContent,
#[serde(default)]
#[allow(dead_code)]
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiUsageMetadata {
#[serde(default)]
prompt_token_count: u32,
#[serde(default)]
candidates_token_count: u32,
#[serde(default)]
total_token_count: u32,
#[serde(default)]
cached_content_token_count: u32,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiResponse {
candidates: Vec<GeminiCandidate>,
usage_metadata: Option<GeminiUsageMetadata>,
}
// Streaming responses from Gemini may include messages without `candidates` (e.g. promptFeedback).
// Use a more permissive struct for streaming to avoid aborting the SSE prematurely.
#[derive(Debug, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
struct GeminiStreamResponse {
#[serde(default)]
candidates: Vec<GeminiStreamCandidate>,
#[serde(default)]
usage_metadata: Option<GeminiUsageMetadata>,
}
#[derive(Debug, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
struct GeminiStreamCandidate {
#[serde(default)]
content: Option<GeminiContent>,
#[serde(default)]
finish_reason: Option<String>,
}
// ========== Provider Implementation ==========
pub struct GeminiProvider {
client: reqwest::Client,
config: crate::config::GeminiConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>,
}
impl GeminiProvider {
pub fn new(config: &crate::config::GeminiConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("gemini")?;
Self::new_with_key(config, app_config, api_key)
}
pub fn new_with_key(config: &crate::config::GeminiConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(300))
.pool_idle_timeout(std::time::Duration::from_secs(90))
.pool_max_idle_per_host(4)
.tcp_keepalive(std::time::Duration::from_secs(30))
.build()?;
Ok(Self {
client,
config: config.clone(),
api_key,
pricing: app_config.pricing.gemini.clone(),
})
}
/// Convert unified messages to Gemini content format.
/// Handles text, images, tool calls (assistant), and tool results.
/// Returns (contents, system_instruction)
async fn convert_messages(
messages: Vec<UnifiedMessage>,
) -> Result<(Vec<GeminiContent>, Option<GeminiContent>), AppError> {
let mut contents: Vec<GeminiContent> = Vec::new();
let mut system_parts = Vec::new();
// PRE-PASS: Build tool_id -> function_name mapping for tool responses
let mut tool_id_to_name = std::collections::HashMap::new();
for msg in &messages {
if let Some(tool_calls) = &msg.tool_calls {
for tc in tool_calls {
tool_id_to_name.insert(tc.id.clone(), tc.function.name.clone());
}
}
}
for msg in messages {
if msg.role == "system" {
for part in msg.content {
if let ContentPart::Text { text } = part {
if !text.trim().is_empty() {
system_parts.push(GeminiPart {
text: Some(text),
inline_data: None,
function_call: None,
function_response: None,
thought: None,
thought_signature_snake: None,
thought_signature_camel: None,
});
}
}
}
continue;
}
let role = match msg.role.as_str() {
"assistant" => "model".to_string(),
"tool" => "user".to_string(), // Tool results are user-side in Gemini
_ => "user".to_string(),
};
let mut parts = Vec::new();
// Handle tool results (role "tool")
if msg.role == "tool" {
let text_content = msg
.content
.first()
.map(|p| match p {
ContentPart::Text { text } => text.clone(),
ContentPart::Image(_) => "[Image]".to_string(),
})
.unwrap_or_default();
// RESOLVE: Use msg.name if present, otherwise look up by tool_call_id
let name = msg.name.clone()
.or_else(|| {
msg.tool_call_id.as_ref()
.and_then(|id| tool_id_to_name.get(id).cloned())
})
.or_else(|| msg.tool_call_id.clone())
.unwrap_or_else(|| "unknown_function".to_string());
// Gemini API requires 'response' to be a JSON object (google.protobuf.Struct).
// If it is an array or primitive, wrap it in an object.
let mut response_value = serde_json::from_str::<Value>(&text_content)
.unwrap_or_else(|_| serde_json::json!({ "result": text_content }));
if !response_value.is_object() {
response_value = serde_json::json!({ "result": response_value });
}
parts.push(GeminiPart {
text: None,
inline_data: None,
function_call: None,
function_response: Some(GeminiFunctionResponse {
name,
response: response_value,
}),
thought: None,
thought_signature_snake: None,
thought_signature_camel: None,
});
} else if msg.role == "assistant" {
// Assistant messages: handle text, thought (reasoning), and tool_calls
for p in &msg.content {
if let ContentPart::Text { text } = p {
if !text.trim().is_empty() {
parts.push(GeminiPart {
text: Some(text.clone()),
inline_data: None,
function_call: None,
function_response: None,
thought: None,
thought_signature_snake: None,
thought_signature_camel: None,
});
}
}
}
// If reasoning_content is present, include it as a 'thought' part
if let Some(reasoning) = &msg.reasoning_content {
if !reasoning.trim().is_empty() {
parts.push(GeminiPart {
text: None,
inline_data: None,
function_call: None,
function_response: None,
thought: Some(reasoning.clone()),
thought_signature_snake: None,
thought_signature_camel: None,
});
}
}
if let Some(tool_calls) = &msg.tool_calls {
for tc in tool_calls {
let args = serde_json::from_str::<Value>(&tc.function.arguments)
.unwrap_or_else(|_| serde_json::json!({}));
// RESTORE: Only use tc.id as thought_signature if it's NOT a synthetic ID.
// Synthetic IDs (starting with 'call_') cause 400 errors as they are not valid Base64 for the TYPE_BYTES field.
let thought_signature = if tc.id.starts_with("call_") {
None
} else {
Some(tc.id.clone())
};
parts.push(GeminiPart {
text: None,
inline_data: None,
function_call: Some(GeminiFunctionCall {
name: tc.function.name.clone(),
args,
}),
function_response: None,
thought: None,
thought_signature_snake: thought_signature.clone(),
thought_signature_camel: thought_signature,
});
}
}
} else {
// Regular text/image messages (mostly user)
for part in msg.content {
match part {
ContentPart::Text { text } => {
if !text.trim().is_empty() {
parts.push(GeminiPart {
text: Some(text),
inline_data: None,
function_call: None,
function_response: None,
thought: None,
thought_signature_snake: None,
thought_signature_camel: None,
});
}
}
ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input
.to_base64()
.await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
parts.push(GeminiPart {
text: None,
inline_data: Some(GeminiInlineData {
mime_type,
data: base64_data,
}),
function_call: None,
function_response: None,
thought: None,
thought_signature_snake: None,
thought_signature_camel: None,
});
}
}
}
}
if parts.is_empty() {
continue;
}
// STRATEGY: Strictly enforce alternating roles.
if let Some(last_content) = contents.last_mut() {
if last_content.role.as_ref() == Some(&role) {
last_content.parts.extend(parts);
continue;
}
}
contents.push(GeminiContent {
parts,
role: Some(role),
});
}
// Gemini requires the first message to be from "user".
if let Some(first) = contents.first() {
if first.role.as_deref() == Some("model") {
contents.insert(0, GeminiContent {
role: Some("user".to_string()),
parts: vec![GeminiPart {
text: Some("Continue conversation.".to_string()),
inline_data: None,
function_call: None,
function_response: None,
thought: None,
thought_signature_snake: None,
thought_signature_camel: None,
}],
});
}
}
// Final check: ensure we don't have empty contents after filtering.
if contents.is_empty() && system_parts.is_empty() {
return Err(AppError::ProviderError("No valid content parts after filtering".to_string()));
}
let system_instruction = if !system_parts.is_empty() {
Some(GeminiContent {
parts: system_parts,
role: None,
})
} else {
None
};
Ok((contents, system_instruction))
}
/// Convert OpenAI tools to Gemini function declarations.
fn convert_tools(request: &UnifiedRequest) -> Option<Vec<GeminiTool>> {
request.tools.as_ref().map(|tools| {
let declarations: Vec<GeminiFunctionDeclaration> = tools
.iter()
.map(|t| {
let mut parameters = t.function.parameters.clone().unwrap_or(serde_json::json!({
"type": "object",
"properties": {}
}));
Self::sanitize_schema(&mut parameters);
GeminiFunctionDeclaration {
name: t.function.name.clone(),
description: t.function.description.clone(),
parameters: Some(parameters),
}
})
.collect();
vec![GeminiTool {
function_declarations: declarations,
}]
})
}
/// Recursively remove unsupported JSON Schema fields that Gemini's API rejects.
fn sanitize_schema(value: &mut Value) {
if let Value::Object(map) = value {
// Remove unsupported fields at this level
map.remove("$schema");
map.remove("additionalProperties");
map.remove("exclusiveMaximum");
map.remove("exclusiveMinimum");
// Recursively sanitize all object properties
if let Some(properties) = map.get_mut("properties") {
if let Value::Object(props_map) = properties {
for prop_value in props_map.values_mut() {
Self::sanitize_schema(prop_value);
}
}
}
// Recursively sanitize array items
if let Some(items) = map.get_mut("items") {
Self::sanitize_schema(items);
}
// Gemini 1.5/2.0+ supports anyOf in some contexts, but it's often
// the source of additionalProperties errors when nested.
if let Some(any_of) = map.get_mut("anyOf") {
if let Value::Array(arr) = any_of {
for item in arr {
Self::sanitize_schema(item);
}
}
}
if let Some(one_of) = map.get_mut("oneOf") {
if let Value::Array(arr) = one_of {
for item in arr {
Self::sanitize_schema(item);
}
}
}
if let Some(all_of) = map.get_mut("allOf") {
if let Value::Array(arr) = all_of {
for item in arr {
Self::sanitize_schema(item);
}
}
}
}
}
/// Convert OpenAI tool_choice to Gemini tool_config.
fn convert_tool_config(request: &UnifiedRequest) -> Option<GeminiToolConfig> {
request.tool_choice.as_ref().map(|tc| {
let (mode, allowed_names) = match tc {
crate::models::ToolChoice::Mode(mode) => {
let gemini_mode = match mode.as_str() {
"auto" => "AUTO",
"none" => "NONE",
"required" => "ANY",
_ => "AUTO",
};
(gemini_mode.to_string(), None)
}
crate::models::ToolChoice::Specific(specific) => {
("ANY".to_string(), Some(vec![specific.function.name.clone()]))
}
};
GeminiToolConfig {
function_calling_config: GeminiFunctionCallingConfig {
mode,
allowed_function_names: allowed_names,
},
}
})
}
/// Extract tool calls from Gemini response parts into OpenAI-format ToolCalls.
fn extract_tool_calls(parts: &[GeminiPart]) -> Option<Vec<ToolCall>> {
let calls: Vec<ToolCall> = parts
.iter()
.filter(|p| p.function_call.is_some())
.map(|p| {
let fc = p.function_call.as_ref().unwrap();
// CAPTURE: Try extracting thought_signature from sibling fields
let id = p.thought_signature_camel.clone()
.or_else(|| p.thought_signature_snake.clone())
.unwrap_or_else(|| format!("call_{}", Uuid::new_v4().simple()));
ToolCall {
id,
call_type: "function".to_string(),
function: FunctionCall {
name: fc.name.clone(),
arguments: serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string()),
},
}
})
.collect();
if calls.is_empty() { None } else { Some(calls) }
}
/// Determine the appropriate base URL for the model.
/// "preview" models often require the v1beta endpoint, but newer promoted ones may be on v1.
fn get_base_url(&self, model: &str) -> String {
let base = &self.config.base_url;
// If the model requires v1beta but the base is currently v1
if (model.contains("preview") || model.contains("thinking") || model.contains("gemini-3")) && base.ends_with("/v1") {
return base.replace("/v1", "/v1beta");
}
// If the model is a standard model but the base is v1beta, we could downgrade it,
// but typically v1beta is a superset, so we just return the base as configured.
base.clone()
}
/// Default safety settings to avoid blocking responses.
fn get_safety_settings(&self, base_url: &str) -> Vec<GeminiSafetySetting> {
let mut categories = vec![
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
];
// Civic integrity is only available in v1beta
if base_url.contains("v1beta") {
categories.push("HARM_CATEGORY_CIVIC_INTEGRITY");
}
categories
.into_iter()
.map(|c| GeminiSafetySetting {
category: c.to_string(),
threshold: "BLOCK_NONE".to_string(),
})
.collect()
}
}
#[async_trait]
impl super::Provider for GeminiProvider {
fn name(&self) -> &str {
"gemini"
}
fn supports_model(&self, model: &str) -> bool {
model.starts_with("gemini-")
}
fn supports_multimodal(&self) -> bool {
true // Gemini supports vision
}
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
let mut model = request.model.clone();
// Normalize model name: If it's a known Gemini model version, use it;
// otherwise, if it starts with gemini- but is an unknown legacy version,
// fallback to the default model to avoid 400 errors.
// We now allow gemini-3+ as valid versions.
let is_known_version = model.starts_with("gemini-1.5") ||
model.starts_with("gemini-2.0") ||
model.starts_with("gemini-2.5") ||
model.starts_with("gemini-3");
if !is_known_version && model.starts_with("gemini-") {
tracing::info!("Mapping unknown Gemini model {} to default {}", model, self.config.default_model);
model = self.config.default_model.clone();
}
let tools = Self::convert_tools(&request);
let tool_config = Self::convert_tool_config(&request);
let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?;
if contents.is_empty() && system_instruction.is_none() {
return Err(AppError::ProviderError("No valid messages to send".to_string()));
}
let base_url = self.get_base_url(&model);
// Sanitize stop sequences: Gemini rejects empty strings
let stop_sequences = request.stop.map(|s| {
s.into_iter()
.filter(|seq| !seq.is_empty())
.collect::<Vec<_>>()
}).filter(|s| !s.is_empty());
let generation_config = Some(GeminiGenerationConfig {
temperature: request.temperature,
top_p: request.top_p,
top_k: request.top_k,
max_output_tokens: request.max_tokens.map(|t| t.min(65536)),
stop_sequences,
candidate_count: request.n,
});
let gemini_request = GeminiRequest {
contents,
system_instruction,
generation_config,
tools,
tool_config,
safety_settings: Some(self.get_safety_settings(&base_url)),
};
let url = format!("{}/models/{}:generateContent", base_url, model);
tracing::info!("Calling Gemini API: {}", url);
let response = self
.client
.post(&url)
.header("x-goog-api-key", &self.api_key)
.json(&gemini_request)
.send()
.await
.map_err(|e| AppError::ProviderError(format!("HTTP request failed: {}", e)))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!(
"Gemini API error ({}): {}",
status, error_text
)));
}
let gemini_response: GeminiResponse = response
.json()
.await
.map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?;
let candidate = gemini_response.candidates.first();
// Extract text content
let content = candidate
.and_then(|c| c.content.parts.iter().find_map(|p| p.text.clone()))
.unwrap_or_default();
// Extract reasoning (Gemini 3 'thought' parts)
let reasoning_content = candidate
.and_then(|c| c.content.parts.iter().find_map(|p| p.thought.clone()));
// Extract function calls → OpenAI tool_calls
let tool_calls = candidate.and_then(|c| Self::extract_tool_calls(&c.content.parts));
let prompt_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.prompt_token_count)
.unwrap_or(0);
let completion_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.candidates_token_count)
.unwrap_or(0);
let total_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.total_token_count)
.unwrap_or(0);
let cache_read_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.cached_content_token_count)
.unwrap_or(0);
Ok(ProviderResponse {
content,
reasoning_content,
tool_calls,
prompt_tokens,
completion_tokens,
total_tokens,
cache_read_tokens,
cache_write_tokens: 0, // Gemini doesn't report cache writes separately
model,
})
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(
&self,
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
cache_read_tokens: u32,
cache_write_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
) -> f64 {
super::helpers::calculate_cost_with_registry(
model,
prompt_tokens,
completion_tokens,
cache_read_tokens,
cache_write_tokens,
registry,
&self.pricing,
0.075,
0.30,
)
}
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
let mut model = request.model.clone();
// Normalize model name: fallback to default if unknown Gemini model is requested
let is_known_version = model.starts_with("gemini-1.5") ||
model.starts_with("gemini-2.0") ||
model.starts_with("gemini-2.5") ||
model.starts_with("gemini-3");
if !is_known_version && model.starts_with("gemini-") {
tracing::info!("Mapping unknown Gemini model {} to default {}", model, self.config.default_model);
model = self.config.default_model.clone();
}
let tools = Self::convert_tools(&request);
let tool_config = Self::convert_tool_config(&request);
let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?;
if contents.is_empty() && system_instruction.is_none() {
return Err(AppError::ProviderError("No valid messages to send".to_string()));
}
let base_url = self.get_base_url(&model);
// Sanitize stop sequences: Gemini rejects empty strings
let stop_sequences = request.stop.map(|s| {
s.into_iter()
.filter(|seq| !seq.is_empty())
.collect::<Vec<_>>()
}).filter(|s| !s.is_empty());
let generation_config = Some(GeminiGenerationConfig {
temperature: request.temperature,
top_p: request.top_p,
top_k: request.top_k,
max_output_tokens: request.max_tokens.map(|t| t.min(65536)),
stop_sequences,
candidate_count: request.n,
});
let gemini_request = GeminiRequest {
contents,
system_instruction,
generation_config,
tools,
tool_config,
safety_settings: Some(self.get_safety_settings(&base_url)),
};
let url = format!(
"{}/models/{}:streamGenerateContent?alt=sse",
base_url, model,
);
tracing::info!("Calling Gemini Stream API: {}", url);
// Capture a clone of the request to probe for errors (Gemini 400s are common)
let probe_request = gemini_request.clone();
let probe_client = self.client.clone();
// Use non-streaming URL for probing to get a valid JSON error body
let probe_url = format!("{}/models/{}:generateContent", base_url, model);
let probe_api_key = self.api_key.clone();
// Create the EventSource first (it doesn't send until polled)
let es = reqwest_eventsource::EventSource::new(
self.client
.post(&url)
.header("x-goog-api-key", &self.api_key)
.header("Accept", "text/event-stream")
.json(&gemini_request),
).map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let stream = async_stream::try_stream! {
let mut es = es;
// Track tool call IDs by their part index to ensure stability during streaming.
// Gemini doesn't always include the thoughtSignature in every chunk for the same part.
let mut tool_call_ids: std::collections::HashMap<u32, String> = std::collections::HashMap::new();
let mut seen_tool_calls = false;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
let gemini_response: GeminiStreamResponse = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
tracing::info!("Received Gemini stream chunk (candidates: {}, has_usage: {}, finish_reason: {:?})",
gemini_response.candidates.len(),
gemini_response.usage_metadata.is_some(),
gemini_response.candidates.first().and_then(|c| c.finish_reason.as_deref())
);
// Extract usage from usageMetadata if present (reported on every/last chunk)
let stream_usage = gemini_response.usage_metadata.as_ref().map(|u| {
super::StreamUsage {
prompt_tokens: u.prompt_token_count,
completion_tokens: u.candidates_token_count,
total_tokens: u.total_token_count,
cache_read_tokens: u.cached_content_token_count,
cache_write_tokens: 0,
}
});
// Some streaming events may not contain candidates (e.g. promptFeedback).
// Only emit chunks when we have candidate content or tool calls.
if let Some(candidate) = gemini_response.candidates.first() {
if let Some(content_obj) = &candidate.content {
let content = content_obj
.parts
.iter()
.find_map(|p| p.text.clone())
.unwrap_or_default();
let reasoning_content = content_obj
.parts
.iter()
.find_map(|p| p.thought.clone());
// Extract tool calls with index and ID stability
let mut deltas = Vec::new();
for (p_idx, p) in content_obj.parts.iter().enumerate() {
if let Some(fc) = &p.function_call {
seen_tool_calls = true;
let tool_call_idx = p_idx as u32;
// Attempt to find a signature in sibling fields
let signature = p.thought_signature_camel.clone()
.or_else(|| p.thought_signature_snake.clone());
// Ensure the ID remains stable for this tool call index.
// If we found a real signature now, we update it; otherwise use the existing or new random ID.
let entry = tool_call_ids.entry(tool_call_idx);
let current_id = match entry {
std::collections::hash_map::Entry::Occupied(mut e) => {
if let Some(sig) = signature {
// If we previously had a 'call_' ID but now found a real signature, upgrade it.
if e.get().starts_with("call_") {
e.insert(sig);
}
}
e.get().clone()
}
std::collections::hash_map::Entry::Vacant(e) => {
let id = signature.unwrap_or_else(|| format!("call_{}", Uuid::new_v4().simple()));
e.insert(id.clone());
id
}
};
deltas.push(ToolCallDelta {
index: tool_call_idx,
id: Some(current_id),
call_type: Some("function".to_string()),
function: Some(FunctionCallDelta {
name: Some(fc.name.clone()),
arguments: Some(serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string())),
}),
});
}
}
let tool_calls = if deltas.is_empty() { None } else { Some(deltas) };
// Determine finish_reason
// STRATEGY: If we have tool calls in this chunk, OR if we have seen them
// previously in the stream, the finish_reason MUST be "tool_calls"
// if the provider signals a stop. This ensures the client executes tools.
let mut finish_reason = candidate.finish_reason.as_ref().map(|fr| {
match fr.as_str() {
"STOP" => "stop".to_string(),
_ => fr.to_lowercase(),
}
});
if seen_tool_calls && finish_reason.as_deref() == Some("stop") {
finish_reason = Some("tool_calls".to_string());
} else if tool_calls.is_some() && finish_reason.is_none() {
// Optional: Could signal tool_calls here too, but OpenAI often waits until EOF.
// For now we only override it at the actual stop signal.
}
// Avoid emitting completely empty chunks unless they carry usage.
if !content.is_empty() || reasoning_content.is_some() || tool_calls.is_some() || stream_usage.is_some() {
yield ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
tool_calls,
model: model.clone(),
usage: stream_usage,
};
}
} else if stream_usage.is_some() {
// Usage-only update
yield ProviderStreamChunk {
content: String::new(),
reasoning_content: None,
finish_reason: None,
tool_calls: None,
model: model.clone(),
usage: stream_usage,
};
}
} else if stream_usage.is_some() {
// No candidates but usage present
yield ProviderStreamChunk {
content: String::new(),
reasoning_content: None,
finish_reason: None,
tool_calls: None,
model: model.clone(),
usage: stream_usage,
};
}
}
Ok(_) => continue,
Err(e) => {
// "Stream ended" is usually a normal EOF signal in reqwest-eventsource.
// We check the string representation to avoid returning it as an error.
if e.to_string().contains("Stream ended") {
break;
}
// On stream error, attempt to probe for the actual error body from the provider
let probe_resp = probe_client
.post(&probe_url)
.header("x-goog-api-key", &probe_api_key)
.json(&probe_request)
.send()
.await;
match probe_resp {
Ok(r) if !r.status().is_success() => {
let status = r.status();
let body = r.text().await.unwrap_or_default();
tracing::error!("Gemini Stream Error Probe ({}): {}", status, body);
Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, body)))?;
}
_ => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
}
}
};
Ok(Box::pin(stream))
}
}