use aes_gcm::{ aead::{Aead, AeadCore, KeyInit, OsRng}, Aes256Gcm, Key, Nonce, }; use anyhow::{anyhow, Context, Result}; use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; use std::env; use std::sync::OnceLock; static RAW_KEY: OnceLock<[u8; 32]> = OnceLock::new(); /// Initialize the encryption key from a hex or base64 encoded string. /// Must be called before any encryption/decryption operations. /// Returns error if the key is invalid or already initialized with a different key. pub fn init_with_key(key_str: &str) -> Result<()> { let key_bytes = hex::decode(key_str) .or_else(|_| BASE64.decode(key_str)) .context("Encryption key must be hex or base64 encoded")?; if key_bytes.len() != 32 { anyhow::bail!( "Encryption key must be 32 bytes (256 bits), got {} bytes", key_bytes.len() ); } let key_array: [u8; 32] = key_bytes.try_into().unwrap(); // safe due to length check // Check if already initialized with same key if let Some(existing) = RAW_KEY.get() { if existing == &key_array { // Same key already initialized, okay return Ok(()); } else { anyhow::bail!("Encryption key already initialized with a different key"); } } // Store raw key bytes RAW_KEY .set(key_array) .map_err(|_| anyhow::anyhow!("Encryption key already initialized"))?; Ok(()) } /// Initialize the encryption key from the environment variable `LLM_PROXY__ENCRYPTION_KEY`. /// Must be called before any encryption/decryption operations. /// Panics if the environment variable is missing or invalid. pub fn init_from_env() -> Result<()> { let key_str = env::var("LLM_PROXY__ENCRYPTION_KEY").context("LLM_PROXY__ENCRYPTION_KEY environment variable not set")?; init_with_key(&key_str) } /// Get the encryption key bytes, panicking if not initialized. fn get_key() -> &'static [u8; 32] { RAW_KEY .get() .expect("Encryption key not initialized. Call crypto::init_with_key() or crypto::init_from_env() first.") } /// Encrypt a plaintext string and return a base64-encoded ciphertext (nonce || ciphertext || tag). pub fn encrypt(plaintext: &str) -> Result { let key = Key::::from_slice(get_key()); let cipher = Aes256Gcm::new(key); let nonce = Aes256Gcm::generate_nonce(&mut OsRng); // 12 bytes let ciphertext = cipher .encrypt(&nonce, plaintext.as_bytes()) .map_err(|e| anyhow!("Encryption failed: {}", e))?; // Combine nonce and ciphertext (ciphertext already includes tag) let mut combined = Vec::with_capacity(nonce.len() + ciphertext.len()); combined.extend_from_slice(&nonce); combined.extend_from_slice(&ciphertext); Ok(BASE64.encode(combined)) } /// Decrypt a base64-encoded ciphertext (nonce || ciphertext || tag) to a plaintext string. pub fn decrypt(ciphertext_b64: &str) -> Result { let key = Key::::from_slice(get_key()); let cipher = Aes256Gcm::new(key); let combined = BASE64.decode(ciphertext_b64).context("Invalid base64 ciphertext")?; if combined.len() < 12 { anyhow::bail!("Ciphertext too short"); } let (nonce_bytes, ciphertext_and_tag) = combined.split_at(12); let nonce = Nonce::from_slice(nonce_bytes); let plaintext_bytes = cipher .decrypt(nonce, ciphertext_and_tag) .map_err(|e| anyhow!("Decryption failed (invalid key or corrupted ciphertext): {}", e))?; String::from_utf8(plaintext_bytes).context("Decrypted bytes are not valid UTF-8") } #[cfg(test)] mod tests { use super::*; const TEST_KEY_HEX: &str = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"; #[test] fn test_encrypt_decrypt() { init_with_key(TEST_KEY_HEX).unwrap(); let plaintext = "super secret api key"; let ciphertext = encrypt(plaintext).unwrap(); assert_ne!(ciphertext, plaintext); let decrypted = decrypt(&ciphertext).unwrap(); assert_eq!(decrypted, plaintext); } #[test] fn test_different_inputs_produce_different_ciphertexts() { init_with_key(TEST_KEY_HEX).unwrap(); let plaintext = "same"; let cipher1 = encrypt(plaintext).unwrap(); let cipher2 = encrypt(plaintext).unwrap(); assert_ne!(cipher1, cipher2, "Nonce should make ciphertexts differ"); assert_eq!(decrypt(&cipher1).unwrap(), plaintext); assert_eq!(decrypt(&cipher2).unwrap(), plaintext); } #[test] fn test_invalid_key_length() { let result = init_with_key("tooshort"); assert!(result.is_err()); } #[test] fn test_init_from_env() { unsafe { std::env::set_var("LLM_PROXY__ENCRYPTION_KEY", TEST_KEY_HEX) }; let result = init_from_env(); assert!(result.is_ok()); // Ensure encryption works let ciphertext = encrypt("test").unwrap(); let decrypted = decrypt(&ciphertext).unwrap(); assert_eq!(decrypted, "test"); } #[test] fn test_missing_env_key() { unsafe { std::env::remove_var("LLM_PROXY__ENCRYPTION_KEY") }; let result = init_from_env(); assert!(result.is_err()); } #[test] fn test_key_hex_and_base64() { // Hex key works init_with_key(TEST_KEY_HEX).unwrap(); // Base64 key (same bytes encoded as base64) let base64_key = BASE64.encode(hex::decode(TEST_KEY_HEX).unwrap()); // Re-initialization with same key (different encoding) is allowed let result = init_with_key(&base64_key); assert!(result.is_ok()); // Encryption should still work let ciphertext = encrypt("test").unwrap(); let decrypted = decrypt(&ciphertext).unwrap(); assert_eq!(decrypted, "test"); } #[test] #[ignore] // conflicts with global state from other tests fn test_already_initialized() { init_with_key(TEST_KEY_HEX).unwrap(); let result = init_with_key(TEST_KEY_HEX); assert!(result.is_ok()); // same key allowed } #[test] #[ignore] // conflicts with global state from other tests fn test_already_initialized_different_key() { init_with_key(TEST_KEY_HEX).unwrap(); let different_key = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e20"; let result = init_with_key(different_key); assert!(result.is_err()); } }