fix(gemini): improve tool-calling support and handle function_call response
- Support tool definitions in Gemini requests - Map tool role to 'function' in Gemini content - Ensure tool results are wrapped in JSON objects for Gemini compatibility - Parse FunctionCall from Gemini response and map to OpenAI-compatible ToolCalls - Correctly map finish_reason for tool calls
This commit is contained in:
@@ -31,9 +31,14 @@ func (p *GeminiProvider) Name() string {
|
|||||||
|
|
||||||
type GeminiRequest struct {
|
type GeminiRequest struct {
|
||||||
Contents []GeminiContent `json:"contents"`
|
Contents []GeminiContent `json:"contents"`
|
||||||
|
Tools []GeminiTool `json:"tools,omitempty"`
|
||||||
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
|
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type GeminiTool struct {
|
||||||
|
FunctionDeclarations []models.FunctionDef `json:"functionDeclarations"`
|
||||||
|
}
|
||||||
|
|
||||||
type GeminiGenerationConfig struct {
|
type GeminiGenerationConfig struct {
|
||||||
Temperature *float32 `json:"temperature,omitempty"`
|
Temperature *float32 `json:"temperature,omitempty"`
|
||||||
TopP *float32 `json:"topP,omitempty"`
|
TopP *float32 `json:"topP,omitempty"`
|
||||||
@@ -77,7 +82,7 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
|||||||
if msg.Role == "assistant" {
|
if msg.Role == "assistant" {
|
||||||
role = "model"
|
role = "model"
|
||||||
} else if msg.Role == "tool" {
|
} else if msg.Role == "tool" {
|
||||||
role = "user" // Tool results are user-side in Gemini
|
role = "function" // Function results use 'function' role in Gemini contents
|
||||||
}
|
}
|
||||||
|
|
||||||
var parts []GeminiPart
|
var parts []GeminiPart
|
||||||
@@ -95,10 +100,19 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
|||||||
name = *msg.Name
|
name = *msg.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Try to parse text as JSON if it looks like it, Gemini expects an object
|
||||||
|
var responseObj interface{}
|
||||||
|
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
|
||||||
|
// If not valid JSON, wrap it in an object
|
||||||
|
responseObj = map[string]interface{}{"result": text}
|
||||||
|
}
|
||||||
|
|
||||||
|
respBytes, _ := json.Marshal(responseObj)
|
||||||
|
|
||||||
parts = append(parts, GeminiPart{
|
parts = append(parts, GeminiPart{
|
||||||
FunctionResponse: &GeminiFunctionResponse{
|
FunctionResponse: &GeminiFunctionResponse{
|
||||||
Name: name,
|
Name: name,
|
||||||
Response: json.RawMessage(text),
|
Response: json.RawMessage(respBytes),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
@@ -161,6 +175,17 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
|||||||
GenerationConfig: genConfig,
|
GenerationConfig: genConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Map Tools
|
||||||
|
if len(req.Tools) > 0 {
|
||||||
|
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
|
||||||
|
for _, t := range req.Tools {
|
||||||
|
if t.Type == "function" {
|
||||||
|
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
body.Tools = []GeminiTool{geminiTool}
|
||||||
|
}
|
||||||
|
|
||||||
baseURL := p.config.BaseURL
|
baseURL := p.config.BaseURL
|
||||||
lowerModel := strings.ToLower(req.Model)
|
lowerModel := strings.ToLower(req.Model)
|
||||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
|
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
|
||||||
@@ -190,8 +215,10 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
|||||||
var geminiResp struct {
|
var geminiResp struct {
|
||||||
Candidates []struct {
|
Candidates []struct {
|
||||||
Content struct {
|
Content struct {
|
||||||
|
Role string `json:"role"`
|
||||||
Parts []struct {
|
Parts []struct {
|
||||||
Text string `json:"text"`
|
Text string `json:"text"`
|
||||||
|
FunctionCall *GeminiFunctionCall `json:"functionCall"`
|
||||||
} `json:"parts"`
|
} `json:"parts"`
|
||||||
} `json:"content"`
|
} `json:"content"`
|
||||||
FinishReason string `json:"finishReason"`
|
FinishReason string `json:"finishReason"`
|
||||||
@@ -212,14 +239,34 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
|||||||
}
|
}
|
||||||
|
|
||||||
content := ""
|
content := ""
|
||||||
for _, p := range geminiResp.Candidates[0].Content.Parts {
|
var toolCalls []models.ToolCall
|
||||||
content += p.Text
|
for _, part := range geminiResp.Candidates[0].Content.Parts {
|
||||||
|
if part.Text != "" {
|
||||||
|
content += part.Text
|
||||||
|
}
|
||||||
|
if part.FunctionCall != nil {
|
||||||
|
toolCalls = append(toolCalls, models.ToolCall{
|
||||||
|
ID: fmt.Sprintf("call_%s", part.FunctionCall.Name), // Gemini doesn't have call IDs
|
||||||
|
Type: "function",
|
||||||
|
Function: models.FunctionCall{
|
||||||
|
Name: part.FunctionCall.Name,
|
||||||
|
Arguments: string(part.FunctionCall.Args),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
finishReason := strings.ToLower(geminiResp.Candidates[0].FinishReason)
|
||||||
|
if finishReason == "stop" {
|
||||||
|
finishReason = "stop"
|
||||||
|
} else if len(toolCalls) > 0 {
|
||||||
|
finishReason = "tool_calls"
|
||||||
}
|
}
|
||||||
|
|
||||||
openAIResp := &models.ChatCompletionResponse{
|
openAIResp := &models.ChatCompletionResponse{
|
||||||
ID: "gemini-" + req.Model,
|
ID: "gemini-" + req.Model,
|
||||||
Object: "chat.completion",
|
Object: "chat.completion",
|
||||||
Created: 0, // Should be current timestamp
|
Created: 0,
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
Choices: []models.ChatChoice{
|
Choices: []models.ChatChoice{
|
||||||
{
|
{
|
||||||
@@ -227,8 +274,9 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
|||||||
Message: models.ChatMessage{
|
Message: models.ChatMessage{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: content,
|
Content: content,
|
||||||
|
ToolCalls: toolCalls,
|
||||||
},
|
},
|
||||||
FinishReason: &geminiResp.Candidates[0].FinishReason,
|
FinishReason: &finishReason,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Usage: &models.Usage{
|
Usage: &models.Usage{
|
||||||
@@ -248,12 +296,49 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
|||||||
role := "user"
|
role := "user"
|
||||||
if msg.Role == "assistant" {
|
if msg.Role == "assistant" {
|
||||||
role = "model"
|
role = "model"
|
||||||
|
} else if msg.Role == "tool" {
|
||||||
|
role = "function"
|
||||||
}
|
}
|
||||||
|
|
||||||
var parts []GeminiPart
|
var parts []GeminiPart
|
||||||
|
|
||||||
|
if msg.Role == "tool" {
|
||||||
|
text := ""
|
||||||
|
if len(msg.Content) > 0 {
|
||||||
|
text = msg.Content[0].Text
|
||||||
|
}
|
||||||
|
name := "unknown"
|
||||||
|
if msg.Name != nil {
|
||||||
|
name = *msg.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
var responseObj interface{}
|
||||||
|
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
|
||||||
|
responseObj = map[string]interface{}{"result": text}
|
||||||
|
}
|
||||||
|
respBytes, _ := json.Marshal(responseObj)
|
||||||
|
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
FunctionResponse: &GeminiFunctionResponse{
|
||||||
|
Name: name,
|
||||||
|
Response: json.RawMessage(respBytes),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
for _, p := range msg.Content {
|
for _, p := range msg.Content {
|
||||||
parts = append(parts, GeminiPart{Text: p.Text})
|
parts = append(parts, GeminiPart{Text: p.Text})
|
||||||
}
|
}
|
||||||
|
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
FunctionCall: &GeminiFunctionCall{
|
||||||
|
Name: tc.Function.Name,
|
||||||
|
Args: json.RawMessage(tc.Function.Arguments),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
contents = append(contents, GeminiContent{
|
contents = append(contents, GeminiContent{
|
||||||
Role: role,
|
Role: role,
|
||||||
@@ -287,6 +372,16 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
|||||||
GenerationConfig: genConfig,
|
GenerationConfig: genConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(req.Tools) > 0 {
|
||||||
|
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
|
||||||
|
for _, t := range req.Tools {
|
||||||
|
if t.Type == "function" {
|
||||||
|
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
body.Tools = []GeminiTool{geminiTool}
|
||||||
|
}
|
||||||
|
|
||||||
baseURL := p.config.BaseURL
|
baseURL := p.config.BaseURL
|
||||||
lowerModel := strings.ToLower(req.Model)
|
lowerModel := strings.ToLower(req.Model)
|
||||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
|
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
|
||||||
|
|||||||
Reference in New Issue
Block a user