diff --git a/src/dashboard/providers.rs b/src/dashboard/providers.rs index c5c3ce24..5713a3e8 100644 --- a/src/dashboard/providers.rs +++ b/src/dashboard/providers.rs @@ -367,7 +367,13 @@ pub(super) async fn handle_test_provider( tool_call_id: None, }], temperature: None, + top_p: None, + top_k: None, + n: None, + stop: None, max_tokens: Some(5), + presence_penalty: None, + frequency_penalty: None, stream: false, has_images: false, tools: None, diff --git a/src/models/mod.rs b/src/models/mod.rs index 2eea0b8f..491edcea 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -12,8 +12,20 @@ pub struct ChatCompletionRequest { #[serde(default)] pub temperature: Option, #[serde(default)] + pub top_p: Option, + #[serde(default)] + pub top_k: Option, + #[serde(default)] + pub n: Option, + #[serde(default)] + pub stop: Option, // Can be string or array of strings + #[serde(default)] pub max_tokens: Option, #[serde(default)] + pub presence_penalty: Option, + #[serde(default)] + pub frequency_penalty: Option, + #[serde(default)] pub stream: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub tools: Option>, @@ -194,7 +206,13 @@ pub struct UnifiedRequest { pub model: String, pub messages: Vec, pub temperature: Option, + pub top_p: Option, + pub top_k: Option, + pub n: Option, + pub stop: Option>, pub max_tokens: Option, + pub presence_penalty: Option, + pub frequency_penalty: Option, pub stream: bool, pub has_images: bool, pub tools: Option>, @@ -326,12 +344,28 @@ impl TryFrom for UnifiedRequest { }) .collect(); + let stop = match req.stop { + Some(Value::String(s)) => Some(vec![s]), + Some(Value::Array(a)) => Some( + a.into_iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect(), + ), + _ => None, + }; + Ok(UnifiedRequest { client_id: String::new(), // Will be populated by auth middleware model: req.model, messages, temperature: req.temperature, + top_p: req.top_p, + top_k: req.top_k, + n: req.n, + stop, max_tokens: req.max_tokens, + presence_penalty: req.presence_penalty, + frequency_penalty: req.frequency_penalty, stream: req.stream.unwrap_or(false), has_images, tools: req.tools, diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index 1b123b8d..dc783856 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -27,6 +27,14 @@ struct GeminiRequest { tools: Option>, #[serde(skip_serializing_if = "Option::is_none")] tool_config: Option, + #[serde(skip_serializing_if = "Option::is_none")] + safety_settings: Option>, +} + +#[derive(Debug, Clone, Serialize)] +struct GeminiSafetySetting { + category: String, + threshold: String, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -71,8 +79,18 @@ struct GeminiFunctionResponse { #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "camelCase")] struct GeminiGenerationConfig { + #[serde(skip_serializing_if = "Option::is_none")] temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] max_output_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + stop_sequences: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + candidate_count: Option, } // ========== Gemini Tool Structs ========== @@ -511,6 +529,25 @@ impl GeminiProvider { self.config.base_url.clone() } } + + /// Default safety settings to avoid blocking responses. + fn get_safety_settings(&self) -> Vec { + let categories = vec![ + "HARM_CATEGORY_HARASSMENT", + "HARM_CATEGORY_HATE_SPEECH", + "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "HARM_CATEGORY_DANGEROUS_CONTENT", + "HARM_CATEGORY_CIVIC_INTEGRITY", + ]; + + categories + .into_iter() + .map(|c| GeminiSafetySetting { + category: c.to_string(), + threshold: "BLOCK_NONE".to_string(), + }) + .collect() + } } #[async_trait] @@ -548,23 +585,18 @@ impl super::Provider for GeminiProvider { let tool_config = Self::convert_tool_config(&request); let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?; - if contents.is_empty() { + if contents.is_empty() && system_instruction.is_none() { return Err(AppError::ProviderError("No valid messages to send".to_string())); } - let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() { - // Some Gemini models (especially 1.5) have lower max_output_tokens limits (e.g. 8192) - // than what clients like opencode might request. Clamp to a safe maximum. - // Note: Gemini 2.0+ supports much higher limits, but 8192 is a safe universal floor. - let max_tokens = request.max_tokens.map(|t| t.min(8192)); - - Some(GeminiGenerationConfig { - temperature: request.temperature, - max_output_tokens: max_tokens, - }) - } else { - None - }; + let generation_config = Some(GeminiGenerationConfig { + temperature: request.temperature, + top_p: request.top_p, + top_k: request.top_k, + max_output_tokens: request.max_tokens.map(|t| t.min(8192)), + stop_sequences: request.stop, + candidate_count: request.n, + }); let gemini_request = GeminiRequest { contents, @@ -572,6 +604,7 @@ impl super::Provider for GeminiProvider { generation_config, tools, tool_config, + safety_settings: Some(self.get_safety_settings()), }; let base_url = self.get_base_url(&model); @@ -692,22 +725,18 @@ impl super::Provider for GeminiProvider { let tool_config = Self::convert_tool_config(&request); let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?; - if contents.is_empty() { + if contents.is_empty() && system_instruction.is_none() { return Err(AppError::ProviderError("No valid messages to send".to_string())); } - let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() { - // Some Gemini models (especially 1.5) have lower max_output_tokens limits (e.g. 8192) - // than what clients like opencode might request. Clamp to a safe maximum. - let max_tokens = request.max_tokens.map(|t| t.min(8192)); - - Some(GeminiGenerationConfig { - temperature: request.temperature, - max_output_tokens: max_tokens, - }) - } else { - None - }; + let generation_config = Some(GeminiGenerationConfig { + temperature: request.temperature, + top_p: request.top_p, + top_k: request.top_k, + max_output_tokens: request.max_tokens.map(|t| t.min(8192)), + stop_sequences: request.stop, + candidate_count: request.n, + }); let gemini_request = GeminiRequest { contents, @@ -715,6 +744,7 @@ impl super::Provider for GeminiProvider { generation_config, tools, tool_config, + safety_settings: Some(self.get_safety_settings()), }; let base_url = self.get_base_url(&model); @@ -735,6 +765,7 @@ impl super::Provider for GeminiProvider { self.client .post(&url) .header("x-goog-api-key", &self.api_key) + .header("Accept", "text/event-stream") .json(&gemini_request), ).map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;