- Exclude 'HARM_CATEGORY_CIVIC_INTEGRITY' when using v1 endpoint (v1beta only). - Filter out empty strings from 'stop_sequences' which are rejected by Gemini. - Update error probe to use non-streaming endpoint for better JSON error diagnostics.
897 lines
34 KiB
Rust
897 lines
34 KiB
Rust
use anyhow::Result;
|
|
use async_trait::async_trait;
|
|
use futures::stream::{BoxStream, StreamExt};
|
|
use reqwest_eventsource::Event;
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::Value;
|
|
use uuid::Uuid;
|
|
|
|
use super::{ProviderResponse, ProviderStreamChunk};
|
|
use crate::{
|
|
config::AppConfig,
|
|
errors::AppError,
|
|
models::{ContentPart, FunctionCall, FunctionCallDelta, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest},
|
|
};
|
|
|
|
// ========== Gemini Request Structs ==========
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiRequest {
|
|
contents: Vec<GeminiContent>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
system_instruction: Option<GeminiContent>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
generation_config: Option<GeminiGenerationConfig>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tools: Option<Vec<GeminiTool>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tool_config: Option<GeminiToolConfig>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
safety_settings: Option<Vec<GeminiSafetySetting>>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
struct GeminiSafetySetting {
|
|
category: String,
|
|
threshold: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct GeminiContent {
|
|
parts: Vec<GeminiPart>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
role: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiPart {
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
text: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
inline_data: Option<GeminiInlineData>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
function_call: Option<GeminiFunctionCall>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
function_response: Option<GeminiFunctionResponse>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct GeminiInlineData {
|
|
mime_type: String,
|
|
data: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct GeminiFunctionCall {
|
|
name: String,
|
|
args: Value,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct GeminiFunctionResponse {
|
|
name: String,
|
|
response: Value,
|
|
}
|
|
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiGenerationConfig {
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
temperature: Option<f64>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
top_p: Option<f64>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
top_k: Option<u32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
max_output_tokens: Option<u32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
stop_sequences: Option<Vec<String>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
candidate_count: Option<u32>,
|
|
}
|
|
|
|
// ========== Gemini Tool Structs ==========
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiTool {
|
|
function_declarations: Vec<GeminiFunctionDeclaration>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
struct GeminiFunctionDeclaration {
|
|
name: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
description: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
parameters: Option<Value>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiToolConfig {
|
|
function_calling_config: GeminiFunctionCallingConfig,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
struct GeminiFunctionCallingConfig {
|
|
mode: String,
|
|
#[serde(skip_serializing_if = "Option::is_none", rename = "allowedFunctionNames")]
|
|
allowed_function_names: Option<Vec<String>>,
|
|
}
|
|
|
|
// ========== Gemini Response Structs ==========
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiCandidate {
|
|
content: GeminiContent,
|
|
#[serde(default)]
|
|
#[allow(dead_code)]
|
|
finish_reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiUsageMetadata {
|
|
#[serde(default)]
|
|
prompt_token_count: u32,
|
|
#[serde(default)]
|
|
candidates_token_count: u32,
|
|
#[serde(default)]
|
|
total_token_count: u32,
|
|
#[serde(default)]
|
|
cached_content_token_count: u32,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiResponse {
|
|
candidates: Vec<GeminiCandidate>,
|
|
usage_metadata: Option<GeminiUsageMetadata>,
|
|
}
|
|
|
|
// Streaming responses from Gemini may include messages without `candidates` (e.g. promptFeedback).
|
|
// Use a more permissive struct for streaming to avoid aborting the SSE prematurely.
|
|
#[derive(Debug, Deserialize, Default)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiStreamResponse {
|
|
#[serde(default)]
|
|
candidates: Vec<GeminiStreamCandidate>,
|
|
#[serde(default)]
|
|
usage_metadata: Option<GeminiUsageMetadata>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Default)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GeminiStreamCandidate {
|
|
#[serde(default)]
|
|
content: Option<GeminiContent>,
|
|
#[serde(default)]
|
|
finish_reason: Option<String>,
|
|
}
|
|
|
|
// ========== Provider Implementation ==========
|
|
|
|
pub struct GeminiProvider {
|
|
client: reqwest::Client,
|
|
config: crate::config::GeminiConfig,
|
|
api_key: String,
|
|
pricing: Vec<crate::config::ModelPricing>,
|
|
}
|
|
|
|
impl GeminiProvider {
|
|
pub fn new(config: &crate::config::GeminiConfig, app_config: &AppConfig) -> Result<Self> {
|
|
let api_key = app_config.get_api_key("gemini")?;
|
|
Self::new_with_key(config, app_config, api_key)
|
|
}
|
|
|
|
pub fn new_with_key(config: &crate::config::GeminiConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
|
let client = reqwest::Client::builder()
|
|
.connect_timeout(std::time::Duration::from_secs(5))
|
|
.timeout(std::time::Duration::from_secs(300))
|
|
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
|
.pool_max_idle_per_host(4)
|
|
.tcp_keepalive(std::time::Duration::from_secs(30))
|
|
.build()?;
|
|
|
|
Ok(Self {
|
|
client,
|
|
config: config.clone(),
|
|
api_key,
|
|
pricing: app_config.pricing.gemini.clone(),
|
|
})
|
|
}
|
|
|
|
/// Convert unified messages to Gemini content format.
|
|
/// Handles text, images, tool calls (assistant), and tool results.
|
|
/// Returns (contents, system_instruction)
|
|
async fn convert_messages(
|
|
messages: Vec<UnifiedMessage>,
|
|
) -> Result<(Vec<GeminiContent>, Option<GeminiContent>), AppError> {
|
|
let mut contents: Vec<GeminiContent> = Vec::new();
|
|
let mut system_parts = Vec::new();
|
|
|
|
for msg in messages {
|
|
if msg.role == "system" {
|
|
for part in msg.content {
|
|
if let ContentPart::Text { text } = part {
|
|
if !text.trim().is_empty() {
|
|
system_parts.push(GeminiPart {
|
|
text: Some(text),
|
|
inline_data: None,
|
|
function_call: None,
|
|
function_response: None,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
continue;
|
|
}
|
|
|
|
let role = match msg.role.as_str() {
|
|
"assistant" => "model".to_string(),
|
|
"tool" => "user".to_string(), // Tool results are user-side in Gemini
|
|
_ => "user".to_string(),
|
|
};
|
|
|
|
let mut parts = Vec::new();
|
|
|
|
// Handle tool results (role "tool")
|
|
if msg.role == "tool" {
|
|
let text_content = msg
|
|
.content
|
|
.first()
|
|
.map(|p| match p {
|
|
ContentPart::Text { text } => text.clone(),
|
|
ContentPart::Image(_) => "[Image]".to_string(),
|
|
})
|
|
.unwrap_or_default();
|
|
|
|
let name = msg.name.clone().or_else(|| msg.tool_call_id.clone()).unwrap_or_else(|| "unknown_function".to_string());
|
|
let response_value = serde_json::from_str::<Value>(&text_content)
|
|
.unwrap_or_else(|_| serde_json::json!({ "result": text_content }));
|
|
|
|
parts.push(GeminiPart {
|
|
text: None,
|
|
inline_data: None,
|
|
function_call: None,
|
|
function_response: Some(GeminiFunctionResponse {
|
|
name,
|
|
response: response_value,
|
|
}),
|
|
});
|
|
} else if msg.role == "assistant" && msg.tool_calls.is_some() {
|
|
// Assistant messages with tool_calls
|
|
if let Some(tool_calls) = &msg.tool_calls {
|
|
for p in &msg.content {
|
|
if let ContentPart::Text { text } = p {
|
|
if !text.trim().is_empty() {
|
|
parts.push(GeminiPart {
|
|
text: Some(text.clone()),
|
|
inline_data: None,
|
|
function_call: None,
|
|
function_response: None,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
for tc in tool_calls {
|
|
let args = serde_json::from_str::<Value>(&tc.function.arguments)
|
|
.unwrap_or_else(|_| serde_json::json!({}));
|
|
parts.push(GeminiPart {
|
|
text: None,
|
|
inline_data: None,
|
|
function_call: Some(GeminiFunctionCall {
|
|
name: tc.function.name.clone(),
|
|
args,
|
|
}),
|
|
function_response: None,
|
|
});
|
|
}
|
|
}
|
|
} else {
|
|
// Regular text/image messages
|
|
for part in msg.content {
|
|
match part {
|
|
ContentPart::Text { text } => {
|
|
if !text.trim().is_empty() {
|
|
parts.push(GeminiPart {
|
|
text: Some(text),
|
|
inline_data: None,
|
|
function_call: None,
|
|
function_response: None,
|
|
});
|
|
}
|
|
}
|
|
ContentPart::Image(image_input) => {
|
|
let (base64_data, mime_type) = image_input
|
|
.to_base64()
|
|
.await
|
|
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
|
|
|
parts.push(GeminiPart {
|
|
text: None,
|
|
inline_data: Some(GeminiInlineData {
|
|
mime_type,
|
|
data: base64_data,
|
|
}),
|
|
function_call: None,
|
|
function_response: None,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if parts.is_empty() {
|
|
continue;
|
|
}
|
|
|
|
// STRATEGY: Strictly enforce alternating roles.
|
|
// If current message has the same role as the last one, merge their parts.
|
|
if let Some(last_content) = contents.last_mut() {
|
|
if last_content.role.as_ref() == Some(&role) {
|
|
last_content.parts.extend(parts);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
contents.push(GeminiContent {
|
|
parts,
|
|
role: Some(role),
|
|
});
|
|
}
|
|
|
|
// Gemini requires the first message to be from "user".
|
|
if let Some(first) = contents.first() {
|
|
if first.role.as_deref() == Some("model") {
|
|
contents.insert(0, GeminiContent {
|
|
role: Some("user".to_string()),
|
|
parts: vec![GeminiPart {
|
|
text: Some("Continue conversation.".to_string()),
|
|
inline_data: None,
|
|
function_call: None,
|
|
function_response: None,
|
|
}],
|
|
});
|
|
}
|
|
}
|
|
|
|
// Final check: ensure we don't have empty contents after filtering.
|
|
// If the last message was merged or filtered, we might have an empty array.
|
|
if contents.is_empty() && system_parts.is_empty() {
|
|
return Err(AppError::ProviderError("No valid content parts after filtering".to_string()));
|
|
}
|
|
|
|
let system_instruction = if !system_parts.is_empty() {
|
|
Some(GeminiContent {
|
|
parts: system_parts,
|
|
role: None,
|
|
})
|
|
} else {
|
|
None
|
|
};
|
|
|
|
Ok((contents, system_instruction))
|
|
}
|
|
|
|
/// Convert OpenAI tools to Gemini function declarations.
|
|
fn convert_tools(request: &UnifiedRequest) -> Option<Vec<GeminiTool>> {
|
|
request.tools.as_ref().map(|tools| {
|
|
let declarations: Vec<GeminiFunctionDeclaration> = tools
|
|
.iter()
|
|
.map(|t| {
|
|
let mut parameters = t.function.parameters.clone().unwrap_or(serde_json::json!({
|
|
"type": "object",
|
|
"properties": {}
|
|
}));
|
|
Self::sanitize_schema(&mut parameters);
|
|
|
|
GeminiFunctionDeclaration {
|
|
name: t.function.name.clone(),
|
|
description: t.function.description.clone(),
|
|
parameters: Some(parameters),
|
|
}
|
|
})
|
|
.collect();
|
|
vec![GeminiTool {
|
|
function_declarations: declarations,
|
|
}]
|
|
})
|
|
}
|
|
|
|
/// Recursively remove unsupported JSON Schema fields that Gemini's API rejects.
|
|
fn sanitize_schema(value: &mut Value) {
|
|
if let Value::Object(map) = value {
|
|
// Remove unsupported fields at this level
|
|
map.remove("$schema");
|
|
map.remove("additionalProperties");
|
|
map.remove("exclusiveMaximum");
|
|
map.remove("exclusiveMinimum");
|
|
|
|
// Recursively sanitize all object properties
|
|
if let Some(properties) = map.get_mut("properties") {
|
|
if let Value::Object(props_map) = properties {
|
|
for prop_value in props_map.values_mut() {
|
|
Self::sanitize_schema(prop_value);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Recursively sanitize array items
|
|
if let Some(items) = map.get_mut("items") {
|
|
Self::sanitize_schema(items);
|
|
}
|
|
|
|
// Gemini 1.5/2.0+ supports anyOf in some contexts, but it's often
|
|
// the source of additionalProperties errors when nested.
|
|
if let Some(any_of) = map.get_mut("anyOf") {
|
|
if let Value::Array(arr) = any_of {
|
|
for item in arr {
|
|
Self::sanitize_schema(item);
|
|
}
|
|
}
|
|
}
|
|
if let Some(one_of) = map.get_mut("oneOf") {
|
|
if let Value::Array(arr) = one_of {
|
|
for item in arr {
|
|
Self::sanitize_schema(item);
|
|
}
|
|
}
|
|
}
|
|
if let Some(all_of) = map.get_mut("allOf") {
|
|
if let Value::Array(arr) = all_of {
|
|
for item in arr {
|
|
Self::sanitize_schema(item);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Convert OpenAI tool_choice to Gemini tool_config.
|
|
fn convert_tool_config(request: &UnifiedRequest) -> Option<GeminiToolConfig> {
|
|
request.tool_choice.as_ref().map(|tc| {
|
|
let (mode, allowed_names) = match tc {
|
|
crate::models::ToolChoice::Mode(mode) => {
|
|
let gemini_mode = match mode.as_str() {
|
|
"auto" => "AUTO",
|
|
"none" => "NONE",
|
|
"required" => "ANY",
|
|
_ => "AUTO",
|
|
};
|
|
(gemini_mode.to_string(), None)
|
|
}
|
|
crate::models::ToolChoice::Specific(specific) => {
|
|
("ANY".to_string(), Some(vec![specific.function.name.clone()]))
|
|
}
|
|
};
|
|
GeminiToolConfig {
|
|
function_calling_config: GeminiFunctionCallingConfig {
|
|
mode,
|
|
allowed_function_names: allowed_names,
|
|
},
|
|
}
|
|
})
|
|
}
|
|
|
|
/// Extract tool calls from Gemini response parts into OpenAI-format ToolCalls.
|
|
fn extract_tool_calls(parts: &[GeminiPart]) -> Option<Vec<ToolCall>> {
|
|
let calls: Vec<ToolCall> = parts
|
|
.iter()
|
|
.filter_map(|p| p.function_call.as_ref())
|
|
.map(|fc| ToolCall {
|
|
id: format!("call_{}", Uuid::new_v4().simple()),
|
|
call_type: "function".to_string(),
|
|
function: FunctionCall {
|
|
name: fc.name.clone(),
|
|
arguments: serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string()),
|
|
},
|
|
})
|
|
.collect();
|
|
|
|
if calls.is_empty() { None } else { Some(calls) }
|
|
}
|
|
|
|
/// Extract tool call deltas from Gemini response parts for streaming.
|
|
fn extract_tool_call_deltas(parts: &[GeminiPart]) -> Option<Vec<ToolCallDelta>> {
|
|
let deltas: Vec<ToolCallDelta> = parts
|
|
.iter()
|
|
.filter_map(|p| p.function_call.as_ref())
|
|
.enumerate()
|
|
.map(|(i, fc)| ToolCallDelta {
|
|
index: i as u32,
|
|
id: Some(format!("call_{}", Uuid::new_v4().simple())),
|
|
call_type: Some("function".to_string()),
|
|
function: Some(FunctionCallDelta {
|
|
name: Some(fc.name.clone()),
|
|
arguments: Some(serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string())),
|
|
}),
|
|
})
|
|
.collect();
|
|
|
|
if deltas.is_empty() { None } else { Some(deltas) }
|
|
}
|
|
|
|
/// Determine the appropriate base URL for the model.
|
|
/// "preview" models often require the v1beta endpoint, but newer promoted ones may be on v1.
|
|
fn get_base_url(&self, model: &str) -> String {
|
|
// Only use v1beta for older preview models or specific "thinking" experimental models.
|
|
// Gemini 3.0+ models are typically released on v1 even in preview.
|
|
if (model.contains("preview") && !model.contains("gemini-3")) || model.contains("thinking") {
|
|
self.config.base_url.replace("/v1", "/v1beta")
|
|
} else {
|
|
self.config.base_url.clone()
|
|
}
|
|
}
|
|
|
|
/// Default safety settings to avoid blocking responses.
|
|
fn get_safety_settings(&self, base_url: &str) -> Vec<GeminiSafetySetting> {
|
|
let mut categories = vec![
|
|
"HARM_CATEGORY_HARASSMENT",
|
|
"HARM_CATEGORY_HATE_SPEECH",
|
|
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
|
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
|
];
|
|
|
|
// Civic integrity is only available in v1beta
|
|
if base_url.contains("v1beta") {
|
|
categories.push("HARM_CATEGORY_CIVIC_INTEGRITY");
|
|
}
|
|
|
|
categories
|
|
.into_iter()
|
|
.map(|c| GeminiSafetySetting {
|
|
category: c.to_string(),
|
|
threshold: "BLOCK_NONE".to_string(),
|
|
})
|
|
.collect()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl super::Provider for GeminiProvider {
|
|
fn name(&self) -> &str {
|
|
"gemini"
|
|
}
|
|
|
|
fn supports_model(&self, model: &str) -> bool {
|
|
model.starts_with("gemini-")
|
|
}
|
|
|
|
fn supports_multimodal(&self) -> bool {
|
|
true // Gemini supports vision
|
|
}
|
|
|
|
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
|
let mut model = request.model.clone();
|
|
|
|
// Normalize model name: If it's a known Gemini model version, use it;
|
|
// otherwise, if it starts with gemini- but is an unknown legacy version,
|
|
// fallback to the default model to avoid 400 errors.
|
|
// We now allow gemini-3+ as valid versions.
|
|
let is_known_version = model.starts_with("gemini-1.5") ||
|
|
model.starts_with("gemini-2.0") ||
|
|
model.starts_with("gemini-2.5") ||
|
|
model.starts_with("gemini-3");
|
|
|
|
if !is_known_version && model.starts_with("gemini-") {
|
|
tracing::info!("Mapping unknown Gemini model {} to default {}", model, self.config.default_model);
|
|
model = self.config.default_model.clone();
|
|
}
|
|
|
|
let tools = Self::convert_tools(&request);
|
|
let tool_config = Self::convert_tool_config(&request);
|
|
let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?;
|
|
|
|
if contents.is_empty() && system_instruction.is_none() {
|
|
return Err(AppError::ProviderError("No valid messages to send".to_string()));
|
|
}
|
|
|
|
let base_url = self.get_base_url(&model);
|
|
|
|
// Sanitize stop sequences: Gemini rejects empty strings
|
|
let stop_sequences = request.stop.map(|s| {
|
|
s.into_iter()
|
|
.filter(|seq| !seq.is_empty())
|
|
.collect::<Vec<_>>()
|
|
}).filter(|s| !s.is_empty());
|
|
|
|
let generation_config = Some(GeminiGenerationConfig {
|
|
temperature: request.temperature,
|
|
top_p: request.top_p,
|
|
top_k: request.top_k,
|
|
max_output_tokens: request.max_tokens.map(|t| t.min(8192)),
|
|
stop_sequences,
|
|
candidate_count: request.n,
|
|
});
|
|
|
|
let gemini_request = GeminiRequest {
|
|
contents,
|
|
system_instruction,
|
|
generation_config,
|
|
tools,
|
|
tool_config,
|
|
safety_settings: Some(self.get_safety_settings(&base_url)),
|
|
};
|
|
|
|
let url = format!("{}/models/{}:generateContent", base_url, model);
|
|
tracing::debug!("Calling Gemini API: {}", url);
|
|
|
|
let response = self
|
|
.client
|
|
.post(&url)
|
|
.header("x-goog-api-key", &self.api_key)
|
|
.json(&gemini_request)
|
|
.send()
|
|
.await
|
|
.map_err(|e| AppError::ProviderError(format!("HTTP request failed: {}", e)))?;
|
|
|
|
let status = response.status();
|
|
if !status.is_success() {
|
|
let error_text = response.text().await.unwrap_or_default();
|
|
return Err(AppError::ProviderError(format!(
|
|
"Gemini API error ({}): {}",
|
|
status, error_text
|
|
)));
|
|
}
|
|
|
|
let gemini_response: GeminiResponse = response
|
|
.json()
|
|
.await
|
|
.map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?;
|
|
|
|
let candidate = gemini_response.candidates.first();
|
|
|
|
// Extract text content (may be absent if only function calls)
|
|
let content = candidate
|
|
.and_then(|c| c.content.parts.iter().find_map(|p| p.text.clone()))
|
|
.unwrap_or_default();
|
|
|
|
// Extract function calls → OpenAI tool_calls
|
|
let tool_calls = candidate.and_then(|c| Self::extract_tool_calls(&c.content.parts));
|
|
|
|
let prompt_tokens = gemini_response
|
|
.usage_metadata
|
|
.as_ref()
|
|
.map(|u| u.prompt_token_count)
|
|
.unwrap_or(0);
|
|
let completion_tokens = gemini_response
|
|
.usage_metadata
|
|
.as_ref()
|
|
.map(|u| u.candidates_token_count)
|
|
.unwrap_or(0);
|
|
let total_tokens = gemini_response
|
|
.usage_metadata
|
|
.as_ref()
|
|
.map(|u| u.total_token_count)
|
|
.unwrap_or(0);
|
|
let cache_read_tokens = gemini_response
|
|
.usage_metadata
|
|
.as_ref()
|
|
.map(|u| u.cached_content_token_count)
|
|
.unwrap_or(0);
|
|
|
|
Ok(ProviderResponse {
|
|
content,
|
|
reasoning_content: None,
|
|
tool_calls,
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
total_tokens,
|
|
cache_read_tokens,
|
|
cache_write_tokens: 0, // Gemini doesn't report cache writes separately
|
|
model,
|
|
})
|
|
}
|
|
|
|
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
|
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
|
}
|
|
|
|
fn calculate_cost(
|
|
&self,
|
|
model: &str,
|
|
prompt_tokens: u32,
|
|
completion_tokens: u32,
|
|
cache_read_tokens: u32,
|
|
cache_write_tokens: u32,
|
|
registry: &crate::models::registry::ModelRegistry,
|
|
) -> f64 {
|
|
super::helpers::calculate_cost_with_registry(
|
|
model,
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
cache_read_tokens,
|
|
cache_write_tokens,
|
|
registry,
|
|
&self.pricing,
|
|
0.075,
|
|
0.30,
|
|
)
|
|
}
|
|
|
|
async fn chat_completion_stream(
|
|
&self,
|
|
request: UnifiedRequest,
|
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
|
let mut model = request.model.clone();
|
|
|
|
// Normalize model name: fallback to default if unknown Gemini model is requested
|
|
let is_known_version = model.starts_with("gemini-1.5") ||
|
|
model.starts_with("gemini-2.0") ||
|
|
model.starts_with("gemini-2.5") ||
|
|
model.starts_with("gemini-3");
|
|
|
|
if !is_known_version && model.starts_with("gemini-") {
|
|
tracing::info!("Mapping unknown Gemini model {} to default {}", model, self.config.default_model);
|
|
model = self.config.default_model.clone();
|
|
}
|
|
|
|
let tools = Self::convert_tools(&request);
|
|
let tool_config = Self::convert_tool_config(&request);
|
|
let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?;
|
|
|
|
if contents.is_empty() && system_instruction.is_none() {
|
|
return Err(AppError::ProviderError("No valid messages to send".to_string()));
|
|
}
|
|
|
|
let base_url = self.get_base_url(&model);
|
|
|
|
// Sanitize stop sequences: Gemini rejects empty strings
|
|
let stop_sequences = request.stop.map(|s| {
|
|
s.into_iter()
|
|
.filter(|seq| !seq.is_empty())
|
|
.collect::<Vec<_>>()
|
|
}).filter(|s| !s.is_empty());
|
|
|
|
let generation_config = Some(GeminiGenerationConfig {
|
|
temperature: request.temperature,
|
|
top_p: request.top_p,
|
|
top_k: request.top_k,
|
|
max_output_tokens: request.max_tokens.map(|t| t.min(8192)),
|
|
stop_sequences,
|
|
candidate_count: request.n,
|
|
});
|
|
|
|
let gemini_request = GeminiRequest {
|
|
contents,
|
|
system_instruction,
|
|
generation_config,
|
|
tools,
|
|
tool_config,
|
|
safety_settings: Some(self.get_safety_settings(&base_url)),
|
|
};
|
|
|
|
let url = format!(
|
|
"{}/models/{}:streamGenerateContent?alt=sse",
|
|
base_url, model,
|
|
);
|
|
tracing::debug!("Calling Gemini Stream API: {}", url);
|
|
|
|
// Capture a clone of the request to probe for errors (Gemini 400s are common)
|
|
let probe_request = gemini_request.clone();
|
|
let probe_client = self.client.clone();
|
|
// Use non-streaming URL for probing to get a valid JSON error body
|
|
let probe_url = format!("{}/models/{}:generateContent", base_url, model);
|
|
let probe_api_key = self.api_key.clone();
|
|
|
|
// Create the EventSource first (it doesn't send until polled)
|
|
let es = reqwest_eventsource::EventSource::new(
|
|
self.client
|
|
.post(&url)
|
|
.header("x-goog-api-key", &self.api_key)
|
|
.header("Accept", "text/event-stream")
|
|
.json(&gemini_request),
|
|
).map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
|
|
|
let stream = async_stream::try_stream! {
|
|
let mut es = es;
|
|
while let Some(event) = es.next().await {
|
|
match event {
|
|
Ok(Event::Message(msg)) => {
|
|
let gemini_response: GeminiStreamResponse = serde_json::from_str(&msg.data)
|
|
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
|
|
|
// (rest of processing remains identical)
|
|
|
|
// Extract usage from usageMetadata if present (reported on every/last chunk)
|
|
let stream_usage = gemini_response.usage_metadata.as_ref().map(|u| {
|
|
super::StreamUsage {
|
|
prompt_tokens: u.prompt_token_count,
|
|
completion_tokens: u.candidates_token_count,
|
|
total_tokens: u.total_token_count,
|
|
cache_read_tokens: u.cached_content_token_count,
|
|
cache_write_tokens: 0,
|
|
}
|
|
});
|
|
|
|
// Some streaming events may not contain candidates (e.g. promptFeedback).
|
|
// Only emit chunks when we have candidate content or tool calls.
|
|
if let Some(candidate) = gemini_response.candidates.first() {
|
|
if let Some(content_obj) = &candidate.content {
|
|
let content = content_obj
|
|
.parts
|
|
.iter()
|
|
.find_map(|p| p.text.clone())
|
|
.unwrap_or_default();
|
|
|
|
let tool_calls = Self::extract_tool_call_deltas(&content_obj.parts);
|
|
|
|
// Determine finish_reason
|
|
let finish_reason = candidate.finish_reason.as_ref().map(|fr| {
|
|
match fr.as_str() {
|
|
"STOP" => "stop".to_string(),
|
|
_ => fr.to_lowercase(),
|
|
}
|
|
});
|
|
|
|
// Avoid emitting completely empty chunks unless they carry usage.
|
|
if !content.is_empty() || tool_calls.is_some() || stream_usage.is_some() {
|
|
yield ProviderStreamChunk {
|
|
content,
|
|
reasoning_content: None,
|
|
finish_reason,
|
|
tool_calls,
|
|
model: model.clone(),
|
|
usage: stream_usage,
|
|
};
|
|
}
|
|
} else if stream_usage.is_some() {
|
|
// Usage-only update
|
|
yield ProviderStreamChunk {
|
|
content: String::new(),
|
|
reasoning_content: None,
|
|
finish_reason: None,
|
|
tool_calls: None,
|
|
model: model.clone(),
|
|
usage: stream_usage,
|
|
};
|
|
}
|
|
} else if stream_usage.is_some() {
|
|
// No candidates but usage present
|
|
yield ProviderStreamChunk {
|
|
content: String::new(),
|
|
reasoning_content: None,
|
|
finish_reason: None,
|
|
tool_calls: None,
|
|
model: model.clone(),
|
|
usage: stream_usage,
|
|
};
|
|
}
|
|
}
|
|
Ok(_) => continue,
|
|
Err(e) => {
|
|
// On stream error, attempt to probe for the actual error body from the provider
|
|
let probe_resp = probe_client
|
|
.post(&probe_url)
|
|
.header("x-goog-api-key", &probe_api_key)
|
|
.json(&probe_request)
|
|
.send()
|
|
.await;
|
|
|
|
match probe_resp {
|
|
Ok(r) if !r.status().is_success() => {
|
|
let status = r.status();
|
|
let body = r.text().await.unwrap_or_default();
|
|
tracing::error!("Gemini Stream Error Probe ({}): {}", status, body);
|
|
Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, body)))?;
|
|
}
|
|
_ => {
|
|
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
Ok(Box::pin(stream))
|
|
}
|
|
}
|