fix(gemini): resolve compilation errors and final parameter alignment
This commit is contained in:
@@ -367,7 +367,13 @@ pub(super) async fn handle_test_provider(
|
||||
tool_call_id: None,
|
||||
}],
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
n: None,
|
||||
stop: None,
|
||||
max_tokens: Some(5),
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
stream: false,
|
||||
has_images: false,
|
||||
tools: None,
|
||||
|
||||
@@ -12,8 +12,20 @@ pub struct ChatCompletionRequest {
|
||||
#[serde(default)]
|
||||
pub temperature: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub top_p: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub top_k: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub n: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub stop: Option<Value>, // Can be string or array of strings
|
||||
#[serde(default)]
|
||||
pub max_tokens: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub presence_penalty: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub frequency_penalty: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub stream: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
@@ -194,7 +206,13 @@ pub struct UnifiedRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<UnifiedMessage>,
|
||||
pub temperature: Option<f64>,
|
||||
pub top_p: Option<f64>,
|
||||
pub top_k: Option<u32>,
|
||||
pub n: Option<u32>,
|
||||
pub stop: Option<Vec<String>>,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub presence_penalty: Option<f64>,
|
||||
pub frequency_penalty: Option<f64>,
|
||||
pub stream: bool,
|
||||
pub has_images: bool,
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
@@ -326,12 +344,28 @@ impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
|
||||
})
|
||||
.collect();
|
||||
|
||||
let stop = match req.stop {
|
||||
Some(Value::String(s)) => Some(vec![s]),
|
||||
Some(Value::Array(a)) => Some(
|
||||
a.into_iter()
|
||||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||||
.collect(),
|
||||
),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Ok(UnifiedRequest {
|
||||
client_id: String::new(), // Will be populated by auth middleware
|
||||
model: req.model,
|
||||
messages,
|
||||
temperature: req.temperature,
|
||||
top_p: req.top_p,
|
||||
top_k: req.top_k,
|
||||
n: req.n,
|
||||
stop,
|
||||
max_tokens: req.max_tokens,
|
||||
presence_penalty: req.presence_penalty,
|
||||
frequency_penalty: req.frequency_penalty,
|
||||
stream: req.stream.unwrap_or(false),
|
||||
has_images,
|
||||
tools: req.tools,
|
||||
|
||||
@@ -27,6 +27,14 @@ struct GeminiRequest {
|
||||
tools: Option<Vec<GeminiTool>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_config: Option<GeminiToolConfig>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
safety_settings: Option<Vec<GeminiSafetySetting>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct GeminiSafetySetting {
|
||||
category: String,
|
||||
threshold: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -71,8 +79,18 @@ struct GeminiFunctionResponse {
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiGenerationConfig {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
temperature: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
top_p: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
top_k: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
max_output_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
stop_sequences: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
candidate_count: Option<u32>,
|
||||
}
|
||||
|
||||
// ========== Gemini Tool Structs ==========
|
||||
@@ -511,6 +529,25 @@ impl GeminiProvider {
|
||||
self.config.base_url.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Default safety settings to avoid blocking responses.
|
||||
fn get_safety_settings(&self) -> Vec<GeminiSafetySetting> {
|
||||
let categories = vec![
|
||||
"HARM_CATEGORY_HARASSMENT",
|
||||
"HARM_CATEGORY_HATE_SPEECH",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"HARM_CATEGORY_CIVIC_INTEGRITY",
|
||||
];
|
||||
|
||||
categories
|
||||
.into_iter()
|
||||
.map(|c| GeminiSafetySetting {
|
||||
category: c.to_string(),
|
||||
threshold: "BLOCK_NONE".to_string(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -548,23 +585,18 @@ impl super::Provider for GeminiProvider {
|
||||
let tool_config = Self::convert_tool_config(&request);
|
||||
let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?;
|
||||
|
||||
if contents.is_empty() {
|
||||
if contents.is_empty() && system_instruction.is_none() {
|
||||
return Err(AppError::ProviderError("No valid messages to send".to_string()));
|
||||
}
|
||||
|
||||
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
|
||||
// Some Gemini models (especially 1.5) have lower max_output_tokens limits (e.g. 8192)
|
||||
// than what clients like opencode might request. Clamp to a safe maximum.
|
||||
// Note: Gemini 2.0+ supports much higher limits, but 8192 is a safe universal floor.
|
||||
let max_tokens = request.max_tokens.map(|t| t.min(8192));
|
||||
|
||||
Some(GeminiGenerationConfig {
|
||||
let generation_config = Some(GeminiGenerationConfig {
|
||||
temperature: request.temperature,
|
||||
max_output_tokens: max_tokens,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
top_p: request.top_p,
|
||||
top_k: request.top_k,
|
||||
max_output_tokens: request.max_tokens.map(|t| t.min(8192)),
|
||||
stop_sequences: request.stop,
|
||||
candidate_count: request.n,
|
||||
});
|
||||
|
||||
let gemini_request = GeminiRequest {
|
||||
contents,
|
||||
@@ -572,6 +604,7 @@ impl super::Provider for GeminiProvider {
|
||||
generation_config,
|
||||
tools,
|
||||
tool_config,
|
||||
safety_settings: Some(self.get_safety_settings()),
|
||||
};
|
||||
|
||||
let base_url = self.get_base_url(&model);
|
||||
@@ -692,22 +725,18 @@ impl super::Provider for GeminiProvider {
|
||||
let tool_config = Self::convert_tool_config(&request);
|
||||
let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?;
|
||||
|
||||
if contents.is_empty() {
|
||||
if contents.is_empty() && system_instruction.is_none() {
|
||||
return Err(AppError::ProviderError("No valid messages to send".to_string()));
|
||||
}
|
||||
|
||||
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
|
||||
// Some Gemini models (especially 1.5) have lower max_output_tokens limits (e.g. 8192)
|
||||
// than what clients like opencode might request. Clamp to a safe maximum.
|
||||
let max_tokens = request.max_tokens.map(|t| t.min(8192));
|
||||
|
||||
Some(GeminiGenerationConfig {
|
||||
let generation_config = Some(GeminiGenerationConfig {
|
||||
temperature: request.temperature,
|
||||
max_output_tokens: max_tokens,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
top_p: request.top_p,
|
||||
top_k: request.top_k,
|
||||
max_output_tokens: request.max_tokens.map(|t| t.min(8192)),
|
||||
stop_sequences: request.stop,
|
||||
candidate_count: request.n,
|
||||
});
|
||||
|
||||
let gemini_request = GeminiRequest {
|
||||
contents,
|
||||
@@ -715,6 +744,7 @@ impl super::Provider for GeminiProvider {
|
||||
generation_config,
|
||||
tools,
|
||||
tool_config,
|
||||
safety_settings: Some(self.get_safety_settings()),
|
||||
};
|
||||
|
||||
let base_url = self.get_base_url(&model);
|
||||
@@ -735,6 +765,7 @@ impl super::Provider for GeminiProvider {
|
||||
self.client
|
||||
.post(&url)
|
||||
.header("x-goog-api-key", &self.api_key)
|
||||
.header("Accept", "text/event-stream")
|
||||
.json(&gemini_request),
|
||||
).map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user