feat: add Ollama provider support and dashboard integration
This commit is contained in:
@@ -13,6 +13,11 @@ DEEPSEEK_API_KEY=your_deepseek_api_key_here
|
|||||||
# xAI Grok (not yet available)
|
# xAI Grok (not yet available)
|
||||||
GROK_API_KEY=your_grok_api_key_here
|
GROK_API_KEY=your_grok_api_key_here
|
||||||
|
|
||||||
|
# Ollama (local server)
|
||||||
|
# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://your-ollama-host:11434/v1
|
||||||
|
# LLM_PROXY__PROVIDERS__OLLAMA__ENABLED=true
|
||||||
|
# LLM_PROXY__PROVIDERS__OLLAMA__MODELS=llama3,mistral,llava
|
||||||
|
|
||||||
# Authentication tokens (comma-separated list)
|
# Authentication tokens (comma-separated list)
|
||||||
LLM_PROXY__SERVER__AUTH_TOKENS=your_bearer_token_here,another_token
|
LLM_PROXY__SERVER__AUTH_TOKENS=your_bearer_token_here,another_token
|
||||||
|
|
||||||
|
|||||||
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
/target
|
||||||
|
/.env
|
||||||
|
/*.db
|
||||||
|
/*.db-shm
|
||||||
|
/*.db-wal
|
||||||
@@ -10,6 +10,7 @@ A unified, high-performance LLM proxy gateway built in Rust. It provides a singl
|
|||||||
* **Google Gemini:** Support for the latest Gemini 2.0 models.
|
* **Google Gemini:** Support for the latest Gemini 2.0 models.
|
||||||
* **DeepSeek:** High-performance, low-cost integration.
|
* **DeepSeek:** High-performance, low-cost integration.
|
||||||
* **xAI Grok:** Integration for Grok-series models.
|
* **xAI Grok:** Integration for Grok-series models.
|
||||||
|
* **Ollama:** Support for local LLMs running on your machine or another host.
|
||||||
- **Observability & Tracking:**
|
- **Observability & Tracking:**
|
||||||
* **Real-time Costing:** Fetches live pricing and context specs from `models.dev` on startup.
|
* **Real-time Costing:** Fetches live pricing and context specs from `models.dev` on startup.
|
||||||
* **Token Counting:** Precise estimation using `tiktoken-rs`.
|
* **Token Counting:** Precise estimation using `tiktoken-rs`.
|
||||||
@@ -52,6 +53,14 @@ A unified, high-performance LLM proxy gateway built in Rust. It provides a singl
|
|||||||
3. Configure providers and server:
|
3. Configure providers and server:
|
||||||
Edit `config.toml` to customize models, pricing fallbacks, and port settings.
|
Edit `config.toml` to customize models, pricing fallbacks, and port settings.
|
||||||
|
|
||||||
|
**Ollama Example (config.toml):**
|
||||||
|
```toml
|
||||||
|
[providers.ollama]
|
||||||
|
enabled = true
|
||||||
|
base_url = "http://192.168.1.50:11434/v1"
|
||||||
|
models = ["llama3", "mistral"]
|
||||||
|
```
|
||||||
|
|
||||||
4. Run the proxy:
|
4. Run the proxy:
|
||||||
```bash
|
```bash
|
||||||
cargo run --release
|
cargo run --release
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ pub struct ProviderConfig {
|
|||||||
pub gemini: GeminiConfig,
|
pub gemini: GeminiConfig,
|
||||||
pub deepseek: DeepSeekConfig,
|
pub deepseek: DeepSeekConfig,
|
||||||
pub grok: GrokConfig,
|
pub grok: GrokConfig,
|
||||||
|
pub ollama: OllamaConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -57,6 +58,13 @@ pub struct GrokConfig {
|
|||||||
pub enabled: bool,
|
pub enabled: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct OllamaConfig {
|
||||||
|
pub base_url: String,
|
||||||
|
pub enabled: bool,
|
||||||
|
pub models: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ModelMappingConfig {
|
pub struct ModelMappingConfig {
|
||||||
pub patterns: Vec<(String, String)>,
|
pub patterns: Vec<(String, String)>,
|
||||||
@@ -68,6 +76,7 @@ pub struct PricingConfig {
|
|||||||
pub gemini: Vec<ModelPricing>,
|
pub gemini: Vec<ModelPricing>,
|
||||||
pub deepseek: Vec<ModelPricing>,
|
pub deepseek: Vec<ModelPricing>,
|
||||||
pub grok: Vec<ModelPricing>,
|
pub grok: Vec<ModelPricing>,
|
||||||
|
pub ollama: Vec<ModelPricing>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -119,7 +128,10 @@ impl AppConfig {
|
|||||||
.set_default("providers.grok.api_key_env", "GROK_API_KEY")?
|
.set_default("providers.grok.api_key_env", "GROK_API_KEY")?
|
||||||
.set_default("providers.grok.base_url", "https://api.x.ai/v1")?
|
.set_default("providers.grok.base_url", "https://api.x.ai/v1")?
|
||||||
.set_default("providers.grok.default_model", "grok-beta")?
|
.set_default("providers.grok.default_model", "grok-beta")?
|
||||||
.set_default("providers.grok.enabled", false)?; // Disabled by default until API is researched
|
.set_default("providers.grok.enabled", false)?
|
||||||
|
.set_default("providers.ollama.base_url", "http://localhost:11434/v1")?
|
||||||
|
.set_default("providers.ollama.enabled", false)?
|
||||||
|
.set_default("providers.ollama.models", Vec::<String>::new())?;
|
||||||
|
|
||||||
// Load from config file if exists
|
// Load from config file if exists
|
||||||
let config_path = config_path.unwrap_or_else(|| std::env::current_dir().unwrap().join("config.toml"));
|
let config_path = config_path.unwrap_or_else(|| std::env::current_dir().unwrap().join("config.toml"));
|
||||||
@@ -151,6 +163,7 @@ impl AppConfig {
|
|||||||
gemini: vec![],
|
gemini: vec![],
|
||||||
deepseek: vec![],
|
deepseek: vec![],
|
||||||
grok: vec![],
|
grok: vec![],
|
||||||
|
ollama: vec![],
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Arc::new(AppConfig {
|
Ok(Arc::new(AppConfig {
|
||||||
|
|||||||
@@ -500,6 +500,16 @@ async fn handle_get_providers(State(state): State<DashboardState>) -> Json<ApiRe
|
|||||||
"last_used": null, // TODO: track last used
|
"last_used": null, // TODO: track last used
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add Ollama explicitly
|
||||||
|
providers_json.push(serde_json::json!({
|
||||||
|
"id": "ollama",
|
||||||
|
"name": "Ollama",
|
||||||
|
"enabled": true,
|
||||||
|
"status": "online",
|
||||||
|
"models": ["llama3", "mistral", "phi3"],
|
||||||
|
"last_used": null,
|
||||||
|
}));
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!(providers_json)))
|
Json(ApiResponse::success(serde_json::json!(providers_json)))
|
||||||
}
|
}
|
||||||
@@ -538,6 +548,13 @@ async fn handle_system_health(State(state): State<DashboardState>) -> Json<ApiRe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check Ollama health
|
||||||
|
if state.app_state.rate_limit_manager.check_provider_request("ollama").await.unwrap_or(true) {
|
||||||
|
components.insert("ollama", "online");
|
||||||
|
} else {
|
||||||
|
components.insert("ollama", "degraded");
|
||||||
|
}
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!({
|
Json(ApiResponse::success(serde_json::json!({
|
||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
"timestamp": chrono::Utc::now().to_rfc3339(),
|
"timestamp": chrono::Utc::now().to_rfc3339(),
|
||||||
|
|||||||
11
src/main.rs
11
src/main.rs
@@ -85,6 +85,17 @@ async fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize Ollama
|
||||||
|
if config.providers.ollama.enabled {
|
||||||
|
match llm_proxy::providers::ollama::OllamaProvider::new(&config.providers.ollama, &config) {
|
||||||
|
Ok(p) => {
|
||||||
|
provider_manager.add_provider(Arc::new(p));
|
||||||
|
info!("Ollama provider initialized at {}", config.providers.ollama.base_url);
|
||||||
|
}
|
||||||
|
Err(e) => error!("Failed to initialize Ollama provider: {}", e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create rate limit manager
|
// Create rate limit manager
|
||||||
let rate_limit_manager = RateLimitManager::new(
|
let rate_limit_manager = RateLimitManager::new(
|
||||||
RateLimiterConfig::default(),
|
RateLimiterConfig::default(),
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ pub mod openai;
|
|||||||
pub mod gemini;
|
pub mod gemini;
|
||||||
pub mod deepseek;
|
pub mod deepseek;
|
||||||
pub mod grok;
|
pub mod grok;
|
||||||
|
pub mod ollama;
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Provider: Send + Sync {
|
pub trait Provider: Send + Sync {
|
||||||
|
|||||||
313
src/providers/ollama.rs
Normal file
313
src/providers/ollama.rs
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
use async_trait::async_trait;
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_openai::{Client, config::OpenAIConfig};
|
||||||
|
use async_openai::types::chat::{CreateChatCompletionRequestArgs, ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessageContent, ChatCompletionRequestSystemMessageContent, ChatCompletionRequestAssistantMessageContent};
|
||||||
|
use futures::stream::{BoxStream, StreamExt};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
models::UnifiedRequest,
|
||||||
|
errors::AppError,
|
||||||
|
config::AppConfig,
|
||||||
|
};
|
||||||
|
use super::{ProviderResponse, ProviderStreamChunk};
|
||||||
|
|
||||||
|
pub struct OllamaProvider {
|
||||||
|
client: Client<OpenAIConfig>,
|
||||||
|
_config: crate::config::OllamaConfig,
|
||||||
|
pricing: Vec<crate::config::ModelPricing>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OllamaProvider {
|
||||||
|
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
|
||||||
|
// Ollama usually doesn't need an API key, use a dummy one
|
||||||
|
let openai_config = OpenAIConfig::default()
|
||||||
|
.with_api_key("ollama")
|
||||||
|
.with_api_base(&config.base_url);
|
||||||
|
|
||||||
|
let client = Client::with_config(openai_config);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
client,
|
||||||
|
_config: config.clone(),
|
||||||
|
pricing: app_config.pricing.ollama.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl super::Provider for OllamaProvider {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"ollama"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_model(&self, model: &str) -> bool {
|
||||||
|
// Check if model is in the list of configured Ollama models
|
||||||
|
self._config.models.iter().any(|m| m == model) || model.starts_with("ollama/")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_multimodal(&self) -> bool {
|
||||||
|
true // Many Ollama models support vision (e.g. llava, moondream)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chat_completion(
|
||||||
|
&self,
|
||||||
|
request: UnifiedRequest,
|
||||||
|
) -> Result<ProviderResponse, AppError> {
|
||||||
|
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
|
||||||
|
|
||||||
|
// Strip "ollama/" prefix if present
|
||||||
|
let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string();
|
||||||
|
|
||||||
|
// Convert UnifiedRequest messages to OpenAI messages
|
||||||
|
let mut messages = Vec::with_capacity(request.messages.len());
|
||||||
|
|
||||||
|
for msg in request.messages {
|
||||||
|
let mut parts = Vec::with_capacity(msg.content.len());
|
||||||
|
|
||||||
|
for part in msg.content {
|
||||||
|
match part {
|
||||||
|
crate::models::ContentPart::Text { text } => {
|
||||||
|
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
|
||||||
|
text,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
crate::models::ContentPart::Image(image_input) => {
|
||||||
|
let (base64_data, mime_type) = image_input.to_base64().await
|
||||||
|
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
||||||
|
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
|
||||||
|
|
||||||
|
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
|
||||||
|
image_url: ImageUrl {
|
||||||
|
url: data_url,
|
||||||
|
detail: Some(ImageDetail::Auto),
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let message = match msg.role.as_str() {
|
||||||
|
"system" => ChatCompletionRequestMessage::System(
|
||||||
|
ChatCompletionRequestSystemMessage {
|
||||||
|
content: ChatCompletionRequestSystemMessageContent::Text(
|
||||||
|
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("
|
||||||
|
")
|
||||||
|
),
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"assistant" => ChatCompletionRequestMessage::Assistant(
|
||||||
|
ChatCompletionRequestAssistantMessage {
|
||||||
|
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
|
||||||
|
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("
|
||||||
|
")
|
||||||
|
)),
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
refusal: None,
|
||||||
|
audio: None,
|
||||||
|
#[allow(deprecated)]
|
||||||
|
function_call: None,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
_ => ChatCompletionRequestMessage::User(
|
||||||
|
ChatCompletionRequestUserMessage {
|
||||||
|
content: ChatCompletionRequestUserMessageContent::Array(parts),
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
};
|
||||||
|
messages.push(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
if messages.is_empty() {
|
||||||
|
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build request using builder pattern
|
||||||
|
let mut builder = CreateChatCompletionRequestArgs::default();
|
||||||
|
builder.model(model);
|
||||||
|
builder.messages(messages);
|
||||||
|
|
||||||
|
// Add optional parameters
|
||||||
|
if let Some(temp) = request.temperature {
|
||||||
|
builder.temperature(temp as f32);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(max_tokens) = request.max_tokens {
|
||||||
|
builder.max_tokens(max_tokens as u16);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute API call
|
||||||
|
let response = self.client
|
||||||
|
.chat()
|
||||||
|
.create(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||||
|
|
||||||
|
// Extract content from response
|
||||||
|
let content = response
|
||||||
|
.choices
|
||||||
|
.first()
|
||||||
|
.and_then(|choice| choice.message.content.clone())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
// Extract token usage
|
||||||
|
let prompt_tokens = response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0) as u32;
|
||||||
|
let completion_tokens = response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0) as u32;
|
||||||
|
let total_tokens = response.usage.as_ref().map(|u| u.total_tokens).unwrap_or(0) as u32;
|
||||||
|
|
||||||
|
Ok(ProviderResponse {
|
||||||
|
content,
|
||||||
|
prompt_tokens,
|
||||||
|
completion_tokens,
|
||||||
|
total_tokens,
|
||||||
|
model: request.model,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
||||||
|
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
|
||||||
|
if let Some(metadata) = registry.find_model(model) {
|
||||||
|
if let Some(cost) = &metadata.cost {
|
||||||
|
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
|
||||||
|
(completion_tokens as f64 * cost.output / 1_000_000.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ollama is free by default
|
||||||
|
let (prompt_rate, completion_rate) = self.pricing.iter()
|
||||||
|
.find(|p| model.contains(&p.model))
|
||||||
|
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||||
|
.unwrap_or((0.0, 0.0));
|
||||||
|
|
||||||
|
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chat_completion_stream(
|
||||||
|
&self,
|
||||||
|
request: UnifiedRequest,
|
||||||
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||||
|
use async_openai::types::chat::{ChatCompletionRequestUserMessageContentPart, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestMessageContentPartImage, ImageUrl, ImageDetail};
|
||||||
|
|
||||||
|
// Strip "ollama/" prefix if present
|
||||||
|
let model = request.model.strip_prefix("ollama/").unwrap_or(&request.model).to_string();
|
||||||
|
|
||||||
|
// Convert UnifiedRequest messages to OpenAI messages
|
||||||
|
let mut messages = Vec::with_capacity(request.messages.len());
|
||||||
|
|
||||||
|
for msg in request.messages {
|
||||||
|
let mut parts = Vec::with_capacity(msg.content.len());
|
||||||
|
|
||||||
|
for part in msg.content {
|
||||||
|
match part {
|
||||||
|
crate::models::ContentPart::Text { text } => {
|
||||||
|
parts.push(ChatCompletionRequestUserMessageContentPart::Text(ChatCompletionRequestMessageContentPartText {
|
||||||
|
text,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
crate::models::ContentPart::Image(image_input) => {
|
||||||
|
let (base64_data, mime_type) = image_input.to_base64().await
|
||||||
|
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
||||||
|
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
|
||||||
|
|
||||||
|
parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(ChatCompletionRequestMessageContentPartImage {
|
||||||
|
image_url: ImageUrl {
|
||||||
|
url: data_url,
|
||||||
|
detail: Some(ImageDetail::Auto),
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let message = match msg.role.as_str() {
|
||||||
|
"system" => ChatCompletionRequestMessage::System(
|
||||||
|
ChatCompletionRequestSystemMessage {
|
||||||
|
content: ChatCompletionRequestSystemMessageContent::Text(
|
||||||
|
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("
|
||||||
|
")
|
||||||
|
),
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
"assistant" => ChatCompletionRequestMessage::Assistant(
|
||||||
|
ChatCompletionRequestAssistantMessage {
|
||||||
|
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
|
||||||
|
parts.iter().filter_map(|p| if let ChatCompletionRequestUserMessageContentPart::Text(t) = p { Some(t.text.clone()) } else { None }).collect::<Vec<_>>().join("
|
||||||
|
")
|
||||||
|
)),
|
||||||
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
refusal: None,
|
||||||
|
audio: None,
|
||||||
|
#[allow(deprecated)]
|
||||||
|
function_call: None,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
_ => ChatCompletionRequestMessage::User(
|
||||||
|
ChatCompletionRequestUserMessage {
|
||||||
|
content: ChatCompletionRequestUserMessageContent::Array(parts),
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
};
|
||||||
|
messages.push(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
if messages.is_empty() {
|
||||||
|
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build request using builder pattern
|
||||||
|
let mut builder = CreateChatCompletionRequestArgs::default();
|
||||||
|
builder.model(model);
|
||||||
|
builder.messages(messages);
|
||||||
|
builder.stream(true); // Enable streaming
|
||||||
|
|
||||||
|
// Add optional parameters
|
||||||
|
if let Some(temp) = request.temperature {
|
||||||
|
builder.temperature(temp as f32);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(max_tokens) = request.max_tokens {
|
||||||
|
builder.max_tokens(max_tokens as u16);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute streaming API call
|
||||||
|
let stream = self.client
|
||||||
|
.chat()
|
||||||
|
.create_stream(builder.build().map_err(|e| AppError::ProviderError(e.to_string()))?)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||||
|
|
||||||
|
// Convert OpenAI stream to our stream format
|
||||||
|
let model_name = request.model.clone();
|
||||||
|
let stream = stream.map(move |chunk_result| {
|
||||||
|
match chunk_result {
|
||||||
|
Ok(chunk) => {
|
||||||
|
// Extract content from chunk
|
||||||
|
let content = chunk.choices.first()
|
||||||
|
.and_then(|choice| choice.delta.content.clone())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let finish_reason = chunk.choices.first()
|
||||||
|
.and_then(|choice| choice.finish_reason.clone())
|
||||||
|
.map(|reason| format!("{:?}", reason));
|
||||||
|
|
||||||
|
Ok(ProviderStreamChunk {
|
||||||
|
content,
|
||||||
|
finish_reason,
|
||||||
|
model: model_name.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Err(e) => Err(AppError::ProviderError(e.to_string())),
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(Box::pin(stream))
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user