feat(providers): model-registry routing + Responses API support and streaming fallbacks for OpenAI/Gemini
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled

This commit is contained in:
2026-03-04 13:36:03 -05:00
parent 1453e64d4b
commit 5a8510bf1e
5 changed files with 328 additions and 8 deletions

View File

@@ -555,13 +555,46 @@ impl super::Provider for GeminiProvider {
use futures::StreamExt;
use reqwest_eventsource::{Event, EventSource};
let es = EventSource::new(
// Try to create an SSE event source for streaming. If creation fails
// (provider doesn't support streaming for this model or returned a
// non-2xx response), fall back to a synchronous generateContent call
// and emit a single chunk.
let es_result = reqwest_eventsource::EventSource::new(
self.client
.post(&url)
.header("x-goog-api-key", &self.api_key)
.json(&gemini_request),
)
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
);
if let Err(e) = es_result {
// Fallback: call non-streaming path and convert to a single-stream chunk
let resp = self.chat_completion(request.clone()).await.map_err(|e2| {
AppError::ProviderError(format!("Failed to create EventSource: {} ; fallback error: {}", e, e2))
})?;
let single_stream = async_stream::try_stream! {
let chunk = ProviderStreamChunk {
content: resp.content,
reasoning_content: resp.reasoning_content,
finish_reason: Some("stop".to_string()),
tool_calls: None,
model: resp.model.clone(),
usage: Some(super::StreamUsage {
prompt_tokens: resp.prompt_tokens,
completion_tokens: resp.completion_tokens,
total_tokens: resp.total_tokens,
cache_read_tokens: resp.cache_read_tokens,
cache_write_tokens: resp.cache_write_tokens,
}),
};
yield chunk;
};
return Ok(Box::pin(single_stream));
}
let es = es_result.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let stream = async_stream::try_stream! {
let mut es = es;
@@ -638,7 +671,33 @@ impl super::Provider for GeminiProvider {
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
// On streaming errors, attempt a synchronous fallback once.
// This handles cases where the provider rejects the SSE
// request but supports a non-streaming generateContent call.
match self.chat_completion(request.clone()).await {
Ok(resp) => {
let chunk = ProviderStreamChunk {
content: resp.content,
reasoning_content: resp.reasoning_content,
finish_reason: Some("stop".to_string()),
tool_calls: resp.tool_calls.map(|d| d.into_iter().map(|tc| tc.into()).collect()),
model: resp.model.clone(),
usage: Some(super::StreamUsage {
prompt_tokens: resp.prompt_tokens,
completion_tokens: resp.completion_tokens,
total_tokens: resp.total_tokens,
cache_read_tokens: resp.cache_read_tokens,
cache_write_tokens: resp.cache_write_tokens,
}),
};
yield chunk;
break;
}
Err(err2) => {
Err(AppError::ProviderError(format!("Stream error: {} ; fallback error: {}", e, err2)))?;
}
}
}
}
}

View File

