Compare commits
2 Commits
6010ec97a8
...
fb98f0ebb8
| Author | SHA1 | Date | |
|---|---|---|---|
| fb98f0ebb8 | |||
| 6b7e245827 |
@@ -1,6 +1,7 @@
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::stream::BoxStream;
|
use futures::stream::{BoxStream, StreamExt};
|
||||||
|
use reqwest_eventsource::Event;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
@@ -215,7 +216,7 @@ impl GeminiProvider {
|
|||||||
|
|
||||||
let role = match msg.role.as_str() {
|
let role = match msg.role.as_str() {
|
||||||
"assistant" => "model".to_string(),
|
"assistant" => "model".to_string(),
|
||||||
"tool" => "user".to_string(), // Tool results are technically from the user side in Gemini
|
"tool" => "user".to_string(), // Tool results are user-side in Gemini
|
||||||
_ => "user".to_string(),
|
_ => "user".to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -232,7 +233,6 @@ impl GeminiProvider {
|
|||||||
})
|
})
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
// Gemini function response MUST have a name. Fallback to tool_call_id if name is missing.
|
|
||||||
let name = msg.name.clone().or_else(|| msg.tool_call_id.clone()).unwrap_or_else(|| "unknown_function".to_string());
|
let name = msg.name.clone().or_else(|| msg.tool_call_id.clone()).unwrap_or_else(|| "unknown_function".to_string());
|
||||||
let response_value = serde_json::from_str::<Value>(&text_content)
|
let response_value = serde_json::from_str::<Value>(&text_content)
|
||||||
.unwrap_or_else(|_| serde_json::json!({ "result": text_content }));
|
.unwrap_or_else(|_| serde_json::json!({ "result": text_content }));
|
||||||
@@ -249,7 +249,6 @@ impl GeminiProvider {
|
|||||||
} else if msg.role == "assistant" && msg.tool_calls.is_some() {
|
} else if msg.role == "assistant" && msg.tool_calls.is_some() {
|
||||||
// Assistant messages with tool_calls
|
// Assistant messages with tool_calls
|
||||||
if let Some(tool_calls) = &msg.tool_calls {
|
if let Some(tool_calls) = &msg.tool_calls {
|
||||||
// Include text content if present
|
|
||||||
for p in &msg.content {
|
for p in &msg.content {
|
||||||
if let ContentPart::Text { text } = p {
|
if let ContentPart::Text { text } = p {
|
||||||
if !text.trim().is_empty() {
|
if !text.trim().is_empty() {
|
||||||
@@ -315,7 +314,8 @@ impl GeminiProvider {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Merge with previous message if role matches
|
// STRATEGY: Strictly enforce alternating roles.
|
||||||
|
// If current message has the same role as the last one, merge their parts.
|
||||||
if let Some(last_content) = contents.last_mut() {
|
if let Some(last_content) = contents.last_mut() {
|
||||||
if last_content.role.as_ref() == Some(&role) {
|
if last_content.role.as_ref() == Some(&role) {
|
||||||
last_content.parts.extend(parts);
|
last_content.parts.extend(parts);
|
||||||
@@ -330,7 +330,6 @@ impl GeminiProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Gemini requires the first message to be from "user".
|
// Gemini requires the first message to be from "user".
|
||||||
// If it starts with "model", we prepend a placeholder user message.
|
|
||||||
if let Some(first) = contents.first() {
|
if let Some(first) = contents.first() {
|
||||||
if first.role.as_deref() == Some("model") {
|
if first.role.as_deref() == Some("model") {
|
||||||
contents.insert(0, GeminiContent {
|
contents.insert(0, GeminiContent {
|
||||||
@@ -345,6 +344,12 @@ impl GeminiProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Final check: ensure we don't have empty contents after filtering.
|
||||||
|
// If the last message was merged or filtered, we might have an empty array.
|
||||||
|
if contents.is_empty() && system_parts.is_empty() {
|
||||||
|
return Err(AppError::ProviderError("No valid content parts after filtering".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
let system_instruction = if !system_parts.is_empty() {
|
let system_instruction = if !system_parts.is_empty() {
|
||||||
Some(GeminiContent {
|
Some(GeminiContent {
|
||||||
parts: system_parts,
|
parts: system_parts,
|
||||||
@@ -662,93 +667,19 @@ impl super::Provider for GeminiProvider {
|
|||||||
);
|
);
|
||||||
tracing::debug!("Calling Gemini Stream API: {}", url);
|
tracing::debug!("Calling Gemini Stream API: {}", url);
|
||||||
|
|
||||||
// (no fallback_request needed here)
|
// Capture a clone of the request to probe for errors (Gemini 400s are common)
|
||||||
|
let probe_request = gemini_request.clone();
|
||||||
|
let probe_client = self.client.clone();
|
||||||
|
let probe_url = url.clone();
|
||||||
|
let probe_api_key = self.api_key.clone();
|
||||||
|
|
||||||
use futures::StreamExt;
|
// Create the EventSource first (it doesn't send until polled)
|
||||||
use reqwest_eventsource::Event;
|
let es = reqwest_eventsource::EventSource::new(
|
||||||
|
|
||||||
// Try to create an SSE event source for streaming. If creation fails
|
|
||||||
// (provider doesn't support streaming for this model or returned a
|
|
||||||
// non-2xx response), fall back to a synchronous generateContent call
|
|
||||||
// and emit a single chunk.
|
|
||||||
// Prepare clones for HTTP fallback usage inside non-streaming paths.
|
|
||||||
let http_client = self.client.clone();
|
|
||||||
let http_api_key = self.api_key.clone();
|
|
||||||
let http_base = base_url.clone();
|
|
||||||
let gemini_request_clone = gemini_request.clone();
|
|
||||||
|
|
||||||
let es_result = reqwest_eventsource::EventSource::new(
|
|
||||||
self.client
|
self.client
|
||||||
.post(&url)
|
.post(&url)
|
||||||
.header("x-goog-api-key", &self.api_key)
|
.header("x-goog-api-key", &self.api_key)
|
||||||
.json(&gemini_request),
|
.json(&gemini_request),
|
||||||
);
|
).map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||||
|
|
||||||
if let Err(_e) = es_result {
|
|
||||||
// Fallback: call non-streaming generateContent via HTTP and convert to a single-stream chunk
|
|
||||||
let resp_http = http_client
|
|
||||||
.post(format!("{}/models/{}:generateContent", http_base, model))
|
|
||||||
.header("x-goog-api-key", &http_api_key)
|
|
||||||
.json(&gemini_request_clone)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e2| AppError::ProviderError(format!("Failed to call generateContent fallback: {}", e2)))?;
|
|
||||||
|
|
||||||
if !resp_http.status().is_success() {
|
|
||||||
let status = resp_http.status();
|
|
||||||
let err = resp_http.text().await.unwrap_or_default();
|
|
||||||
return Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, err)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let gemini_response: GeminiResponse = resp_http
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(|e2| AppError::ProviderError(format!("Failed to parse generateContent response: {}", e2)))?;
|
|
||||||
|
|
||||||
let candidate = gemini_response.candidates.first();
|
|
||||||
let content = candidate
|
|
||||||
.and_then(|c| c.content.parts.iter().find_map(|p| p.text.clone()))
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
let prompt_tokens = gemini_response
|
|
||||||
.usage_metadata
|
|
||||||
.as_ref()
|
|
||||||
.map(|u| u.prompt_token_count)
|
|
||||||
.unwrap_or(0);
|
|
||||||
let completion_tokens = gemini_response
|
|
||||||
.usage_metadata
|
|
||||||
.as_ref()
|
|
||||||
.map(|u| u.candidates_token_count)
|
|
||||||
.unwrap_or(0);
|
|
||||||
let total_tokens = gemini_response
|
|
||||||
.usage_metadata
|
|
||||||
.as_ref()
|
|
||||||
.map(|u| u.total_token_count)
|
|
||||||
.unwrap_or(0);
|
|
||||||
|
|
||||||
let single_stream = async_stream::try_stream! {
|
|
||||||
let chunk = ProviderStreamChunk {
|
|
||||||
content,
|
|
||||||
reasoning_content: None,
|
|
||||||
finish_reason: Some("stop".to_string()),
|
|
||||||
tool_calls: None,
|
|
||||||
model: model.clone(),
|
|
||||||
usage: Some(super::StreamUsage {
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
total_tokens,
|
|
||||||
cache_read_tokens: gemini_response.usage_metadata.as_ref().map(|u| u.cached_content_token_count).unwrap_or(0),
|
|
||||||
cache_write_tokens: 0,
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
|
|
||||||
yield chunk;
|
|
||||||
};
|
|
||||||
|
|
||||||
return Ok(Box::pin(single_stream));
|
|
||||||
}
|
|
||||||
|
|
||||||
let es = es_result.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
|
||||||
|
|
||||||
let stream = async_stream::try_stream! {
|
let stream = async_stream::try_stream! {
|
||||||
let mut es = es;
|
let mut es = es;
|
||||||
@@ -757,7 +688,9 @@ impl super::Provider for GeminiProvider {
|
|||||||
Ok(Event::Message(msg)) => {
|
Ok(Event::Message(msg)) => {
|
||||||
let gemini_response: GeminiStreamResponse = serde_json::from_str(&msg.data)
|
let gemini_response: GeminiStreamResponse = serde_json::from_str(&msg.data)
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||||
|
|
||||||
|
// (rest of processing remains identical)
|
||||||
|
|
||||||
// Extract usage from usageMetadata if present (reported on every/last chunk)
|
// Extract usage from usageMetadata if present (reported on every/last chunk)
|
||||||
let stream_usage = gemini_response.usage_metadata.as_ref().map(|u| {
|
let stream_usage = gemini_response.usage_metadata.as_ref().map(|u| {
|
||||||
super::StreamUsage {
|
super::StreamUsage {
|
||||||
@@ -825,7 +758,25 @@ impl super::Provider for GeminiProvider {
|
|||||||
}
|
}
|
||||||
Ok(_) => continue,
|
Ok(_) => continue,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
// On stream error, attempt to probe for the actual error body from the provider
|
||||||
|
let probe_resp = probe_client
|
||||||
|
.post(&probe_url)
|
||||||
|
.header("x-goog-api-key", &probe_api_key)
|
||||||
|
.json(&probe_request)
|
||||||
|
.send()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match probe_resp {
|
||||||
|
Ok(r) if !r.status().is_success() => {
|
||||||
|
let status = r.status();
|
||||||
|
let body = r.text().await.unwrap_or_default();
|
||||||
|
tracing::error!("Gemini Stream Error Probe ({}): {}", status, body);
|
||||||
|
Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, body)))?;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user