From 371e1f2e776adf4388ae2bdb00f6eb817db154cf Mon Sep 17 00:00:00 2001 From: hobokenchicken Date: Thu, 26 Feb 2026 13:31:09 -0500 Subject: [PATCH] feat: add Ollama provider support and dashboard integration --- .env.example | 5 + .gitignore | 5 + README.md | 9 ++ src/config/mod.rs | 15 +- src/dashboard/mod.rs | 17 +++ src/main.rs | 11 ++ src/providers/mod.rs | 1 + src/providers/ollama.rs | 313 ++++++++++++++++++++++++++++++++++++++++ 8 files changed, 375 insertions(+), 1 deletion(-) create mode 100644 .gitignore create mode 100644 src/providers/ollama.rs diff --git a/.env.example b/.env.example index c84d43d9..f2851a1f 100644 --- a/.env.example +++ b/.env.example @@ -13,6 +13,11 @@ DEEPSEEK_API_KEY=your_deepseek_api_key_here # xAI Grok (not yet available) GROK_API_KEY=your_grok_api_key_here +# Ollama (local server) +# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://your-ollama-host:11434/v1 +# LLM_PROXY__PROVIDERS__OLLAMA__ENABLED=true +# LLM_PROXY__PROVIDERS__OLLAMA__MODELS=llama3,mistral,llava + # Authentication tokens (comma-separated list) LLM_PROXY__SERVER__AUTH_TOKENS=your_bearer_token_here,another_token diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..fdb9268f --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +/target +/.env +/*.db +/*.db-shm +/*.db-wal diff --git a/README.md b/README.md index fbecffb5..299b98e0 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ A unified, high-performance LLM proxy gateway built in Rust. It provides a singl * **Google Gemini:** Support for the latest Gemini 2.0 models. * **DeepSeek:** High-performance, low-cost integration. * **xAI Grok:** Integration for Grok-series models. + * **Ollama:** Support for local LLMs running on your machine or another host. - **Observability & Tracking:** * **Real-time Costing:** Fetches live pricing and context specs from `models.dev` on startup. * **Token Counting:** Precise estimation using `tiktoken-rs`. @@ -52,6 +53,14 @@ A unified, high-performance LLM proxy gateway built in Rust. It provides a singl 3. Configure providers and server: Edit `config.toml` to customize models, pricing fallbacks, and port settings. + **Ollama Example (config.toml):** + ```toml + [providers.ollama] + enabled = true + base_url = "http://192.168.1.50:11434/v1" + models = ["llama3", "mistral"] + ``` + 4. Run the proxy: ```bash cargo run --release diff --git a/src/config/mod.rs b/src/config/mod.rs index 4a251dca..649a35e8 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -23,6 +23,7 @@ pub struct ProviderConfig { pub gemini: GeminiConfig, pub deepseek: DeepSeekConfig, pub grok: GrokConfig, + pub ollama: OllamaConfig, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -57,6 +58,13 @@ pub struct GrokConfig { pub enabled: bool, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OllamaConfig { + pub base_url: String, + pub enabled: bool, + pub models: Vec, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelMappingConfig { pub patterns: Vec<(String, String)>, @@ -68,6 +76,7 @@ pub struct PricingConfig { pub gemini: Vec, pub deepseek: Vec, pub grok: Vec, + pub ollama: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -119,7 +128,10 @@ impl AppConfig { .set_default("providers.grok.api_key_env", "GROK_API_KEY")? .set_default("providers.grok.base_url", "https://api.x.ai/v1")? .set_default("providers.grok.default_model", "grok-beta")? - .set_default("providers.grok.enabled", false)?; // Disabled by default until API is researched + .set_default("providers.grok.enabled", false)? + .set_default("providers.ollama.base_url", "http://localhost:11434/v1")? + .set_default("providers.ollama.enabled", false)? + .set_default("providers.ollama.models", Vec::::new())?; // Load from config file if exists let config_path = config_path.unwrap_or_else(|| std::env::current_dir().unwrap().join("config.toml")); @@ -151,6 +163,7 @@ impl AppConfig { gemini: vec![], deepseek: vec![], grok: vec![], + ollama: vec![], }; Ok(Arc::new(AppConfig { diff --git a/src/dashboard/mod.rs b/src/dashboard/mod.rs index 511e620a..0079e538 100644 --- a/src/dashboard/mod.rs +++ b/src/dashboard/mod.rs @@ -500,6 +500,16 @@ async fn handle_get_providers(State(state): State) -> Json) -> Json Result<()> { } } + // Initialize Ollama + if config.providers.ollama.enabled { + match llm_proxy::providers::ollama::OllamaProvider::new(&config.providers.ollama, &config) { + Ok(p) => { + provider_manager.add_provider(Arc::new(p)); + info!("Ollama provider initialized at {}", config.providers.ollama.base_url); + } + Err(e) => error!("Failed to initialize Ollama provider: {}", e), + } + } + // Create rate limit manager let rate_limit_manager = RateLimitManager::new( RateLimiterConfig::default(), diff --git a/src/providers/mod.rs b/src/providers/mod.rs index d7151629..6de57d1c 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -10,6 +10,7 @@ pub mod openai; pub mod gemini; pub mod deepseek; pub mod grok; +pub mod ollama; #[async_trait] pub trait Provider: Send + Sync { diff --git a/src/providers/ollama.rs b/src/providers/ollama.rs new file mode 100644 index 00000000..e6adc894 --- /dev/null +++ b/src/providers/ollama.rs @@ -0,0 +1,313 @@ +use async_trait::async_trait; +use anyhow::Result; +use async_openai::{Client, config::OpenAIConfig}; +use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent}; +use futures::stream::{BoxStream, StreamExt}; + +use crate::{ + models::UnifiedRequest, + errors::AppError, + config::AppConfig, +}; +use super::{ProviderResponse, ProviderStreamChunk}; + +pub struct OllamaProvider { + client: Client, + _config: crate::config::OllamaConfig, + pricing: Vec, +} + +impl OllamaProvider { + pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result { + // Ollama usually doesn't need an API key, use a dummy one + let openai_config = OpenAIConfig::default() + .with_api_key("ollama") + .with_api_base(&config.base_url); + + let client = Client::with_config(openai_config); + + Ok(Self { + client, + _config: config.clone(), + pricing: app_config.pricing.ollama.clone(), + }) + } +} + +#[async_trait] +impl super::Provider for OllamaProvider { + fn name(&self) -> &str { + "ollama" + } + + fn supports_model(&self, model: &str) -> bool { + // Check if model is in the list of configured Ollama models + self._config.models.iter().any(|m| m == model) || model.starts_with("ollama/") + } + + fn supports_multimodal(&self) -> bool { + true // Many Ollama models support vision (e.g. llava, moondream) + } + + async fn chat_completion( + &self, + request: UnifiedRequest, + ) -> Result { + use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; + + // Strip "ollama/" prefix if present + let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string(); + + // Convert UnifiedRequest messages to OpenAI messages + let mut messages = Vec::with_capacity(request.messages.len()); + + for msg in request.messages { + let mut parts = Vec::with_capacity(msg.content.len()); + + for part in msg.content { + match part { + crate::models::ContentPart::Text { text } => { + parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { + text, + })); + } + crate::models::ContentPart::Image(image_input) => { + let (base64_data, mime_type) = image_input.to_base64().await + .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; + let data_url = format!("data:{};base64,{}", mime_type, base64_data); + + parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage { + image_url: ImageUrl { + url: data_url, + detail: Some(ImageDetail::Auto), + } + })); + } + } + } + + let message = match msg.role.as_str() { + "system" => ChatCompletionRequestMessage::System( + ChatCompletionRequestSystemMessage { + content: ChatCompletionRequestSystemMessageContent::Text( + parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join(" +") + ), + name: None, + } + ), + "assistant" => ChatCompletionRequestMessage::Assistant( + ChatCompletionRequestAssistantMessage { + content: Some(ChatCompletionRequestAssistantMessageContent::Text( + parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join(" +") + )), + name: None, + tool_calls: None, + refusal: None, + audio: None, + #[allow(deprecated)] + function_call: None, + } + ), + _ => ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: ChatCompletionRequestUserMessageContent::Array(parts), + name: None, + } + ), + }; + messages.push(message); + } + + if messages.is_empty() { + return Err(AppError::ProviderError("No valid text messages to send".to_string())); + } + + // Build request using builder pattern + let mut builder = CreateChatCompletionRequestArgs::default(); + builder.model(model); + builder.messages(messages); + + // Add optional parameters + if let Some(temp) = request.temperature { + builder.temperature(temp as f32); + } + + if let Some(max_tokens) = request.max_tokens { + builder.max_tokens(max_tokens as u16); + } + + // Execute API call + let response = self.client + .chat() + .create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?) + .await + .map_err(|e| AppError::ProviderError(e.to_string()))?; + + // Extract content from response + let content = response + .choices + .first() + .and_then(|choice| choice.message.content.clone()) + .unwrap_or_default(); + + // Extract token usage + let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32; + let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32; + let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32; + + Ok(ProviderResponse { + content, + prompt_tokens, + completion_tokens, + total_tokens, + model: request.model, + }) + } + + fn estimate_tokens(&self, request: &UnifiedRequest) -> Result { + Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) + } + + fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 { + if let Some(metadata) = registry.find_model(model) { + if let Some(cost) = &metadata.cost { + return (prompt_tokens as f64 * cost.input / 1_000_000.0) + + (completion_tokens as f64 * cost.output / 1_000_000.0); + } + } + + // Ollama is free by default + let (prompt_rate, completion_rate) = self.pricing.iter() + .find(|p| model.contains(&p.model)) + .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) + .unwrap_or((0.0, 0.0)); + + (prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) + } + + async fn chat_completion_stream( + &self, + request: UnifiedRequest, + ) -> Result>, AppError> { + use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; + + // Strip "ollama/" prefix if present + let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string(); + + // Convert UnifiedRequest messages to OpenAI messages + let mut messages = Vec::with_capacity(request.messages.len()); + + for msg in request.messages { + let mut parts = Vec::with_capacity(msg.content.len()); + + for part in msg.content { + match part { + crate::models::ContentPart::Text { text } => { + parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { + text, + })); + } + crate::models::ContentPart::Image(image_input) => { + let (base64_data, mime_type) = image_input.to_base64().await + .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; + let data_url = format!("data:{};base64,{}", mime_type, base64_data); + + parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage { + image_url: ImageUrl { + url: data_url, + detail: Some(ImageDetail::Auto), + } + })); + } + } + } + + let message = match msg.role.as_str() { + "system" => ChatCompletionRequestMessage::System( + ChatCompletionRequestSystemMessage { + content: ChatCompletionRequestSystemMessageContent::Text( + parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join(" +") + ), + name: None, + } + ), + "assistant" => ChatCompletionRequestMessage::Assistant( + ChatCompletionRequestAssistantMessage { + content: Some(ChatCompletionRequestAssistantMessageContent::Text( + parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join(" +") + )), + name: None, + tool_calls: None, + refusal: None, + audio: None, + #[allow(deprecated)] + function_call: None, + } + ), + _ => ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: ChatCompletionRequestUserMessageContent::Array(parts), + name: None, + } + ), + }; + messages.push(message); + } + + if messages.is_empty() { + return Err(AppError::ProviderError("No valid text messages to send".to_string())); + } + + // Build request using builder pattern + let mut builder = CreateChatCompletionRequestArgs::default(); + builder.model(model); + builder.messages(messages); + builder.stream(true); // Enable streaming + + // Add optional parameters + if let Some(temp) = request.temperature { + builder.temperature(temp as f32); + } + + if let Some(max_tokens) = request.max_tokens { + builder.max_tokens(max_tokens as u16); + } + + // Execute streaming API call + let stream = self.client + .chat() + .create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?) + .await + .map_err(|e| AppError::ProviderError(e.to_string()))?; + + // Convert OpenAI stream to our stream format + let model_name = request.model.clone(); + let stream = stream.map(move |chunk_result| { + match chunk_result { + Ok(chunk) => { + // Extract content from chunk + let content = chunk.choices.first() + .and_then(|choice| choice.delta.content.clone()) + .unwrap_or_default(); + + let finish_reason = chunk.choices.first() + .and_then(|choice| choice.finish_reason.clone()) + .map(|reason| format!("{:?}", reason)); + + Ok(ProviderStreamChunk { + content, + finish_reason, + model: model_name.clone(), + }) + } + Err(e) => Err(AppError::ProviderError(e.to_string())), + } + }); + + Ok(Box::pin(stream)) + } +}