feat(gemini): implement stream error probing for better diagnostics

This commit is contained in:
2026-03-05 15:40:32 +00:00
parent 6010ec97a8
commit 6b7e245827

View File

@@ -1,6 +1,7 @@
use anyhow::Result; use anyhow::Result;
use async_trait::async_trait; use async_trait::async_trait;
use futures::stream::BoxStream; use futures::stream::{BoxStream, StreamExt};
use reqwest_eventsource::Event;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use uuid::Uuid; use uuid::Uuid;
@@ -662,93 +663,19 @@ impl super::Provider for GeminiProvider {
); );
tracing::debug!("Calling Gemini Stream API: {}", url); tracing::debug!("Calling Gemini Stream API: {}", url);
// (no fallback_request needed here) // Capture a clone of the request to probe for errors (Gemini 400s are common)
let probe_request = gemini_request.clone();
let probe_client = self.client.clone();
let probe_url = url.clone();
let probe_api_key = self.api_key.clone();
use futures::StreamExt; // Create the EventSource first (it doesn't send until polled)
use reqwest_eventsource::Event; let es = reqwest_eventsource::EventSource::new(
// Try to create an SSE event source for streaming. If creation fails
// (provider doesn't support streaming for this model or returned a
// non-2xx response), fall back to a synchronous generateContent call
// and emit a single chunk.
// Prepare clones for HTTP fallback usage inside non-streaming paths.
let http_client = self.client.clone();
let http_api_key = self.api_key.clone();
let http_base = base_url.clone();
let gemini_request_clone = gemini_request.clone();
let es_result = reqwest_eventsource::EventSource::new(
self.client self.client
.post(&url) .post(&url)
.header("x-goog-api-key", &self.api_key) .header("x-goog-api-key", &self.api_key)
.json(&gemini_request), .json(&gemini_request),
); ).map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
if let Err(_e) = es_result {
// Fallback: call non-streaming generateContent via HTTP and convert to a single-stream chunk
let resp_http = http_client
.post(format!("{}/models/{}:generateContent", http_base, model))
.header("x-goog-api-key", &http_api_key)
.json(&gemini_request_clone)
.send()
.await
.map_err(|e2| AppError::ProviderError(format!("Failed to call generateContent fallback: {}", e2)))?;
if !resp_http.status().is_success() {
let status = resp_http.status();
let err = resp_http.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, err)));
}
let gemini_response: GeminiResponse = resp_http
.json()
.await
.map_err(|e2| AppError::ProviderError(format!("Failed to parse generateContent response: {}", e2)))?;
let candidate = gemini_response.candidates.first();
let content = candidate
.and_then(|c| c.content.parts.iter().find_map(|p| p.text.clone()))
.unwrap_or_default();
let prompt_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.prompt_token_count)
.unwrap_or(0);
let completion_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.candidates_token_count)
.unwrap_or(0);
let total_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.total_token_count)
.unwrap_or(0);
let single_stream = async_stream::try_stream! {
let chunk = ProviderStreamChunk {
content,
reasoning_content: None,
finish_reason: Some("stop".to_string()),
tool_calls: None,
model: model.clone(),
usage: Some(super::StreamUsage {
prompt_tokens,
completion_tokens,
total_tokens,
cache_read_tokens: gemini_response.usage_metadata.as_ref().map(|u| u.cached_content_token_count).unwrap_or(0),
cache_write_tokens: 0,
}),
};
yield chunk;
};
return Ok(Box::pin(single_stream));
}
let es = es_result.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let stream = async_stream::try_stream! { let stream = async_stream::try_stream! {
let mut es = es; let mut es = es;
@@ -758,6 +685,8 @@ impl super::Provider for GeminiProvider {
let gemini_response: GeminiStreamResponse = serde_json::from_str(&msg.data) let gemini_response: GeminiStreamResponse = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
// (rest of processing remains identical)
// Extract usage from usageMetadata if present (reported on every/last chunk) // Extract usage from usageMetadata if present (reported on every/last chunk)
let stream_usage = gemini_response.usage_metadata.as_ref().map(|u| { let stream_usage = gemini_response.usage_metadata.as_ref().map(|u| {
super::StreamUsage { super::StreamUsage {
@@ -825,10 +754,28 @@ impl super::Provider for GeminiProvider {
} }
Ok(_) => continue, Ok(_) => continue,
Err(e) => { Err(e) => {
// On stream error, attempt to probe for the actual error body from the provider
let probe_resp = probe_client
.post(&probe_url)
.header("x-goog-api-key", &probe_api_key)
.json(&probe_request)
.send()
.await;
match probe_resp {
Ok(r) if !r.status().is_success() => {
let status = r.status();
let body = r.text().await.unwrap_or_default();
tracing::error!("Gemini Stream Error Probe ({}): {}", status, body);
Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, body)))?;
}
_ => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?; Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
} }
} }
} }
}
}
}; };
Ok(Box::pin(stream)) Ok(Box::pin(stream))