fix(gemini): resolve compilation errors and final parameter alignment
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled

This commit is contained in:
2026-03-05 15:57:33 +00:00
parent 3086a3b6d9
commit f8598060f9
3 changed files with 98 additions and 27 deletions

View File

@@ -367,7 +367,13 @@ pub(super) async fn handle_test_provider(
tool_call_id: None,
}],
temperature: None,
top_p: None,
top_k: None,
n: None,
stop: None,
max_tokens: Some(5),
presence_penalty: None,
frequency_penalty: None,
stream: false,
has_images: false,
tools: None,

View File

@@ -12,8 +12,20 @@ pub struct ChatCompletionRequest {
#[serde(default)]
pub temperature: Option<f64>,
#[serde(default)]
pub top_p: Option<f64>,
#[serde(default)]
pub top_k: Option<u32>,
#[serde(default)]
pub n: Option<u32>,
#[serde(default)]
pub stop: Option<Value>, // Can be string or array of strings
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub presence_penalty: Option<f64>,
#[serde(default)]
pub frequency_penalty: Option<f64>,
#[serde(default)]
pub stream: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
@@ -194,7 +206,13 @@ pub struct UnifiedRequest {
pub model: String,
pub messages: Vec<UnifiedMessage>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub top_k: Option<u32>,
pub n: Option<u32>,
pub stop: Option<Vec<String>>,
pub max_tokens: Option<u32>,
pub presence_penalty: Option<f64>,
pub frequency_penalty: Option<f64>,
pub stream: bool,
pub has_images: bool,
pub tools: Option<Vec<Tool>>,
@@ -326,12 +344,28 @@ impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
})
.collect();
let stop = match req.stop {
Some(Value::String(s)) => Some(vec![s]),
Some(Value::Array(a)) => Some(
a.into_iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect(),
),
_ => None,
};
Ok(UnifiedRequest {
client_id: String::new(), // Will be populated by auth middleware
model: req.model,
messages,
temperature: req.temperature,
top_p: req.top_p,
top_k: req.top_k,
n: req.n,
stop,
max_tokens: req.max_tokens,
presence_penalty: req.presence_penalty,
frequency_penalty: req.frequency_penalty,
stream: req.stream.unwrap_or(false),
has_images,
tools: req.tools,

View File

@@ -27,6 +27,14 @@ struct GeminiRequest {
tools: Option<Vec<GeminiTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_config: Option<GeminiToolConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
safety_settings: Option<Vec<GeminiSafetySetting>>,
}
#[derive(Debug, Clone, Serialize)]
struct GeminiSafetySetting {
category: String,
threshold: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -71,8 +79,18 @@ struct GeminiFunctionResponse {
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiGenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_output_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
candidate_count: Option<u32>,
}
// ========== Gemini Tool Structs ==========
@@ -511,6 +529,25 @@ impl GeminiProvider {
self.config.base_url.clone()
}
}
/// Default safety settings to avoid blocking responses.
fn get_safety_settings(&self) -> Vec<GeminiSafetySetting> {
let categories = vec![
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
"HARM_CATEGORY_CIVIC_INTEGRITY",
];
categories
.into_iter()
.map(|c| GeminiSafetySetting {
category: c.to_string(),
threshold: "BLOCK_NONE".to_string(),
})
.collect()
}
}
#[async_trait]
@@ -548,23 +585,18 @@ impl super::Provider for GeminiProvider {
let tool_config = Self::convert_tool_config(&request);
let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?;
if contents.is_empty() {
if contents.is_empty() && system_instruction.is_none() {
return Err(AppError::ProviderError("No valid messages to send".to_string()));
}
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
// Some Gemini models (especially 1.5) have lower max_output_tokens limits (e.g. 8192)
// than what clients like opencode might request. Clamp to a safe maximum.
// Note: Gemini 2.0+ supports much higher limits, but 8192 is a safe universal floor.
let max_tokens = request.max_tokens.map(|t| t.min(8192));
Some(GeminiGenerationConfig {
let generation_config = Some(GeminiGenerationConfig {
temperature: request.temperature,
max_output_tokens: max_tokens,
})
} else {
None
};
top_p: request.top_p,
top_k: request.top_k,
max_output_tokens: request.max_tokens.map(|t| t.min(8192)),
stop_sequences: request.stop,
candidate_count: request.n,
});
let gemini_request = GeminiRequest {
contents,
@@ -572,6 +604,7 @@ impl super::Provider for GeminiProvider {
generation_config,
tools,
tool_config,
safety_settings: Some(self.get_safety_settings()),
};
let base_url = self.get_base_url(&model);
@@ -692,22 +725,18 @@ impl super::Provider for GeminiProvider {
let tool_config = Self::convert_tool_config(&request);
let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?;
if contents.is_empty() {
if contents.is_empty() && system_instruction.is_none() {
return Err(AppError::ProviderError("No valid messages to send".to_string()));
}
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
// Some Gemini models (especially 1.5) have lower max_output_tokens limits (e.g. 8192)
// than what clients like opencode might request. Clamp to a safe maximum.
let max_tokens = request.max_tokens.map(|t| t.min(8192));
Some(GeminiGenerationConfig {
let generation_config = Some(GeminiGenerationConfig {
temperature: request.temperature,
max_output_tokens: max_tokens,
})
} else {
None
};
top_p: request.top_p,
top_k: request.top_k,
max_output_tokens: request.max_tokens.map(|t| t.min(8192)),
stop_sequences: request.stop,
candidate_count: request.n,
});
let gemini_request = GeminiRequest {
contents,
@@ -715,6 +744,7 @@ impl super::Provider for GeminiProvider {
generation_config,
tools,
tool_config,
safety_settings: Some(self.get_safety_settings()),
};
let base_url = self.get_base_url(&model);
@@ -735,6 +765,7 @@ impl super::Provider for GeminiProvider {
self.client
.post(&url)
.header("x-goog-api-key", &self.api_key)
.header("Accept", "text/event-stream")
.json(&gemini_request),
).map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;