use async_trait::async_trait; use anyhow::Result; use serde::{Deserialize, Serialize}; use futures::stream::BoxStream; use crate::{ models::UnifiedRequest, errors::AppError, config::AppConfig, }; use super::{ProviderResponse, ProviderStreamChunk}; #[derive(Debug, Serialize)] struct GeminiRequest { contents: Vec, generation_config: Option, } #[derive(Debug, Serialize, Deserialize)] struct GeminiContent { parts: Vec, role: String, } #[derive(Debug, Serialize, Deserialize)] struct GeminiPart { #[serde(skip_serializing_if = "Option::is_none")] text: Option, #[serde(skip_serializing_if = "Option::is_none")] inline_data: Option, } #[derive(Debug, Serialize, Deserialize)] struct GeminiInlineData { mime_type: String, data: String, } #[derive(Debug, Serialize)] struct GeminiGenerationConfig { temperature: Option, max_output_tokens: Option, } #[derive(Debug, Deserialize)] struct GeminiCandidate { content: GeminiContent, _finish_reason: Option, } #[derive(Debug, Deserialize)] struct GeminiUsageMetadata { prompt_token_count: u32, candidates_token_count: u32, total_token_count: u32, } #[derive(Debug, Deserialize)] struct GeminiResponse { candidates: Vec, usage_metadata: Option, } pub struct GeminiProvider { client: reqwest::Client, config: crate::config::GeminiConfig, api_key: String, pricing: Vec, } impl GeminiProvider { pub fn new(config: &crate::config::GeminiConfig, app_config: &AppConfig) -> Result { let api_key = app_config.get_api_key("gemini")?; Self::new_with_key(config, app_config, api_key) } pub fn new_with_key(config: &crate::config::GeminiConfig, app_config: &AppConfig, api_key: String) -> Result { let client = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(30)) .build()?; Ok(Self { client, config: config.clone(), api_key, pricing: app_config.pricing.gemini.clone(), }) } } #[async_trait] impl super::Provider for GeminiProvider { fn name(&self) -> &str { "gemini" } fn supports_model(&self, model: &str) -> bool { model.starts_with("gemini-") } fn supports_multimodal(&self) -> bool { true // Gemini supports vision } async fn chat_completion( &self, request: UnifiedRequest, ) -> Result { // Convert UnifiedRequest to Gemini request let mut contents = Vec::with_capacity(request.messages.len()); for msg in request.messages { let mut parts = Vec::with_capacity(msg.content.len()); for part in msg.content { match part { crate::models::ContentPart::Text { text } => { parts.push(GeminiPart { text: Some(text), inline_data: None, }); } crate::models::ContentPart::Image(image_input) => { let (base64_data, mime_type) = image_input.to_base64().await .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; parts.push(GeminiPart { text: None, inline_data: Some(GeminiInlineData { mime_type, data: base64_data, }), }); } } } // Map role: "user" -> "user", "assistant" -> "model", "system" -> "user" let role = match msg.role.as_str() { "assistant" => "model".to_string(), _ => "user".to_string(), }; contents.push(GeminiContent { parts, role, }); } if contents.is_empty() { return Err(AppError::ProviderError("No valid text messages to send".to_string())); } // Build generation config let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() { Some(GeminiGenerationConfig { temperature: request.temperature, max_output_tokens: request.max_tokens, }) } else { None }; let gemini_request = GeminiRequest { contents, generation_config, }; // Build URL let url = format!("{}/models/{}:generateContent?key={}", self.config.base_url, request.model, self.api_key ); // Send request let response = self.client .post(&url) .json(&gemini_request) .send() .await .map_err(|e| AppError::ProviderError(format!("HTTP request failed: {}", e)))?; // Check status let status = response.status(); if !status.is_success() { let error_text = response.text().await.unwrap_or_default(); return Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, error_text))); } let gemini_response: GeminiResponse = response .json() .await .map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?; // Extract content from first candidate let content = gemini_response.candidates .first() .and_then(|c| c.content.parts.first()) .and_then(|p| p.text.clone()) .unwrap_or_default(); // Extract token usage let prompt_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.prompt_token_count).unwrap_or(0); let completion_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.candidates_token_count).unwrap_or(0); let total_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.total_token_count).unwrap_or(0); Ok(ProviderResponse { content, reasoning_content: None, // Gemini doesn't use this field name prompt_tokens, completion_tokens, total_tokens, model: request.model, }) } fn estimate_tokens(&self, request: &UnifiedRequest) -> Result { Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) } fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 { if let Some(metadata) = registry.find_model(model) { if let Some(cost) = &metadata.cost { return (prompt_tokens as f64 * cost.input / 1_000_000.0) + (completion_tokens as f64 * cost.output / 1_000_000.0); } } let (prompt_rate, completion_rate) = self.pricing.iter() .find(|p| model.contains(&p.model)) .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) .unwrap_or((0.075, 0.30)); // Default to Gemini 2.0 Flash price if not found (prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) } async fn chat_completion_stream( &self, request: UnifiedRequest, ) -> Result>, AppError> { // Convert UnifiedRequest to Gemini request let mut contents = Vec::with_capacity(request.messages.len()); for msg in request.messages { let mut parts = Vec::with_capacity(msg.content.len()); for part in msg.content { match part { crate::models::ContentPart::Text { text } => { parts.push(GeminiPart { text: Some(text), inline_data: None, }); } crate::models::ContentPart::Image(image_input) => { let (base64_data, mime_type) = image_input.to_base64().await .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; parts.push(GeminiPart { text: None, inline_data: Some(GeminiInlineData { mime_type, data: base64_data, }), }); } } } // Map role let role = match msg.role.as_str() { "assistant" => "model".to_string(), _ => "user".to_string(), }; contents.push(GeminiContent { parts, role, }); } // Build generation config let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() { Some(GeminiGenerationConfig { temperature: request.temperature, max_output_tokens: request.max_tokens, }) } else { None }; let gemini_request = GeminiRequest { contents, generation_config, }; // Build URL for streaming let url = format!("{}/models/{}:streamGenerateContent?alt=sse&key={}", self.config.base_url, request.model, self.api_key ); // Create eventsource stream use reqwest_eventsource::{EventSource, Event}; use futures::StreamExt; let es = EventSource::new(self.client.post(&url).json(&gemini_request)) .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; let model = request.model.clone(); let stream = async_stream::try_stream! { let mut es = es; while let Some(event) = es.next().await { match event { Ok(Event::Message(msg)) => { let gemini_response: GeminiResponse = serde_json::from_str(&msg.data) .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; if let Some(candidate) = gemini_response.candidates.first() { let content = candidate.content.parts.first() .and_then(|p| p.text.clone()) .unwrap_or_default(); yield ProviderStreamChunk { content, reasoning_content: None, finish_reason: None, // Will be set in the last chunk model: model.clone(), }; } } Ok(_) => continue, Err(e) => { Err(AppError::ProviderError(format!("Stream error: {}", e)))?; } } } }; Ok(Box::pin(stream)) } }