fix(gemini): improve tool-calling support and handle function_call response
- Support tool definitions in Gemini requests - Map tool role to 'function' in Gemini content - Ensure tool results are wrapped in JSON objects for Gemini compatibility - Parse FunctionCall from Gemini response and map to OpenAI-compatible ToolCalls - Correctly map finish_reason for tool calls
This commit is contained in:
+106
-11
@@ -31,9 +31,14 @@ func (p *GeminiProvider) Name() string {
|
||||
|
||||
type GeminiRequest struct {
|
||||
Contents []GeminiContent `json:"contents"`
|
||||
Tools []GeminiTool `json:"tools,omitempty"`
|
||||
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiTool struct {
|
||||
FunctionDeclarations []models.FunctionDef `json:"functionDeclarations"`
|
||||
}
|
||||
|
||||
type GeminiGenerationConfig struct {
|
||||
Temperature *float32 `json:"temperature,omitempty"`
|
||||
TopP *float32 `json:"topP,omitempty"`
|
||||
@@ -77,7 +82,7 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
if msg.Role == "assistant" {
|
||||
role = "model"
|
||||
} else if msg.Role == "tool" {
|
||||
role = "user" // Tool results are user-side in Gemini
|
||||
role = "function" // Function results use 'function' role in Gemini contents
|
||||
}
|
||||
|
||||
var parts []GeminiPart
|
||||
@@ -95,10 +100,19 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
name = *msg.Name
|
||||
}
|
||||
|
||||
// Try to parse text as JSON if it looks like it, Gemini expects an object
|
||||
var responseObj interface{}
|
||||
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
|
||||
// If not valid JSON, wrap it in an object
|
||||
responseObj = map[string]interface{}{"result": text}
|
||||
}
|
||||
|
||||
respBytes, _ := json.Marshal(responseObj)
|
||||
|
||||
parts = append(parts, GeminiPart{
|
||||
FunctionResponse: &GeminiFunctionResponse{
|
||||
Name: name,
|
||||
Response: json.RawMessage(text),
|
||||
Response: json.RawMessage(respBytes),
|
||||
},
|
||||
})
|
||||
} else {
|
||||
@@ -161,6 +175,17 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
GenerationConfig: genConfig,
|
||||
}
|
||||
|
||||
// Map Tools
|
||||
if len(req.Tools) > 0 {
|
||||
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
|
||||
for _, t := range req.Tools {
|
||||
if t.Type == "function" {
|
||||
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
|
||||
}
|
||||
}
|
||||
body.Tools = []GeminiTool{geminiTool}
|
||||
}
|
||||
|
||||
baseURL := p.config.BaseURL
|
||||
lowerModel := strings.ToLower(req.Model)
|
||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
|
||||
@@ -190,8 +215,10 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
var geminiResp struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Role string `json:"role"`
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
Text string `json:"text"`
|
||||
FunctionCall *GeminiFunctionCall `json:"functionCall"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
@@ -212,23 +239,44 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
}
|
||||
|
||||
content := ""
|
||||
for _, p := range geminiResp.Candidates[0].Content.Parts {
|
||||
content += p.Text
|
||||
var toolCalls []models.ToolCall
|
||||
for _, part := range geminiResp.Candidates[0].Content.Parts {
|
||||
if part.Text != "" {
|
||||
content += part.Text
|
||||
}
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, models.ToolCall{
|
||||
ID: fmt.Sprintf("call_%s", part.FunctionCall.Name), // Gemini doesn't have call IDs
|
||||
Type: "function",
|
||||
Function: models.FunctionCall{
|
||||
Name: part.FunctionCall.Name,
|
||||
Arguments: string(part.FunctionCall.Args),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
finishReason := strings.ToLower(geminiResp.Candidates[0].FinishReason)
|
||||
if finishReason == "stop" {
|
||||
finishReason = "stop"
|
||||
} else if len(toolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
|
||||
openAIResp := &models.ChatCompletionResponse{
|
||||
ID: "gemini-" + req.Model,
|
||||
Object: "chat.completion",
|
||||
Created: 0, // Should be current timestamp
|
||||
Created: 0,
|
||||
Model: req.Model,
|
||||
Choices: []models.ChatChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: models.ChatMessage{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
},
|
||||
FinishReason: &geminiResp.Candidates[0].FinishReason,
|
||||
FinishReason: &finishReason,
|
||||
},
|
||||
},
|
||||
Usage: &models.Usage{
|
||||
@@ -248,11 +296,48 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
||||
role := "user"
|
||||
if msg.Role == "assistant" {
|
||||
role = "model"
|
||||
} else if msg.Role == "tool" {
|
||||
role = "function"
|
||||
}
|
||||
|
||||
var parts []GeminiPart
|
||||
for _, p := range msg.Content {
|
||||
parts = append(parts, GeminiPart{Text: p.Text})
|
||||
|
||||
if msg.Role == "tool" {
|
||||
text := ""
|
||||
if len(msg.Content) > 0 {
|
||||
text = msg.Content[0].Text
|
||||
}
|
||||
name := "unknown"
|
||||
if msg.Name != nil {
|
||||
name = *msg.Name
|
||||
}
|
||||
|
||||
var responseObj interface{}
|
||||
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
|
||||
responseObj = map[string]interface{}{"result": text}
|
||||
}
|
||||
respBytes, _ := json.Marshal(responseObj)
|
||||
|
||||
parts = append(parts, GeminiPart{
|
||||
FunctionResponse: &GeminiFunctionResponse{
|
||||
Name: name,
|
||||
Response: json.RawMessage(respBytes),
|
||||
},
|
||||
})
|
||||
} else {
|
||||
for _, p := range msg.Content {
|
||||
parts = append(parts, GeminiPart{Text: p.Text})
|
||||
}
|
||||
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
parts = append(parts, GeminiPart{
|
||||
FunctionCall: &GeminiFunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Args: json.RawMessage(tc.Function.Arguments),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
contents = append(contents, GeminiContent{
|
||||
@@ -287,6 +372,16 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
||||
GenerationConfig: genConfig,
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
|
||||
for _, t := range req.Tools {
|
||||
if t.Type == "function" {
|
||||
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
|
||||
}
|
||||
}
|
||||
body.Tools = []GeminiTool{geminiTool}
|
||||
}
|
||||
|
||||
baseURL := p.config.BaseURL
|
||||
lowerModel := strings.ToLower(req.Model)
|
||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
|
||||
|
||||
Reference in New Issue
Block a user