Init repo
This commit is contained in:
285
src/multimodal/mod.rs
Normal file
285
src/multimodal/mod.rs
Normal file
@@ -0,0 +1,285 @@
|
||||
//! Multimodal support for image processing and conversion
|
||||
//!
|
||||
//! This module handles:
|
||||
//! 1. Image format detection and conversion
|
||||
//! 2. Base64 encoding/decoding
|
||||
//! 3. URL fetching for images
|
||||
//! 4. Provider-specific image format conversion
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Supported image formats for multimodal input
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ImageInput {
|
||||
/// Base64-encoded image data with MIME type
|
||||
Base64 {
|
||||
data: String,
|
||||
mime_type: String,
|
||||
},
|
||||
/// URL to fetch image from
|
||||
Url(String),
|
||||
/// Raw bytes with MIME type
|
||||
Bytes {
|
||||
data: Vec<u8>,
|
||||
mime_type: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl ImageInput {
|
||||
/// Create ImageInput from base64 string
|
||||
pub fn from_base64(data: String, mime_type: String) -> Self {
|
||||
Self::Base64 { data, mime_type }
|
||||
}
|
||||
|
||||
/// Create ImageInput from URL
|
||||
pub fn from_url(url: String) -> Self {
|
||||
Self::Url(url)
|
||||
}
|
||||
|
||||
/// Create ImageInput from raw bytes
|
||||
pub fn from_bytes(data: Vec<u8>, mime_type: String) -> Self {
|
||||
Self::Bytes { data, mime_type }
|
||||
}
|
||||
|
||||
/// Get MIME type if available
|
||||
pub fn mime_type(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::Base64 { mime_type, .. } => Some(mime_type),
|
||||
Self::Bytes { mime_type, .. } => Some(mime_type),
|
||||
Self::Url(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to base64 if not already
|
||||
pub async fn to_base64(&self) -> Result<(String, String)> {
|
||||
match self {
|
||||
Self::Base64 { data, mime_type } => Ok((data.clone(), mime_type.clone())),
|
||||
Self::Bytes { data, mime_type } => {
|
||||
let base64_data = general_purpose::STANDARD.encode(data);
|
||||
Ok((base64_data, mime_type.clone()))
|
||||
}
|
||||
Self::Url(url) => {
|
||||
// Fetch image from URL
|
||||
info!("Fetching image from URL: {}", url);
|
||||
let response = reqwest::get(url)
|
||||
.await
|
||||
.context("Failed to fetch image from URL")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!("Failed to fetch image: HTTP {}", response.status());
|
||||
}
|
||||
|
||||
let mime_type = response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.unwrap_or("image/jpeg")
|
||||
.to_string();
|
||||
|
||||
let bytes = response.bytes().await.context("Failed to read image bytes")?;
|
||||
|
||||
let base64_data = general_purpose::STANDARD.encode(&bytes);
|
||||
Ok((base64_data, mime_type))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get image dimensions (width, height)
|
||||
pub async fn get_dimensions(&self) -> Result<(u32, u32)> {
|
||||
let bytes = match self {
|
||||
Self::Base64 { data, .. } => {
|
||||
general_purpose::STANDARD.decode(data).context("Failed to decode base64")?
|
||||
}
|
||||
Self::Bytes { data, .. } => data.clone(),
|
||||
Self::Url(_) => {
|
||||
let (base64_data, _) = self.to_base64().await?;
|
||||
general_purpose::STANDARD.decode(&base64_data).context("Failed to decode base64")?
|
||||
}
|
||||
};
|
||||
|
||||
let img = image::load_from_memory(&bytes).context("Failed to load image from bytes")?;
|
||||
Ok((img.width(), img.height()))
|
||||
}
|
||||
|
||||
/// Validate image size and format
|
||||
pub async fn validate(&self, max_size_mb: f64) -> Result<()> {
|
||||
let (width, height) = self.get_dimensions().await?;
|
||||
|
||||
// Check dimensions
|
||||
if width > 4096 || height > 4096 {
|
||||
warn!("Image dimensions too large: {}x{}", width, height);
|
||||
// Continue anyway, but log warning
|
||||
}
|
||||
|
||||
// Check file size
|
||||
let size_bytes = match self {
|
||||
Self::Base64 { data, .. } => {
|
||||
// Base64 size is ~4/3 of original
|
||||
(data.len() as f64 * 0.75) as usize
|
||||
}
|
||||
Self::Bytes { data, .. } => data.len(),
|
||||
Self::Url(_) => {
|
||||
// For URLs, we'd need to fetch to check size
|
||||
// Skip size check for URLs for now
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let size_mb = size_bytes as f64 / (1024.0 * 1024.0);
|
||||
if size_mb > max_size_mb {
|
||||
anyhow::bail!("Image too large: {:.2}MB > {:.2}MB limit", size_mb, max_size_mb);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Provider-specific image format conversion
|
||||
pub struct ImageConverter;
|
||||
|
||||
impl ImageConverter {
|
||||
/// Convert image to OpenAI-compatible format
|
||||
pub async fn to_openai_format(image: &ImageInput) -> Result<serde_json::Value> {
|
||||
let (base64_data, mime_type) = image.to_base64().await?;
|
||||
|
||||
// OpenAI expects data URL format: "data:image/jpeg;base64,{data}"
|
||||
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": data_url,
|
||||
"detail": "auto" // Can be "low", "high", or "auto"
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Convert image to Gemini-compatible format
|
||||
pub async fn to_gemini_format(image: &ImageInput) -> Result<serde_json::Value> {
|
||||
let (base64_data, mime_type) = image.to_base64().await?;
|
||||
|
||||
// Gemini expects inline data format
|
||||
Ok(serde_json::json!({
|
||||
"inline_data": {
|
||||
"mime_type": mime_type,
|
||||
"data": base64_data
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Convert image to DeepSeek-compatible format
|
||||
pub async fn to_deepseek_format(image: &ImageInput) -> Result<serde_json::Value> {
|
||||
// DeepSeek uses OpenAI-compatible format for vision models
|
||||
Self::to_openai_format(image).await
|
||||
}
|
||||
|
||||
/// Detect if a model supports multimodal input
|
||||
pub fn model_supports_multimodal(model: &str) -> bool {
|
||||
// OpenAI vision models
|
||||
if (model.starts_with("gpt-4") && (model.contains("vision") || model.contains("-v") || model.contains("4o"))) ||
|
||||
model.starts_with("o1-") || model.starts_with("o3-") {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Gemini vision models
|
||||
if model.starts_with("gemini") {
|
||||
// Most Gemini models support vision
|
||||
return true;
|
||||
}
|
||||
|
||||
// DeepSeek vision models
|
||||
if model.starts_with("deepseek-vl") {
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse OpenAI-compatible multimodal message content
|
||||
pub fn parse_openai_content(content: &serde_json::Value) -> Result<Vec<(String, Option<ImageInput>)>> {
|
||||
let mut parts = Vec::new();
|
||||
|
||||
if let Some(content_str) = content.as_str() {
|
||||
// Simple text content
|
||||
parts.push((content_str.to_string(), None));
|
||||
} else if let Some(content_array) = content.as_array() {
|
||||
// Array of content parts (text and/or images)
|
||||
for part in content_array {
|
||||
if let Some(part_obj) = part.as_object() {
|
||||
if let Some(part_type) = part_obj.get("type").and_then(|t| t.as_str()) {
|
||||
match part_type {
|
||||
"text" => {
|
||||
if let Some(text) = part_obj.get("text").and_then(|t| t.as_str()) {
|
||||
parts.push((text.to_string(), None));
|
||||
}
|
||||
}
|
||||
"image_url" => {
|
||||
if let Some(image_url_obj) = part_obj.get("image_url").and_then(|o| o.as_object()) {
|
||||
if let Some(url) = image_url_obj.get("url").and_then(|u| u.as_str()) {
|
||||
if url.starts_with("data:") {
|
||||
// Parse data URL
|
||||
if let Some((mime_type, data)) = parse_data_url(url) {
|
||||
let image_input = ImageInput::from_base64(data, mime_type);
|
||||
parts.push(("".to_string(), Some(image_input)));
|
||||
}
|
||||
} else {
|
||||
// Regular URL
|
||||
let image_input = ImageInput::from_url(url.to_string());
|
||||
parts.push(("".to_string(), Some(image_input)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("Unknown content part type: {}", part_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(parts)
|
||||
}
|
||||
|
||||
/// Parse data URL (data:image/jpeg;base64,{data})
|
||||
fn parse_data_url(data_url: &str) -> Option<(String, String)> {
|
||||
if !data_url.starts_with("data:") {
|
||||
return None;
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = data_url[5..].split(";base64,").collect();
|
||||
if parts.len() != 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mime_type = parts[0].to_string();
|
||||
let data = parts[1].to_string();
|
||||
|
||||
Some((mime_type, data))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_data_url() {
|
||||
let test_url = "data:image/jpeg;base64,SGVsbG8gV29ybGQ="; // "Hello World" in base64
|
||||
let (mime_type, data) = parse_data_url(test_url).unwrap();
|
||||
|
||||
assert_eq!(mime_type, "image/jpeg");
|
||||
assert_eq!(data, "SGVsbG8gV29ybGQ=");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_model_supports_multimodal() {
|
||||
assert!(ImageConverter::model_supports_multimodal("gpt-4-vision-preview"));
|
||||
assert!(ImageConverter::model_supports_multimodal("gemini-pro-vision"));
|
||||
assert!(!ImageConverter::model_supports_multimodal("gpt-3.5-turbo"));
|
||||
assert!(!ImageConverter::model_supports_multimodal("gemini-pro"));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user