use async_trait::async_trait; use anyhow::Result; use async_openai::{Client, config::OpenAIConfig}; use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent}; use futures::stream::{BoxStream, StreamExt}; use crate::{ models::UnifiedRequest, errors::AppError, config::AppConfig, }; use super::{ProviderResponse, ProviderStreamChunk}; pub struct OpenAIProvider { client: Client, _config: crate::config::OpenAIConfig, pricing: Vec, } impl OpenAIProvider { pub fn new(config: &crate::config::OpenAIConfig, app_config: &AppConfig) -> Result { let api_key = app_config.get_api_key("openai")?; // Create OpenAIConfig with api key and base url let openai_config = OpenAIConfig::default() .with_api_key(api_key) .with_api_base(&config.base_url); let client = Client::with_config(openai_config); Ok(Self { client, _config: config.clone(), pricing: app_config.pricing.openai.clone(), }) } } #[async_trait] impl super::Provider for OpenAIProvider { fn name(&self) -> &str { "openai" } fn supports_model(&self, model: &str) -> bool { model.starts_with("gpt-") || model.starts_with("o1-") || model.starts_with("o3-") } fn supports_multimodal(&self) -> bool { true // OpenAI supports vision models } async fn chat_completion( &self, request: UnifiedRequest, ) -> Result { use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; // Convert UnifiedRequest messages to OpenAI messages let mut messages = Vec::with_capacity(request.messages.len()); for msg in request.messages { let mut parts = Vec::with_capacity(msg.content.len()); for part in msg.content { match part { crate::models::ContentPart::Text { text } => { parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { text, })); } crate::models::ContentPart::Image(image_input) => { let (base64_data, mime_type) = image_input.to_base64().await .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; let data_url = format!("data:{};base64,{}", mime_type, base64_data); parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage { image_url: ImageUrl { url: data_url, detail: Some(ImageDetail::Auto), } })); } } } let message = match msg.role.as_str() { "system" => ChatCompletionRequestMessage::System( ChatCompletionRequestSystemMessage { content: ChatCompletionRequestSystemMessageContent::Text( parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") ), name: None, } ), "assistant" => ChatCompletionRequestMessage::Assistant( ChatCompletionRequestAssistantMessage { content: Some(ChatCompletionRequestAssistantMessageContent::Text( parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") )), name: None, tool_calls: None, refusal: None, audio: None, #[allow(deprecated)] function_call: None, } ), _ => ChatCompletionRequestMessage::User( ChatCompletionRequestUserMessage { content: ChatCompletionRequestUserMessageContent::Array(parts), name: None, } ), }; messages.push(message); } if messages.is_empty() { return Err(AppError::ProviderError("No valid text messages to send".to_string())); } // Build request using builder pattern let mut builder = CreateChatCompletionRequestArgs::default(); builder.model(request.model.clone()); builder.messages(messages); // Add optional parameters if let Some(temp) = request.temperature { builder.temperature(temp as f32); } if let Some(max_tokens) = request.max_tokens { builder.max_tokens(max_tokens as u16); } // Execute API call let response = self.client .chat() .create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?) .await .map_err(|e| AppError::ProviderError(e.to_string()))?; // Extract content from response let content = response .choices .first() .and_then(|choice| choice.message.content.clone()) .unwrap_or_default(); // Extract token usage let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32; let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32; let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32; Ok(ProviderResponse { content, prompt_tokens, completion_tokens, total_tokens, model: request.model, }) } fn estimate_tokens(&self, request: &UnifiedRequest) -> Result { Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) } fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 { if let Some(metadata) = registry.find_model(model) { if let Some(cost) = &metadata.cost { return (prompt_tokens as f64 * cost.input / 1_000_000.0) + (completion_tokens as f64 * cost.output / 1_000_000.0); } } // Fallback to static pricing if not in registry let (prompt_rate, completion_rate) = self.pricing.iter() .find(|p| model.contains(&p.model)) .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) .unwrap_or((0.15, 0.60)); (prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) } async fn chat_completion_stream( &self, request: UnifiedRequest, ) -> Result>, AppError> { use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; // Convert UnifiedRequest messages to OpenAI messages let mut messages = Vec::with_capacity(request.messages.len()); for msg in request.messages { let mut parts = Vec::with_capacity(msg.content.len()); for part in msg.content { match part { crate::models::ContentPart::Text { text } => { parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { text, })); } crate::models::ContentPart::Image(image_input) => { let (base64_data, mime_type) = image_input.to_base64().await .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; let data_url = format!("data:{};base64,{}", mime_type, base64_data); parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage { image_url: ImageUrl { url: data_url, detail: Some(ImageDetail::Auto), } })); } } } let message = match msg.role.as_str() { "system" => ChatCompletionRequestMessage::System( ChatCompletionRequestSystemMessage { content: ChatCompletionRequestSystemMessageContent::Text( parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") ), name: None, } ), "assistant" => ChatCompletionRequestMessage::Assistant( ChatCompletionRequestAssistantMessage { content: Some(ChatCompletionRequestAssistantMessageContent::Text( parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") )), name: None, tool_calls: None, refusal: None, audio: None, #[allow(deprecated)] function_call: None, } ), _ => ChatCompletionRequestMessage::User( ChatCompletionRequestUserMessage { content: ChatCompletionRequestUserMessageContent::Array(parts), name: None, } ), }; messages.push(message); } if messages.is_empty() { return Err(AppError::ProviderError("No valid text messages to send".to_string())); } // Build request using builder pattern let mut builder = CreateChatCompletionRequestArgs::default(); builder.model(request.model.clone()); builder.messages(messages); builder.stream(true); // Enable streaming // Add optional parameters if let Some(temp) = request.temperature { builder.temperature(temp as f32); } if let Some(max_tokens) = request.max_tokens { builder.max_tokens(max_tokens as u16); } // Execute streaming API call let stream = self.client .chat() .create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?) .await .map_err(|e| AppError::ProviderError(e.to_string()))?; // Convert OpenAI stream to our stream format let model = request.model.clone(); let stream = stream.map(move |chunk_result| { match chunk_result { Ok(chunk) => { // Extract content from chunk let content = chunk.choices.first() .and_then(|choice| choice.delta.content.clone()) .unwrap_or_default(); let finish_reason = chunk.choices.first() .and_then(|choice| choice.finish_reason.clone()) .map(|reason| format!("{:?}", reason)); Ok(ProviderStreamChunk { content, finish_reason, model: model.clone(), }) } Err(e) => Err(AppError::ProviderError(e.to_string())), } }); Ok(Box::pin(stream)) } }