Files
GopherGate/src/config/mod.rs
hobokenchicken 9b8483e797 feat(security): implement AES-256-GCM encryption for API keys and HMAC-signed session tokens
This commit introduces:
- AES-256-GCM encryption for LLM provider API keys in the database.
- HMAC-SHA256 signed session tokens with activity-based refresh logic.
- Standardized frontend XSS protection using a global escapeHtml utility.
- Hardened security headers and request body size limits.
- Improved database integrity with foreign key enforcement and atomic transactions.
- Integration tests for the full encrypted key storage and proxy usage lifecycle.
2026-03-06 14:17:56 -05:00

261 lines
8.9 KiB
Rust

use anyhow::Result;
use base64::{Engine as _};
use config::{Config, File, FileFormat};
use hex;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
pub port: u16,
pub host: String,
#[serde(deserialize_with = "deserialize_vec_or_string")]
pub auth_tokens: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
pub path: PathBuf,
pub max_connections: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub openai: OpenAIConfig,
pub gemini: GeminiConfig,
pub deepseek: DeepSeekConfig,
pub grok: GrokConfig,
pub ollama: OllamaConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIConfig {
pub api_key_env: String,
pub base_url: String,
pub default_model: String,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeminiConfig {
pub api_key_env: String,
pub base_url: String,
pub default_model: String,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeepSeekConfig {
pub api_key_env: String,
pub base_url: String,
pub default_model: String,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrokConfig {
pub api_key_env: String,
pub base_url: String,
pub default_model: String,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaConfig {
pub base_url: String,
pub enabled: bool,
#[serde(deserialize_with = "deserialize_vec_or_string")]
pub models: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMappingConfig {
pub patterns: Vec<(String, String)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PricingConfig {
pub openai: Vec<ModelPricing>,
pub gemini: Vec<ModelPricing>,
pub deepseek: Vec<ModelPricing>,
pub grok: Vec<ModelPricing>,
pub ollama: Vec<ModelPricing>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelPricing {
pub model: String,
pub prompt_tokens_per_million: f64,
pub completion_tokens_per_million: f64,
}
#[derive(Debug, Clone)]
pub struct AppConfig {
pub server: ServerConfig,
pub database: DatabaseConfig,
pub providers: ProviderConfig,
pub model_mapping: ModelMappingConfig,
pub pricing: PricingConfig,
pub config_path: Option<PathBuf>,
pub encryption_key: String,
}
impl AppConfig {
pub async fn load() -> Result<Arc<Self>> {
Self::load_from_path(None).await
}
/// Load configuration from a specific path (for testing)
pub async fn load_from_path(config_path: Option<PathBuf>) -> Result<Arc<Self>> {
// Load configuration from multiple sources
let mut config_builder = Config::builder();
// Default configuration
config_builder = config_builder
.set_default("server.port", 8080)?
.set_default("server.host", "0.0.0.0")?
.set_default("server.auth_tokens", Vec::<String>::new())?
.set_default("database.path", "./data/llm_proxy.db")?
.set_default("database.max_connections", 10)?
.set_default("providers.openai.api_key_env", "OPENAI_API_KEY")?
.set_default("providers.openai.base_url", "https://api.openai.com/v1")?
.set_default("providers.openai.default_model", "gpt-4o")?
.set_default("providers.openai.enabled", true)?
.set_default("providers.gemini.api_key_env", "GEMINI_API_KEY")?
.set_default(
"providers.gemini.base_url",
"https://generativelanguage.googleapis.com/v1",
)?
.set_default("providers.gemini.default_model", "gemini-2.0-flash")?
.set_default("providers.gemini.enabled", true)?
.set_default("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")?
.set_default("providers.deepseek.base_url", "https://api.deepseek.com")?
.set_default("providers.deepseek.default_model", "deepseek-reasoner")?
.set_default("providers.deepseek.enabled", true)?
.set_default("providers.grok.api_key_env", "GROK_API_KEY")?
.set_default("providers.grok.base_url", "https://api.x.ai/v1")?
.set_default("providers.grok.default_model", "grok-beta")?
.set_default("providers.grok.enabled", true)?
.set_default("providers.ollama.base_url", "http://localhost:11434/v1")?
.set_default("providers.ollama.enabled", false)?
.set_default("providers.ollama.models", Vec::<String>::new())?
.set_default("encryption_key", "")?;
// Load from config file if exists
// Priority: explicit path arg > LLM_PROXY__CONFIG_PATH env var > ./config.toml
let config_path = config_path
.or_else(|| std::env::var("LLM_PROXY__CONFIG_PATH").ok().map(PathBuf::from))
.unwrap_or_else(|| {
std::env::current_dir()
.unwrap_or_else(|_| PathBuf::from("."))
.join("config.toml")
});
if config_path.exists() {
config_builder = config_builder.add_source(File::from(config_path.clone()).format(FileFormat::Toml));
}
// Load from .env file
dotenvy::dotenv().ok();
// Load from environment variables (with prefix "LLM_PROXY_")
config_builder = config_builder.add_source(
config::Environment::with_prefix("LLM_PROXY")
.separator("__")
.try_parsing(true),
);
let config = config_builder.build()?;
// Deserialize configuration
let server: ServerConfig = config.get("server")?;
let database: DatabaseConfig = config.get("database")?;
let providers: ProviderConfig = config.get("providers")?;
let encryption_key: String = config.get("encryption_key")?;
// Validate encryption key length (must be 32 bytes after hex or base64 decoding)
if encryption_key.is_empty() {
anyhow::bail!("Encryption key is required (LLM_PROXY__ENCRYPTION_KEY environment variable)");
}
// Try hex decode first, then base64
let key_bytes = hex::decode(&encryption_key)
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&encryption_key))
.map_err(|e| anyhow::anyhow!("Encryption key must be hex or base64 encoded: {}", e))?;
if key_bytes.len() != 32 {
anyhow::bail!("Encryption key must be 32 bytes (256 bits), got {} bytes", key_bytes.len());
}
// For now, use empty model mapping and pricing (will be populated later)
let model_mapping = ModelMappingConfig { patterns: vec![] };
let pricing = PricingConfig {
openai: vec![],
gemini: vec![],
deepseek: vec![],
grok: vec![],
ollama: vec![],
};
Ok(Arc::new(AppConfig {
server,
database,
providers,
model_mapping,
pricing,
config_path: Some(config_path),
encryption_key,
}))
}
pub fn get_api_key(&self, provider: &str) -> Result<String> {
let env_var = match provider {
"openai" => &self.providers.openai.api_key_env,
"gemini" => &self.providers.gemini.api_key_env,
"deepseek" => &self.providers.deepseek.api_key_env,
"grok" => &self.providers.grok.api_key_env,
_ => return Err(anyhow::anyhow!("Unknown provider: {}", provider)),
};
std::env::var(env_var).map_err(|_| anyhow::anyhow!("Environment variable {} not set for {}", env_var, provider))
}
}
/// Helper function to deserialize a Vec<String> from either a sequence or a comma-separated string
fn deserialize_vec_or_string<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
struct VecOrString;
impl<'de> serde::de::Visitor<'de> for VecOrString {
type Value = Vec<String>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a sequence or a comma-separated string")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(value
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect())
}
fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
where
S: serde::de::SeqAccess<'de>,
{
let mut vec = Vec::new();
while let Some(element) = seq.next_element()? {
vec.push(element);
}
Ok(vec)
}
}
deserializer.deserialize_any(VecOrString)
}