@@ -28,6 +28,13 @@ pub trait Provider: Send + Sync {
/// Process a chat completion request
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError>;
/// Process a chat request using provider-specific "responses" style endpoint
/// Default implementation falls back to `chat_completion` for providers
/// that do not implement a dedicated responses endpoint.
async fn chat_responses(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
self.chat_completion(request).await
}
/// Process a streaming chat completion request
async fn chat_completion_stream(
&self,

View File

@@ -65,7 +65,107 @@ impl super::Provider for OpenAIProvider {
.map_err(|e| AppError::ProviderError(e.to_string()))?;
if !response.status().is_success() {
// Read error body to diagnose. If the model requires the Responses
// API (v1/responses), retry against that endpoint.
let error_text = response.text().await.unwrap_or_default();
if error_text.to_lowercase().contains("v1/responses") || error_text.to_lowercase().contains("only supported in v1/responses") {
// Build a simple `input` string by concatenating message parts.
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let mut inputs: Vec<String> = Vec::new();
for m in &messages_json {
let role = m["role"].as_str().unwrap_or("");
let parts = m.get("content").and_then(|c| c.as_array()).cloned().unwrap_or_default();
let mut text_parts = Vec::new();
for p in parts {
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
text_parts.push(t.to_string());
}
}
inputs.push(format!("{}: {}", role, text_parts.join("")));
}
let input_text = inputs.join("\n");
let resp = self
.client
.post(format!("{}/responses", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&serde_json::json!({ "model": request.model, "input": input_text }))
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
if !resp.status().is_success() {
let err = resp.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("OpenAI Responses API error: {}", err)));
}
let resp_json: serde_json::Value = resp.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Try to normalize: if it's chat-style, use existing parser
if resp_json.get("choices").is_some() {
return helpers::parse_openai_response(&resp_json, request.model);
}
// Responses API: try to extract text from `output` or `candidates`
// output -> [{"content": [{"type":..., "text": "..."}, ...]}]
let mut content_text = String::new();
if let Some(output) = resp_json.get("output").and_then(|o| o.as_array()) {
if let Some(first) = output.get(0) {
if let Some(contents) = first.get("content").and_then(|c| c.as_array()) {
for item in contents {
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
if !content_text.is_empty() {
content_text.push_str("\n");
}
content_text.push_str(text);
} else if let Some(parts) = item.get("parts").and_then(|p| p.as_array()) {
for p in parts {
if let Some(t) = p.as_str() {
if !content_text.is_empty() { content_text.push_str("\n"); }
content_text.push_str(t);
}
}
}
}
}
}
}
// Fallback: check `candidates` -> candidate.content.parts.text
if content_text.is_empty() {
if let Some(cands) = resp_json.get("candidates").and_then(|c| c.as_array()) {
if let Some(c0) = cands.get(0) {
if let Some(content) = c0.get("content") {
if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) {
for p in parts {
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
if !content_text.is_empty() { content_text.push_str("\n"); }
content_text.push_str(t);
}
}
}
}
}
}
}
// Extract simple usage if present
let prompt_tokens = resp_json.get("usage").and_then(|u| u.get("prompt_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
let completion_tokens = resp_json.get("usage").and_then(|u| u.get("completion_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
let total_tokens = resp_json.get("usage").and_then(|u| u.get("total_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
return Ok(ProviderResponse {
content: content_text,
reasoning_content: None,
tool_calls: None,
prompt_tokens,
completion_tokens,
total_tokens,
cache_read_tokens: 0,
cache_write_tokens: 0,
model: request.model,
});
}
return Err(AppError::ProviderError(format!("OpenAI API error: {}", error_text)));
}
@@ -77,6 +177,95 @@ impl super::Provider for OpenAIProvider {
helpers::parse_openai_response(&resp_json, request.model)
}
async fn chat_responses(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
// Build a simple `input` string by concatenating message parts.
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let mut inputs: Vec<String> = Vec::new();
for m in &messages_json {
let role = m["role"].as_str().unwrap_or("");
let parts = m.get("content").and_then(|c| c.as_array()).cloned().unwrap_or_default();
let mut text_parts = Vec::new();
for p in parts {
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
text_parts.push(t.to_string());
}
}
inputs.push(format!("{}: {}", role, text_parts.join("")));
}
let input_text = inputs.join("\n");
let resp = self
.client
.post(format!("{}/responses", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&serde_json::json!({ "model": request.model, "input": input_text }))
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
if !resp.status().is_success() {
let err = resp.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("OpenAI Responses API error: {}", err)));
}
let resp_json: serde_json::Value = resp.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Normalize Responses API output into ProviderResponse
let mut content_text = String::new();
if let Some(output) = resp_json.get("output").and_then(|o| o.as_array()) {
if let Some(first) = output.get(0) {
if let Some(contents) = first.get("content").and_then(|c| c.as_array()) {
for item in contents {
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
if !content_text.is_empty() { content_text.push_str("\n"); }
content_text.push_str(text);
} else if let Some(parts) = item.get("parts").and_then(|p| p.as_array()) {
for p in parts {
if let Some(t) = p.as_str() {
if !content_text.is_empty() { content_text.push_str("\n"); }
content_text.push_str(t);
}
}
}
}
}
}
}
if content_text.is_empty() {
if let Some(cands) = resp_json.get("candidates").and_then(|c| c.as_array()) {
if let Some(c0) = cands.get(0) {
if let Some(content) = c0.get("content") {
if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) {
for p in parts {
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
if !content_text.is_empty() { content_text.push_str("\n"); }
content_text.push_str(t);
}
}
}
}
}
}
}
let prompt_tokens = resp_json.get("usage").and_then(|u| u.get("prompt_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
let completion_tokens = resp_json.get("usage").and_then(|u| u.get("completion_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
let total_tokens = resp_json.get("usage").and_then(|u| u.get("total_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
Ok(ProviderResponse {
content: content_text,
reasoning_content: None,
tool_calls: None,
prompt_tokens,
completion_tokens,
total_tokens,
cache_read_tokens: 0,
cache_write_tokens: 0,
model: request.model,
})
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
@@ -110,13 +299,43 @@ impl super::Provider for OpenAIProvider {
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let body = helpers::build_openai_body(&request, messages_json, true);
let es = reqwest_eventsource::EventSource::new(
// Try to create an EventSource for streaming; if creation fails or
// the stream errors, fall back to a single synchronous request and
// emit its result as a single chunk.
let es_result = reqwest_eventsource::EventSource::new(
self.client
.post(format!("{}/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body),
)
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
);
if let Err(e) = es_result {
// Fallback to non-streaming request which itself may retry to
// Responses API if necessary (handled in chat_completion).
let resp = self.chat_completion(request.clone()).await?;
let single_stream = async_stream::try_stream! {
let chunk = ProviderStreamChunk {
content: resp.content,
reasoning_content: resp.reasoning_content,
finish_reason: Some("stop".to_string()),
tool_calls: None,
model: resp.model.clone(),
usage: Some(super::StreamUsage {
prompt_tokens: resp.prompt_tokens,
completion_tokens: resp.completion_tokens,
total_tokens: resp.total_tokens,
cache_read_tokens: resp.cache_read_tokens,
cache_write_tokens: resp.cache_write_tokens,
}),
};
yield chunk;
};
return Ok(Box::pin(single_stream));
}
let es = es_result.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
Ok(helpers::create_openai_stream(es, request.model, None))
}

View File

@@ -304,7 +304,17 @@ async fn chat_completions(
}
} else {
// Handle non-streaming response
let result = provider.chat_completion(unified_request).await;
// Allow provider-specific routing: for OpenAI, some models prefer the
// Responses API (/v1/responses). Use the model registry heuristic to
// choose chat_responses vs chat_completion automatically.
let use_responses = provider.name() == "openai"
&& crate::utils::registry::model_prefers_responses(&state.model_registry, &unified_request.model);
let result = if use_responses {
provider.chat_responses(unified_request).await
} else {
provider.chat_completion(unified_request).await
};
match result {
Ok(response) => {

View File

@@ -22,3 +22,28 @@ pub async fn fetch_registry() -> Result<ModelRegistry> {
Ok(registry)
}
/// Heuristic: decide whether a model should be routed to OpenAI Responses API
/// instead of the legacy chat/completions endpoint.
///
/// Currently this uses simple patterns (codex, gpt-5 series) and also checks
/// the loaded registry metadata name for the substring "codex" as a hint.
pub fn model_prefers_responses(registry: &ModelRegistry, model: &str) -> bool {
let model_lc = model.to_lowercase();
if model_lc.contains("codex") {
return true;
}
if model_lc.starts_with("gpt-5") || model_lc.contains("gpt-5.") {
return true;
}
if let Some(meta) = registry.find_model(model) {
if meta.name.to_lowercase().contains("codex") {
return true;
}
}
false
}