From 6143b88eac812576450d48ce29bcc459e9fd5429 Mon Sep 17 00:00:00 2001 From: hobokenchicken Date: Thu, 26 Feb 2026 13:41:25 -0500 Subject: [PATCH] fix: update Grok provider to be OpenAI-compatible with vision and streaming support --- src/main.rs | 3 +- src/providers/grok.rs | 254 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 235 insertions(+), 22 deletions(-) diff --git a/src/main.rs b/src/main.rs index ea22a8e7..13f72efc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,7 @@ use llm_proxy::{ gemini::GeminiProvider, deepseek::DeepSeekProvider, grok::GrokProvider, + ollama::OllamaProvider, }, database, server, @@ -87,7 +88,7 @@ async fn main() -> Result<()> { // Initialize Ollama if config.providers.ollama.enabled { - match llm_proxy::providers::ollama::OllamaProvider::new(&config.providers.ollama, &config) { + match OllamaProvider::new(&config.providers.ollama, &config) { Ok(p) => { provider_manager.add_provider(Arc::new(p)); info!("Ollama provider initialized at {}", config.providers.ollama.base_url); diff --git a/src/providers/grok.rs b/src/providers/grok.rs index a0c64312..0e546e00 100644 --- a/src/providers/grok.rs +++ b/src/providers/grok.rs @@ -1,6 +1,8 @@ use async_trait::async_trait; use anyhow::Result; -use futures::stream::BoxStream; +use async_openai::{Client, config::OpenAIConfig}; +use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent}; +use futures::stream::{BoxStream, StreamExt}; use crate::{ models::UnifiedRequest, @@ -10,9 +12,8 @@ use crate::{ use super::{ProviderResponse, ProviderStreamChunk}; pub struct GrokProvider { - _client: reqwest::Client, + client: Client, _config: crate::config::GrokConfig, - _api_key: String, pricing: Vec, } @@ -20,14 +21,16 @@ impl GrokProvider { pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result { let api_key = app_config.get_api_key("grok")?; - let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(30)) - .build()?; + // Grok is OpenAI-compatible + let openai_config = OpenAIConfig::default() + .with_api_key(api_key) + .with_api_base(&config.base_url); + + let client = Client::with_config(openai_config); Ok(Self { - _client: client, + client, _config: config.clone(), - _api_key: api_key, pricing: app_config.pricing.grok.clone(), }) } @@ -40,24 +43,121 @@ impl super::Provider for GrokProvider { } fn supports_model(&self, model: &str) -> bool { - model.starts_with("grok-") || model.contains("grok") + model.starts_with("grok-") } fn supports_multimodal(&self) -> bool { - false // Unknown - assume false until API is researched + true // Grok supports vision models } async fn chat_completion( &self, request: UnifiedRequest, ) -> Result { - // TODO: Implement actual Grok API call (once API is available) - // For now, return placeholder response + use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; + + // Convert UnifiedRequest messages to OpenAI messages + let mut messages = Vec::with_capacity(request.messages.len()); + + for msg in request.messages { + let mut parts = Vec::with_capacity(msg.content.len()); + + for part in msg.content { + match part { + crate::models::ContentPart::Text { text } => { + parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { + text, + })); + } + crate::models::ContentPart::Image(image_input) => { + let (base64_data, mime_type) = image_input.to_base64().await + .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; + let data_url = format!("data:{};base64,{}", mime_type, base64_data); + + parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage { + image_url: ImageUrl { + url: data_url, + detail: Some(ImageDetail::Auto), + } + })); + } + } + } + + let message = match msg.role.as_str() { + "system" => ChatCompletionRequestMessage::System( + ChatCompletionRequestSystemMessage { + content: ChatCompletionRequestSystemMessageContent::Text( + parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") + ), + name: None, + } + ), + "assistant" => ChatCompletionRequestMessage::Assistant( + ChatCompletionRequestAssistantMessage { + content: Some(ChatCompletionRequestAssistantMessageContent::Text( + parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") + )), + name: None, + tool_calls: None, + refusal: None, + audio: None, + #[allow(deprecated)] + function_call: None, + } + ), + _ => ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: ChatCompletionRequestUserMessageContent::Array(parts), + name: None, + } + ), + }; + messages.push(message); + } + + if messages.is_empty() { + return Err(AppError::ProviderError("No valid text messages to send".to_string())); + } + + // Build request using builder pattern + let mut builder = CreateChatCompletionRequestArgs::default(); + builder.model(request.model.clone()); + builder.messages(messages); + + // Add optional parameters + if let Some(temp) = request.temperature { + builder.temperature(temp as f32); + } + + if let Some(max_tokens) = request.max_tokens { + builder.max_tokens(max_tokens as u16); + } + + // Execute API call + let response = self.client + .chat() + .create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?) + .await + .map_err(|e| AppError::ProviderError(e.to_string()))?; + + // Extract content from response + let content = response + .choices + .first() + .and_then(|choice| choice.message.content.clone()) + .unwrap_or_default(); + + // Extract token usage + let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32; + let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32; + let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32; + Ok(ProviderResponse { - content: "Grok provider not yet implemented (API not researched)".to_string(), - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, + content, + prompt_tokens, + completion_tokens, + total_tokens, model: request.model, }) } @@ -74,19 +174,131 @@ impl super::Provider for GrokProvider { } } + // Fallback to static pricing if not in registry let (prompt_rate, completion_rate) = self.pricing.iter() .find(|p| model.contains(&p.model)) .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) - .unwrap_or((1.0, 3.0)); // Default to some reasonable Grok price if not found + .unwrap_or((5.0, 15.0)); // Grok-2 pricing is roughly this (prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) } async fn chat_completion_stream( &self, - _request: UnifiedRequest, + request: UnifiedRequest, ) -> Result>, AppError> { - // Grok API not yet implemented - Err(AppError::ProviderError("Streaming not supported for Grok provider (API not implemented)".to_string())) + use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail}; + + // Convert UnifiedRequest messages to OpenAI messages + let mut messages = Vec::with_capacity(request.messages.len()); + + for msg in request.messages { + let mut parts = Vec::with_capacity(msg.content.len()); + + for part in msg.content { + match part { + crate::models::ContentPart::Text { text } => { + parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText { + text, + })); + } + crate::models::ContentPart::Image(image_input) => { + let (base64_data, mime_type) = image_input.to_base64().await + .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; + let data_url = format!("data:{};base64,{}", mime_type, base64_data); + + parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage { + image_url: ImageUrl { + url: data_url, + detail: Some(ImageDetail::Auto), + } + })); + } + } + } + + let message = match msg.role.as_str() { + "system" => ChatCompletionRequestMessage::System( + ChatCompletionRequestSystemMessage { + content: ChatCompletionRequestSystemMessageContent::Text( + parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") + ), + name: None, + } + ), + "assistant" => ChatCompletionRequestMessage::Assistant( + ChatCompletionRequestAssistantMessage { + content: Some(ChatCompletionRequestAssistantMessageContent::Text( + parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::>().join("\n") + )), + name: None, + tool_calls: None, + refusal: None, + audio: None, + #[allow(deprecated)] + function_call: None, + } + ), + _ => ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: ChatCompletionRequestUserMessageContent::Array(parts), + name: None, + } + ), + }; + messages.push(message); + } + + if messages.is_empty() { + return Err(AppError::ProviderError("No valid text messages to send".to_string())); + } + + // Build request using builder pattern + let mut builder = CreateChatCompletionRequestArgs::default(); + builder.model(request.model.clone()); + builder.messages(messages); + builder.stream(true); // Enable streaming + + // Add optional parameters + if let Some(temp) = request.temperature { + builder.temperature(temp as f32); + } + + if let Some(max_tokens) = request.max_tokens { + builder.max_tokens(max_tokens as u16); + } + + // Execute streaming API call + let stream = self.client + .chat() + .create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?) + .await + .map_err(|e| AppError::ProviderError(e.to_string()))?; + + // Convert OpenAI stream to our stream format + let model = request.model.clone(); + let stream = stream.map(move |chunk_result| { + match chunk_result { + Ok(chunk) => { + // Extract content from chunk + let content = chunk.choices.first() + .and_then(|choice| choice.delta.content.clone()) + .unwrap_or_default(); + + let finish_reason = chunk.choices.first() + .and_then(|choice| choice.finish_reason.clone()) + .map(|reason| format!("{:?}", reason)); + + Ok(ProviderStreamChunk { + content, + finish_reason, + model: model.clone(), + }) + } + Err(e) => Err(AppError::ProviderError(e.to_string())), + } + }); + + Ok(Box::pin(stream)) } -} \ No newline at end of file +}