diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index 2b23dafb..a7fcb4d2 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -555,13 +555,46 @@ impl super::Provider for GeminiProvider { use futures::StreamExt; use reqwest_eventsource::{Event, EventSource}; - let es = EventSource::new( + // Try to create an SSE event source for streaming. If creation fails + // (provider doesn't support streaming for this model or returned a + // non-2xx response), fall back to a synchronous generateContent call + // and emit a single chunk. + let es_result = reqwest_eventsource::EventSource::new( self.client .post(&url) .header("x-goog-api-key", &self.api_key) .json(&gemini_request), - ) - .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; + ); + + if let Err(e) = es_result { + // Fallback: call non-streaming path and convert to a single-stream chunk + let resp = self.chat_completion(request.clone()).await.map_err(|e2| { + AppError::ProviderError(format!("Failed to create EventSource: {} ; fallback error: {}", e, e2)) + })?; + + let single_stream = async_stream::try_stream! { + let chunk = ProviderStreamChunk { + content: resp.content, + reasoning_content: resp.reasoning_content, + finish_reason: Some("stop".to_string()), + tool_calls: None, + model: resp.model.clone(), + usage: Some(super::StreamUsage { + prompt_tokens: resp.prompt_tokens, + completion_tokens: resp.completion_tokens, + total_tokens: resp.total_tokens, + cache_read_tokens: resp.cache_read_tokens, + cache_write_tokens: resp.cache_write_tokens, + }), + }; + + yield chunk; + }; + + return Ok(Box::pin(single_stream)); + } + + let es = es_result.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; let stream = async_stream::try_stream! { let mut es = es; @@ -638,7 +671,33 @@ impl super::Provider for GeminiProvider { } Ok(_) => continue, Err(e) => { - Err(AppError::ProviderError(format!("Stream error: {}", e)))?; + // On streaming errors, attempt a synchronous fallback once. + // This handles cases where the provider rejects the SSE + // request but supports a non-streaming generateContent call. + match self.chat_completion(request.clone()).await { + Ok(resp) => { + let chunk = ProviderStreamChunk { + content: resp.content, + reasoning_content: resp.reasoning_content, + finish_reason: Some("stop".to_string()), + tool_calls: resp.tool_calls.map(|d| d.into_iter().map(|tc| tc.into()).collect()), + model: resp.model.clone(), + usage: Some(super::StreamUsage { + prompt_tokens: resp.prompt_tokens, + completion_tokens: resp.completion_tokens, + total_tokens: resp.total_tokens, + cache_read_tokens: resp.cache_read_tokens, + cache_write_tokens: resp.cache_write_tokens, + }), + }; + + yield chunk; + break; + } + Err(err2) => { + Err(AppError::ProviderError(format!("Stream error: {} ; fallback error: {}", e, err2)))?; + } + } } } } diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 078a01c8..2a2e138d 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -28,6 +28,13 @@ pub trait Provider: Send + Sync { /// Process a chat completion request async fn chat_completion(&self, request: UnifiedRequest) -> Result; + /// Process a chat request using provider-specific "responses" style endpoint + /// Default implementation falls back to `chat_completion` for providers + /// that do not implement a dedicated responses endpoint. + async fn chat_responses(&self, request: UnifiedRequest) -> Result { + self.chat_completion(request).await + } + /// Process a streaming chat completion request async fn chat_completion_stream( &self, diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 15001f0a..19c34f24 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -65,7 +65,107 @@ impl super::Provider for OpenAIProvider { .map_err(|e| AppError::ProviderError(e.to_string()))?; if !response.status().is_success() { + // Read error body to diagnose. If the model requires the Responses + // API (v1/responses), retry against that endpoint. let error_text = response.text().await.unwrap_or_default(); + if error_text.to_lowercase().contains("v1/responses") || error_text.to_lowercase().contains("only supported in v1/responses") { + // Build a simple `input` string by concatenating message parts. + let messages_json = helpers::messages_to_openai_json(&request.messages).await?; + let mut inputs: Vec = Vec::new(); + for m in &messages_json { + let role = m["role"].as_str().unwrap_or(""); + let parts = m.get("content").and_then(|c| c.as_array()).cloned().unwrap_or_default(); + let mut text_parts = Vec::new(); + for p in parts { + if let Some(t) = p.get("text").and_then(|v| v.as_str()) { + text_parts.push(t.to_string()); + } + } + inputs.push(format!("{}: {}", role, text_parts.join(""))); + } + let input_text = inputs.join("\n"); + + let resp = self + .client + .post(format!("{}/responses", self.config.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&serde_json::json!({ "model": request.model, "input": input_text })) + .send() + .await + .map_err(|e| AppError::ProviderError(e.to_string()))?; + + if !resp.status().is_success() { + let err = resp.text().await.unwrap_or_default(); + return Err(AppError::ProviderError(format!("OpenAI Responses API error: {}", err))); + } + + let resp_json: serde_json::Value = resp.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?; + // Try to normalize: if it's chat-style, use existing parser + if resp_json.get("choices").is_some() { + return helpers::parse_openai_response(&resp_json, request.model); + } + + // Responses API: try to extract text from `output` or `candidates` + // output -> [{"content": [{"type":..., "text": "..."}, ...]}] + let mut content_text = String::new(); + if let Some(output) = resp_json.get("output").and_then(|o| o.as_array()) { + if let Some(first) = output.get(0) { + if let Some(contents) = first.get("content").and_then(|c| c.as_array()) { + for item in contents { + if let Some(text) = item.get("text").and_then(|t| t.as_str()) { + if !content_text.is_empty() { + content_text.push_str("\n"); + } + content_text.push_str(text); + } else if let Some(parts) = item.get("parts").and_then(|p| p.as_array()) { + for p in parts { + if let Some(t) = p.as_str() { + if !content_text.is_empty() { content_text.push_str("\n"); } + content_text.push_str(t); + } + } + } + } + } + } + } + + // Fallback: check `candidates` -> candidate.content.parts.text + if content_text.is_empty() { + if let Some(cands) = resp_json.get("candidates").and_then(|c| c.as_array()) { + if let Some(c0) = cands.get(0) { + if let Some(content) = c0.get("content") { + if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) { + for p in parts { + if let Some(t) = p.get("text").and_then(|v| v.as_str()) { + if !content_text.is_empty() { content_text.push_str("\n"); } + content_text.push_str(t); + } + } + } + } + } + } + } + + // Extract simple usage if present + let prompt_tokens = resp_json.get("usage").and_then(|u| u.get("prompt_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32; + let completion_tokens = resp_json.get("usage").and_then(|u| u.get("completion_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32; + let total_tokens = resp_json.get("usage").and_then(|u| u.get("total_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32; + + return Ok(ProviderResponse { + content: content_text, + reasoning_content: None, + tool_calls: None, + prompt_tokens, + completion_tokens, + total_tokens, + cache_read_tokens: 0, + cache_write_tokens: 0, + model: request.model, + }); + } + return Err(AppError::ProviderError(format!("OpenAI API error: {}", error_text))); } @@ -77,6 +177,95 @@ impl super::Provider for OpenAIProvider { helpers::parse_openai_response(&resp_json, request.model) } + async fn chat_responses(&self, request: UnifiedRequest) -> Result { + // Build a simple `input` string by concatenating message parts. + let messages_json = helpers::messages_to_openai_json(&request.messages).await?; + let mut inputs: Vec = Vec::new(); + for m in &messages_json { + let role = m["role"].as_str().unwrap_or(""); + let parts = m.get("content").and_then(|c| c.as_array()).cloned().unwrap_or_default(); + let mut text_parts = Vec::new(); + for p in parts { + if let Some(t) = p.get("text").and_then(|v| v.as_str()) { + text_parts.push(t.to_string()); + } + } + inputs.push(format!("{}: {}", role, text_parts.join(""))); + } + let input_text = inputs.join("\n"); + + let resp = self + .client + .post(format!("{}/responses", self.config.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&serde_json::json!({ "model": request.model, "input": input_text })) + .send() + .await + .map_err(|e| AppError::ProviderError(e.to_string()))?; + + if !resp.status().is_success() { + let err = resp.text().await.unwrap_or_default(); + return Err(AppError::ProviderError(format!("OpenAI Responses API error: {}", err))); + } + + let resp_json: serde_json::Value = resp.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?; + + // Normalize Responses API output into ProviderResponse + let mut content_text = String::new(); + if let Some(output) = resp_json.get("output").and_then(|o| o.as_array()) { + if let Some(first) = output.get(0) { + if let Some(contents) = first.get("content").and_then(|c| c.as_array()) { + for item in contents { + if let Some(text) = item.get("text").and_then(|t| t.as_str()) { + if !content_text.is_empty() { content_text.push_str("\n"); } + content_text.push_str(text); + } else if let Some(parts) = item.get("parts").and_then(|p| p.as_array()) { + for p in parts { + if let Some(t) = p.as_str() { + if !content_text.is_empty() { content_text.push_str("\n"); } + content_text.push_str(t); + } + } + } + } + } + } + } + + if content_text.is_empty() { + if let Some(cands) = resp_json.get("candidates").and_then(|c| c.as_array()) { + if let Some(c0) = cands.get(0) { + if let Some(content) = c0.get("content") { + if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) { + for p in parts { + if let Some(t) = p.get("text").and_then(|v| v.as_str()) { + if !content_text.is_empty() { content_text.push_str("\n"); } + content_text.push_str(t); + } + } + } + } + } + } + } + + let prompt_tokens = resp_json.get("usage").and_then(|u| u.get("prompt_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32; + let completion_tokens = resp_json.get("usage").and_then(|u| u.get("completion_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32; + let total_tokens = resp_json.get("usage").and_then(|u| u.get("total_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32; + + Ok(ProviderResponse { + content: content_text, + reasoning_content: None, + tool_calls: None, + prompt_tokens, + completion_tokens, + total_tokens, + cache_read_tokens: 0, + cache_write_tokens: 0, + model: request.model, + }) + } + fn estimate_tokens(&self, request: &UnifiedRequest) -> Result { Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) } @@ -110,13 +299,43 @@ impl super::Provider for OpenAIProvider { let messages_json = helpers::messages_to_openai_json(&request.messages).await?; let body = helpers::build_openai_body(&request, messages_json, true); - let es = reqwest_eventsource::EventSource::new( + // Try to create an EventSource for streaming; if creation fails or + // the stream errors, fall back to a single synchronous request and + // emit its result as a single chunk. + let es_result = reqwest_eventsource::EventSource::new( self.client .post(format!("{}/chat/completions", self.config.base_url)) .header("Authorization", format!("Bearer {}", self.api_key)) .json(&body), - ) - .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; + ); + + if let Err(e) = es_result { + // Fallback to non-streaming request which itself may retry to + // Responses API if necessary (handled in chat_completion). + let resp = self.chat_completion(request.clone()).await?; + let single_stream = async_stream::try_stream! { + let chunk = ProviderStreamChunk { + content: resp.content, + reasoning_content: resp.reasoning_content, + finish_reason: Some("stop".to_string()), + tool_calls: None, + model: resp.model.clone(), + usage: Some(super::StreamUsage { + prompt_tokens: resp.prompt_tokens, + completion_tokens: resp.completion_tokens, + total_tokens: resp.total_tokens, + cache_read_tokens: resp.cache_read_tokens, + cache_write_tokens: resp.cache_write_tokens, + }), + }; + + yield chunk; + }; + + return Ok(Box::pin(single_stream)); + } + + let es = es_result.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; Ok(helpers::create_openai_stream(es, request.model, None)) } diff --git a/src/server/mod.rs b/src/server/mod.rs index 5a33f848..9c7d7ab5 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -304,7 +304,17 @@ async fn chat_completions( } } else { // Handle non-streaming response - let result = provider.chat_completion(unified_request).await; + // Allow provider-specific routing: for OpenAI, some models prefer the + // Responses API (/v1/responses). Use the model registry heuristic to + // choose chat_responses vs chat_completion automatically. + let use_responses = provider.name() == "openai" + && crate::utils::registry::model_prefers_responses(&state.model_registry, &unified_request.model); + + let result = if use_responses { + provider.chat_responses(unified_request).await + } else { + provider.chat_completion(unified_request).await + }; match result { Ok(response) => { diff --git a/src/utils/registry.rs b/src/utils/registry.rs index ca7f8d52..1bf9e1cf 100644 --- a/src/utils/registry.rs +++ b/src/utils/registry.rs @@ -22,3 +22,28 @@ pub async fn fetch_registry() -> Result { Ok(registry) } + +/// Heuristic: decide whether a model should be routed to OpenAI Responses API +/// instead of the legacy chat/completions endpoint. +/// +/// Currently this uses simple patterns (codex, gpt-5 series) and also checks +/// the loaded registry metadata name for the substring "codex" as a hint. +pub fn model_prefers_responses(registry: &ModelRegistry, model: &str) -> bool { + let model_lc = model.to_lowercase(); + + if model_lc.contains("codex") { + return true; + } + + if model_lc.starts_with("gpt-5") || model_lc.contains("gpt-5.") { + return true; + } + + if let Some(meta) = registry.find_model(model) { + if meta.name.to_lowercase().contains("codex") { + return true; + } + } + + false +}