Gemini often sends tool calls in one chunk and then 'STOP' in a final chunk. If we pass the raw 'stop' at the end, clients stop and ignore the previously received tool calls. We now track if any tools were seen and override the final 'stop' to 'tool_calls'.
1018 lines
41 KiB
Rust
1018 lines
41 KiB
Rust
use anyhow::Result;
|
|
use async_trait::async_trait;
|
|
use futures::stream::{BoxStream, StreamExt};
|
|
use reqwest_eventsource::Event;
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::Value;
|
|
use uuid::Uuid;
|
|
|
|
use super::{ProviderResponse, ProviderStreamChunk};
|
|
use crate::{
|
|
config::AppConfig,
|
|
errors::AppError,
|
|
models::{ContentPart, FunctionCall, FunctionCallDelta, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest},
|
|
};
|
|
|
|
// ========== Gemini Request Structs ==========
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiRequest {
|
|
contents: Vec<GeminiContent>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
system_instruction: Option<GeminiContent>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
generation_config: Option<GeminiGenerationConfig>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tools: Option<Vec<GeminiTool>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tool_config: Option<GeminiToolConfig>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
safety_settings: Option<Vec<GeminiSafetySetting>>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
struct GeminiSafetySetting {
|
|
category: String,
|
|
threshold: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct GeminiContent {
|
|
parts: Vec<GeminiPart>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
role: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiPart {
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
text: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
inline_data: Option<GeminiInlineData>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
function_call: Option<GeminiFunctionCall>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
function_response: Option<GeminiFunctionResponse>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
thought: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none", rename = "thought_signature")]
|
|
thought_signature_snake: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none", rename = "thoughtSignature")]
|
|
thought_signature_camel: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct GeminiInlineData {
|
|
mime_type: String,
|
|
data: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiFunctionCall {
|
|
name: String,
|
|
args: Value,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct GeminiFunctionResponse {
|
|
name: String,
|
|
response: Value,
|
|
}
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiGenerationConfig {
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
temperature: Option<f64>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
top_p: Option<f64>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
top_k: Option<u32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
max_output_tokens: Option<u32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
stop_sequences: Option<Vec<String>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
candidate_count: Option<u32>,
|
|
}
|
|
|
|
// ========== Gemini Tool Structs ==========
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiTool {
|
|
function_declarations: Vec<GeminiFunctionDeclaration>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
struct GeminiFunctionDeclaration {
|
|
name: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
description: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
parameters: Option<Value>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiToolConfig {
|
|
function_calling_config: GeminiFunctionCallingConfig,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
struct GeminiFunctionCallingConfig {
|
|
mode: String,
|
|
#[serde(skip_serializing_if = "Option::is_none", rename = "allowedFunctionNames")]
|
|
allowed_function_names: Option<Vec<String>>,
|
|
}
|
|
|
|
// ========== Gemini Response Structs ==========
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiCandidate {
|
|
content: GeminiContent,
|
|
#[serde(default)]
|
|
#[allow(dead_code)]
|
|
finish_reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiUsageMetadata {
|
|
#[serde(default)]
|
|
prompt_token_count: u32,
|
|
#[serde(default)]
|
|
candidates_token_count: u32,
|
|
#[serde(default)]
|
|
total_token_count: u32,
|
|
#[serde(default)]
|
|
cached_content_token_count: u32,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiResponse {
|
|
candidates: Vec<GeminiCandidate>,
|
|
usage_metadata: Option<GeminiUsageMetadata>,
|
|
}
|
|
|
|
// Streaming responses from Gemini may include messages without `candidates` (e.g. promptFeedback).
|
|
// Use a more permissive struct for streaming to avoid aborting the SSE prematurely.
|
|
#[derive(Debug, Deserialize, Default)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiStreamResponse {
|
|
#[serde(default)]
|
|
candidates: Vec<GeminiStreamCandidate>,
|
|
#[serde(default)]
|
|
usage_metadata: Option<GeminiUsageMetadata>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Default)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiStreamCandidate {
|
|
#[serde(default)]
|
|
content: Option<GeminiContent>,
|
|
#[serde(default)]
|
|
finish_reason: Option<String>,
|
|
}
|
|
|
|
// ========== Provider Implementation ==========
|
|
|
|
pub struct GeminiProvider {
|
|
client: reqwest::Client,
|
|
config: crate::config::GeminiConfig,
|
|
api_key: String,
|
|
pricing: Vec<crate::config::ModelPricing>,
|
|
}
|
|
|
|
impl GeminiProvider {
|
|
pub fn new(config: &crate::config::GeminiConfig, app_config: &AppConfig) -> Result<Self> {
|
|
let api_key = app_config.get_api_key("gemini")?;
|
|
Self::new_with_key(config, app_config, api_key)
|
|
}
|
|
|
|
pub fn new_with_key(config: &crate::config::GeminiConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
|
let client = reqwest::Client::builder()
|
|
.connect_timeout(std::time::Duration::from_secs(5))
|
|
.timeout(std::time::Duration::from_secs(300))
|
|
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
|
.pool_max_idle_per_host(4)
|
|
.tcp_keepalive(std::time::Duration::from_secs(30))
|
|
.build()?;
|
|
|
|
Ok(Self {
|
|
client,
|
|
config: config.clone(),
|
|
api_key,
|
|
pricing: app_config.pricing.gemini.clone(),
|
|
})
|
|
}
|
|
|
|
/// Convert unified messages to Gemini content format.
|
|
/// Handles text, images, tool calls (assistant), and tool results.
|
|
/// Returns (contents, system_instruction)
|
|
async fn convert_messages(
|
|
messages: Vec<UnifiedMessage>,
|
|
) -> Result<(Vec<GeminiContent>, Option<GeminiContent>), AppError> {
|
|
let mut contents: Vec<GeminiContent> = Vec::new();
|
|
let mut system_parts = Vec::new();
|
|
|
|
for msg in messages {
|
|
if msg.role == "system" {
|
|
for part in msg.content {
|
|
if let ContentPart::Text { text } = part {
|
|
if !text.trim().is_empty() {
|
|
system_parts.push(GeminiPart {
|
|
text: Some(text),
|
|
inline_data: None,
|
|
function_call: None,
|
|
function_response: None,
|
|
thought: None,
|
|
thought_signature_snake: None,
|
|
thought_signature_camel: None,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
continue;
|
|
}
|
|
|
|
let role = match msg.role.as_str() {
|
|
"assistant" => "model".to_string(),
|
|
"tool" => "user".to_string(), // Tool results are user-side in Gemini
|
|
_ => "user".to_string(),
|
|
};
|
|
|
|
let mut parts = Vec::new();
|
|
|
|
// Handle tool results (role "tool")
|
|
if msg.role == "tool" {
|
|
let text_content = msg
|
|
.content
|
|
.first()
|
|
.map(|p| match p {
|
|
ContentPart::Text { text } => text.clone(),
|
|
ContentPart::Image(_) => "[Image]".to_string(),
|
|
})
|
|
.unwrap_or_default();
|
|
|
|
let name = msg.name.clone().or_else(|| msg.tool_call_id.clone()).unwrap_or_else(|| "unknown_function".to_string());
|
|
|
|
// Gemini API requires 'response' to be a JSON object (google.protobuf.Struct).
|
|
// If it is an array or primitive, wrap it in an object.
|
|
let mut response_value = serde_json::from_str::<Value>(&text_content)
|
|
.unwrap_or_else(|_| serde_json::json!({ "result": text_content }));
|
|
|
|
if !response_value.is_object() {
|
|
response_value = serde_json::json!({ "result": response_value });
|
|
}
|
|
|
|
parts.push(GeminiPart {
|
|
text: None,
|
|
inline_data: None,
|
|
function_call: None,
|
|
function_response: Some(GeminiFunctionResponse {
|
|
name,
|
|
response: response_value,
|
|
}),
|
|
thought: None,
|
|
thought_signature_snake: None,
|
|
thought_signature_camel: None,
|
|
});
|
|
} else if msg.role == "assistant" {
|
|
// Assistant messages: handle text, thought (reasoning), and tool_calls
|
|
for p in &msg.content {
|
|
if let ContentPart::Text { text } = p {
|
|
if !text.trim().is_empty() {
|
|
parts.push(GeminiPart {
|
|
text: Some(text.clone()),
|
|
inline_data: None,
|
|
function_call: None,
|
|
function_response: None,
|
|
thought: None,
|
|
thought_signature_snake: None,
|
|
thought_signature_camel: None,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
// If reasoning_content is present, include it as a 'thought' part
|
|
if let Some(reasoning) = &msg.reasoning_content {
|
|
if !reasoning.trim().is_empty() {
|
|
parts.push(GeminiPart {
|
|
text: None,
|
|
inline_data: None,
|
|
function_call: None,
|
|
function_response: None,
|
|
thought: Some(reasoning.clone()),
|
|
thought_signature_snake: None,
|
|
thought_signature_camel: None,
|
|
});
|
|
}
|
|
}
|
|
|
|
if let Some(tool_calls) = &msg.tool_calls {
|
|
for tc in tool_calls {
|
|
let args = serde_json::from_str::<Value>(&tc.function.arguments)
|
|
.unwrap_or_else(|_| serde_json::json!({}));
|
|
|
|
// RESTORE: Use tc.id as thought_signature.
|
|
// Gemini 3 models require this field for any function call in the history.
|
|
// We include it regardless of format to ensure the model has context.
|
|
let thought_signature = Some(tc.id.clone());
|
|
|
|
parts.push(GeminiPart {
|
|
text: None,
|
|
inline_data: None,
|
|
function_call: Some(GeminiFunctionCall {
|
|
name: tc.function.name.clone(),
|
|
args,
|
|
}),
|
|
function_response: None,
|
|
thought: None,
|
|
thought_signature_snake: thought_signature.clone(),
|
|
thought_signature_camel: thought_signature,
|
|
});
|
|
}
|
|
}
|
|
} else {
|
|
// Regular text/image messages (mostly user)
|
|
for part in msg.content {
|
|
match part {
|
|
ContentPart::Text { text } => {
|
|
if !text.trim().is_empty() {
|
|
parts.push(GeminiPart {
|
|
text: Some(text),
|
|
inline_data: None,
|
|
function_call: None,
|
|
function_response: None,
|
|
thought: None,
|
|
thought_signature_snake: None,
|
|
thought_signature_camel: None,
|
|
});
|
|
}
|
|
}
|
|
ContentPart::Image(image_input) => {
|
|
let (base64_data, mime_type) = image_input
|
|
.to_base64()
|
|
.await
|
|
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
|
|
|
parts.push(GeminiPart {
|
|
text: None,
|
|
inline_data: Some(GeminiInlineData {
|
|
mime_type,
|
|
data: base64_data,
|
|
}),
|
|
function_call: None,
|
|
function_response: None,
|
|
thought: None,
|
|
thought_signature_snake: None,
|
|
thought_signature_camel: None,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if parts.is_empty() {
|
|
continue;
|
|
}
|
|
|
|
// STRATEGY: Strictly enforce alternating roles.
|
|
if let Some(last_content) = contents.last_mut() {
|
|
if last_content.role.as_ref() == Some(&role) {
|
|
last_content.parts.extend(parts);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
contents.push(GeminiContent {
|
|
parts,
|
|
role: Some(role),
|
|
});
|
|
}
|
|
|
|
// Gemini requires the first message to be from "user".
|
|
if let Some(first) = contents.first() {
|
|
if first.role.as_deref() == Some("model") {
|
|
contents.insert(0, GeminiContent {
|
|
role: Some("user".to_string()),
|
|
parts: vec![GeminiPart {
|
|
text: Some("Continue conversation.".to_string()),
|
|
inline_data: None,
|
|
function_call: None,
|
|
function_response: None,
|
|
thought: None,
|
|
thought_signature_snake: None,
|
|
thought_signature_camel: None,
|
|
}],
|
|
});
|
|
}
|
|
}
|
|
|
|
// Final check: ensure we don't have empty contents after filtering.
|
|
if contents.is_empty() && system_parts.is_empty() {
|
|
return Err(AppError::ProviderError("No valid content parts after filtering".to_string()));
|
|
}
|
|
|
|
let system_instruction = if !system_parts.is_empty() {
|
|
Some(GeminiContent {
|
|
parts: system_parts,
|
|
role: None,
|
|
})
|
|
} else {
|
|
None
|
|
};
|
|
|
|
Ok((contents, system_instruction))
|
|
}
|
|
|
|
/// Convert OpenAI tools to Gemini function declarations.
|
|
fn convert_tools(request: &UnifiedRequest) -> Option<Vec<GeminiTool>> {
|
|
request.tools.as_ref().map(|tools| {
|
|
let declarations: Vec<GeminiFunctionDeclaration> = tools
|
|
.iter()
|
|
.map(|t| {
|
|
let mut parameters = t.function.parameters.clone().unwrap_or(serde_json::json!({
|
|
"type": "object",
|
|
"properties": {}
|
|
}));
|
|
Self::sanitize_schema(&mut parameters);
|
|
|
|
GeminiFunctionDeclaration {
|
|
name: t.function.name.clone(),
|
|
description: t.function.description.clone(),
|
|
parameters: Some(parameters),
|
|
}
|
|
})
|
|
.collect();
|
|
vec![GeminiTool {
|
|
function_declarations: declarations,
|
|
}]
|
|
})
|
|
}
|
|
|
|
/// Recursively remove unsupported JSON Schema fields that Gemini's API rejects.
|
|
fn sanitize_schema(value: &mut Value) {
|
|
if let Value::Object(map) = value {
|
|
// Remove unsupported fields at this level
|
|
map.remove("$schema");
|
|
map.remove("additionalProperties");
|
|
map.remove("exclusiveMaximum");
|
|
map.remove("exclusiveMinimum");
|
|
|
|
// Recursively sanitize all object properties
|
|
if let Some(properties) = map.get_mut("properties") {
|
|
if let Value::Object(props_map) = properties {
|
|
for prop_value in props_map.values_mut() {
|
|
Self::sanitize_schema(prop_value);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Recursively sanitize array items
|
|
if let Some(items) = map.get_mut("items") {
|
|
Self::sanitize_schema(items);
|
|
}
|
|
|
|
// Gemini 1.5/2.0+ supports anyOf in some contexts, but it's often
|
|
// the source of additionalProperties errors when nested.
|
|
if let Some(any_of) = map.get_mut("anyOf") {
|
|
if let Value::Array(arr) = any_of {
|
|
for item in arr {
|
|
Self::sanitize_schema(item);
|
|
}
|
|
}
|
|
}
|
|
if let Some(one_of) = map.get_mut("oneOf") {
|
|
if let Value::Array(arr) = one_of {
|
|
for item in arr {
|
|
Self::sanitize_schema(item);
|
|
}
|
|
}
|
|
}
|
|
if let Some(all_of) = map.get_mut("allOf") {
|
|
if let Value::Array(arr) = all_of {
|
|
for item in arr {
|
|
Self::sanitize_schema(item);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Convert OpenAI tool_choice to Gemini tool_config.
|
|
fn convert_tool_config(request: &UnifiedRequest) -> Option<GeminiToolConfig> {
|
|
request.tool_choice.as_ref().map(|tc| {
|
|
let (mode, allowed_names) = match tc {
|
|
crate::models::ToolChoice::Mode(mode) => {
|
|
let gemini_mode = match mode.as_str() {
|
|
"auto" => "AUTO",
|
|
"none" => "NONE",
|
|
"required" => "ANY",
|
|
_ => "AUTO",
|
|
};
|
|
(gemini_mode.to_string(), None)
|
|
}
|
|
crate::models::ToolChoice::Specific(specific) => {
|
|
("ANY".to_string(), Some(vec![specific.function.name.clone()]))
|
|
}
|
|
};
|
|
GeminiToolConfig {
|
|
function_calling_config: GeminiFunctionCallingConfig {
|
|
mode,
|
|
allowed_function_names: allowed_names,
|
|
},
|
|
}
|
|
})
|
|
}
|
|
|
|
/// Extract tool calls from Gemini response parts into OpenAI-format ToolCalls.
|
|
fn extract_tool_calls(parts: &[GeminiPart]) -> Option<Vec<ToolCall>> {
|
|
let calls: Vec<ToolCall> = parts
|
|
.iter()
|
|
.filter(|p| p.function_call.is_some())
|
|
.map(|p| {
|
|
let fc = p.function_call.as_ref().unwrap();
|
|
// CAPTURE: Try extracting thought_signature from sibling fields
|
|
let id = p.thought_signature_camel.clone()
|
|
.or_else(|| p.thought_signature_snake.clone())
|
|
.unwrap_or_else(|| format!("call_{}", Uuid::new_v4().simple()));
|
|
|
|
ToolCall {
|
|
id,
|
|
call_type: "function".to_string(),
|
|
function: FunctionCall {
|
|
name: fc.name.clone(),
|
|
arguments: serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string()),
|
|
},
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
if calls.is_empty() { None } else { Some(calls) }
|
|
}
|
|
|
|
/// Determine the appropriate base URL for the model.
|
|
/// "preview" models often require the v1beta endpoint, but newer promoted ones may be on v1.
|
|
fn get_base_url(&self, model: &str) -> String {
|
|
let base = &self.config.base_url;
|
|
|
|
// If the model requires v1beta but the base is currently v1
|
|
if (model.contains("preview") || model.contains("thinking") || model.contains("gemini-3")) && base.ends_with("/v1") {
|
|
return base.replace("/v1", "/v1beta");
|
|
}
|
|
|
|
// If the model is a standard model but the base is v1beta, we could downgrade it,
|
|
// but typically v1beta is a superset, so we just return the base as configured.
|
|
base.clone()
|
|
}
|
|
|
|
/// Default safety settings to avoid blocking responses.
|
|
fn get_safety_settings(&self, base_url: &str) -> Vec<GeminiSafetySetting> {
|
|
let mut categories = vec![
|
|
"HARM_CATEGORY_HARASSMENT",
|
|
"HARM_CATEGORY_HATE_SPEECH",
|
|
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
|
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
|
];
|
|
|
|
// Civic integrity is only available in v1beta
|
|
if base_url.contains("v1beta") {
|
|
categories.push("HARM_CATEGORY_CIVIC_INTEGRITY");
|
|
}
|
|
|
|
categories
|
|
.into_iter()
|
|
.map(|c| GeminiSafetySetting {
|
|
category: c.to_string(),
|
|
threshold: "BLOCK_NONE".to_string(),
|
|
})
|
|
.collect()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl super::Provider for GeminiProvider {
|
|
fn name(&self) -> &str {
|
|
"gemini"
|
|
}
|
|
|
|
fn supports_model(&self, model: &str) -> bool {
|
|
model.starts_with("gemini-")
|
|
}
|
|
|
|
fn supports_multimodal(&self) -> bool {
|
|
true // Gemini supports vision
|
|
}
|
|
|
|
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
|
let mut model = request.model.clone();
|
|
|
|
// Normalize model name: If it's a known Gemini model version, use it;
|
|
// otherwise, if it starts with gemini- but is an unknown legacy version,
|
|
// fallback to the default model to avoid 400 errors.
|
|
// We now allow gemini-3+ as valid versions.
|
|
let is_known_version = model.starts_with("gemini-1.5") ||
|
|
model.starts_with("gemini-2.0") ||
|
|
model.starts_with("gemini-2.5") ||
|
|
model.starts_with("gemini-3");
|
|
|
|
if !is_known_version && model.starts_with("gemini-") {
|
|
tracing::info!("Mapping unknown Gemini model {} to default {}", model, self.config.default_model);
|
|
model = self.config.default_model.clone();
|
|
}
|
|
|
|
let tools = Self::convert_tools(&request);
|
|
let tool_config = Self::convert_tool_config(&request);
|
|
let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?;
|
|
|
|
if contents.is_empty() && system_instruction.is_none() {
|
|
return Err(AppError::ProviderError("No valid messages to send".to_string()));
|
|
}
|
|
|
|
let base_url = self.get_base_url(&model);
|
|
|
|
// Sanitize stop sequences: Gemini rejects empty strings
|
|
let stop_sequences = request.stop.map(|s| {
|
|
s.into_iter()
|
|
.filter(|seq| !seq.is_empty())
|
|
.collect::<Vec<_>>()
|
|
}).filter(|s| !s.is_empty());
|
|
|
|
let generation_config = Some(GeminiGenerationConfig {
|
|
temperature: request.temperature,
|
|
top_p: request.top_p,
|
|
top_k: request.top_k,
|
|
max_output_tokens: request.max_tokens.map(|t| t.min(65536)),
|
|
stop_sequences,
|
|
candidate_count: request.n,
|
|
});
|
|
|
|
let gemini_request = GeminiRequest {
|
|
contents,
|
|
system_instruction,
|
|
generation_config,
|
|
tools,
|
|
tool_config,
|
|
safety_settings: Some(self.get_safety_settings(&base_url)),
|
|
};
|
|
|
|
let url = format!("{}/models/{}:generateContent", base_url, model);
|
|
tracing::info!("Calling Gemini API: {}", url);
|
|
|
|
let response = self
|
|
.client
|
|
.post(&url)
|
|
.header("x-goog-api-key", &self.api_key)
|
|
.json(&gemini_request)
|
|
.send()
|
|
.await
|
|
.map_err(|e| AppError::ProviderError(format!("HTTP request failed: {}", e)))?;
|
|
|
|
let status = response.status();
|
|
if !status.is_success() {
|
|
let error_text = response.text().await.unwrap_or_default();
|
|
return Err(AppError::ProviderError(format!(
|
|
"Gemini API error ({}): {}",
|
|
status, error_text
|
|
)));
|
|
}
|
|
|
|
let gemini_response: GeminiResponse = response
|
|
.json()
|
|
.await
|
|
.map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?;
|
|
|
|
let candidate = gemini_response.candidates.first();
|
|
|
|
// Extract text content
|
|
let content = candidate
|
|
.and_then(|c| c.content.parts.iter().find_map(|p| p.text.clone()))
|
|
.unwrap_or_default();
|
|
|
|
// Extract reasoning (Gemini 3 'thought' parts)
|
|
let reasoning_content = candidate
|
|
.and_then(|c| c.content.parts.iter().find_map(|p| p.thought.clone()));
|
|
|
|
// Extract function calls → OpenAI tool_calls
|
|
let tool_calls = candidate.and_then(|c| Self::extract_tool_calls(&c.content.parts));
|
|
|
|
let prompt_tokens = gemini_response
|
|
.usage_metadata
|
|
.as_ref()
|
|
.map(|u| u.prompt_token_count)
|
|
.unwrap_or(0);
|
|
let completion_tokens = gemini_response
|
|
.usage_metadata
|
|
.as_ref()
|
|
.map(|u| u.candidates_token_count)
|
|
.unwrap_or(0);
|
|
let total_tokens = gemini_response
|
|
.usage_metadata
|
|
.as_ref()
|
|
.map(|u| u.total_token_count)
|
|
.unwrap_or(0);
|
|
let cache_read_tokens = gemini_response
|
|
.usage_metadata
|
|
.as_ref()
|
|
.map(|u| u.cached_content_token_count)
|
|
.unwrap_or(0);
|
|
|
|
Ok(ProviderResponse {
|
|
content,
|
|
reasoning_content,
|
|
tool_calls,
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
total_tokens,
|
|
cache_read_tokens,
|
|
cache_write_tokens: 0, // Gemini doesn't report cache writes separately
|
|
model,
|
|
})
|
|
}
|
|
|
|
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
|
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
|
}
|
|
|
|
fn calculate_cost(
|
|
&self,
|
|
model: &str,
|
|
prompt_tokens: u32,
|
|
completion_tokens: u32,
|
|
cache_read_tokens: u32,
|
|
cache_write_tokens: u32,
|
|
registry: &crate::models::registry::ModelRegistry,
|
|
) -> f64 {
|
|
super::helpers::calculate_cost_with_registry(
|
|
model,
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
cache_read_tokens,
|
|
cache_write_tokens,
|
|
registry,
|
|
&self.pricing,
|
|
0.075,
|
|
0.30,
|
|
)
|
|
}
|
|
|
|
async fn chat_completion_stream(
|
|
&self,
|
|
request: UnifiedRequest,
|
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
|
let mut model = request.model.clone();
|
|
|
|
// Normalize model name: fallback to default if unknown Gemini model is requested
|
|
let is_known_version = model.starts_with("gemini-1.5") ||
|
|
model.starts_with("gemini-2.0") ||
|
|
model.starts_with("gemini-2.5") ||
|
|
model.starts_with("gemini-3");
|
|
|
|
if !is_known_version && model.starts_with("gemini-") {
|
|
tracing::info!("Mapping unknown Gemini model {} to default {}", model, self.config.default_model);
|
|
model = self.config.default_model.clone();
|
|
}
|
|
|
|
let tools = Self::convert_tools(&request);
|
|
let tool_config = Self::convert_tool_config(&request);
|
|
let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?;
|
|
|
|
if contents.is_empty() && system_instruction.is_none() {
|
|
return Err(AppError::ProviderError("No valid messages to send".to_string()));
|
|
}
|
|
|
|
let base_url = self.get_base_url(&model);
|
|
|
|
// Sanitize stop sequences: Gemini rejects empty strings
|
|
let stop_sequences = request.stop.map(|s| {
|
|
s.into_iter()
|
|
.filter(|seq| !seq.is_empty())
|
|
.collect::<Vec<_>>()
|
|
}).filter(|s| !s.is_empty());
|
|
|
|
let generation_config = Some(GeminiGenerationConfig {
|
|
temperature: request.temperature,
|
|
top_p: request.top_p,
|
|
top_k: request.top_k,
|
|
max_output_tokens: request.max_tokens.map(|t| t.min(65536)),
|
|
stop_sequences,
|
|
candidate_count: request.n,
|
|
});
|
|
|
|
let gemini_request = GeminiRequest {
|
|
contents,
|
|
system_instruction,
|
|
generation_config,
|
|
tools,
|
|
tool_config,
|
|
safety_settings: Some(self.get_safety_settings(&base_url)),
|
|
};
|
|
|
|
let url = format!(
|
|
"{}/models/{}:streamGenerateContent?alt=sse",
|
|
base_url, model,
|
|
);
|
|
tracing::info!("Calling Gemini Stream API: {}", url);
|
|
|
|
// Capture a clone of the request to probe for errors (Gemini 400s are common)
|
|
let probe_request = gemini_request.clone();
|
|
let probe_client = self.client.clone();
|
|
// Use non-streaming URL for probing to get a valid JSON error body
|
|
let probe_url = format!("{}/models/{}:generateContent", base_url, model);
|
|
let probe_api_key = self.api_key.clone();
|
|
|
|
// Create the EventSource first (it doesn't send until polled)
|
|
let es = reqwest_eventsource::EventSource::new(
|
|
self.client
|
|
.post(&url)
|
|
.header("x-goog-api-key", &self.api_key)
|
|
.header("Accept", "text/event-stream")
|
|
.json(&gemini_request),
|
|
).map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
|
|
|
let stream = async_stream::try_stream! {
|
|
let mut es = es;
|
|
// Track tool call IDs by their part index to ensure stability during streaming.
|
|
// Gemini doesn't always include the thoughtSignature in every chunk for the same part.
|
|
let mut tool_call_ids: std::collections::HashMap<u32, String> = std::collections::HashMap::new();
|
|
let mut seen_tool_calls = false;
|
|
|
|
while let Some(event) = es.next().await {
|
|
match event {
|
|
Ok(Event::Message(msg)) => {
|
|
let gemini_response: GeminiStreamResponse = serde_json::from_str(&msg.data)
|
|
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
|
|
|
tracing::info!("Received Gemini stream chunk (candidates: {}, has_usage: {}, finish_reason: {:?})",
|
|
gemini_response.candidates.len(),
|
|
gemini_response.usage_metadata.is_some(),
|
|
gemini_response.candidates.first().and_then(|c| c.finish_reason.as_deref())
|
|
);
|
|
|
|
// Extract usage from usageMetadata if present (reported on every/last chunk)
|
|
let stream_usage = gemini_response.usage_metadata.as_ref().map(|u| {
|
|
super::StreamUsage {
|
|
prompt_tokens: u.prompt_token_count,
|
|
completion_tokens: u.candidates_token_count,
|
|
total_tokens: u.total_token_count,
|
|
cache_read_tokens: u.cached_content_token_count,
|
|
cache_write_tokens: 0,
|
|
}
|
|
});
|
|
|
|
// Some streaming events may not contain candidates (e.g. promptFeedback).
|
|
// Only emit chunks when we have candidate content or tool calls.
|
|
if let Some(candidate) = gemini_response.candidates.first() {
|
|
if let Some(content_obj) = &candidate.content {
|
|
let content = content_obj
|
|
.parts
|
|
.iter()
|
|
.find_map(|p| p.text.clone())
|
|
.unwrap_or_default();
|
|
|
|
let reasoning_content = content_obj
|
|
.parts
|
|
.iter()
|
|
.find_map(|p| p.thought.clone());
|
|
|
|
// Extract tool calls with index and ID stability
|
|
let mut deltas = Vec::new();
|
|
for (p_idx, p) in content_obj.parts.iter().enumerate() {
|
|
if let Some(fc) = &p.function_call {
|
|
seen_tool_calls = true;
|
|
let tool_call_idx = p_idx as u32;
|
|
|
|
// Attempt to find a signature in sibling fields
|
|
let signature = p.thought_signature_camel.clone()
|
|
.or_else(|| p.thought_signature_snake.clone());
|
|
|
|
// Ensure the ID remains stable for this tool call index.
|
|
// If we found a real signature now, we update it; otherwise use the existing or new random ID.
|
|
let entry = tool_call_ids.entry(tool_call_idx);
|
|
let current_id = match entry {
|
|
std::collections::hash_map::Entry::Occupied(mut e) => {
|
|
if let Some(sig) = signature {
|
|
// If we previously had a 'call_' ID but now found a real signature, upgrade it.
|
|
if e.get().starts_with("call_") {
|
|
e.insert(sig);
|
|
}
|
|
}
|
|
e.get().clone()
|
|
}
|
|
std::collections::hash_map::Entry::Vacant(e) => {
|
|
let id = signature.unwrap_or_else(|| format!("call_{}", Uuid::new_v4().simple()));
|
|
e.insert(id.clone());
|
|
id
|
|
}
|
|
};
|
|
|
|
deltas.push(ToolCallDelta {
|
|
index: tool_call_idx,
|
|
id: Some(current_id),
|
|
call_type: Some("function".to_string()),
|
|
function: Some(FunctionCallDelta {
|
|
name: Some(fc.name.clone()),
|
|
arguments: Some(serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string())),
|
|
}),
|
|
});
|
|
}
|
|
}
|
|
let tool_calls = if deltas.is_empty() { None } else { Some(deltas) };
|
|
|
|
// Determine finish_reason
|
|
// STRATEGY: If we have tool calls in this chunk, OR if we have seen them
|
|
// previously in the stream, the finish_reason MUST be "tool_calls"
|
|
// if the provider signals a stop. This ensures the client executes tools.
|
|
let mut finish_reason = candidate.finish_reason.as_ref().map(|fr| {
|
|
match fr.as_str() {
|
|
"STOP" => "stop".to_string(),
|
|
_ => fr.to_lowercase(),
|
|
}
|
|
});
|
|
|
|
if seen_tool_calls && finish_reason.as_deref() == Some("stop") {
|
|
finish_reason = Some("tool_calls".to_string());
|
|
} else if tool_calls.is_some() && finish_reason.is_none() {
|
|
// Optional: Could signal tool_calls here too, but OpenAI often waits until EOF.
|
|
// For now we only override it at the actual stop signal.
|
|
}
|
|
|
|
// Avoid emitting completely empty chunks unless they carry usage.
|
|
if !content.is_empty() || reasoning_content.is_some() || tool_calls.is_some() || stream_usage.is_some() {
|
|
yield ProviderStreamChunk {
|
|
content,
|
|
reasoning_content,
|
|
finish_reason,
|
|
tool_calls,
|
|
model: model.clone(),
|
|
usage: stream_usage,
|
|
};
|
|
}
|
|
} else if stream_usage.is_some() {
|
|
// Usage-only update
|
|
yield ProviderStreamChunk {
|
|
content: String::new(),
|
|
reasoning_content: None,
|
|
finish_reason: None,
|
|
tool_calls: None,
|
|
model: model.clone(),
|
|
usage: stream_usage,
|
|
};
|
|
}
|
|
} else if stream_usage.is_some() {
|
|
// No candidates but usage present
|
|
yield ProviderStreamChunk {
|
|
content: String::new(),
|
|
reasoning_content: None,
|
|
finish_reason: None,
|
|
tool_calls: None,
|
|
model: model.clone(),
|
|
usage: stream_usage,
|
|
};
|
|
}
|
|
}
|
|
Ok(_) => continue,
|
|
Err(e) => {
|
|
// "Stream ended" is usually a normal EOF signal in reqwest-eventsource.
|
|
// We check the string representation to avoid returning it as an error.
|
|
if e.to_string().contains("Stream ended") {
|
|
break;
|
|
}
|
|
|
|
// On stream error, attempt to probe for the actual error body from the provider
|
|
let probe_resp = probe_client
|
|
.post(&probe_url)
|
|
.header("x-goog-api-key", &probe_api_key)
|
|
.json(&probe_request)
|
|
.send()
|
|
.await;
|
|
|
|
match probe_resp {
|
|
Ok(r) if !r.status().is_success() => {
|
|
let status = r.status();
|
|
let body = r.text().await.unwrap_or_default();
|
|
tracing::error!("Gemini Stream Error Probe ({}): {}", status, body);
|
|
Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, body)))?;
|
|
}
|
|
_ => {
|
|
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
Ok(Box::pin(stream))
|
|
}
|
|
}
|