diff --git a/src/config/mod.rs b/src/config/mod.rs index 649a35e8..b0144aec 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -8,6 +8,7 @@ use std::sync::Arc; pub struct ServerConfig { pub port: u16, pub host: String, + #[serde(deserialize_with = "deserialize_vec_or_string")] pub auth_tokens: Vec, } @@ -62,6 +63,7 @@ pub struct GrokConfig { pub struct OllamaConfig { pub base_url: String, pub enabled: bool, + #[serde(deserialize_with = "deserialize_vec_or_string")] pub models: Vec, } @@ -185,7 +187,48 @@ impl AppConfig { _ => return Err(anyhow::anyhow!("Unknown provider: {}", provider)), }; - std::env::var(env_var) - .map_err(|_| anyhow::anyhow!("Environment variable {} not set for {}", env_var, provider)) - } -} \ No newline at end of file + std::env::var(env_var) + .map_err(|_| anyhow::anyhow!("Environment variable {} not set for {}", env_var, provider)) + } + } + + /// Helper function to deserialize a Vec from either a sequence or a comma-separated string + fn deserialize_vec_or_string<'de, D>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + struct VecOrString; + + impl<'de> serde::de::Visitor<'de> for VecOrString { + type Value = Vec; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a sequence or a comma-separated string") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + Ok(value + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect()) + } + + fn visit_seq(self, mut seq: S) -> Result + where + S: serde::de::SeqAccess<'de>, + { + let mut vec = Vec::new(); + while let Some(element) = seq.next_element()? { + vec.push(element); + } + Ok(vec) + } + } + + deserializer.deserialize_any(VecOrString) + } + \ No newline at end of file