fix(providers): add registry routing, OpenAI Responses support and Gemini streaming fallbacks; compile fixes
This commit is contained in:
@@ -14,7 +14,7 @@ use crate::{
|
|||||||
|
|
||||||
// ========== Gemini Request Structs ==========
|
// ========== Gemini Request Structs ==========
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
struct GeminiRequest {
|
struct GeminiRequest {
|
||||||
contents: Vec<GeminiContent>,
|
contents: Vec<GeminiContent>,
|
||||||
@@ -26,13 +26,13 @@ struct GeminiRequest {
|
|||||||
tool_config: Option<GeminiToolConfig>,
|
tool_config: Option<GeminiToolConfig>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
struct GeminiContent {
|
struct GeminiContent {
|
||||||
parts: Vec<GeminiPart>,
|
parts: Vec<GeminiPart>,
|
||||||
role: String,
|
role: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
struct GeminiPart {
|
struct GeminiPart {
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
@@ -45,25 +45,26 @@ struct GeminiPart {
|
|||||||
function_response: Option<GeminiFunctionResponse>,
|
function_response: Option<GeminiFunctionResponse>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
struct GeminiInlineData {
|
struct GeminiInlineData {
|
||||||
mime_type: String,
|
mime_type: String,
|
||||||
data: String,
|
data: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
struct GeminiFunctionCall {
|
struct GeminiFunctionCall {
|
||||||
name: String,
|
name: String,
|
||||||
args: Value,
|
args: Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
struct GeminiFunctionResponse {
|
struct GeminiFunctionResponse {
|
||||||
name: String,
|
name: String,
|
||||||
response: Value,
|
response: Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
struct GeminiGenerationConfig {
|
struct GeminiGenerationConfig {
|
||||||
temperature: Option<f64>,
|
temperature: Option<f64>,
|
||||||
@@ -72,13 +73,13 @@ struct GeminiGenerationConfig {
|
|||||||
|
|
||||||
// ========== Gemini Tool Structs ==========
|
// ========== Gemini Tool Structs ==========
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
struct GeminiTool {
|
struct GeminiTool {
|
||||||
function_declarations: Vec<GeminiFunctionDeclaration>,
|
function_declarations: Vec<GeminiFunctionDeclaration>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
struct GeminiFunctionDeclaration {
|
struct GeminiFunctionDeclaration {
|
||||||
name: String,
|
name: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
@@ -87,13 +88,13 @@ struct GeminiFunctionDeclaration {
|
|||||||
parameters: Option<Value>,
|
parameters: Option<Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
struct GeminiToolConfig {
|
struct GeminiToolConfig {
|
||||||
function_calling_config: GeminiFunctionCallingConfig,
|
function_calling_config: GeminiFunctionCallingConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
struct GeminiFunctionCallingConfig {
|
struct GeminiFunctionCallingConfig {
|
||||||
mode: String,
|
mode: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none", rename = "allowedFunctionNames")]
|
#[serde(skip_serializing_if = "Option::is_none", rename = "allowedFunctionNames")]
|
||||||
@@ -405,7 +406,7 @@ impl super::Provider for GeminiProvider {
|
|||||||
let model = request.model.clone();
|
let model = request.model.clone();
|
||||||
let tools = Self::convert_tools(&request);
|
let tools = Self::convert_tools(&request);
|
||||||
let tool_config = Self::convert_tool_config(&request);
|
let tool_config = Self::convert_tool_config(&request);
|
||||||
let contents = Self::convert_messages(request.messages).await?;
|
let contents = Self::convert_messages(request.messages.clone()).await?;
|
||||||
|
|
||||||
if contents.is_empty() {
|
if contents.is_empty() {
|
||||||
return Err(AppError::ProviderError("No valid messages to send".to_string()));
|
return Err(AppError::ProviderError("No valid messages to send".to_string()));
|
||||||
@@ -529,7 +530,7 @@ impl super::Provider for GeminiProvider {
|
|||||||
let model = request.model.clone();
|
let model = request.model.clone();
|
||||||
let tools = Self::convert_tools(&request);
|
let tools = Self::convert_tools(&request);
|
||||||
let tool_config = Self::convert_tool_config(&request);
|
let tool_config = Self::convert_tool_config(&request);
|
||||||
let contents = Self::convert_messages(request.messages).await?;
|
let contents = Self::convert_messages(request.messages.clone()).await?;
|
||||||
|
|
||||||
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
|
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
|
||||||
Some(GeminiGenerationConfig {
|
Some(GeminiGenerationConfig {
|
||||||
@@ -552,13 +553,21 @@ impl super::Provider for GeminiProvider {
|
|||||||
self.config.base_url, model,
|
self.config.base_url, model,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// (no fallback_request needed here)
|
||||||
|
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use reqwest_eventsource::{Event, EventSource};
|
use reqwest_eventsource::Event;
|
||||||
|
|
||||||
// Try to create an SSE event source for streaming. If creation fails
|
// Try to create an SSE event source for streaming. If creation fails
|
||||||
// (provider doesn't support streaming for this model or returned a
|
// (provider doesn't support streaming for this model or returned a
|
||||||
// non-2xx response), fall back to a synchronous generateContent call
|
// non-2xx response), fall back to a synchronous generateContent call
|
||||||
// and emit a single chunk.
|
// and emit a single chunk.
|
||||||
|
// Prepare clones for HTTP fallback usage inside non-streaming paths.
|
||||||
|
let http_client = self.client.clone();
|
||||||
|
let http_api_key = self.api_key.clone();
|
||||||
|
let http_base = self.config.base_url.clone();
|
||||||
|
let gemini_request_clone = gemini_request.clone();
|
||||||
|
|
||||||
let es_result = reqwest_eventsource::EventSource::new(
|
let es_result = reqwest_eventsource::EventSource::new(
|
||||||
self.client
|
self.client
|
||||||
.post(&url)
|
.post(&url)
|
||||||
@@ -566,25 +575,61 @@ impl super::Provider for GeminiProvider {
|
|||||||
.json(&gemini_request),
|
.json(&gemini_request),
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Err(e) = es_result {
|
if let Err(_e) = es_result {
|
||||||
// Fallback: call non-streaming path and convert to a single-stream chunk
|
// Fallback: call non-streaming generateContent via HTTP and convert to a single-stream chunk
|
||||||
let resp = self.chat_completion(request.clone()).await.map_err(|e2| {
|
let resp_http = http_client
|
||||||
AppError::ProviderError(format!("Failed to create EventSource: {} ; fallback error: {}", e, e2))
|
.post(format!("{}/models/{}:generateContent", http_base, model))
|
||||||
})?;
|
.header("x-goog-api-key", &http_api_key)
|
||||||
|
.json(&gemini_request_clone)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e2| AppError::ProviderError(format!("Failed to call generateContent fallback: {}", e2)))?;
|
||||||
|
|
||||||
|
if !resp_http.status().is_success() {
|
||||||
|
let status = resp_http.status();
|
||||||
|
let err = resp_http.text().await.unwrap_or_default();
|
||||||
|
return Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, err)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let gemini_response: GeminiResponse = resp_http
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|e2| AppError::ProviderError(format!("Failed to parse generateContent response: {}", e2)))?;
|
||||||
|
|
||||||
|
let candidate = gemini_response.candidates.first();
|
||||||
|
let content = candidate
|
||||||
|
.and_then(|c| c.content.parts.iter().find_map(|p| p.text.clone()))
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let prompt_tokens = gemini_response
|
||||||
|
.usage_metadata
|
||||||
|
.as_ref()
|
||||||
|
.map(|u| u.prompt_token_count)
|
||||||
|
.unwrap_or(0);
|
||||||
|
let completion_tokens = gemini_response
|
||||||
|
.usage_metadata
|
||||||
|
.as_ref()
|
||||||
|
.map(|u| u.candidates_token_count)
|
||||||
|
.unwrap_or(0);
|
||||||
|
let total_tokens = gemini_response
|
||||||
|
.usage_metadata
|
||||||
|
.as_ref()
|
||||||
|
.map(|u| u.total_token_count)
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
let single_stream = async_stream::try_stream! {
|
let single_stream = async_stream::try_stream! {
|
||||||
let chunk = ProviderStreamChunk {
|
let chunk = ProviderStreamChunk {
|
||||||
content: resp.content,
|
content,
|
||||||
reasoning_content: resp.reasoning_content,
|
reasoning_content: None,
|
||||||
finish_reason: Some("stop".to_string()),
|
finish_reason: Some("stop".to_string()),
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
model: resp.model.clone(),
|
model: model.clone(),
|
||||||
usage: Some(super::StreamUsage {
|
usage: Some(super::StreamUsage {
|
||||||
prompt_tokens: resp.prompt_tokens,
|
prompt_tokens,
|
||||||
completion_tokens: resp.completion_tokens,
|
completion_tokens,
|
||||||
total_tokens: resp.total_tokens,
|
total_tokens,
|
||||||
cache_read_tokens: resp.cache_read_tokens,
|
cache_read_tokens: gemini_response.usage_metadata.as_ref().map(|u| u.cached_content_token_count).unwrap_or(0),
|
||||||
cache_write_tokens: resp.cache_write_tokens,
|
cache_write_tokens: 0,
|
||||||
}),
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -671,33 +716,7 @@ impl super::Provider for GeminiProvider {
|
|||||||
}
|
}
|
||||||
Ok(_) => continue,
|
Ok(_) => continue,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// On streaming errors, attempt a synchronous fallback once.
|
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
||||||
// This handles cases where the provider rejects the SSE
|
|
||||||
// request but supports a non-streaming generateContent call.
|
|
||||||
match self.chat_completion(request.clone()).await {
|
|
||||||
Ok(resp) => {
|
|
||||||
let chunk = ProviderStreamChunk {
|
|
||||||
content: resp.content,
|
|
||||||
reasoning_content: resp.reasoning_content,
|
|
||||||
finish_reason: Some("stop".to_string()),
|
|
||||||
tool_calls: resp.tool_calls.map(|d| d.into_iter().map(|tc| tc.into()).collect()),
|
|
||||||
model: resp.model.clone(),
|
|
||||||
usage: Some(super::StreamUsage {
|
|
||||||
prompt_tokens: resp.prompt_tokens,
|
|
||||||
completion_tokens: resp.completion_tokens,
|
|
||||||
total_tokens: resp.total_tokens,
|
|
||||||
cache_read_tokens: resp.cache_read_tokens,
|
|
||||||
cache_write_tokens: resp.cache_write_tokens,
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
|
|
||||||
yield chunk;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
Err(err2) => {
|
|
||||||
Err(AppError::ProviderError(format!("Stream error: {} ; fallback error: {}", e, err2)))?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -309,7 +309,7 @@ impl super::Provider for OpenAIProvider {
|
|||||||
.json(&body),
|
.json(&body),
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Err(e) = es_result {
|
if es_result.is_err() {
|
||||||
// Fallback to non-streaming request which itself may retry to
|
// Fallback to non-streaming request which itself may retry to
|
||||||
// Responses API if necessary (handled in chat_completion).
|
// Responses API if necessary (handled in chat_completion).
|
||||||
let resp = self.chat_completion(request.clone()).await?;
|
let resp = self.chat_completion(request.clone()).await?;
|
||||||
|
|||||||
Reference in New Issue
Block a user