234 lines
8.0 KiB
Rust
234 lines
8.0 KiB
Rust
use anyhow::Result;
|
|
use config::{Config, File, FileFormat};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::path::PathBuf;
|
|
use std::sync::Arc;
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ServerConfig {
|
|
pub port: u16,
|
|
pub host: String,
|
|
#[serde(deserialize_with = "deserialize_vec_or_string")]
|
|
pub auth_tokens: Vec<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct DatabaseConfig {
|
|
pub path: PathBuf,
|
|
pub max_connections: u32,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ProviderConfig {
|
|
pub openai: OpenAIConfig,
|
|
pub gemini: GeminiConfig,
|
|
pub deepseek: DeepSeekConfig,
|
|
pub grok: GrokConfig,
|
|
pub ollama: OllamaConfig,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct OpenAIConfig {
|
|
pub api_key_env: String,
|
|
pub base_url: String,
|
|
pub default_model: String,
|
|
pub enabled: bool,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct GeminiConfig {
|
|
pub api_key_env: String,
|
|
pub base_url: String,
|
|
pub default_model: String,
|
|
pub enabled: bool,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct DeepSeekConfig {
|
|
pub api_key_env: String,
|
|
pub base_url: String,
|
|
pub default_model: String,
|
|
pub enabled: bool,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct GrokConfig {
|
|
pub api_key_env: String,
|
|
pub base_url: String,
|
|
pub default_model: String,
|
|
pub enabled: bool,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct OllamaConfig {
|
|
pub base_url: String,
|
|
pub enabled: bool,
|
|
#[serde(deserialize_with = "deserialize_vec_or_string")]
|
|
pub models: Vec<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ModelMappingConfig {
|
|
pub patterns: Vec<(String, String)>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct PricingConfig {
|
|
pub openai: Vec<ModelPricing>,
|
|
pub gemini: Vec<ModelPricing>,
|
|
pub deepseek: Vec<ModelPricing>,
|
|
pub grok: Vec<ModelPricing>,
|
|
pub ollama: Vec<ModelPricing>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ModelPricing {
|
|
pub model: String,
|
|
pub prompt_tokens_per_million: f64,
|
|
pub completion_tokens_per_million: f64,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct AppConfig {
|
|
pub server: ServerConfig,
|
|
pub database: DatabaseConfig,
|
|
pub providers: ProviderConfig,
|
|
pub model_mapping: ModelMappingConfig,
|
|
pub pricing: PricingConfig,
|
|
pub config_path: PathBuf,
|
|
}
|
|
|
|
impl AppConfig {
|
|
pub async fn load() -> Result<Arc<Self>> {
|
|
Self::load_from_path(None).await
|
|
}
|
|
|
|
/// Load configuration from a specific path (for testing)
|
|
pub async fn load_from_path(config_path: Option<PathBuf>) -> Result<Arc<Self>> {
|
|
// Load configuration from multiple sources
|
|
let mut config_builder = Config::builder();
|
|
|
|
// Default configuration
|
|
config_builder = config_builder
|
|
.set_default("server.port", 8080)?
|
|
.set_default("server.host", "0.0.0.0")?
|
|
.set_default("server.auth_tokens", Vec::<String>::new())?
|
|
.set_default("database.path", "./data/llm_proxy.db")?
|
|
.set_default("database.max_connections", 10)?
|
|
.set_default("providers.openai.api_key_env", "OPENAI_API_KEY")?
|
|
.set_default("providers.openai.base_url", "https://api.openai.com/v1")?
|
|
.set_default("providers.openai.default_model", "gpt-4o")?
|
|
.set_default("providers.openai.enabled", true)?
|
|
.set_default("providers.gemini.api_key_env", "GEMINI_API_KEY")?
|
|
.set_default("providers.gemini.base_url", "https://generativelanguage.googleapis.com/v1")?
|
|
.set_default("providers.gemini.default_model", "gemini-2.0-flash")?
|
|
.set_default("providers.gemini.enabled", true)?
|
|
.set_default("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")?
|
|
.set_default("providers.deepseek.base_url", "https://api.deepseek.com")?
|
|
.set_default("providers.deepseek.default_model", "deepseek-reasoner")?
|
|
.set_default("providers.deepseek.enabled", true)?
|
|
.set_default("providers.grok.api_key_env", "GROK_API_KEY")?
|
|
.set_default("providers.grok.base_url", "https://api.x.ai/v1")?
|
|
.set_default("providers.grok.default_model", "grok-beta")?
|
|
.set_default("providers.grok.enabled", false)?
|
|
.set_default("providers.ollama.base_url", "http://localhost:11434/v1")?
|
|
.set_default("providers.ollama.enabled", false)?
|
|
.set_default("providers.ollama.models", Vec::<String>::new())?;
|
|
|
|
// Load from config file if exists
|
|
let config_path = config_path.unwrap_or_else(|| std::env::current_dir().unwrap().join("config.toml"));
|
|
if config_path.exists() {
|
|
config_builder = config_builder.add_source(File::from(config_path.clone()).format(FileFormat::Toml));
|
|
}
|
|
|
|
// Load from .env file
|
|
dotenvy::dotenv().ok();
|
|
|
|
// Load from environment variables (with prefix "LLM_PROXY_")
|
|
config_builder = config_builder.add_source(
|
|
config::Environment::with_prefix("LLM_PROXY")
|
|
.separator("__")
|
|
.try_parsing(true),
|
|
);
|
|
|
|
let config = config_builder.build()?;
|
|
|
|
// Deserialize configuration
|
|
let server: ServerConfig = config.get("server")?;
|
|
let database: DatabaseConfig = config.get("database")?;
|
|
let providers: ProviderConfig = config.get("providers")?;
|
|
|
|
// For now, use empty model mapping and pricing (will be populated later)
|
|
let model_mapping = ModelMappingConfig { patterns: vec![] };
|
|
let pricing = PricingConfig {
|
|
openai: vec![],
|
|
gemini: vec![],
|
|
deepseek: vec![],
|
|
grok: vec![],
|
|
ollama: vec![],
|
|
};
|
|
|
|
Ok(Arc::new(AppConfig {
|
|
server,
|
|
database,
|
|
providers,
|
|
model_mapping,
|
|
pricing,
|
|
config_path,
|
|
}))
|
|
}
|
|
|
|
pub fn get_api_key(&self, provider: &str) -> Result<String> {
|
|
let env_var = match provider {
|
|
"openai" => &self.providers.openai.api_key_env,
|
|
"gemini" => &self.providers.gemini.api_key_env,
|
|
"deepseek" => &self.providers.deepseek.api_key_env,
|
|
"grok" => &self.providers.grok.api_key_env,
|
|
_ => return Err(anyhow::anyhow!("Unknown provider: {}", provider)),
|
|
};
|
|
|
|
std::env::var(env_var)
|
|
.map_err(|_| anyhow::anyhow!("Environment variable {} not set for {}", env_var, provider))
|
|
}
|
|
}
|
|
|
|
/// Helper function to deserialize a Vec<String> from either a sequence or a comma-separated string
|
|
fn deserialize_vec_or_string<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
|
where
|
|
D: serde::Deserializer<'de>,
|
|
{
|
|
struct VecOrString;
|
|
|
|
impl<'de> serde::de::Visitor<'de> for VecOrString {
|
|
type Value = Vec<String>;
|
|
|
|
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
formatter.write_str("a sequence or a comma-separated string")
|
|
}
|
|
|
|
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(value
|
|
.split(',')
|
|
.map(|s| s.trim().to_string())
|
|
.filter(|s| !s.is_empty())
|
|
.collect())
|
|
}
|
|
|
|
fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
|
|
where
|
|
S: serde::de::SeqAccess<'de>,
|
|
{
|
|
let mut vec = Vec::new();
|
|
while let Some(element) = seq.next_element()? {
|
|
vec.push(element);
|
|
}
|
|
Ok(vec)
|
|
}
|
|
}
|
|
|
|
deserializer.deserialize_any(VecOrString)
|
|
}
|
|
|