Files
GopherGate/src/providers/gemini.rs
hobokenchicken 9318336f62
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled
feat: add tool-calling passthrough for all providers
Implement full OpenAI-compatible tool-calling support across the proxy,
enabling OpenCode to use llm-proxy as its sole LLM backend.

- Add 9 tool-calling types (Tool, FunctionDef, ToolChoice, ToolCall, etc.)
- Update ChatCompletionRequest/ChatMessage/ChatStreamDelta with tool fields
- Update UnifiedRequest/UnifiedMessage to carry tool data through the pipeline
- Shared helpers: messages_to_openai_json handles tool messages, build_openai_body
  includes tools/tool_choice, parse/stream extract tool_calls from responses
- Gemini: full OpenAI<->Gemini format translation (functionDeclarations,
  functionCall/functionResponse, synthetic call IDs, tool_config mapping)
- Gemini: extract duplicated message-conversion into shared convert_messages()
- Server: SSE streams include tool_calls deltas, finish_reason='tool_calls'
- AggregatingStream: accumulate tool call deltas across stream chunks
- OpenAI provider: add o4- prefix to supports_model()
2026-03-02 09:40:57 -05:00

574 lines
19 KiB
Rust

use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use uuid::Uuid;
use super::{ProviderResponse, ProviderStreamChunk};
use crate::{
config::AppConfig,
errors::AppError,
models::{ContentPart, FunctionCall, FunctionCallDelta, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest},
};
// ========== Gemini Request Structs ==========
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiRequest {
contents: Vec<GeminiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
generation_config: Option<GeminiGenerationConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<GeminiTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_config: Option<GeminiToolConfig>,
}
#[derive(Debug, Serialize, Deserialize)]
struct GeminiContent {
parts: Vec<GeminiPart>,
role: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiPart {
#[serde(skip_serializing_if = "Option::is_none")]
text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
inline_data: Option<GeminiInlineData>,
#[serde(skip_serializing_if = "Option::is_none")]
function_call: Option<GeminiFunctionCall>,
#[serde(skip_serializing_if = "Option::is_none")]
function_response: Option<GeminiFunctionResponse>,
}
#[derive(Debug, Serialize, Deserialize)]
struct GeminiInlineData {
mime_type: String,
data: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct GeminiFunctionCall {
name: String,
args: Value,
}
#[derive(Debug, Serialize, Deserialize)]
struct GeminiFunctionResponse {
name: String,
response: Value,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiGenerationConfig {
temperature: Option<f64>,
max_output_tokens: Option<u32>,
}
// ========== Gemini Tool Structs ==========
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiTool {
function_declarations: Vec<GeminiFunctionDeclaration>,
}
#[derive(Debug, Serialize)]
struct GeminiFunctionDeclaration {
name: String,
#[serde(skip_serializing_if = "Option::is_none")]
description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
parameters: Option<Value>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct GeminiToolConfig {
function_calling_config: GeminiFunctionCallingConfig,
}
#[derive(Debug, Serialize)]
struct GeminiFunctionCallingConfig {
mode: String,
#[serde(skip_serializing_if = "Option::is_none", rename = "allowedFunctionNames")]
allowed_function_names: Option<Vec<String>>,
}
// ========== Gemini Response Structs ==========
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiCandidate {
content: GeminiContent,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiUsageMetadata {
#[serde(default)]
prompt_token_count: u32,
#[serde(default)]
candidates_token_count: u32,
#[serde(default)]
total_token_count: u32,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GeminiResponse {
candidates: Vec<GeminiCandidate>,
usage_metadata: Option<GeminiUsageMetadata>,
}
// ========== Provider Implementation ==========
pub struct GeminiProvider {
client: reqwest::Client,
config: crate::config::GeminiConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>,
}
impl GeminiProvider {
pub fn new(config: &crate::config::GeminiConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("gemini")?;
Self::new_with_key(config, app_config, api_key)
}
pub fn new_with_key(config: &crate::config::GeminiConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()?;
Ok(Self {
client,
config: config.clone(),
api_key,
pricing: app_config.pricing.gemini.clone(),
})
}
/// Convert unified messages to Gemini content format.
/// Handles text, images, tool calls (assistant), and tool results.
async fn convert_messages(messages: Vec<UnifiedMessage>) -> Result<Vec<GeminiContent>, AppError> {
let mut contents = Vec::with_capacity(messages.len());
for msg in messages {
// Tool-result messages → functionResponse parts under role "user"
if msg.role == "tool" {
let text_content = msg
.content
.first()
.map(|p| match p {
ContentPart::Text { text } => text.clone(),
ContentPart::Image(_) => "[Image]".to_string(),
})
.unwrap_or_default();
let name = msg.name.unwrap_or_default();
// Parse the content as JSON if possible, otherwise wrap as string
let response_value = serde_json::from_str::<Value>(&text_content)
.unwrap_or_else(|_| serde_json::json!({ "result": text_content }));
contents.push(GeminiContent {
parts: vec![GeminiPart {
text: None,
inline_data: None,
function_call: None,
function_response: Some(GeminiFunctionResponse {
name,
response: response_value,
}),
}],
role: "user".to_string(),
});
continue;
}
// Assistant messages with tool_calls → functionCall parts
if msg.role == "assistant" {
if let Some(tool_calls) = &msg.tool_calls {
let mut parts = Vec::new();
// Include text content if present
for p in &msg.content {
if let ContentPart::Text { text } = p {
if !text.is_empty() {
parts.push(GeminiPart {
text: Some(text.clone()),
inline_data: None,
function_call: None,
function_response: None,
});
}
}
}
// Convert each tool call to a functionCall part
for tc in tool_calls {
let args = serde_json::from_str::<Value>(&tc.function.arguments)
.unwrap_or_else(|_| serde_json::json!({}));
parts.push(GeminiPart {
text: None,
inline_data: None,
function_call: Some(GeminiFunctionCall {
name: tc.function.name.clone(),
args,
}),
function_response: None,
});
}
contents.push(GeminiContent {
parts,
role: "model".to_string(),
});
continue;
}
}
// Regular text/image messages
let mut parts = Vec::with_capacity(msg.content.len());
for part in msg.content {
match part {
ContentPart::Text { text } => {
parts.push(GeminiPart {
text: Some(text),
inline_data: None,
function_call: None,
function_response: None,
});
}
ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input
.to_base64()
.await
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
parts.push(GeminiPart {
text: None,
inline_data: Some(GeminiInlineData {
mime_type,
data: base64_data,
}),
function_call: None,
function_response: None,
});
}
}
}
let role = match msg.role.as_str() {
"assistant" => "model".to_string(),
_ => "user".to_string(),
};
contents.push(GeminiContent { parts, role });
}
Ok(contents)
}
/// Convert OpenAI tools to Gemini function declarations.
fn convert_tools(request: &UnifiedRequest) -> Option<Vec<GeminiTool>> {
request.tools.as_ref().map(|tools| {
let declarations: Vec<GeminiFunctionDeclaration> = tools
.iter()
.map(|t| GeminiFunctionDeclaration {
name: t.function.name.clone(),
description: t.function.description.clone(),
parameters: t.function.parameters.clone(),
})
.collect();
vec![GeminiTool {
function_declarations: declarations,
}]
})
}
/// Convert OpenAI tool_choice to Gemini tool_config.
fn convert_tool_config(request: &UnifiedRequest) -> Option<GeminiToolConfig> {
request.tool_choice.as_ref().map(|tc| {
let (mode, allowed_names) = match tc {
crate::models::ToolChoice::Mode(mode) => {
let gemini_mode = match mode.as_str() {
"auto" => "AUTO",
"none" => "NONE",
"required" => "ANY",
_ => "AUTO",
};
(gemini_mode.to_string(), None)
}
crate::models::ToolChoice::Specific(specific) => {
("ANY".to_string(), Some(vec![specific.function.name.clone()]))
}
};
GeminiToolConfig {
function_calling_config: GeminiFunctionCallingConfig {
mode,
allowed_function_names: allowed_names,
},
}
})
}
/// Extract tool calls from Gemini response parts into OpenAI-format ToolCalls.
fn extract_tool_calls(parts: &[GeminiPart]) -> Option<Vec<ToolCall>> {
let calls: Vec<ToolCall> = parts
.iter()
.filter_map(|p| p.function_call.as_ref())
.map(|fc| ToolCall {
id: format!("call_{}", Uuid::new_v4().simple()),
call_type: "function".to_string(),
function: FunctionCall {
name: fc.name.clone(),
arguments: serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string()),
},
})
.collect();
if calls.is_empty() { None } else { Some(calls) }
}
/// Extract tool call deltas from Gemini response parts for streaming.
fn extract_tool_call_deltas(parts: &[GeminiPart]) -> Option<Vec<ToolCallDelta>> {
let deltas: Vec<ToolCallDelta> = parts
.iter()
.filter_map(|p| p.function_call.as_ref())
.enumerate()
.map(|(i, fc)| ToolCallDelta {
index: i as u32,
id: Some(format!("call_{}", Uuid::new_v4().simple())),
call_type: Some("function".to_string()),
function: Some(FunctionCallDelta {
name: Some(fc.name.clone()),
arguments: Some(serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string())),
}),
})
.collect();
if deltas.is_empty() { None } else { Some(deltas) }
}
}
#[async_trait]
impl super::Provider for GeminiProvider {
fn name(&self) -> &str {
"gemini"
}
fn supports_model(&self, model: &str) -> bool {
model.starts_with("gemini-")
}
fn supports_multimodal(&self) -> bool {
true // Gemini supports vision
}
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
let model = request.model.clone();
let tools = Self::convert_tools(&request);
let tool_config = Self::convert_tool_config(&request);
let contents = Self::convert_messages(request.messages).await?;
if contents.is_empty() {
return Err(AppError::ProviderError("No valid messages to send".to_string()));
}
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
Some(GeminiGenerationConfig {
temperature: request.temperature,
max_output_tokens: request.max_tokens,
})
} else {
None
};
let gemini_request = GeminiRequest {
contents,
generation_config,
tools,
tool_config,
};
let url = format!("{}/models/{}:generateContent", self.config.base_url, model);
let response = self
.client
.post(&url)
.header("x-goog-api-key", &self.api_key)
.json(&gemini_request)
.send()
.await
.map_err(|e| AppError::ProviderError(format!("HTTP request failed: {}", e)))?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!(
"Gemini API error ({}): {}",
status, error_text
)));
}
let gemini_response: GeminiResponse = response
.json()
.await
.map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?;
let candidate = gemini_response.candidates.first();
// Extract text content (may be absent if only function calls)
let content = candidate
.and_then(|c| c.content.parts.iter().find_map(|p| p.text.clone()))
.unwrap_or_default();
// Extract function calls → OpenAI tool_calls
let tool_calls = candidate.and_then(|c| Self::extract_tool_calls(&c.content.parts));
let prompt_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.prompt_token_count)
.unwrap_or(0);
let completion_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.candidates_token_count)
.unwrap_or(0);
let total_tokens = gemini_response
.usage_metadata
.as_ref()
.map(|u| u.total_token_count)
.unwrap_or(0);
Ok(ProviderResponse {
content,
reasoning_content: None,
tool_calls,
prompt_tokens,
completion_tokens,
total_tokens,
model,
})
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(
&self,
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
) -> f64 {
super::helpers::calculate_cost_with_registry(
model,
prompt_tokens,
completion_tokens,
registry,
&self.pricing,
0.075,
0.30,
)
}
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
let model = request.model.clone();
let tools = Self::convert_tools(&request);
let tool_config = Self::convert_tool_config(&request);
let contents = Self::convert_messages(request.messages).await?;
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
Some(GeminiGenerationConfig {
temperature: request.temperature,
max_output_tokens: request.max_tokens,
})
} else {
None
};
let gemini_request = GeminiRequest {
contents,
generation_config,
tools,
tool_config,
};
let url = format!(
"{}/models/{}:streamGenerateContent?alt=sse",
self.config.base_url, model,
);
use futures::StreamExt;
use reqwest_eventsource::{Event, EventSource};
let es = EventSource::new(
self.client
.post(&url)
.header("x-goog-api-key", &self.api_key)
.json(&gemini_request),
)
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
let gemini_response: GeminiResponse = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(candidate) = gemini_response.candidates.first() {
let content = candidate
.content
.parts
.iter()
.find_map(|p| p.text.clone())
.unwrap_or_default();
let tool_calls = Self::extract_tool_call_deltas(&candidate.content.parts);
// Determine finish_reason
let finish_reason = candidate.finish_reason.as_ref().map(|fr| {
match fr.as_str() {
"STOP" => "stop".to_string(),
_ => fr.to_lowercase(),
}
});
yield ProviderStreamChunk {
content,
reasoning_content: None,
finish_reason,
tool_calls,
model: model.clone(),
};
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Ok(Box::pin(stream))
}
}