fix(providers): add registry routing, OpenAI Responses support and Gemini streaming fallbacks; compile fixes
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled

This commit is contained in:
2026-03-04 14:17:30 -05:00
parent 5a8510bf1e
commit 5b6583301d
2 changed files with 74 additions and 55 deletions

View File

@@ -14,7 +14,7 @@ use crate::{
// ========== Gemini Request Structs ========== // ========== Gemini Request Structs ==========
#[derive(Debug, Serialize)] #[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
struct GeminiRequest { struct GeminiRequest {
contents: Vec<GeminiContent>, contents: Vec<GeminiContent>,
@@ -26,13 +26,13 @@ struct GeminiRequest {
tool_config: Option<GeminiToolConfig>, tool_config: Option<GeminiToolConfig>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
struct GeminiContent { struct GeminiContent {
parts: Vec<GeminiPart>, parts: Vec<GeminiPart>,
role: String, role: String,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
struct GeminiPart { struct GeminiPart {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
@@ -45,25 +45,26 @@ struct GeminiPart {
function_response: Option<GeminiFunctionResponse>, function_response: Option<GeminiFunctionResponse>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
struct GeminiInlineData { struct GeminiInlineData {
mime_type: String, mime_type: String,
data: String, data: String,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
struct GeminiFunctionCall { struct GeminiFunctionCall {
name: String, name: String,
args: Value, args: Value,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
struct GeminiFunctionResponse { struct GeminiFunctionResponse {
name: String, name: String,
response: Value, response: Value,
} }
#[derive(Debug, Serialize)]
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
struct GeminiGenerationConfig { struct GeminiGenerationConfig {
temperature: Option<f64>, temperature: Option<f64>,
@@ -72,13 +73,13 @@ struct GeminiGenerationConfig {
// ========== Gemini Tool Structs ========== // ========== Gemini Tool Structs ==========
#[derive(Debug, Serialize)] #[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
struct GeminiTool { struct GeminiTool {
function_declarations: Vec<GeminiFunctionDeclaration>, function_declarations: Vec<GeminiFunctionDeclaration>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Clone, Serialize)]
struct GeminiFunctionDeclaration { struct GeminiFunctionDeclaration {
name: String, name: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
@@ -87,13 +88,13 @@ struct GeminiFunctionDeclaration {
parameters: Option<Value>, parameters: Option<Value>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
struct GeminiToolConfig { struct GeminiToolConfig {
function_calling_config: GeminiFunctionCallingConfig, function_calling_config: GeminiFunctionCallingConfig,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Clone, Serialize)]
struct GeminiFunctionCallingConfig { struct GeminiFunctionCallingConfig {
mode: String, mode: String,
#[serde(skip_serializing_if = "Option::is_none", rename = "allowedFunctionNames")] #[serde(skip_serializing_if = "Option::is_none", rename = "allowedFunctionNames")]
@@ -405,7 +406,7 @@ impl super::Provider for GeminiProvider {
let model = request.model.clone(); let model = request.model.clone();
let tools = Self::convert_tools(&request); let tools = Self::convert_tools(&request);
let tool_config = Self::convert_tool_config(&request); let tool_config = Self::convert_tool_config(&request);
let contents = Self::convert_messages(request.messages).await?; let contents = Self::convert_messages(request.messages.clone()).await?;
if contents.is_empty() { if contents.is_empty() {
return Err(AppError::ProviderError("No valid messages to send".to_string())); return Err(AppError::ProviderError("No valid messages to send".to_string()));
@@ -529,7 +530,7 @@ impl super::Provider for GeminiProvider {
let model = request.model.clone(); let model = request.model.clone();
let tools = Self::convert_tools(&request); let tools = Self::convert_tools(&request);
let tool_config = Self::convert_tool_config(&request); let tool_config = Self::convert_tool_config(&request);
let contents = Self::convert_messages(request.messages).await?; let contents = Self::convert_messages(request.messages.clone()).await?;
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() { let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
Some(GeminiGenerationConfig { Some(GeminiGenerationConfig {
@@ -552,13 +553,21 @@ impl super::Provider for GeminiProvider {
self.config.base_url, model, self.config.base_url, model,
); );
// (no fallback_request needed here)
use futures::StreamExt; use futures::StreamExt;
use reqwest_eventsource::{Event, EventSource}; use reqwest_eventsource::Event;
// Try to create an SSE event source for streaming. If creation fails // Try to create an SSE event source for streaming. If creation fails
// (provider doesn't support streaming for this model or returned a // (provider doesn't support streaming for this model or returned a
// non-2xx response), fall back to a synchronous generateContent call // non-2xx response), fall back to a synchronous generateContent call
// and emit a single chunk. // and emit a single chunk.
// Prepare clones for HTTP fallback usage inside non-streaming paths.
let http_client = self.client.clone();
let http_api_key = self.api_key.clone();
let http_base = self.config.base_url.clone();
let gemini_request_clone = gemini_request.clone();
let es_result = reqwest_eventsource::EventSource::new( let es_result = reqwest_eventsource::EventSource::new(
self.client self.client
.post(&url) .post(&url)
@@ -566,25 +575,61 @@ impl super::Provider for GeminiProvider {
.json(&gemini_request), .json(&gemini_request),
); );
if let Err(e) = es_result { if let Err(_e) = es_result {
// Fallback: call non-streaming path and convert to a single-stream chunk // Fallback: call non-streaming generateContent via HTTP and convert to a single-stream chunk
let resp = self.chat_completion(request.clone()).await.map_err(|e2| { let resp_http = http_client
AppError::ProviderError(format!("Failed to create EventSource: {} ; fallback error: {}", e, e2)) .post(format!("{}/models/{}:generateContent", http_base, model))
})?; .header("x-goog-api-key", &http_api_key)
.json(&gemini_request_clone)
.send()
.await
.map_err(|e2| AppError::ProviderError(format!("Failed to call generateContent fallback: {}", e2)))?;
if !resp_http.status().is_success() {
let status = resp_http.status();
let err = resp_http.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, err)));
}
let gemini_response: GeminiResponse = resp_http
.json()
.await
.map_err(|e2| AppError::ProviderError(format!("Failed to parse generateContent response: {}", e2)))?;
let candidate = gemini_response.candidates.first();
let content = candidate
.and_then(|c| c.content.parts.iter().find_map(|p| p.text.clone()))
.unwrap_or_default();
let prompt_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.prompt_token_count)
.unwrap_or(0);
let completion_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.candidates_token_count)
.unwrap_or(0);
let total_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.total_token_count)
.unwrap_or(0);
let single_stream = async_stream::try_stream! { let single_stream = async_stream::try_stream! {
let chunk = ProviderStreamChunk { let chunk = ProviderStreamChunk {
content: resp.content, content,
reasoning_content: resp.reasoning_content, reasoning_content: None,
finish_reason: Some("stop".to_string()), finish_reason: Some("stop".to_string()),
tool_calls: None, tool_calls: None,
model: resp.model.clone(), model: model.clone(),
usage: Some(super::StreamUsage { usage: Some(super::StreamUsage {
prompt_tokens: resp.prompt_tokens, prompt_tokens,
completion_tokens: resp.completion_tokens, completion_tokens,
total_tokens: resp.total_tokens, total_tokens,
cache_read_tokens: resp.cache_read_tokens, cache_read_tokens: gemini_response.usage_metadata.as_ref().map(|u| u.cached_content_token_count).unwrap_or(0),
cache_write_tokens: resp.cache_write_tokens, cache_write_tokens: 0,
}), }),
}; };
@@ -671,33 +716,7 @@ impl super::Provider for GeminiProvider {
} }
Ok(_) => continue, Ok(_) => continue,
Err(e) => { Err(e) => {
// On streaming errors, attempt a synchronous fallback once. Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
// This handles cases where the provider rejects the SSE
// request but supports a non-streaming generateContent call.
match self.chat_completion(request.clone()).await {
Ok(resp) => {
let chunk = ProviderStreamChunk {
content: resp.content,
reasoning_content: resp.reasoning_content,
finish_reason: Some("stop".to_string()),
tool_calls: resp.tool_calls.map(|d| d.into_iter().map(|tc| tc.into()).collect()),
model: resp.model.clone(),
usage: Some(super::StreamUsage {
prompt_tokens: resp.prompt_tokens,
completion_tokens: resp.completion_tokens,
total_tokens: resp.total_tokens,
cache_read_tokens: resp.cache_read_tokens,
cache_write_tokens: resp.cache_write_tokens,
}),
};
yield chunk;
break;
}
Err(err2) => {
Err(AppError::ProviderError(format!("Stream error: {} ; fallback error: {}", e, err2)))?;
}
}
} }
} }
} }

View File

@@ -309,7 +309,7 @@ impl super::Provider for OpenAIProvider {
.json(&body), .json(&body),
); );
if let Err(e) = es_result { if es_result.is_err() {
// Fallback to non-streaming request which itself may retry to // Fallback to non-streaming request which itself may retry to
// Responses API if necessary (handled in chat_completion). // Responses API if necessary (handled in chat_completion).
let resp = self.chat_completion(request.clone()).await?; let resp = self.chat_completion(request.clone()).await?;