fix(gemini): improve tool-calling support and handle function_call response
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled

- Support tool definitions in Gemini requests
- Map tool role to 'function' in Gemini content
- Ensure tool results are wrapped in JSON objects for Gemini compatibility
- Parse FunctionCall from Gemini response and map to OpenAI-compatible ToolCalls
- Correctly map finish_reason for tool calls
This commit is contained in:
2026-04-07 18:37:57 +00:00
parent 21e5204abd
commit e67aafdac1
+101 -6
View File
@@ -31,9 +31,14 @@ func (p *GeminiProvider) Name() string {
type GeminiRequest struct { type GeminiRequest struct {
Contents []GeminiContent `json:"contents"` Contents []GeminiContent `json:"contents"`
Tools []GeminiTool `json:"tools,omitempty"`
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"` GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
} }
type GeminiTool struct {
FunctionDeclarations []models.FunctionDef `json:"functionDeclarations"`
}
type GeminiGenerationConfig struct { type GeminiGenerationConfig struct {
Temperature *float32 `json:"temperature,omitempty"` Temperature *float32 `json:"temperature,omitempty"`
TopP *float32 `json:"topP,omitempty"` TopP *float32 `json:"topP,omitempty"`
@@ -77,7 +82,7 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
if msg.Role == "assistant" { if msg.Role == "assistant" {
role = "model" role = "model"
} else if msg.Role == "tool" { } else if msg.Role == "tool" {
role = "user" // Tool results are user-side in Gemini role = "function" // Function results use 'function' role in Gemini contents
} }
var parts []GeminiPart var parts []GeminiPart
@@ -95,10 +100,19 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
name = *msg.Name name = *msg.Name
} }
// Try to parse text as JSON if it looks like it, Gemini expects an object
var responseObj interface{}
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
// If not valid JSON, wrap it in an object
responseObj = map[string]interface{}{"result": text}
}
respBytes, _ := json.Marshal(responseObj)
parts = append(parts, GeminiPart{ parts = append(parts, GeminiPart{
FunctionResponse: &GeminiFunctionResponse{ FunctionResponse: &GeminiFunctionResponse{
Name: name, Name: name,
Response: json.RawMessage(text), Response: json.RawMessage(respBytes),
}, },
}) })
} else { } else {
@@ -161,6 +175,17 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
GenerationConfig: genConfig, GenerationConfig: genConfig,
} }
// Map Tools
if len(req.Tools) > 0 {
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
for _, t := range req.Tools {
if t.Type == "function" {
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
}
}
body.Tools = []GeminiTool{geminiTool}
}
baseURL := p.config.BaseURL baseURL := p.config.BaseURL
lowerModel := strings.ToLower(req.Model) lowerModel := strings.ToLower(req.Model)
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") { if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
@@ -190,8 +215,10 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
var geminiResp struct { var geminiResp struct {
Candidates []struct { Candidates []struct {
Content struct { Content struct {
Role string `json:"role"`
Parts []struct { Parts []struct {
Text string `json:"text"` Text string `json:"text"`
FunctionCall *GeminiFunctionCall `json:"functionCall"`
} `json:"parts"` } `json:"parts"`
} `json:"content"` } `json:"content"`
FinishReason string `json:"finishReason"` FinishReason string `json:"finishReason"`
@@ -212,14 +239,34 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
} }
content := "" content := ""
for _, p := range geminiResp.Candidates[0].Content.Parts { var toolCalls []models.ToolCall
content += p.Text for _, part := range geminiResp.Candidates[0].Content.Parts {
if part.Text != "" {
content += part.Text
}
if part.FunctionCall != nil {
toolCalls = append(toolCalls, models.ToolCall{
ID: fmt.Sprintf("call_%s", part.FunctionCall.Name), // Gemini doesn't have call IDs
Type: "function",
Function: models.FunctionCall{
Name: part.FunctionCall.Name,
Arguments: string(part.FunctionCall.Args),
},
})
}
}
finishReason := strings.ToLower(geminiResp.Candidates[0].FinishReason)
if finishReason == "stop" {
finishReason = "stop"
} else if len(toolCalls) > 0 {
finishReason = "tool_calls"
} }
openAIResp := &models.ChatCompletionResponse{ openAIResp := &models.ChatCompletionResponse{
ID: "gemini-" + req.Model, ID: "gemini-" + req.Model,
Object: "chat.completion", Object: "chat.completion",
Created: 0, // Should be current timestamp Created: 0,
Model: req.Model, Model: req.Model,
Choices: []models.ChatChoice{ Choices: []models.ChatChoice{
{ {
@@ -227,8 +274,9 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
Message: models.ChatMessage{ Message: models.ChatMessage{
Role: "assistant", Role: "assistant",
Content: content, Content: content,
ToolCalls: toolCalls,
}, },
FinishReason: &geminiResp.Candidates[0].FinishReason, FinishReason: &finishReason,
}, },
}, },
Usage: &models.Usage{ Usage: &models.Usage{
@@ -248,12 +296,49 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
role := "user" role := "user"
if msg.Role == "assistant" { if msg.Role == "assistant" {
role = "model" role = "model"
} else if msg.Role == "tool" {
role = "function"
} }
var parts []GeminiPart var parts []GeminiPart
if msg.Role == "tool" {
text := ""
if len(msg.Content) > 0 {
text = msg.Content[0].Text
}
name := "unknown"
if msg.Name != nil {
name = *msg.Name
}
var responseObj interface{}
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
responseObj = map[string]interface{}{"result": text}
}
respBytes, _ := json.Marshal(responseObj)
parts = append(parts, GeminiPart{
FunctionResponse: &GeminiFunctionResponse{
Name: name,
Response: json.RawMessage(respBytes),
},
})
} else {
for _, p := range msg.Content { for _, p := range msg.Content {
parts = append(parts, GeminiPart{Text: p.Text}) parts = append(parts, GeminiPart{Text: p.Text})
} }
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
for _, tc := range msg.ToolCalls {
parts = append(parts, GeminiPart{
FunctionCall: &GeminiFunctionCall{
Name: tc.Function.Name,
Args: json.RawMessage(tc.Function.Arguments),
},
})
}
}
}
contents = append(contents, GeminiContent{ contents = append(contents, GeminiContent{
Role: role, Role: role,
@@ -287,6 +372,16 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
GenerationConfig: genConfig, GenerationConfig: genConfig,
} }
if len(req.Tools) > 0 {
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
for _, t := range req.Tools {
if t.Type == "function" {
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
}
}
body.Tools = []GeminiTool{geminiTool}
}
baseURL := p.config.BaseURL baseURL := p.config.BaseURL
lowerModel := strings.ToLower(req.Model) lowerModel := strings.ToLower(req.Model)
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") { if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {