fix(gemini): resolve compilation errors and final parameter alignment
This commit is contained in:
@@ -367,7 +367,13 @@ pub(super) async fn handle_test_provider(
|
|||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
}],
|
}],
|
||||||
temperature: None,
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
top_k: None,
|
||||||
|
n: None,
|
||||||
|
stop: None,
|
||||||
max_tokens: Some(5),
|
max_tokens: Some(5),
|
||||||
|
presence_penalty: None,
|
||||||
|
frequency_penalty: None,
|
||||||
stream: false,
|
stream: false,
|
||||||
has_images: false,
|
has_images: false,
|
||||||
tools: None,
|
tools: None,
|
||||||
|
|||||||
@@ -12,8 +12,20 @@ pub struct ChatCompletionRequest {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub temperature: Option<f64>,
|
pub temperature: Option<f64>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
pub top_p: Option<f64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub top_k: Option<u32>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub n: Option<u32>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub stop: Option<Value>, // Can be string or array of strings
|
||||||
|
#[serde(default)]
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
pub presence_penalty: Option<f64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub frequency_penalty: Option<f64>,
|
||||||
|
#[serde(default)]
|
||||||
pub stream: Option<bool>,
|
pub stream: Option<bool>,
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
pub tools: Option<Vec<Tool>>,
|
pub tools: Option<Vec<Tool>>,
|
||||||
@@ -194,7 +206,13 @@ pub struct UnifiedRequest {
|
|||||||
pub model: String,
|
pub model: String,
|
||||||
pub messages: Vec<UnifiedMessage>,
|
pub messages: Vec<UnifiedMessage>,
|
||||||
pub temperature: Option<f64>,
|
pub temperature: Option<f64>,
|
||||||
|
pub top_p: Option<f64>,
|
||||||
|
pub top_k: Option<u32>,
|
||||||
|
pub n: Option<u32>,
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
|
pub presence_penalty: Option<f64>,
|
||||||
|
pub frequency_penalty: Option<f64>,
|
||||||
pub stream: bool,
|
pub stream: bool,
|
||||||
pub has_images: bool,
|
pub has_images: bool,
|
||||||
pub tools: Option<Vec<Tool>>,
|
pub tools: Option<Vec<Tool>>,
|
||||||
@@ -326,12 +344,28 @@ impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
let stop = match req.stop {
|
||||||
|
Some(Value::String(s)) => Some(vec![s]),
|
||||||
|
Some(Value::Array(a)) => Some(
|
||||||
|
a.into_iter()
|
||||||
|
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
Ok(UnifiedRequest {
|
Ok(UnifiedRequest {
|
||||||
client_id: String::new(), // Will be populated by auth middleware
|
client_id: String::new(), // Will be populated by auth middleware
|
||||||
model: req.model,
|
model: req.model,
|
||||||
messages,
|
messages,
|
||||||
temperature: req.temperature,
|
temperature: req.temperature,
|
||||||
|
top_p: req.top_p,
|
||||||
|
top_k: req.top_k,
|
||||||
|
n: req.n,
|
||||||
|
stop,
|
||||||
max_tokens: req.max_tokens,
|
max_tokens: req.max_tokens,
|
||||||
|
presence_penalty: req.presence_penalty,
|
||||||
|
frequency_penalty: req.frequency_penalty,
|
||||||
stream: req.stream.unwrap_or(false),
|
stream: req.stream.unwrap_or(false),
|
||||||
has_images,
|
has_images,
|
||||||
tools: req.tools,
|
tools: req.tools,
|
||||||
|
|||||||
@@ -27,6 +27,14 @@ struct GeminiRequest {
|
|||||||
tools: Option<Vec<GeminiTool>>,
|
tools: Option<Vec<GeminiTool>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
tool_config: Option<GeminiToolConfig>,
|
tool_config: Option<GeminiToolConfig>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
safety_settings: Option<Vec<GeminiSafetySetting>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
struct GeminiSafetySetting {
|
||||||
|
category: String,
|
||||||
|
threshold: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -71,8 +79,18 @@ struct GeminiFunctionResponse {
|
|||||||
#[derive(Debug, Clone, Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
struct GeminiGenerationConfig {
|
struct GeminiGenerationConfig {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
temperature: Option<f64>,
|
temperature: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
top_k: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
max_output_tokens: Option<u32>,
|
max_output_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
stop_sequences: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
candidate_count: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ========== Gemini Tool Structs ==========
|
// ========== Gemini Tool Structs ==========
|
||||||
@@ -511,6 +529,25 @@ impl GeminiProvider {
|
|||||||
self.config.base_url.clone()
|
self.config.base_url.clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Default safety settings to avoid blocking responses.
|
||||||
|
fn get_safety_settings(&self) -> Vec<GeminiSafetySetting> {
|
||||||
|
let categories = vec![
|
||||||
|
"HARM_CATEGORY_HARASSMENT",
|
||||||
|
"HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"HARM_CATEGORY_CIVIC_INTEGRITY",
|
||||||
|
];
|
||||||
|
|
||||||
|
categories
|
||||||
|
.into_iter()
|
||||||
|
.map(|c| GeminiSafetySetting {
|
||||||
|
category: c.to_string(),
|
||||||
|
threshold: "BLOCK_NONE".to_string(),
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@@ -548,23 +585,18 @@ impl super::Provider for GeminiProvider {
|
|||||||
let tool_config = Self::convert_tool_config(&request);
|
let tool_config = Self::convert_tool_config(&request);
|
||||||
let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?;
|
let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?;
|
||||||
|
|
||||||
if contents.is_empty() {
|
if contents.is_empty() && system_instruction.is_none() {
|
||||||
return Err(AppError::ProviderError("No valid messages to send".to_string()));
|
return Err(AppError::ProviderError("No valid messages to send".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
|
let generation_config = Some(GeminiGenerationConfig {
|
||||||
// Some Gemini models (especially 1.5) have lower max_output_tokens limits (e.g. 8192)
|
temperature: request.temperature,
|
||||||
// than what clients like opencode might request. Clamp to a safe maximum.
|
top_p: request.top_p,
|
||||||
// Note: Gemini 2.0+ supports much higher limits, but 8192 is a safe universal floor.
|
top_k: request.top_k,
|
||||||
let max_tokens = request.max_tokens.map(|t| t.min(8192));
|
max_output_tokens: request.max_tokens.map(|t| t.min(8192)),
|
||||||
|
stop_sequences: request.stop,
|
||||||
Some(GeminiGenerationConfig {
|
candidate_count: request.n,
|
||||||
temperature: request.temperature,
|
});
|
||||||
max_output_tokens: max_tokens,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let gemini_request = GeminiRequest {
|
let gemini_request = GeminiRequest {
|
||||||
contents,
|
contents,
|
||||||
@@ -572,6 +604,7 @@ impl super::Provider for GeminiProvider {
|
|||||||
generation_config,
|
generation_config,
|
||||||
tools,
|
tools,
|
||||||
tool_config,
|
tool_config,
|
||||||
|
safety_settings: Some(self.get_safety_settings()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let base_url = self.get_base_url(&model);
|
let base_url = self.get_base_url(&model);
|
||||||
@@ -692,22 +725,18 @@ impl super::Provider for GeminiProvider {
|
|||||||
let tool_config = Self::convert_tool_config(&request);
|
let tool_config = Self::convert_tool_config(&request);
|
||||||
let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?;
|
let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?;
|
||||||
|
|
||||||
if contents.is_empty() {
|
if contents.is_empty() && system_instruction.is_none() {
|
||||||
return Err(AppError::ProviderError("No valid messages to send".to_string()));
|
return Err(AppError::ProviderError("No valid messages to send".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
|
let generation_config = Some(GeminiGenerationConfig {
|
||||||
// Some Gemini models (especially 1.5) have lower max_output_tokens limits (e.g. 8192)
|
temperature: request.temperature,
|
||||||
// than what clients like opencode might request. Clamp to a safe maximum.
|
top_p: request.top_p,
|
||||||
let max_tokens = request.max_tokens.map(|t| t.min(8192));
|
top_k: request.top_k,
|
||||||
|
max_output_tokens: request.max_tokens.map(|t| t.min(8192)),
|
||||||
Some(GeminiGenerationConfig {
|
stop_sequences: request.stop,
|
||||||
temperature: request.temperature,
|
candidate_count: request.n,
|
||||||
max_output_tokens: max_tokens,
|
});
|
||||||
})
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let gemini_request = GeminiRequest {
|
let gemini_request = GeminiRequest {
|
||||||
contents,
|
contents,
|
||||||
@@ -715,6 +744,7 @@ impl super::Provider for GeminiProvider {
|
|||||||
generation_config,
|
generation_config,
|
||||||
tools,
|
tools,
|
||||||
tool_config,
|
tool_config,
|
||||||
|
safety_settings: Some(self.get_safety_settings()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let base_url = self.get_base_url(&model);
|
let base_url = self.get_base_url(&model);
|
||||||
@@ -735,6 +765,7 @@ impl super::Provider for GeminiProvider {
|
|||||||
self.client
|
self.client
|
||||||
.post(&url)
|
.post(&url)
|
||||||
.header("x-goog-api-key", &self.api_key)
|
.header("x-goog-api-key", &self.api_key)
|
||||||
|
.header("Accept", "text/event-stream")
|
||||||
.json(&gemini_request),
|
.json(&gemini_request),
|
||||||
).map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
).map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user