fix(gemini): ensure strict 1:1 pairing of model calls and function responses
- Gemini requires function results to immediately follow the model message that called them - Implemented look-ahead grouping to pair assistant calls with their tool results - Standardized system and orphaned tool message handling for Gemini compatibility
This commit is contained in:
+132
-83
@@ -78,28 +78,48 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
|||||||
// Gemini mapping
|
// Gemini mapping
|
||||||
var contents []GeminiContent
|
var contents []GeminiContent
|
||||||
|
|
||||||
// Group tool messages together for Gemini
|
|
||||||
for i := 0; i < len(req.Messages); i++ {
|
for i := 0; i < len(req.Messages); i++ {
|
||||||
msg := req.Messages[i]
|
msg := req.Messages[i]
|
||||||
role := "user"
|
|
||||||
if msg.Role == "assistant" {
|
|
||||||
role = "model"
|
|
||||||
} else if msg.Role == "tool" {
|
|
||||||
role = "function"
|
|
||||||
}
|
|
||||||
|
|
||||||
var parts []GeminiPart
|
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||||
|
// 1. Add the assistant (model) message with tool calls
|
||||||
if msg.Role == "tool" {
|
parts := []GeminiPart{}
|
||||||
// Check if we can group this with previous tool message
|
for _, cp := range msg.Content {
|
||||||
// Actually, it's easier to just collect all current and subsequent tool messages
|
if cp.Type == "text" && cp.Text != "" {
|
||||||
for j := i; j < len(req.Messages) && req.Messages[j].Role == "tool"; j++ {
|
parts = append(parts, GeminiPart{Text: cp.Text})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
FunctionCall: &GeminiFunctionCall{
|
||||||
|
Name: tc.Function.Name,
|
||||||
|
Args: json.RawMessage(tc.Function.Arguments),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
contents = append(contents, GeminiContent{Role: "model", Parts: parts})
|
||||||
|
|
||||||
|
// 2. The VERY NEXT message MUST be the "function" results for THESE EXACT calls.
|
||||||
|
// Look ahead for tool messages.
|
||||||
|
var functionParts []GeminiPart
|
||||||
|
toolCallIDs := make(map[string]bool)
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
toolCallIDs[tc.ID] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// We need to find tool messages that correspond to these calls.
|
||||||
|
// In many patterns, they follow immediately.
|
||||||
|
j := i + 1
|
||||||
|
foundAny := false
|
||||||
|
for j < len(req.Messages) && req.Messages[j].Role == "tool" {
|
||||||
m := req.Messages[j]
|
m := req.Messages[j]
|
||||||
|
|
||||||
|
// Try to match by ID or just take them in order if IDs are missing/mismatched
|
||||||
|
// Gemini is strict: you must respond to EVERY call in the previous message.
|
||||||
text := ""
|
text := ""
|
||||||
if len(m.Content) > 0 {
|
if len(m.Content) > 0 {
|
||||||
text = m.Content[0].Text
|
text = m.Content[0].Text
|
||||||
}
|
}
|
||||||
|
|
||||||
name := "unknown_function"
|
name := "unknown_function"
|
||||||
if m.Name != nil {
|
if m.Name != nil {
|
||||||
name = *m.Name
|
name = *m.Name
|
||||||
@@ -111,46 +131,58 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
|||||||
}
|
}
|
||||||
respBytes, _ := json.Marshal(responseObj)
|
respBytes, _ := json.Marshal(responseObj)
|
||||||
|
|
||||||
parts = append(parts, GeminiPart{
|
functionParts = append(functionParts, GeminiPart{
|
||||||
FunctionResponse: &GeminiFunctionResponse{
|
FunctionResponse: &GeminiFunctionResponse{
|
||||||
Name: name,
|
Name: name,
|
||||||
Response: json.RawMessage(respBytes),
|
Response: json.RawMessage(respBytes),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
i = j // Advance outer loop
|
foundAny = true
|
||||||
}
|
j++
|
||||||
} else {
|
|
||||||
for _, cp := range msg.Content {
|
|
||||||
if cp.Type == "text" {
|
|
||||||
parts = append(parts, GeminiPart{Text: cp.Text})
|
|
||||||
} else if cp.Image != nil {
|
|
||||||
base64Data, mimeType, _ := cp.Image.ToBase64()
|
|
||||||
parts = append(parts, GeminiPart{
|
|
||||||
InlineData: &GeminiInlineData{
|
|
||||||
MimeType: mimeType,
|
|
||||||
Data: base64Data,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle assistant tool calls
|
if foundAny {
|
||||||
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
contents = append(contents, GeminiContent{Role: "function", Parts: functionParts})
|
||||||
for _, tc := range msg.ToolCalls {
|
i = j - 1 // Advance outer loop past the tool messages we consumed
|
||||||
parts = append(parts, GeminiPart{
|
} else {
|
||||||
FunctionCall: &GeminiFunctionCall{
|
// If no tool results found but assistant made calls, Gemini WILL error.
|
||||||
Name: tc.Function.Name,
|
// We should probably skip the calls or provide dummy results,
|
||||||
Args: json.RawMessage(tc.Function.Arguments),
|
// but usually this means the conversation is incomplete.
|
||||||
},
|
// For now, don't add a "function" message if none found.
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Standard message handling (System/User/Assistant without tools)
|
||||||
|
role := "user"
|
||||||
|
if msg.Role == "assistant" {
|
||||||
|
role = "model"
|
||||||
|
} else if msg.Role == "system" {
|
||||||
|
role = "user" // Gemini uses 'user' for system prompts in some versions, or handles it via systemInstruction
|
||||||
|
} else if msg.Role == "tool" {
|
||||||
|
// Orphaned tool message (not following an assistant call) - Gemini doesn't like this.
|
||||||
|
// Skip or map to user? Skipping is safer for API stability.
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
contents = append(contents, GeminiContent{
|
var parts []GeminiPart
|
||||||
Role: role,
|
for _, cp := range msg.Content {
|
||||||
Parts: parts,
|
if cp.Type == "text" && cp.Text != "" {
|
||||||
})
|
parts = append(parts, GeminiPart{Text: cp.Text})
|
||||||
|
} else if cp.Image != nil {
|
||||||
|
base64Data, mimeType, _ := cp.Image.ToBase64()
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
InlineData: &GeminiInlineData{
|
||||||
|
MimeType: mimeType,
|
||||||
|
Data: base64Data,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parts) > 0 {
|
||||||
|
contents = append(contents, GeminiContent{Role: role, Parts: parts})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
genConfig := &GeminiGenerationConfig{}
|
genConfig := &GeminiGenerationConfig{}
|
||||||
@@ -302,17 +334,28 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
|||||||
var contents []GeminiContent
|
var contents []GeminiContent
|
||||||
for i := 0; i < len(req.Messages); i++ {
|
for i := 0; i < len(req.Messages); i++ {
|
||||||
msg := req.Messages[i]
|
msg := req.Messages[i]
|
||||||
role := "user"
|
|
||||||
if msg.Role == "assistant" {
|
|
||||||
role = "model"
|
|
||||||
} else if msg.Role == "tool" {
|
|
||||||
role = "function"
|
|
||||||
}
|
|
||||||
|
|
||||||
var parts []GeminiPart
|
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||||
|
parts := []GeminiPart{}
|
||||||
if msg.Role == "tool" {
|
for _, cp := range msg.Content {
|
||||||
for j := i; j < len(req.Messages) && req.Messages[j].Role == "tool"; j++ {
|
if cp.Type == "text" && cp.Text != "" {
|
||||||
|
parts = append(parts, GeminiPart{Text: cp.Text})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
FunctionCall: &GeminiFunctionCall{
|
||||||
|
Name: tc.Function.Name,
|
||||||
|
Args: json.RawMessage(tc.Function.Arguments),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
contents = append(contents, GeminiContent{Role: "model", Parts: parts})
|
||||||
|
|
||||||
|
var functionParts []GeminiPart
|
||||||
|
j := i + 1
|
||||||
|
foundAny := false
|
||||||
|
for j < len(req.Messages) && req.Messages[j].Role == "tool" {
|
||||||
m := req.Messages[j]
|
m := req.Messages[j]
|
||||||
text := ""
|
text := ""
|
||||||
if len(m.Content) > 0 {
|
if len(m.Content) > 0 {
|
||||||
@@ -329,44 +372,50 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
|||||||
}
|
}
|
||||||
respBytes, _ := json.Marshal(responseObj)
|
respBytes, _ := json.Marshal(responseObj)
|
||||||
|
|
||||||
parts = append(parts, GeminiPart{
|
functionParts = append(functionParts, GeminiPart{
|
||||||
FunctionResponse: &GeminiFunctionResponse{
|
FunctionResponse: &GeminiFunctionResponse{
|
||||||
Name: name,
|
Name: name,
|
||||||
Response: json.RawMessage(respBytes),
|
Response: json.RawMessage(respBytes),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
i = j
|
foundAny = true
|
||||||
|
j++
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
for _, p := range msg.Content {
|
if foundAny {
|
||||||
if p.Type == "text" {
|
contents = append(contents, GeminiContent{Role: "function", Parts: functionParts})
|
||||||
parts = append(parts, GeminiPart{Text: p.Text})
|
i = j - 1
|
||||||
} else if p.Image != nil {
|
|
||||||
base64Data, mimeType, _ := p.Image.ToBase64()
|
|
||||||
parts = append(parts, GeminiPart{
|
|
||||||
InlineData: &GeminiInlineData{
|
|
||||||
MimeType: mimeType,
|
|
||||||
Data: base64Data,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
|
||||||
for _, tc := range msg.ToolCalls {
|
|
||||||
parts = append(parts, GeminiPart{
|
|
||||||
FunctionCall: &GeminiFunctionCall{
|
|
||||||
Name: tc.Function.Name,
|
|
||||||
Args: json.RawMessage(tc.Function.Arguments),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
role := "user"
|
||||||
|
if msg.Role == "assistant" {
|
||||||
|
role = "model"
|
||||||
|
} else if msg.Role == "system" {
|
||||||
|
role = "user"
|
||||||
|
} else if msg.Role == "tool" {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
contents = append(contents, GeminiContent{
|
var parts []GeminiPart
|
||||||
Role: role,
|
for _, cp := range msg.Content {
|
||||||
Parts: parts,
|
if cp.Type == "text" && cp.Text != "" {
|
||||||
})
|
parts = append(parts, GeminiPart{Text: cp.Text})
|
||||||
|
} else if cp.Image != nil {
|
||||||
|
base64Data, mimeType, _ := cp.Image.ToBase64()
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
InlineData: &GeminiInlineData{
|
||||||
|
MimeType: mimeType,
|
||||||
|
Data: base64Data,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parts) > 0 {
|
||||||
|
contents = append(contents, GeminiContent{Role: role, Parts: parts})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
genConfig := &GeminiGenerationConfig{}
|
genConfig := &GeminiGenerationConfig{}
|
||||||
|
|||||||
Reference in New Issue
Block a user