use serde::{Deserialize, Serialize}; use serde_json::Value; pub mod registry; // ========== OpenAI-compatible Request/Response Structs ========== #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionRequest { pub model: String, pub messages: Vec, #[serde(default)] pub temperature: Option, #[serde(default)] pub top_p: Option, #[serde(default)] pub top_k: Option, #[serde(default)] pub n: Option, #[serde(default)] pub stop: Option, // Can be string or array of strings #[serde(default)] pub max_tokens: Option, #[serde(default)] pub presence_penalty: Option, #[serde(default)] pub frequency_penalty: Option, #[serde(default)] pub stream: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub tools: Option>, #[serde(default, skip_serializing_if = "Option::is_none")] pub tool_choice: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatMessage { pub role: String, // "system", "user", "assistant", "tool" #[serde(flatten)] pub content: MessageContent, #[serde(alias = "reasoning", alias = "thought", skip_serializing_if = "Option::is_none")] pub reasoning_content: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_call_id: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(untagged)] pub enum MessageContent { Text { content: String }, Parts { content: Vec }, None, // Handle cases where content might be null but reasoning is present } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ContentPartValue { Text { text: String }, ImageUrl { image_url: ImageUrl }, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ImageUrl { pub url: String, #[serde(skip_serializing_if = "Option::is_none")] pub detail: Option, } // ========== Tool-Calling Types ========== #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Tool { #[serde(rename = "type")] pub tool_type: String, pub function: FunctionDef, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FunctionDef { pub name: String, #[serde(skip_serializing_if = "Option::is_none")] pub description: Option, #[serde(skip_serializing_if = "Option::is_none")] pub parameters: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(untagged)] pub enum ToolChoice { Mode(String), // "auto", "none", "required" Specific(ToolChoiceSpecific), } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolChoiceSpecific { #[serde(rename = "type")] pub choice_type: String, pub function: ToolChoiceFunction, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolChoiceFunction { pub name: String, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolCall { pub id: String, #[serde(rename = "type")] pub call_type: String, pub function: FunctionCall, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FunctionCall { pub name: String, pub arguments: String, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolCallDelta { pub index: u32, #[serde(skip_serializing_if = "Option::is_none")] pub id: Option, #[serde(rename = "type", skip_serializing_if = "Option::is_none")] pub call_type: Option, #[serde(skip_serializing_if = "Option::is_none")] pub function: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FunctionCallDelta { #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, #[serde(skip_serializing_if = "Option::is_none")] pub arguments: Option, } // ========== OpenAI-compatible Response Structs ========== #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionResponse { pub id: String, pub object: String, pub created: u64, pub model: String, pub choices: Vec, pub usage: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatChoice { pub index: u32, pub message: ChatMessage, pub finish_reason: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Usage { pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, #[serde(skip_serializing_if = "Option::is_none")] pub reasoning_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] pub cache_read_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] pub cache_write_tokens: Option, } // ========== Streaming Response Structs ========== #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionStreamResponse { pub id: String, pub object: String, pub created: u64, pub model: String, pub choices: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub usage: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatStreamChoice { pub index: u32, pub delta: ChatStreamDelta, pub finish_reason: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatStreamDelta { pub role: Option, pub content: Option, #[serde(alias = "reasoning", alias = "thought", skip_serializing_if = "Option::is_none")] pub reasoning_content: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, } // ========== Unified Request Format (for internal use) ========== #[derive(Debug, Clone)] pub struct UnifiedRequest { pub client_id: String, pub model: String, pub messages: Vec, pub temperature: Option, pub top_p: Option, pub top_k: Option, pub n: Option, pub stop: Option>, pub max_tokens: Option, pub presence_penalty: Option, pub frequency_penalty: Option, pub stream: bool, pub has_images: bool, pub tools: Option>, pub tool_choice: Option, } #[derive(Debug, Clone)] pub struct UnifiedMessage { pub role: String, pub content: Vec, pub reasoning_content: Option, pub tool_calls: Option>, pub name: Option, pub tool_call_id: Option, } #[derive(Debug, Clone)] pub enum ContentPart { Text { text: String }, Image(crate::multimodal::ImageInput), } // ========== Provider-specific Structs ========== #[derive(Debug, Clone, Serialize)] pub struct OpenAIRequest { pub model: String, pub messages: Vec, pub temperature: Option, pub max_tokens: Option, pub stream: Option, } #[derive(Debug, Clone, Serialize)] pub struct OpenAIMessage { pub role: String, pub content: Vec, } #[derive(Debug, Clone, Serialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum OpenAIContentPart { Text { text: String }, ImageUrl { image_url: ImageUrl }, } // Note: ImageUrl struct is defined earlier in the file // ========== Conversion Traits ========== pub trait ToOpenAI { fn to_openai(&self) -> Result; } pub trait FromOpenAI { fn from_openai(request: &OpenAIRequest) -> Result where Self: Sized; } impl UnifiedRequest { /// Hydrate all image content by fetching URLs and converting to base64/bytes pub async fn hydrate_images(&mut self) -> anyhow::Result<()> { if !self.has_images { return Ok(()); } for msg in &mut self.messages { for part in &mut msg.content { if let ContentPart::Image(image_input) = part { // Pre-fetch and validate if it's a URL if let crate::multimodal::ImageInput::Url(_url) = image_input { let (base64_data, mime_type) = image_input.to_base64().await?; *image_input = crate::multimodal::ImageInput::Base64 { data: base64_data, mime_type, }; } } } } Ok(()) } } impl TryFrom for UnifiedRequest { type Error = anyhow::Error; fn try_from(req: ChatCompletionRequest) -> Result { let mut has_images = false; // Convert OpenAI-compatible request to unified format let messages = req .messages .into_iter() .map(|msg| { let (content, _images_in_message) = match msg.content { MessageContent::Text { content } => (vec![ContentPart::Text { text: content }], false), MessageContent::Parts { content } => { let mut unified_content = Vec::new(); let mut has_images_in_msg = false; for part in content { match part { ContentPartValue::Text { text } => { unified_content.push(ContentPart::Text { text }); } ContentPartValue::ImageUrl { image_url } => { has_images_in_msg = true; has_images = true; unified_content.push(ContentPart::Image(crate::multimodal::ImageInput::from_url( image_url.url, ))); } } } (unified_content, has_images_in_msg) } MessageContent::None => (vec![], false), }; UnifiedMessage { role: msg.role, content, reasoning_content: msg.reasoning_content, tool_calls: msg.tool_calls, name: msg.name, tool_call_id: msg.tool_call_id, } }) .collect(); let stop = match req.stop { Some(Value::String(s)) => Some(vec![s]), Some(Value::Array(a)) => Some( a.into_iter() .filter_map(|v| v.as_str().map(|s| s.to_string())) .collect(), ), _ => None, }; Ok(UnifiedRequest { client_id: String::new(), // Will be populated by auth middleware model: req.model, messages, temperature: req.temperature, top_p: req.top_p, top_k: req.top_k, n: req.n, stop, max_tokens: req.max_tokens, presence_penalty: req.presence_penalty, frequency_penalty: req.frequency_penalty, stream: req.stream.unwrap_or(false), has_images, tools: req.tools, tool_choice: req.tool_choice, }) } }