- Added 'provider_configs' and 'model_configs' tables to database. - Refactored ProviderManager to support thread-safe dynamic updates and database overrides. - Implemented 'Models' tab in dashboard to manage model visibility, mapping, and pricing. - Added provider configuration modal to 'Providers' tab. - Integrated database overrides into chat completion logic (enabled state, mapping, and cost).
347 lines
12 KiB
Rust
347 lines
12 KiB
Rust
use async_trait::async_trait;
|
|
use anyhow::Result;
|
|
use serde::{Deserialize, Serialize};
|
|
use futures::stream::BoxStream;
|
|
|
|
use crate::{
|
|
models::UnifiedRequest,
|
|
errors::AppError,
|
|
config::AppConfig,
|
|
};
|
|
use super::{ProviderResponse, ProviderStreamChunk};
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct GeminiRequest {
|
|
contents: Vec<GeminiContent>,
|
|
generation_config: Option<GeminiGenerationConfig>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct GeminiContent {
|
|
parts: Vec<GeminiPart>,
|
|
role: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct GeminiPart {
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
text: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
inline_data: Option<GeminiInlineData>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
struct GeminiInlineData {
|
|
mime_type: String,
|
|
data: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct GeminiGenerationConfig {
|
|
temperature: Option<f64>,
|
|
max_output_tokens: Option<u32>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct GeminiCandidate {
|
|
content: GeminiContent,
|
|
_finish_reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct GeminiUsageMetadata {
|
|
prompt_token_count: u32,
|
|
candidates_token_count: u32,
|
|
total_token_count: u32,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct GeminiResponse {
|
|
candidates: Vec<GeminiCandidate>,
|
|
usage_metadata: Option<GeminiUsageMetadata>,
|
|
}
|
|
|
|
|
|
|
|
pub struct GeminiProvider {
|
|
client: reqwest::Client,
|
|
config: crate::config::GeminiConfig,
|
|
api_key: String,
|
|
pricing: Vec<crate::config::ModelPricing>,
|
|
}
|
|
|
|
impl GeminiProvider {
|
|
pub fn new(config: &crate::config::GeminiConfig, app_config: &AppConfig) -> Result<Self> {
|
|
let api_key = app_config.get_api_key("gemini")?;
|
|
Self::new_with_key(config, app_config, api_key)
|
|
}
|
|
|
|
pub fn new_with_key(config: &crate::config::GeminiConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
|
let client = reqwest::Client::builder()
|
|
.timeout(std::time::Duration::from_secs(30))
|
|
.build()?;
|
|
|
|
Ok(Self {
|
|
client,
|
|
config: config.clone(),
|
|
api_key,
|
|
pricing: app_config.pricing.gemini.clone(),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl super::Provider for GeminiProvider {
|
|
fn name(&self) -> &str {
|
|
"gemini"
|
|
}
|
|
|
|
fn supports_model(&self, model: &str) -> bool {
|
|
model.starts_with("gemini-")
|
|
}
|
|
|
|
fn supports_multimodal(&self) -> bool {
|
|
true // Gemini supports vision
|
|
}
|
|
|
|
async fn chat_completion(
|
|
&self,
|
|
request: UnifiedRequest,
|
|
) -> Result<ProviderResponse, AppError> {
|
|
// Convert UnifiedRequest to Gemini request
|
|
let mut contents = Vec::with_capacity(request.messages.len());
|
|
|
|
for msg in request.messages {
|
|
let mut parts = Vec::with_capacity(msg.content.len());
|
|
|
|
for part in msg.content {
|
|
match part {
|
|
crate::models::ContentPart::Text { text } => {
|
|
parts.push(GeminiPart {
|
|
text: Some(text),
|
|
inline_data: None,
|
|
});
|
|
}
|
|
crate::models::ContentPart::Image(image_input) => {
|
|
let (base64_data, mime_type) = image_input.to_base64().await
|
|
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
|
|
|
parts.push(GeminiPart {
|
|
text: None,
|
|
inline_data: Some(GeminiInlineData {
|
|
mime_type,
|
|
data: base64_data,
|
|
}),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
// Map role: "user" -> "user", "assistant" -> "model", "system" -> "user"
|
|
let role = match msg.role.as_str() {
|
|
"assistant" => "model".to_string(),
|
|
_ => "user".to_string(),
|
|
};
|
|
|
|
contents.push(GeminiContent {
|
|
parts,
|
|
role,
|
|
});
|
|
}
|
|
|
|
if contents.is_empty() {
|
|
return Err(AppError::ProviderError("No valid text messages to send".to_string()));
|
|
}
|
|
|
|
// Build generation config
|
|
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
|
|
Some(GeminiGenerationConfig {
|
|
temperature: request.temperature,
|
|
max_output_tokens: request.max_tokens,
|
|
})
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let gemini_request = GeminiRequest {
|
|
contents,
|
|
generation_config,
|
|
};
|
|
|
|
// Build URL
|
|
let url = format!("{}/models/{}:generateContent?key={}",
|
|
self.config.base_url,
|
|
request.model,
|
|
self.api_key
|
|
);
|
|
|
|
// Send request
|
|
let response = self.client
|
|
.post(&url)
|
|
.json(&gemini_request)
|
|
.send()
|
|
.await
|
|
.map_err(|e| AppError::ProviderError(format!("HTTP request failed: {}", e)))?;
|
|
|
|
// Check status
|
|
let status = response.status();
|
|
if !status.is_success() {
|
|
let error_text = response.text().await.unwrap_or_default();
|
|
return Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, error_text)));
|
|
}
|
|
|
|
let gemini_response: GeminiResponse = response
|
|
.json()
|
|
.await
|
|
.map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?;
|
|
|
|
// Extract content from first candidate
|
|
let content = gemini_response.candidates
|
|
.first()
|
|
.and_then(|c| c.content.parts.first())
|
|
.and_then(|p| p.text.clone())
|
|
.unwrap_or_default();
|
|
|
|
// Extract token usage
|
|
let prompt_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.prompt_token_count).unwrap_or(0);
|
|
let completion_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.candidates_token_count).unwrap_or(0);
|
|
let total_tokens = gemini_response.usage_metadata.as_ref().map(|u| u.total_token_count).unwrap_or(0);
|
|
|
|
Ok(ProviderResponse {
|
|
content,
|
|
reasoning_content: None, // Gemini doesn't use this field name
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
total_tokens,
|
|
model: request.model,
|
|
})
|
|
}
|
|
|
|
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
|
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
|
}
|
|
|
|
fn calculate_cost(&self, model: &str, prompt_tokens: u32, completion_tokens: u32, registry: &crate::models::registry::ModelRegistry) -> f64 {
|
|
if let Some(metadata) = registry.find_model(model) {
|
|
if let Some(cost) = &metadata.cost {
|
|
return (prompt_tokens as f64 * cost.input / 1_000_000.0) +
|
|
(completion_tokens as f64 * cost.output / 1_000_000.0);
|
|
}
|
|
}
|
|
|
|
let (prompt_rate, completion_rate) = self.pricing.iter()
|
|
.find(|p| model.contains(&p.model))
|
|
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
|
.unwrap_or((0.075, 0.30)); // Default to Gemini 2.0 Flash price if not found
|
|
|
|
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
|
}
|
|
|
|
async fn chat_completion_stream(
|
|
&self,
|
|
request: UnifiedRequest,
|
|
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
|
// Convert UnifiedRequest to Gemini request
|
|
let mut contents = Vec::with_capacity(request.messages.len());
|
|
|
|
for msg in request.messages {
|
|
let mut parts = Vec::with_capacity(msg.content.len());
|
|
|
|
for part in msg.content {
|
|
match part {
|
|
crate::models::ContentPart::Text { text } => {
|
|
parts.push(GeminiPart {
|
|
text: Some(text),
|
|
inline_data: None,
|
|
});
|
|
}
|
|
crate::models::ContentPart::Image(image_input) => {
|
|
let (base64_data, mime_type) = image_input.to_base64().await
|
|
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
|
|
|
parts.push(GeminiPart {
|
|
text: None,
|
|
inline_data: Some(GeminiInlineData {
|
|
mime_type,
|
|
data: base64_data,
|
|
}),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
// Map role
|
|
let role = match msg.role.as_str() {
|
|
"assistant" => "model".to_string(),
|
|
_ => "user".to_string(),
|
|
};
|
|
|
|
contents.push(GeminiContent {
|
|
parts,
|
|
role,
|
|
});
|
|
}
|
|
|
|
// Build generation config
|
|
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
|
|
Some(GeminiGenerationConfig {
|
|
temperature: request.temperature,
|
|
max_output_tokens: request.max_tokens,
|
|
})
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let gemini_request = GeminiRequest {
|
|
contents,
|
|
generation_config,
|
|
};
|
|
|
|
// Build URL for streaming
|
|
let url = format!("{}/models/{}:streamGenerateContent?alt=sse&key={}",
|
|
self.config.base_url,
|
|
request.model,
|
|
self.api_key
|
|
);
|
|
|
|
// Create eventsource stream
|
|
use reqwest_eventsource::{EventSource, Event};
|
|
use futures::StreamExt;
|
|
|
|
let es = EventSource::new(self.client.post(&url).json(&gemini_request))
|
|
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
|
|
|
let model = request.model.clone();
|
|
|
|
let stream = async_stream::try_stream! {
|
|
let mut es = es;
|
|
while let Some(event) = es.next().await {
|
|
match event {
|
|
Ok(Event::Message(msg)) => {
|
|
let gemini_response: GeminiResponse = serde_json::from_str(&msg.data)
|
|
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
|
|
|
if let Some(candidate) = gemini_response.candidates.first() {
|
|
let content = candidate.content.parts.first()
|
|
.and_then(|p| p.text.clone())
|
|
.unwrap_or_default();
|
|
|
|
yield ProviderStreamChunk {
|
|
content,
|
|
reasoning_content: None,
|
|
finish_reason: None, // Will be set in the last chunk
|
|
model: model.clone(),
|
|
};
|
|
}
|
|
}
|
|
Ok(_) => continue,
|
|
Err(e) => {
|
|
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
Ok(Box::pin(stream))
|
|
}
|
|
} |