fix(gemini): ensure strict 1:1 pairing of model calls and function responses
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled

- Gemini requires function results to immediately follow the model message that called them
- Implemented look-ahead grouping to pair assistant calls with their tool results
- Standardized system and orphaned tool message handling for Gemini compatibility
This commit is contained in:
2026-04-07 18:57:13 +00:00
parent be4ec3482a
commit e12418cc4c
+113 -64
View File
@@ -78,28 +78,48 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
// Gemini mapping // Gemini mapping
var contents []GeminiContent var contents []GeminiContent
// Group tool messages together for Gemini
for i := 0; i < len(req.Messages); i++ { for i := 0; i < len(req.Messages); i++ {
msg := req.Messages[i] msg := req.Messages[i]
role := "user"
if msg.Role == "assistant" { if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
role = "model" // 1. Add the assistant (model) message with tool calls
} else if msg.Role == "tool" { parts := []GeminiPart{}
role = "function" for _, cp := range msg.Content {
if cp.Type == "text" && cp.Text != "" {
parts = append(parts, GeminiPart{Text: cp.Text})
}
}
for _, tc := range msg.ToolCalls {
parts = append(parts, GeminiPart{
FunctionCall: &GeminiFunctionCall{
Name: tc.Function.Name,
Args: json.RawMessage(tc.Function.Arguments),
},
})
}
contents = append(contents, GeminiContent{Role: "model", Parts: parts})
// 2. The VERY NEXT message MUST be the "function" results for THESE EXACT calls.
// Look ahead for tool messages.
var functionParts []GeminiPart
toolCallIDs := make(map[string]bool)
for _, tc := range msg.ToolCalls {
toolCallIDs[tc.ID] = true
} }
var parts []GeminiPart // We need to find tool messages that correspond to these calls.
// In many patterns, they follow immediately.
if msg.Role == "tool" { j := i + 1
// Check if we can group this with previous tool message foundAny := false
// Actually, it's easier to just collect all current and subsequent tool messages for j < len(req.Messages) && req.Messages[j].Role == "tool" {
for j := i; j < len(req.Messages) && req.Messages[j].Role == "tool"; j++ {
m := req.Messages[j] m := req.Messages[j]
// Try to match by ID or just take them in order if IDs are missing/mismatched
// Gemini is strict: you must respond to EVERY call in the previous message.
text := "" text := ""
if len(m.Content) > 0 { if len(m.Content) > 0 {
text = m.Content[0].Text text = m.Content[0].Text
} }
name := "unknown_function" name := "unknown_function"
if m.Name != nil { if m.Name != nil {
name = *m.Name name = *m.Name
@@ -111,17 +131,43 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
} }
respBytes, _ := json.Marshal(responseObj) respBytes, _ := json.Marshal(responseObj)
parts = append(parts, GeminiPart{ functionParts = append(functionParts, GeminiPart{
FunctionResponse: &GeminiFunctionResponse{ FunctionResponse: &GeminiFunctionResponse{
Name: name, Name: name,
Response: json.RawMessage(respBytes), Response: json.RawMessage(respBytes),
}, },
}) })
i = j // Advance outer loop foundAny = true
j++
} }
if foundAny {
contents = append(contents, GeminiContent{Role: "function", Parts: functionParts})
i = j - 1 // Advance outer loop past the tool messages we consumed
} else { } else {
// If no tool results found but assistant made calls, Gemini WILL error.
// We should probably skip the calls or provide dummy results,
// but usually this means the conversation is incomplete.
// For now, don't add a "function" message if none found.
}
continue
}
// Standard message handling (System/User/Assistant without tools)
role := "user"
if msg.Role == "assistant" {
role = "model"
} else if msg.Role == "system" {
role = "user" // Gemini uses 'user' for system prompts in some versions, or handles it via systemInstruction
} else if msg.Role == "tool" {
// Orphaned tool message (not following an assistant call) - Gemini doesn't like this.
// Skip or map to user? Skipping is safer for API stability.
continue
}
var parts []GeminiPart
for _, cp := range msg.Content { for _, cp := range msg.Content {
if cp.Type == "text" { if cp.Type == "text" && cp.Text != "" {
parts = append(parts, GeminiPart{Text: cp.Text}) parts = append(parts, GeminiPart{Text: cp.Text})
} else if cp.Image != nil { } else if cp.Image != nil {
base64Data, mimeType, _ := cp.Image.ToBase64() base64Data, mimeType, _ := cp.Image.ToBase64()
@@ -134,24 +180,10 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
} }
} }
// Handle assistant tool calls if len(parts) > 0 {
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 { contents = append(contents, GeminiContent{Role: role, Parts: parts})
for _, tc := range msg.ToolCalls {
parts = append(parts, GeminiPart{
FunctionCall: &GeminiFunctionCall{
Name: tc.Function.Name,
Args: json.RawMessage(tc.Function.Arguments),
},
})
} }
} }
}
contents = append(contents, GeminiContent{
Role: role,
Parts: parts,
})
}
genConfig := &GeminiGenerationConfig{} genConfig := &GeminiGenerationConfig{}
if req.Temperature != nil { if req.Temperature != nil {
@@ -302,17 +334,28 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
var contents []GeminiContent var contents []GeminiContent
for i := 0; i < len(req.Messages); i++ { for i := 0; i < len(req.Messages); i++ {
msg := req.Messages[i] msg := req.Messages[i]
role := "user"
if msg.Role == "assistant" { if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
role = "model" parts := []GeminiPart{}
} else if msg.Role == "tool" { for _, cp := range msg.Content {
role = "function" if cp.Type == "text" && cp.Text != "" {
parts = append(parts, GeminiPart{Text: cp.Text})
} }
}
for _, tc := range msg.ToolCalls {
parts = append(parts, GeminiPart{
FunctionCall: &GeminiFunctionCall{
Name: tc.Function.Name,
Args: json.RawMessage(tc.Function.Arguments),
},
})
}
contents = append(contents, GeminiContent{Role: "model", Parts: parts})
var parts []GeminiPart var functionParts []GeminiPart
j := i + 1
if msg.Role == "tool" { foundAny := false
for j := i; j < len(req.Messages) && req.Messages[j].Role == "tool"; j++ { for j < len(req.Messages) && req.Messages[j].Role == "tool" {
m := req.Messages[j] m := req.Messages[j]
text := "" text := ""
if len(m.Content) > 0 { if len(m.Content) > 0 {
@@ -329,20 +372,38 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
} }
respBytes, _ := json.Marshal(responseObj) respBytes, _ := json.Marshal(responseObj)
parts = append(parts, GeminiPart{ functionParts = append(functionParts, GeminiPart{
FunctionResponse: &GeminiFunctionResponse{ FunctionResponse: &GeminiFunctionResponse{
Name: name, Name: name,
Response: json.RawMessage(respBytes), Response: json.RawMessage(respBytes),
}, },
}) })
i = j foundAny = true
j++
} }
} else {
for _, p := range msg.Content { if foundAny {
if p.Type == "text" { contents = append(contents, GeminiContent{Role: "function", Parts: functionParts})
parts = append(parts, GeminiPart{Text: p.Text}) i = j - 1
} else if p.Image != nil { }
base64Data, mimeType, _ := p.Image.ToBase64() continue
}
role := "user"
if msg.Role == "assistant" {
role = "model"
} else if msg.Role == "system" {
role = "user"
} else if msg.Role == "tool" {
continue
}
var parts []GeminiPart
for _, cp := range msg.Content {
if cp.Type == "text" && cp.Text != "" {
parts = append(parts, GeminiPart{Text: cp.Text})
} else if cp.Image != nil {
base64Data, mimeType, _ := cp.Image.ToBase64()
parts = append(parts, GeminiPart{ parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{ InlineData: &GeminiInlineData{
MimeType: mimeType, MimeType: mimeType,
@@ -351,22 +412,10 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
}) })
} }
} }
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
for _, tc := range msg.ToolCalls {
parts = append(parts, GeminiPart{
FunctionCall: &GeminiFunctionCall{
Name: tc.Function.Name,
Args: json.RawMessage(tc.Function.Arguments),
},
})
}
}
}
contents = append(contents, GeminiContent{ if len(parts) > 0 {
Role: role, contents = append(contents, GeminiContent{Role: role, Parts: parts})
Parts: parts, }
})
} }
genConfig := &GeminiGenerationConfig{} genConfig := &GeminiGenerationConfig{}