378 lines
11 KiB
Rust
378 lines
11 KiB
Rust
use serde::{Deserialize, Serialize};
|
|
use serde_json::Value;
|
|
|
|
pub mod registry;
|
|
|
|
// ========== OpenAI-compatible Request/Response Structs ==========
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ChatCompletionRequest {
|
|
pub model: String,
|
|
pub messages: Vec<ChatMessage>,
|
|
#[serde(default)]
|
|
pub temperature: Option<f64>,
|
|
#[serde(default)]
|
|
pub top_p: Option<f64>,
|
|
#[serde(default)]
|
|
pub top_k: Option<u32>,
|
|
#[serde(default)]
|
|
pub n: Option<u32>,
|
|
#[serde(default)]
|
|
pub stop: Option<Value>, // Can be string or array of strings
|
|
#[serde(default)]
|
|
pub max_tokens: Option<u32>,
|
|
#[serde(default)]
|
|
pub presence_penalty: Option<f64>,
|
|
#[serde(default)]
|
|
pub frequency_penalty: Option<f64>,
|
|
#[serde(default)]
|
|
pub stream: Option<bool>,
|
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
pub tools: Option<Vec<Tool>>,
|
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
pub tool_choice: Option<ToolChoice>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ChatMessage {
|
|
pub role: String, // "system", "user", "assistant", "tool"
|
|
#[serde(flatten)]
|
|
pub content: MessageContent,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub reasoning_content: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub tool_calls: Option<Vec<ToolCall>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub name: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub tool_call_id: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(untagged)]
|
|
pub enum MessageContent {
|
|
Text { content: String },
|
|
Parts { content: Vec<ContentPartValue> },
|
|
None, // Handle cases where content might be null but reasoning is present
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(tag = "type", rename_all = "snake_case")]
|
|
pub enum ContentPartValue {
|
|
Text { text: String },
|
|
ImageUrl { image_url: ImageUrl },
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ImageUrl {
|
|
pub url: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub detail: Option<String>,
|
|
}
|
|
|
|
// ========== Tool-Calling Types ==========
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct Tool {
|
|
#[serde(rename = "type")]
|
|
pub tool_type: String,
|
|
pub function: FunctionDef,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct FunctionDef {
|
|
pub name: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub description: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub parameters: Option<Value>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(untagged)]
|
|
pub enum ToolChoice {
|
|
Mode(String), // "auto", "none", "required"
|
|
Specific(ToolChoiceSpecific),
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ToolChoiceSpecific {
|
|
#[serde(rename = "type")]
|
|
pub choice_type: String,
|
|
pub function: ToolChoiceFunction,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ToolChoiceFunction {
|
|
pub name: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ToolCall {
|
|
pub id: String,
|
|
#[serde(rename = "type")]
|
|
pub call_type: String,
|
|
pub function: FunctionCall,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct FunctionCall {
|
|
pub name: String,
|
|
pub arguments: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ToolCallDelta {
|
|
pub index: u32,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub id: Option<String>,
|
|
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
|
pub call_type: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub function: Option<FunctionCallDelta>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct FunctionCallDelta {
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub name: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub arguments: Option<String>,
|
|
}
|
|
|
|
// ========== OpenAI-compatible Response Structs ==========
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ChatCompletionResponse {
|
|
pub id: String,
|
|
pub object: String,
|
|
pub created: u64,
|
|
pub model: String,
|
|
pub choices: Vec<ChatChoice>,
|
|
pub usage: Option<Usage>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ChatChoice {
|
|
pub index: u32,
|
|
pub message: ChatMessage,
|
|
pub finish_reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct Usage {
|
|
pub prompt_tokens: u32,
|
|
pub completion_tokens: u32,
|
|
pub total_tokens: u32,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub cache_read_tokens: Option<u32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub cache_write_tokens: Option<u32>,
|
|
}
|
|
|
|
// ========== Streaming Response Structs ==========
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ChatCompletionStreamResponse {
|
|
pub id: String,
|
|
pub object: String,
|
|
pub created: u64,
|
|
pub model: String,
|
|
pub choices: Vec<ChatStreamChoice>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ChatStreamChoice {
|
|
pub index: u32,
|
|
pub delta: ChatStreamDelta,
|
|
pub finish_reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ChatStreamDelta {
|
|
pub role: Option<String>,
|
|
pub content: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub reasoning_content: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
|
}
|
|
|
|
// ========== Unified Request Format (for internal use) ==========
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct UnifiedRequest {
|
|
pub client_id: String,
|
|
pub model: String,
|
|
pub messages: Vec<UnifiedMessage>,
|
|
pub temperature: Option<f64>,
|
|
pub top_p: Option<f64>,
|
|
pub top_k: Option<u32>,
|
|
pub n: Option<u32>,
|
|
pub stop: Option<Vec<String>>,
|
|
pub max_tokens: Option<u32>,
|
|
pub presence_penalty: Option<f64>,
|
|
pub frequency_penalty: Option<f64>,
|
|
pub stream: bool,
|
|
pub has_images: bool,
|
|
pub tools: Option<Vec<Tool>>,
|
|
pub tool_choice: Option<ToolChoice>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct UnifiedMessage {
|
|
pub role: String,
|
|
pub content: Vec<ContentPart>,
|
|
pub reasoning_content: Option<String>,
|
|
pub tool_calls: Option<Vec<ToolCall>>,
|
|
pub name: Option<String>,
|
|
pub tool_call_id: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub enum ContentPart {
|
|
Text { text: String },
|
|
Image(crate::multimodal::ImageInput),
|
|
}
|
|
|
|
// ========== Provider-specific Structs ==========
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
pub struct OpenAIRequest {
|
|
pub model: String,
|
|
pub messages: Vec<OpenAIMessage>,
|
|
pub temperature: Option<f64>,
|
|
pub max_tokens: Option<u32>,
|
|
pub stream: Option<bool>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
pub struct OpenAIMessage {
|
|
pub role: String,
|
|
pub content: Vec<OpenAIContentPart>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(tag = "type", rename_all = "snake_case")]
|
|
pub enum OpenAIContentPart {
|
|
Text { text: String },
|
|
ImageUrl { image_url: ImageUrl },
|
|
}
|
|
|
|
// Note: ImageUrl struct is defined earlier in the file
|
|
|
|
// ========== Conversion Traits ==========
|
|
|
|
pub trait ToOpenAI {
|
|
fn to_openai(&self) -> Result<OpenAIRequest, anyhow::Error>;
|
|
}
|
|
|
|
pub trait FromOpenAI {
|
|
fn from_openai(request: &OpenAIRequest) -> Result<Self, anyhow::Error>
|
|
where
|
|
Self: Sized;
|
|
}
|
|
|
|
impl UnifiedRequest {
|
|
/// Hydrate all image content by fetching URLs and converting to base64/bytes
|
|
pub async fn hydrate_images(&mut self) -> anyhow::Result<()> {
|
|
if !self.has_images {
|
|
return Ok(());
|
|
}
|
|
|
|
for msg in &mut self.messages {
|
|
for part in &mut msg.content {
|
|
if let ContentPart::Image(image_input) = part {
|
|
// Pre-fetch and validate if it's a URL
|
|
if let crate::multimodal::ImageInput::Url(_url) = image_input {
|
|
let (base64_data, mime_type) = image_input.to_base64().await?;
|
|
*image_input = crate::multimodal::ImageInput::Base64 {
|
|
data: base64_data,
|
|
mime_type,
|
|
};
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
|
|
type Error = anyhow::Error;
|
|
|
|
fn try_from(req: ChatCompletionRequest) -> Result<Self, Self::Error> {
|
|
let mut has_images = false;
|
|
|
|
// Convert OpenAI-compatible request to unified format
|
|
let messages = req
|
|
.messages
|
|
.into_iter()
|
|
.map(|msg| {
|
|
let (content, _images_in_message) = match msg.content {
|
|
MessageContent::Text { content } => (vec![ContentPart::Text { text: content }], false),
|
|
MessageContent::Parts { content } => {
|
|
let mut unified_content = Vec::new();
|
|
let mut has_images_in_msg = false;
|
|
|
|
for part in content {
|
|
match part {
|
|
ContentPartValue::Text { text } => {
|
|
unified_content.push(ContentPart::Text { text });
|
|
}
|
|
ContentPartValue::ImageUrl { image_url } => {
|
|
has_images_in_msg = true;
|
|
has_images = true;
|
|
unified_content.push(ContentPart::Image(crate::multimodal::ImageInput::from_url(
|
|
image_url.url,
|
|
)));
|
|
}
|
|
}
|
|
}
|
|
|
|
(unified_content, has_images_in_msg)
|
|
}
|
|
MessageContent::None => (vec![], false),
|
|
};
|
|
|
|
UnifiedMessage {
|
|
role: msg.role,
|
|
content,
|
|
reasoning_content: msg.reasoning_content,
|
|
tool_calls: msg.tool_calls,
|
|
name: msg.name,
|
|
tool_call_id: msg.tool_call_id,
|
|
}
|
|
})
|
|
.collect();
|
|
|
|
let stop = match req.stop {
|
|
Some(Value::String(s)) => Some(vec![s]),
|
|
Some(Value::Array(a)) => Some(
|
|
a.into_iter()
|
|
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
|
.collect(),
|
|
),
|
|
_ => None,
|
|
};
|
|
|
|
Ok(UnifiedRequest {
|
|
client_id: String::new(), // Will be populated by auth middleware
|
|
model: req.model,
|
|
messages,
|
|
temperature: req.temperature,
|
|
top_p: req.top_p,
|
|
top_k: req.top_k,
|
|
n: req.n,
|
|
stop,
|
|
max_tokens: req.max_tokens,
|
|
presence_penalty: req.presence_penalty,
|
|
frequency_penalty: req.frequency_penalty,
|
|
stream: req.stream.unwrap_or(false),
|
|
has_images,
|
|
tools: req.tools,
|
|
tool_choice: req.tool_choice,
|
|
})
|
|
}
|
|
}
|