use anyhow::Result; use config::{Config, File, FileFormat}; use serde::{Deserialize, Serialize}; use std::path::PathBuf; use std::sync::Arc; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ServerConfig { pub port: u16, pub host: String, #[serde(deserialize_with = "deserialize_vec_or_string")] pub auth_tokens: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DatabaseConfig { pub path: PathBuf, pub max_connections: u32, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ProviderConfig { pub openai: OpenAIConfig, pub gemini: GeminiConfig, pub deepseek: DeepSeekConfig, pub grok: GrokConfig, pub ollama: OllamaConfig, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OpenAIConfig { pub api_key_env: String, pub base_url: String, pub default_model: String, pub enabled: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GeminiConfig { pub api_key_env: String, pub base_url: String, pub default_model: String, pub enabled: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DeepSeekConfig { pub api_key_env: String, pub base_url: String, pub default_model: String, pub enabled: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GrokConfig { pub api_key_env: String, pub base_url: String, pub default_model: String, pub enabled: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OllamaConfig { pub base_url: String, pub enabled: bool, #[serde(deserialize_with = "deserialize_vec_or_string")] pub models: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelMappingConfig { pub patterns: Vec<(String, String)>, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PricingConfig { pub openai: Vec, pub gemini: Vec, pub deepseek: Vec, pub grok: Vec, pub ollama: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelPricing { pub model: String, pub prompt_tokens_per_million: f64, pub completion_tokens_per_million: f64, } #[derive(Debug, Clone)] pub struct AppConfig { pub server: ServerConfig, pub database: DatabaseConfig, pub providers: ProviderConfig, pub model_mapping: ModelMappingConfig, pub pricing: PricingConfig, pub config_path: PathBuf, } impl AppConfig { pub async fn load() -> Result> { Self::load_from_path(None).await } /// Load configuration from a specific path (for testing) pub async fn load_from_path(config_path: Option) -> Result> { // Load configuration from multiple sources let mut config_builder = Config::builder(); // Default configuration config_builder = config_builder .set_default("server.port", 8080)? .set_default("server.host", "0.0.0.0")? .set_default("server.auth_tokens", Vec::::new())? .set_default("database.path", "./data/llm_proxy.db")? .set_default("database.max_connections", 10)? .set_default("providers.openai.api_key_env", "OPENAI_API_KEY")? .set_default("providers.openai.base_url", "https://api.openai.com/v1")? .set_default("providers.openai.default_model", "gpt-4o")? .set_default("providers.openai.enabled", true)? .set_default("providers.gemini.api_key_env", "GEMINI_API_KEY")? .set_default("providers.gemini.base_url", "https://generativelanguage.googleapis.com/v1")? .set_default("providers.gemini.default_model", "gemini-2.0-flash")? .set_default("providers.gemini.enabled", true)? .set_default("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")? .set_default("providers.deepseek.base_url", "https://api.deepseek.com")? .set_default("providers.deepseek.default_model", "deepseek-reasoner")? .set_default("providers.deepseek.enabled", true)? .set_default("providers.grok.api_key_env", "GROK_API_KEY")? .set_default("providers.grok.base_url", "https://api.x.ai/v1")? .set_default("providers.grok.default_model", "grok-beta")? .set_default("providers.grok.enabled", false)? .set_default("providers.ollama.base_url", "http://localhost:11434/v1")? .set_default("providers.ollama.enabled", false)? .set_default("providers.ollama.models", Vec::::new())?; // Load from config file if exists let config_path = config_path.unwrap_or_else(|| std::env::current_dir().unwrap().join("config.toml")); if config_path.exists() { config_builder = config_builder.add_source(File::from(config_path.clone()).format(FileFormat::Toml)); } // Load from .env file dotenvy::dotenv().ok(); // Load from environment variables (with prefix "LLM_PROXY_") config_builder = config_builder.add_source( config::Environment::with_prefix("LLM_PROXY") .separator("__") .try_parsing(true), ); let config = config_builder.build()?; // Deserialize configuration let server: ServerConfig = config.get("server")?; let database: DatabaseConfig = config.get("database")?; let providers: ProviderConfig = config.get("providers")?; // For now, use empty model mapping and pricing (will be populated later) let model_mapping = ModelMappingConfig { patterns: vec![] }; let pricing = PricingConfig { openai: vec![], gemini: vec![], deepseek: vec![], grok: vec![], ollama: vec![], }; Ok(Arc::new(AppConfig { server, database, providers, model_mapping, pricing, config_path, })) } pub fn get_api_key(&self, provider: &str) -> Result { let env_var = match provider { "openai" => &self.providers.openai.api_key_env, "gemini" => &self.providers.gemini.api_key_env, "deepseek" => &self.providers.deepseek.api_key_env, "grok" => &self.providers.grok.api_key_env, _ => return Err(anyhow::anyhow!("Unknown provider: {}", provider)), }; std::env::var(env_var) .map_err(|_| anyhow::anyhow!("Environment variable {} not set for {}", env_var, provider)) } } /// Helper function to deserialize a Vec from either a sequence or a comma-separated string fn deserialize_vec_or_string<'de, D>(deserializer: D) -> Result, D::Error> where D: serde::Deserializer<'de>, { struct VecOrString; impl<'de> serde::de::Visitor<'de> for VecOrString { type Value = Vec; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { formatter.write_str("a sequence or a comma-separated string") } fn visit_str(self, value: &str) -> Result where E: serde::de::Error, { Ok(value .split(',') .map(|s| s.trim().to_string()) .filter(|s| !s.is_empty()) .collect()) } fn visit_seq(self, mut seq: S) -> Result where S: serde::de::SeqAccess<'de>, { let mut vec = Vec::new(); while let Some(element) = seq.next_element()? { vec.push(element); } Ok(vec) } } deserializer.deserialize_any(VecOrString) }