343 lines
15 KiB
Rust
343 lines
15 KiB
Rust
use anyhow::Result;
|
|
use async_trait::async_trait;
|
|
use futures::stream::BoxStream;
|
|
|
|
use super::helpers;
|
|
use super::{ProviderResponse, ProviderStreamChunk};
|
|
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
|
|
|
pub struct OpenAIProvider {
|
|
client: reqwest::Client,
|
|
config: crate::config::OpenAIConfig,
|
|
api_key: String,
|
|
pricing: Vec<crate::config::ModelPricing>,
|
|
}
|
|
|
|
impl OpenAIProvider {
|
|
pub fn new(config: &crate::config::OpenAIConfig, app_config: &AppConfig) -> Result<Self> {
|
|
let api_key = app_config.get_api_key("openai")?;
|
|
Self::new_with_key(config, app_config, api_key)
|
|
}
|
|
|
|
pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
|
let client = reqwest::Client::builder()
|
|
.connect_timeout(std::time::Duration::from_secs(5))
|
|
.timeout(std::time::Duration::from_secs(300))
|
|
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
|
.pool_max_idle_per_host(4)
|
|
.tcp_keepalive(std::time::Duration::from_secs(30))
|
|
.build()?;
|
|
|
|
Ok(Self {
|
|
client,
|
|
config: config.clone(),
|
|
api_key,
|
|
pricing: app_config.pricing.openai.clone(),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl super::Provider for OpenAIProvider {
|
|
fn name(&self) -> &str {
|
|
"openai"
|
|
}
|
|
|
|
fn supports_model(&self, model: &str) -> bool {
|
|
model.starts_with("gpt-") || model.starts_with("o1-") || model.starts_with("o3-") || model.starts_with("o4-")
|
|
}
|
|
|
|
fn supports_multimodal(&self) -> bool {
|
|
true
|
|
}
|
|
|
|
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
|
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
|
let body = helpers::build_openai_body(&request, messages_json, false);
|
|
|
|
let response = self
|
|
.client
|
|
.post(format!("{}/chat/completions", self.config.base_url))
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.json(&body)
|
|
.send()
|
|
.await
|
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
|
|
if !response.status().is_success() {
|
|
// Read error body to diagnose. If the model requires the Responses
|
|
// API (v1/responses), retry against that endpoint.
|
|
let error_text = response.text().await.unwrap_or_default();
|
|
if error_text.to_lowercase().contains("v1/responses") || error_text.to_lowercase().contains("only supported in v1/responses") {
|
|
// Build a simple `input` string by concatenating message parts.
|
|
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
|
let mut inputs: Vec<String> = Vec::new();
|
|
for m in &messages_json {
|
|
let role = m["role"].as_str().unwrap_or("");
|
|
let parts = m.get("content").and_then(|c| c.as_array()).cloned().unwrap_or_default();
|
|
let mut text_parts = Vec::new();
|
|
for p in parts {
|
|
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
|
|
text_parts.push(t.to_string());
|
|
}
|
|
}
|
|
inputs.push(format!("{}: {}", role, text_parts.join("")));
|
|
}
|
|
let input_text = inputs.join("\n");
|
|
|
|
let resp = self
|
|
.client
|
|
.post(format!("{}/responses", self.config.base_url))
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.json(&serde_json::json!({ "model": request.model, "input": input_text }))
|
|
.send()
|
|
.await
|
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
|
|
if !resp.status().is_success() {
|
|
let err = resp.text().await.unwrap_or_default();
|
|
return Err(AppError::ProviderError(format!("OpenAI Responses API error: {}", err)));
|
|
}
|
|
|
|
let resp_json: serde_json::Value = resp.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
// Try to normalize: if it's chat-style, use existing parser
|
|
if resp_json.get("choices").is_some() {
|
|
return helpers::parse_openai_response(&resp_json, request.model);
|
|
}
|
|
|
|
// Responses API: try to extract text from `output` or `candidates`
|
|
// output -> [{"content": [{"type":..., "text": "..."}, ...]}]
|
|
let mut content_text = String::new();
|
|
if let Some(output) = resp_json.get("output").and_then(|o| o.as_array()) {
|
|
if let Some(first) = output.get(0) {
|
|
if let Some(contents) = first.get("content").and_then(|c| c.as_array()) {
|
|
for item in contents {
|
|
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
|
|
if !content_text.is_empty() {
|
|
content_text.push_str("\n");
|
|
}
|
|
content_text.push_str(text);
|
|
} else if let Some(parts) = item.get("parts").and_then(|p| p.as_array()) {
|
|
for p in parts {
|
|
if let Some(t) = p.as_str() {
|
|
if !content_text.is_empty() { content_text.push_str("\n"); }
|
|
content_text.push_str(t);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fallback: check `candidates` -> candidate.content.parts.text
|
|
if content_text.is_empty() {
|
|
if let Some(cands) = resp_json.get("candidates").and_then(|c| c.as_array()) {
|
|
if let Some(c0) = cands.get(0) {
|
|
if let Some(content) = c0.get("content") {
|
|
if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) {
|
|
for p in parts {
|
|
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
|
|
if !content_text.is_empty() { content_text.push_str("\n"); }
|
|
content_text.push_str(t);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Extract simple usage if present
|
|
let prompt_tokens = resp_json.get("usage").and_then(|u| u.get("prompt_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
|
let completion_tokens = resp_json.get("usage").and_then(|u| u.get("completion_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
|
let total_tokens = resp_json.get("usage").and_then(|u| u.get("total_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
|
|
|
return Ok(ProviderResponse {
|
|
content: content_text,
|
|
reasoning_content: None,
|
|
tool_calls: None,
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
total_tokens,
|
|
cache_read_tokens: 0,
|
|
cache_write_tokens: 0,
|
|
model: request.model,
|
|
});
|
|
}
|
|
|
|
return Err(AppError::ProviderError(format!("OpenAI API error: {}", error_text)));
|
|
}
|
|
|
|
let resp_json: serde_json::Value = response
|
|
.json()
|
|
.await
|
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
|
|
helpers::parse_openai_response(&resp_json, request.model)
|
|
}
|
|
|
|
async fn chat_responses(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
|
// Build a simple `input` string by concatenating message parts.
|
|
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
|
let mut inputs: Vec<String> = Vec::new();
|
|
for m in &messages_json {
|
|
let role = m["role"].as_str().unwrap_or("");
|
|
let parts = m.get("content").and_then(|c| c.as_array()).cloned().unwrap_or_default();
|
|
let mut text_parts = Vec::new();
|
|
for p in parts {
|
|
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
|
|
text_parts.push(t.to_string());
|
|
}
|
|
}
|
|
inputs.push(format!("{}: {}", role, text_parts.join("")));
|
|
}
|
|
let input_text = inputs.join("\n");
|
|
|
|
let resp = self
|
|
.client
|
|
.post(format!("{}/responses", self.config.base_url))
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.json(&serde_json::json!({ "model": request.model, "input": input_text }))
|
|
.send()
|
|
.await
|
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
|
|
if !resp.status().is_success() {
|
|
let err = resp.text().await.unwrap_or_default();
|
|
return Err(AppError::ProviderError(format!("OpenAI Responses API error: {}", err)));
|
|
}
|
|
|
|
let resp_json: serde_json::Value = resp.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
|
|
// Normalize Responses API output into ProviderResponse
|
|
let mut content_text = String::new();
|
|
if let Some(output) = resp_json.get("output").and_then(|o| o.as_array()) {
|
|
if let Some(first) = output.get(0) {
|
|
if let Some(contents) = first.get("content").and_then(|c| c.as_array()) {
|
|
for item in contents {
|
|
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
|
|
if !content_text.is_empty() { content_text.push_str("\n"); }
|
|
content_text.push_str(text);
|
|
} else if let Some(parts) = item.get("parts").and_then(|p| p.as_array()) {
|
|
for p in parts {
|
|
if let Some(t) = p.as_str() {
|
|
if !content_text.is_empty() { content_text.push_str("\n"); }
|
|
content_text.push_str(t);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if content_text.is_empty() {
|
|
if let Some(cands) = resp_json.get("candidates").and_then(|c| c.as_array()) {
|
|
if let Some(c0) = cands.get(0) {
|
|
if let Some(content) = c0.get("content") {
|
|
if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) {
|
|
for p in parts {
|
|
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
|
|
if !content_text.is_empty() { content_text.push_str("\n"); }
|
|
content_text.push_str(t);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let prompt_tokens = resp_json.get("usage").and_then(|u| u.get("prompt_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
|
let completion_tokens = resp_json.get("usage").and_then(|u| u.get("completion_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
|
let total_tokens = resp_json.get("usage").and_then(|u| u.get("total_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
|
|
|
Ok(ProviderResponse {
|
|
content: content_text,
|
|
reasoning_content: None,
|
|
tool_calls: None,
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
total_tokens,
|
|
cache_read_tokens: 0,
|
|
cache_write_tokens: 0,
|
|
model: request.model,
|
|
})
|
|
}
|
|
|
|
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
|
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
|
}
|
|
|
|
fn calculate_cost(
|
|
&self,
|
|
model: &str,
|
|
prompt_tokens: u32,
|
|
completion_tokens: u32,
|
|
cache_read_tokens: u32,
|
|
cache_write_tokens: u32,
|
|
registry: &crate::models::registry::ModelRegistry,
|
|
) -> f64 {
|
|
helpers::calculate_cost_with_registry(
|
|
model,
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
cache_read_tokens,
|
|
cache_write_tokens,
|
|
registry,
|
|
&self.pricing,
|
|
0.15,
|
|
0.60,
|
|
)
|
|
}
|
|
|
|
async fn chat_completion_stream(
|
|
&self,
|
|
request: UnifiedRequest,
|
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
|
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
|
let body = helpers::build_openai_body(&request, messages_json, true);
|
|
|
|
// Try to create an EventSource for streaming; if creation fails or
|
|
// the stream errors, fall back to a single synchronous request and
|
|
// emit its result as a single chunk.
|
|
let es_result = reqwest_eventsource::EventSource::new(
|
|
self.client
|
|
.post(format!("{}/chat/completions", self.config.base_url))
|
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
.json(&body),
|
|
);
|
|
|
|
if es_result.is_err() {
|
|
// Fallback to non-streaming request which itself may retry to
|
|
// Responses API if necessary (handled in chat_completion).
|
|
let resp = self.chat_completion(request.clone()).await?;
|
|
let single_stream = async_stream::try_stream! {
|
|
let chunk = ProviderStreamChunk {
|
|
content: resp.content,
|
|
reasoning_content: resp.reasoning_content,
|
|
finish_reason: Some("stop".to_string()),
|
|
tool_calls: None,
|
|
model: resp.model.clone(),
|
|
usage: Some(super::StreamUsage {
|
|
prompt_tokens: resp.prompt_tokens,
|
|
completion_tokens: resp.completion_tokens,
|
|
total_tokens: resp.total_tokens,
|
|
cache_read_tokens: resp.cache_read_tokens,
|
|
cache_write_tokens: resp.cache_write_tokens,
|
|
}),
|
|
};
|
|
|
|
yield chunk;
|
|
};
|
|
|
|
return Ok(Box::pin(single_stream));
|
|
}
|
|
|
|
let es = es_result.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
|
|
|
Ok(helpers::create_openai_stream(es, request.model, None))
|
|
}
|
|
}
|