diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index a7fcb4d2..8dd286a9 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -14,7 +14,7 @@ use crate::{ // ========== Gemini Request Structs ========== -#[derive(Debug, Serialize)] +#[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] struct GeminiRequest { contents: Vec, @@ -26,13 +26,13 @@ struct GeminiRequest { tool_config: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] struct GeminiContent { parts: Vec, role: String, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] struct GeminiPart { #[serde(skip_serializing_if = "Option::is_none")] @@ -45,25 +45,26 @@ struct GeminiPart { function_response: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] struct GeminiInlineData { mime_type: String, data: String, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] struct GeminiFunctionCall { name: String, args: Value, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] struct GeminiFunctionResponse { name: String, response: Value, } -#[derive(Debug, Serialize)] + +#[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] struct GeminiGenerationConfig { temperature: Option, @@ -72,13 +73,13 @@ struct GeminiGenerationConfig { // ========== Gemini Tool Structs ========== -#[derive(Debug, Serialize)] +#[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] struct GeminiTool { function_declarations: Vec, } -#[derive(Debug, Serialize)] +#[derive(Debug, Clone, Serialize)] struct GeminiFunctionDeclaration { name: String, #[serde(skip_serializing_if = "Option::is_none")] @@ -87,13 +88,13 @@ struct GeminiFunctionDeclaration { parameters: Option, } -#[derive(Debug, Serialize)] +#[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] struct GeminiToolConfig { function_calling_config: GeminiFunctionCallingConfig, } -#[derive(Debug, Serialize)] +#[derive(Debug, Clone, Serialize)] struct GeminiFunctionCallingConfig { mode: String, #[serde(skip_serializing_if = "Option::is_none", rename = "allowedFunctionNames")] @@ -405,7 +406,7 @@ impl super::Provider for GeminiProvider { let model = request.model.clone(); let tools = Self::convert_tools(&request); let tool_config = Self::convert_tool_config(&request); - let contents = Self::convert_messages(request.messages).await?; + let contents = Self::convert_messages(request.messages.clone()).await?; if contents.is_empty() { return Err(AppError::ProviderError("No valid messages to send".to_string())); @@ -529,7 +530,7 @@ impl super::Provider for GeminiProvider { let model = request.model.clone(); let tools = Self::convert_tools(&request); let tool_config = Self::convert_tool_config(&request); - let contents = Self::convert_messages(request.messages).await?; + let contents = Self::convert_messages(request.messages.clone()).await?; let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() { Some(GeminiGenerationConfig { @@ -552,13 +553,21 @@ impl super::Provider for GeminiProvider { self.config.base_url, model, ); + // (no fallback_request needed here) + use futures::StreamExt; - use reqwest_eventsource::{Event, EventSource}; + use reqwest_eventsource::Event; // Try to create an SSE event source for streaming. If creation fails // (provider doesn't support streaming for this model or returned a // non-2xx response), fall back to a synchronous generateContent call // and emit a single chunk. + // Prepare clones for HTTP fallback usage inside non-streaming paths. + let http_client = self.client.clone(); + let http_api_key = self.api_key.clone(); + let http_base = self.config.base_url.clone(); + let gemini_request_clone = gemini_request.clone(); + let es_result = reqwest_eventsource::EventSource::new( self.client .post(&url) @@ -566,25 +575,61 @@ impl super::Provider for GeminiProvider { .json(&gemini_request), ); - if let Err(e) = es_result { - // Fallback: call non-streaming path and convert to a single-stream chunk - let resp = self.chat_completion(request.clone()).await.map_err(|e2| { - AppError::ProviderError(format!("Failed to create EventSource: {} ; fallback error: {}", e, e2)) - })?; + if let Err(_e) = es_result { + // Fallback: call non-streaming generateContent via HTTP and convert to a single-stream chunk + let resp_http = http_client + .post(format!("{}/models/{}:generateContent", http_base, model)) + .header("x-goog-api-key", &http_api_key) + .json(&gemini_request_clone) + .send() + .await + .map_err(|e2| AppError::ProviderError(format!("Failed to call generateContent fallback: {}", e2)))?; + + if !resp_http.status().is_success() { + let status = resp_http.status(); + let err = resp_http.text().await.unwrap_or_default(); + return Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, err))); + } + + let gemini_response: GeminiResponse = resp_http + .json() + .await + .map_err(|e2| AppError::ProviderError(format!("Failed to parse generateContent response: {}", e2)))?; + + let candidate = gemini_response.candidates.first(); + let content = candidate + .and_then(|c| c.content.parts.iter().find_map(|p| p.text.clone())) + .unwrap_or_default(); + + let prompt_tokens = gemini_response + .usage_metadata + .as_ref() + .map(|u| u.prompt_token_count) + .unwrap_or(0); + let completion_tokens = gemini_response + .usage_metadata + .as_ref() + .map(|u| u.candidates_token_count) + .unwrap_or(0); + let total_tokens = gemini_response + .usage_metadata + .as_ref() + .map(|u| u.total_token_count) + .unwrap_or(0); let single_stream = async_stream::try_stream! { let chunk = ProviderStreamChunk { - content: resp.content, - reasoning_content: resp.reasoning_content, + content, + reasoning_content: None, finish_reason: Some("stop".to_string()), tool_calls: None, - model: resp.model.clone(), + model: model.clone(), usage: Some(super::StreamUsage { - prompt_tokens: resp.prompt_tokens, - completion_tokens: resp.completion_tokens, - total_tokens: resp.total_tokens, - cache_read_tokens: resp.cache_read_tokens, - cache_write_tokens: resp.cache_write_tokens, + prompt_tokens, + completion_tokens, + total_tokens, + cache_read_tokens: gemini_response.usage_metadata.as_ref().map(|u| u.cached_content_token_count).unwrap_or(0), + cache_write_tokens: 0, }), }; @@ -671,33 +716,7 @@ impl super::Provider for GeminiProvider { } Ok(_) => continue, Err(e) => { - // On streaming errors, attempt a synchronous fallback once. - // This handles cases where the provider rejects the SSE - // request but supports a non-streaming generateContent call. - match self.chat_completion(request.clone()).await { - Ok(resp) => { - let chunk = ProviderStreamChunk { - content: resp.content, - reasoning_content: resp.reasoning_content, - finish_reason: Some("stop".to_string()), - tool_calls: resp.tool_calls.map(|d| d.into_iter().map(|tc| tc.into()).collect()), - model: resp.model.clone(), - usage: Some(super::StreamUsage { - prompt_tokens: resp.prompt_tokens, - completion_tokens: resp.completion_tokens, - total_tokens: resp.total_tokens, - cache_read_tokens: resp.cache_read_tokens, - cache_write_tokens: resp.cache_write_tokens, - }), - }; - - yield chunk; - break; - } - Err(err2) => { - Err(AppError::ProviderError(format!("Stream error: {} ; fallback error: {}", e, err2)))?; - } - } + Err(AppError::ProviderError(format!("Stream error: {}", e)))?; } } } diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 19c34f24..19297888 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -309,7 +309,7 @@ impl super::Provider for OpenAIProvider { .json(&body), ); - if let Err(e) = es_result { + if es_result.is_err() { // Fallback to non-streaming request which itself may retry to // Responses API if necessary (handled in chat_completion). let resp = self.chat_completion(request.clone()).await?;