diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index dc783856..5a642523 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -531,15 +531,19 @@ impl GeminiProvider { } /// Default safety settings to avoid blocking responses. - fn get_safety_settings(&self) -> Vec { - let categories = vec![ + fn get_safety_settings(&self, base_url: &str) -> Vec { + let mut categories = vec![ "HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_SEXUALLY_EXPLICIT", "HARM_CATEGORY_DANGEROUS_CONTENT", - "HARM_CATEGORY_CIVIC_INTEGRITY", ]; + // Civic integrity is only available in v1beta + if base_url.contains("v1beta") { + categories.push("HARM_CATEGORY_CIVIC_INTEGRITY"); + } + categories .into_iter() .map(|c| GeminiSafetySetting { @@ -589,12 +593,21 @@ impl super::Provider for GeminiProvider { return Err(AppError::ProviderError("No valid messages to send".to_string())); } + let base_url = self.get_base_url(&model); + + // Sanitize stop sequences: Gemini rejects empty strings + let stop_sequences = request.stop.map(|s| { + s.into_iter() + .filter(|seq| !seq.is_empty()) + .collect::>() + }).filter(|s| !s.is_empty()); + let generation_config = Some(GeminiGenerationConfig { temperature: request.temperature, top_p: request.top_p, top_k: request.top_k, max_output_tokens: request.max_tokens.map(|t| t.min(8192)), - stop_sequences: request.stop, + stop_sequences, candidate_count: request.n, }); @@ -604,10 +617,9 @@ impl super::Provider for GeminiProvider { generation_config, tools, tool_config, - safety_settings: Some(self.get_safety_settings()), + safety_settings: Some(self.get_safety_settings(&base_url)), }; - let base_url = self.get_base_url(&model); let url = format!("{}/models/{}:generateContent", base_url, model); tracing::debug!("Calling Gemini API: {}", url); @@ -729,12 +741,21 @@ impl super::Provider for GeminiProvider { return Err(AppError::ProviderError("No valid messages to send".to_string())); } + let base_url = self.get_base_url(&model); + + // Sanitize stop sequences: Gemini rejects empty strings + let stop_sequences = request.stop.map(|s| { + s.into_iter() + .filter(|seq| !seq.is_empty()) + .collect::>() + }).filter(|s| !s.is_empty()); + let generation_config = Some(GeminiGenerationConfig { temperature: request.temperature, top_p: request.top_p, top_k: request.top_k, max_output_tokens: request.max_tokens.map(|t| t.min(8192)), - stop_sequences: request.stop, + stop_sequences, candidate_count: request.n, }); @@ -744,10 +765,9 @@ impl super::Provider for GeminiProvider { generation_config, tools, tool_config, - safety_settings: Some(self.get_safety_settings()), + safety_settings: Some(self.get_safety_settings(&base_url)), }; - let base_url = self.get_base_url(&model); let url = format!( "{}/models/{}:streamGenerateContent?alt=sse", base_url, model, @@ -757,7 +777,8 @@ impl super::Provider for GeminiProvider { // Capture a clone of the request to probe for errors (Gemini 400s are common) let probe_request = gemini_request.clone(); let probe_client = self.client.clone(); - let probe_url = url.clone(); + // Use non-streaming URL for probing to get a valid JSON error body + let probe_url = format!("{}/models/{}:generateContent", base_url, model); let probe_api_key = self.api_key.clone(); // Create the EventSource first (it doesn't send until polled)