fix: update Grok provider to be OpenAI-compatible with vision and streaming support
This commit is contained in:
@@ -13,6 +13,7 @@ use llm_proxy::{
|
|||||||
gemini::GeminiProvider,
|
gemini::GeminiProvider,
|
||||||
deepseek::DeepSeekProvider,
|
deepseek::DeepSeekProvider,
|
||||||
grok::GrokProvider,
|
grok::GrokProvider,
|
||||||
|
ollama::OllamaProvider,
|
||||||
},
|
},
|
||||||
database,
|
database,
|
||||||
server,
|
server,
|
||||||
@@ -87,7 +88,7 @@ async fn main() -> Result<()> {
|
|||||||
|
|
||||||
// Initialize Ollama
|
// Initialize Ollama
|
||||||
if config.providers.ollama.enabled {
|
if config.providers.ollama.enabled {
|
||||||
match llm_proxy::providers::ollama::OllamaProvider::new(&config.providers.ollama, &config) {
|
match OllamaProvider::new(&config.providers.ollama, &config) {
|
||||||
Ok(p) => {
|
Ok(p) => {
|
||||||
provider_manager.add_provider(Arc::new(p));
|
provider_manager.add_provider(Arc::new(p));
|
||||||
info!("Ollama provider initialized at {}", config.providers.ollama.base_url);
|
info!("Ollama provider initialized at {}", config.providers.ollama.base_url);
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use futures::stream::BoxStream;
|
use async_openai::{Client, config::OpenAIConfig};
|
||||||
|
use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent};
|
||||||
|
use futures::stream::{BoxStream, StreamExt};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
models::UnifiedRequest,
|
models::UnifiedRequest,
|
||||||
@@ -10,9 +12,8 @@ use crate::{
|
|||||||
use super::{ProviderResponse, ProviderStreamChunk};
|
use super::{ProviderResponse, ProviderStreamChunk};
|
||||||
|
|
||||||
pub struct GrokProvider {
|
pub struct GrokProvider {
|
||||||
_client: reqwest::Client,
|
client: Client<OpenAIConfig>,
|
||||||
_config: crate::config::GrokConfig,
|
_config: crate::config::GrokConfig,
|
||||||
_api_key: String,
|
|
||||||
pricing: Vec<crate::config::ModelPricing>,
|
pricing: Vec<crate::config::ModelPricing>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -20,14 +21,16 @@ impl GrokProvider {
|
|||||||
pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result<Self> {
|
pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result<Self> {
|
||||||
let api_key = app_config.get_api_key("grok")?;
|
let api_key = app_config.get_api_key("grok")?;
|
||||||
|
|
||||||
let client = reqwest::Client::builder()
|
// Grok is OpenAI-compatible
|
||||||
.timeout(std::time::Duration::from_secs(30))
|
let openai_config = OpenAIConfig::default()
|
||||||
.build()?;
|
.with_api_key(api_key)
|
||||||
|
.with_api_base(&config.base_url);
|
||||||
|
|
||||||
|
let client = Client::with_config(openai_config);
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
_client: client,
|
client,
|
||||||
_config: config.clone(),
|
_config: config.clone(),
|
||||||
_api_key: api_key,
|
|
||||||
pricing: app_config.pricing.grok.clone(),
|
pricing: app_config.pricing.grok.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -40,24 +43,121 @@ impl super::Provider for GrokProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn supports_model(&self, model: &str) -> bool {
|
fn supports_model(&self, model: &str) -> bool {
|
||||||
model.starts_with("grok-") || model.contains("grok")
|
model.starts_with("grok-")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn supports_multimodal(&self) -> bool {
|
fn supports_multimodal(&self) -> bool {
|
||||||
false // Unknown - assume false until API is researched
|
true // Grok supports vision models
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_completion(
|
async fn chat_completion(
|
||||||
&self,
|
&self,
|
||||||
request: UnifiedRequest,
|
request: UnifiedRequest,
|
||||||
) -> Result<ProviderResponse, AppError> {
|
) -> Result<ProviderResponse, AppError> {
|
||||||
// TODO: Implement actual Grok API call (once API is available)
|
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
|
||||||
// For now, return placeholder response
|
|
||||||
|
// Convert UnifiedRequest messages to OpenAI messages
|
||||||
|
let mut messages = Vec::with_capacity(request.messages.len());
|
||||||
|
|
||||||
|
for msg in request.messages {
|
||||||
|
let mut parts = Vec::with_capacity(msg.content.len());
|
||||||
|
|
||||||
|
for part in msg.content {
|
||||||
|
match part {
|
||||||
|
crate::models::ContentPart::Text { text } => {
|
||||||
|
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
|
||||||
|
text,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
crate::models::ContentPart::Image(image_input) => {
|
||||||
|
let (base64_data, mime_type) = image_input.to_base64().await
|
||||||
|
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
||||||
|
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
|
||||||
|
|
||||||
|
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
|
||||||
|
image_url: ImageUrl {
|
||||||
|
url: data_url,
|
||||||
|
detail: Some(ImageDetail::Auto),
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let message = match msg.role.as_str() {
|
||||||
|
"system" => ChatCompletionRequestMessage::System(
|
||||||
|
ChatCompletionRequestSystemMessage {
|
||||||
|
content: ChatCompletionRequestSystemMessageContent::Text(
|
||||||
|
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
||||||
|
),
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"assistant" => ChatCompletionRequestMessage::Assistant(
|
||||||
|
ChatCompletionRequestAssistantMessage {
|
||||||
|
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
|
||||||
|
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
||||||
|
)),
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
refusal: None,
|
||||||
|
audio: None,
|
||||||
|
#[allow(deprecated)]
|
||||||
|
function_call: None,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
_ => ChatCompletionRequestMessage::User(
|
||||||
|
ChatCompletionRequestUserMessage {
|
||||||
|
content: ChatCompletionRequestUserMessageContent::Array(parts),
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
};
|
||||||
|
messages.push(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
if messages.is_empty() {
|
||||||
|
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build request using builder pattern
|
||||||
|
let mut builder = CreateChatCompletionRequestArgs::default();
|
||||||
|
builder.model(request.model.clone());
|
||||||
|
builder.messages(messages);
|
||||||
|
|
||||||
|
// Add optional parameters
|
||||||
|
if let Some(temp) = request.temperature {
|
||||||
|
builder.temperature(temp as f32);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(max_tokens) = request.max_tokens {
|
||||||
|
builder.max_tokens(max_tokens as u16);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute API call
|
||||||
|
let response = self.client
|
||||||
|
.chat()
|
||||||
|
.create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||||
|
|
||||||
|
// Extract content from response
|
||||||
|
let content = response
|
||||||
|
.choices
|
||||||
|
.first()
|
||||||
|
.and_then(|choice| choice.message.content.clone())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
// Extract token usage
|
||||||
|
let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32;
|
||||||
|
let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32;
|
||||||
|
let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32;
|
||||||
|
|
||||||
Ok(ProviderResponse {
|
Ok(ProviderResponse {
|
||||||
content: "Grok provider not yet implemented (API not researched)".to_string(),
|
content,
|
||||||
prompt_tokens: 0,
|
prompt_tokens,
|
||||||
completion_tokens: 0,
|
completion_tokens,
|
||||||
total_tokens: 0,
|
total_tokens,
|
||||||
model: request.model,
|
model: request.model,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -74,19 +174,131 @@ impl super::Provider for GrokProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fallback to static pricing if not in registry
|
||||||
let (prompt_rate, completion_rate) = self.pricing.iter()
|
let (prompt_rate, completion_rate) = self.pricing.iter()
|
||||||
.find(|p| model.contains(&p.model))
|
.find(|p| model.contains(&p.model))
|
||||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||||
.unwrap_or((1.0, 3.0)); // Default to some reasonable Grok price if not found
|
.unwrap_or((5.0, 15.0)); // Grok-2 pricing is roughly this
|
||||||
|
|
||||||
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_completion_stream(
|
async fn chat_completion_stream(
|
||||||
&self,
|
&self,
|
||||||
_request: UnifiedRequest,
|
request: UnifiedRequest,
|
||||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||||
// Grok API not yet implemented
|
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
|
||||||
Err(AppError::ProviderError("Streaming not supported for Grok provider (API not implemented)".to_string()))
|
|
||||||
|
// Convert UnifiedRequest messages to OpenAI messages
|
||||||
|
let mut messages = Vec::with_capacity(request.messages.len());
|
||||||
|
|
||||||
|
for msg in request.messages {
|
||||||
|
let mut parts = Vec::with_capacity(msg.content.len());
|
||||||
|
|
||||||
|
for part in msg.content {
|
||||||
|
match part {
|
||||||
|
crate::models::ContentPart::Text { text } => {
|
||||||
|
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
|
||||||
|
text,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
crate::models::ContentPart::Image(image_input) => {
|
||||||
|
let (base64_data, mime_type) = image_input.to_base64().await
|
||||||
|
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
||||||
|
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
|
||||||
|
|
||||||
|
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
|
||||||
|
image_url: ImageUrl {
|
||||||
|
url: data_url,
|
||||||
|
detail: Some(ImageDetail::Auto),
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let message = match msg.role.as_str() {
|
||||||
|
"system" => ChatCompletionRequestMessage::System(
|
||||||
|
ChatCompletionRequestSystemMessage {
|
||||||
|
content: ChatCompletionRequestSystemMessageContent::Text(
|
||||||
|
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
||||||
|
),
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"assistant" => ChatCompletionRequestMessage::Assistant(
|
||||||
|
ChatCompletionRequestAssistantMessage {
|
||||||
|
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
|
||||||
|
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("\n")
|
||||||
|
)),
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
refusal: None,
|
||||||
|
audio: None,
|
||||||
|
#[allow(deprecated)]
|
||||||
|
function_call: None,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
_ => ChatCompletionRequestMessage::User(
|
||||||
|
ChatCompletionRequestUserMessage {
|
||||||
|
content: ChatCompletionRequestUserMessageContent::Array(parts),
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
};
|
||||||
|
messages.push(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
if messages.is_empty() {
|
||||||
|
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build request using builder pattern
|
||||||
|
let mut builder = CreateChatCompletionRequestArgs::default();
|
||||||
|
builder.model(request.model.clone());
|
||||||
|
builder.messages(messages);
|
||||||
|
builder.stream(true); // Enable streaming
|
||||||
|
|
||||||
|
// Add optional parameters
|
||||||
|
if let Some(temp) = request.temperature {
|
||||||
|
builder.temperature(temp as f32);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(max_tokens) = request.max_tokens {
|
||||||
|
builder.max_tokens(max_tokens as u16);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute streaming API call
|
||||||
|
let stream = self.client
|
||||||
|
.chat()
|
||||||
|
.create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||||
|
|
||||||
|
// Convert OpenAI stream to our stream format
|
||||||
|
let model = request.model.clone();
|
||||||
|
let stream = stream.map(move |chunk_result| {
|
||||||
|
match chunk_result {
|
||||||
|
Ok(chunk) => {
|
||||||
|
// Extract content from chunk
|
||||||
|
let content = chunk.choices.first()
|
||||||
|
.and_then(|choice| choice.delta.content.clone())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let finish_reason = chunk.choices.first()
|
||||||
|
.and_then(|choice| choice.finish_reason.clone())
|
||||||
|
.map(|reason| format!("{:?}", reason));
|
||||||
|
|
||||||
|
Ok(ProviderStreamChunk {
|
||||||
|
content,
|
||||||
|
finish_reason,
|
||||||
|
model: model.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Err(e) => Err(AppError::ProviderError(e.to_string())),
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(Box::pin(stream))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user