use anyhow::Result; use async_trait::async_trait; use futures::stream::BoxStream; use serde::{Deserialize, Serialize}; use serde_json::Value; use uuid::Uuid; use super::{ProviderResponse, ProviderStreamChunk}; use crate::{ config::AppConfig, errors::AppError, models::{ContentPart, FunctionCall, FunctionCallDelta, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest}, }; // ========== Gemini Request Structs ========== #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] struct GeminiRequest { contents: Vec, #[serde(skip_serializing_if = "Option::is_none")] generation_config: Option, #[serde(skip_serializing_if = "Option::is_none")] tools: Option>, #[serde(skip_serializing_if = "Option::is_none")] tool_config: Option, } #[derive(Debug, Serialize, Deserialize)] struct GeminiContent { parts: Vec, role: String, } #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] struct GeminiPart { #[serde(skip_serializing_if = "Option::is_none")] text: Option, #[serde(skip_serializing_if = "Option::is_none")] inline_data: Option, #[serde(skip_serializing_if = "Option::is_none")] function_call: Option, #[serde(skip_serializing_if = "Option::is_none")] function_response: Option, } #[derive(Debug, Serialize, Deserialize)] struct GeminiInlineData { mime_type: String, data: String, } #[derive(Debug, Serialize, Deserialize)] struct GeminiFunctionCall { name: String, args: Value, } #[derive(Debug, Serialize, Deserialize)] struct GeminiFunctionResponse { name: String, response: Value, } #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] struct GeminiGenerationConfig { temperature: Option, max_output_tokens: Option, } // ========== Gemini Tool Structs ========== #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] struct GeminiTool { function_declarations: Vec, } #[derive(Debug, Serialize)] struct GeminiFunctionDeclaration { name: String, #[serde(skip_serializing_if = "Option::is_none")] description: Option, #[serde(skip_serializing_if = "Option::is_none")] parameters: Option, } #[derive(Debug, Serialize)] #[serde(rename_all = "camelCase")] struct GeminiToolConfig { function_calling_config: GeminiFunctionCallingConfig, } #[derive(Debug, Serialize)] struct GeminiFunctionCallingConfig { mode: String, #[serde(skip_serializing_if = "Option::is_none", rename = "allowedFunctionNames")] allowed_function_names: Option>, } // ========== Gemini Response Structs ========== #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] struct GeminiCandidate { content: GeminiContent, #[serde(default)] finish_reason: Option, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] struct GeminiUsageMetadata { #[serde(default)] prompt_token_count: u32, #[serde(default)] candidates_token_count: u32, #[serde(default)] total_token_count: u32, } #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] struct GeminiResponse { candidates: Vec, usage_metadata: Option, } // ========== Provider Implementation ========== pub struct GeminiProvider { client: reqwest::Client, config: crate::config::GeminiConfig, api_key: String, pricing: Vec, } impl GeminiProvider { pub fn new(config: &crate::config::GeminiConfig, app_config: &AppConfig) -> Result { let api_key = app_config.get_api_key("gemini")?; Self::new_with_key(config, app_config, api_key) } pub fn new_with_key(config: &crate::config::GeminiConfig, app_config: &AppConfig, api_key: String) -> Result { let client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(30)) .build()?; Ok(Self { client, config: config.clone(), api_key, pricing: app_config.pricing.gemini.clone(), }) } /// Convert unified messages to Gemini content format. /// Handles text, images, tool calls (assistant), and tool results. async fn convert_messages(messages: Vec) -> Result, AppError> { let mut contents = Vec::with_capacity(messages.len()); for msg in messages { // Tool-result messages → functionResponse parts under role "user" if msg.role == "tool" { let text_content = msg .content .first() .map(|p| match p { ContentPart::Text { text } => text.clone(), ContentPart::Image(_) => "[Image]".to_string(), }) .unwrap_or_default(); let name = msg.name.unwrap_or_default(); // Parse the content as JSON if possible, otherwise wrap as string let response_value = serde_json::from_str::(&text_content) .unwrap_or_else(|_| serde_json::json!({ "result": text_content })); contents.push(GeminiContent { parts: vec![GeminiPart { text: None, inline_data: None, function_call: None, function_response: Some(GeminiFunctionResponse { name, response: response_value, }), }], role: "user".to_string(), }); continue; } // Assistant messages with tool_calls → functionCall parts if msg.role == "assistant" { if let Some(tool_calls) = &msg.tool_calls { let mut parts = Vec::new(); // Include text content if present for p in &msg.content { if let ContentPart::Text { text } = p { if !text.is_empty() { parts.push(GeminiPart { text: Some(text.clone()), inline_data: None, function_call: None, function_response: None, }); } } } // Convert each tool call to a functionCall part for tc in tool_calls { let args = serde_json::from_str::(&tc.function.arguments) .unwrap_or_else(|_| serde_json::json!({})); parts.push(GeminiPart { text: None, inline_data: None, function_call: Some(GeminiFunctionCall { name: tc.function.name.clone(), args, }), function_response: None, }); } contents.push(GeminiContent { parts, role: "model".to_string(), }); continue; } } // Regular text/image messages let mut parts = Vec::with_capacity(msg.content.len()); for part in msg.content { match part { ContentPart::Text { text } => { parts.push(GeminiPart { text: Some(text), inline_data: None, function_call: None, function_response: None, }); } ContentPart::Image(image_input) => { let (base64_data, mime_type) = image_input .to_base64() .await .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; parts.push(GeminiPart { text: None, inline_data: Some(GeminiInlineData { mime_type, data: base64_data, }), function_call: None, function_response: None, }); } } } let role = match msg.role.as_str() { "assistant" => "model".to_string(), _ => "user".to_string(), }; contents.push(GeminiContent { parts, role }); } Ok(contents) } /// Convert OpenAI tools to Gemini function declarations. fn convert_tools(request: &UnifiedRequest) -> Option> { request.tools.as_ref().map(|tools| { let declarations: Vec = tools .iter() .map(|t| GeminiFunctionDeclaration { name: t.function.name.clone(), description: t.function.description.clone(), parameters: t.function.parameters.clone(), }) .collect(); vec![GeminiTool { function_declarations: declarations, }] }) } /// Convert OpenAI tool_choice to Gemini tool_config. fn convert_tool_config(request: &UnifiedRequest) -> Option { request.tool_choice.as_ref().map(|tc| { let (mode, allowed_names) = match tc { crate::models::ToolChoice::Mode(mode) => { let gemini_mode = match mode.as_str() { "auto" => "AUTO", "none" => "NONE", "required" => "ANY", _ => "AUTO", }; (gemini_mode.to_string(), None) } crate::models::ToolChoice::Specific(specific) => { ("ANY".to_string(), Some(vec![specific.function.name.clone()])) } }; GeminiToolConfig { function_calling_config: GeminiFunctionCallingConfig { mode, allowed_function_names: allowed_names, }, } }) } /// Extract tool calls from Gemini response parts into OpenAI-format ToolCalls. fn extract_tool_calls(parts: &[GeminiPart]) -> Option> { let calls: Vec = parts .iter() .filter_map(|p| p.function_call.as_ref()) .map(|fc| ToolCall { id: format!("call_{}", Uuid::new_v4().simple()), call_type: "function".to_string(), function: FunctionCall { name: fc.name.clone(), arguments: serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string()), }, }) .collect(); if calls.is_empty() { None } else { Some(calls) } } /// Extract tool call deltas from Gemini response parts for streaming. fn extract_tool_call_deltas(parts: &[GeminiPart]) -> Option> { let deltas: Vec = parts .iter() .filter_map(|p| p.function_call.as_ref()) .enumerate() .map(|(i, fc)| ToolCallDelta { index: i as u32, id: Some(format!("call_{}", Uuid::new_v4().simple())), call_type: Some("function".to_string()), function: Some(FunctionCallDelta { name: Some(fc.name.clone()), arguments: Some(serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string())), }), }) .collect(); if deltas.is_empty() { None } else { Some(deltas) } } } #[async_trait] impl super::Provider for GeminiProvider { fn name(&self) -> &str { "gemini" } fn supports_model(&self, model: &str) -> bool { model.starts_with("gemini-") } fn supports_multimodal(&self) -> bool { true // Gemini supports vision } async fn chat_completion(&self, request: UnifiedRequest) -> Result { let model = request.model.clone(); let tools = Self::convert_tools(&request); let tool_config = Self::convert_tool_config(&request); let contents = Self::convert_messages(request.messages).await?; if contents.is_empty() { return Err(AppError::ProviderError("No valid messages to send".to_string())); } let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() { Some(GeminiGenerationConfig { temperature: request.temperature, max_output_tokens: request.max_tokens, }) } else { None }; let gemini_request = GeminiRequest { contents, generation_config, tools, tool_config, }; let url = format!("{}/models/{}:generateContent", self.config.base_url, model); let response = self .client .post(&url) .header("x-goog-api-key", &self.api_key) .json(&gemini_request) .send() .await .map_err(|e| AppError::ProviderError(format!("HTTP request failed: {}", e)))?; let status = response.status(); if !status.is_success() { let error_text = response.text().await.unwrap_or_default(); return Err(AppError::ProviderError(format!( "Gemini API error ({}): {}", status, error_text ))); } let gemini_response: GeminiResponse = response .json() .await .map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?; let candidate = gemini_response.candidates.first(); // Extract text content (may be absent if only function calls) let content = candidate .and_then(|c| c.content.parts.iter().find_map(|p| p.text.clone())) .unwrap_or_default(); // Extract function calls → OpenAI tool_calls let tool_calls = candidate.and_then(|c| Self::extract_tool_calls(&c.content.parts)); let prompt_tokens = gemini_response .usage_metadata .as_ref() .map(|u| u.prompt_token_count) .unwrap_or(0); let completion_tokens = gemini_response .usage_metadata .as_ref() .map(|u| u.candidates_token_count) .unwrap_or(0); let total_tokens = gemini_response .usage_metadata .as_ref() .map(|u| u.total_token_count) .unwrap_or(0); Ok(ProviderResponse { content, reasoning_content: None, tool_calls, prompt_tokens, completion_tokens, total_tokens, model, }) } fn estimate_tokens(&self, request: &UnifiedRequest) -> Result { Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) } fn calculate_cost( &self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry, ) -> f64 { super::helpers::calculate_cost_with_registry( model, prompt_tokens, completion_tokens, registry, &self.pricing, 0.075, 0.30, ) } async fn chat_completion_stream( &self, request: UnifiedRequest, ) -> Result>, AppError> { let model = request.model.clone(); let tools = Self::convert_tools(&request); let tool_config = Self::convert_tool_config(&request); let contents = Self::convert_messages(request.messages).await?; let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() { Some(GeminiGenerationConfig { temperature: request.temperature, max_output_tokens: request.max_tokens, }) } else { None }; let gemini_request = GeminiRequest { contents, generation_config, tools, tool_config, }; let url = format!( "{}/models/{}:streamGenerateContent?alt=sse", self.config.base_url, model, ); use futures::StreamExt; use reqwest_eventsource::{Event, EventSource}; let es = EventSource::new( self.client .post(&url) .header("x-goog-api-key", &self.api_key) .json(&gemini_request), ) .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; let stream = async_stream::try_stream! { let mut es = es; while let Some(event) = es.next().await { match event { Ok(Event::Message(msg)) => { let gemini_response: GeminiResponse = serde_json::from_str(&msg.data) .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; if let Some(candidate) = gemini_response.candidates.first() { let content = candidate .content .parts .iter() .find_map(|p| p.text.clone()) .unwrap_or_default(); let tool_calls = Self::extract_tool_call_deltas(&candidate.content.parts); // Determine finish_reason let finish_reason = candidate.finish_reason.as_ref().map(|fr| { match fr.as_str() { "STOP" => "stop".to_string(), _ => fr.to_lowercase(), } }); yield ProviderStreamChunk { content, reasoning_content: None, finish_reason, tool_calls, model: model.clone(), }; } } Ok(_) => continue, Err(e) => { Err(AppError::ProviderError(format!("Stream error: {}", e)))?; } } } }; Ok(Box::pin(stream)) } }