fix(gemini): group adjacent tool messages and ensure correct role sequence
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled

- Group consecutive 'tool' messages into a single Gemini content message with multiple 'functionResponse' parts
- Ensure assistant tool calls are properly mapped and sent
- Maintain v1beta for preview and newer models
- Added debug logging for API errors
This commit is contained in:
2026-04-07 18:50:48 +00:00
parent e67aafdac1
commit be4ec3482a
+40 -17
View File
@@ -77,36 +77,38 @@ type GeminiFunctionResponse struct {
func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) { func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
// Gemini mapping // Gemini mapping
var contents []GeminiContent var contents []GeminiContent
for _, msg := range req.Messages {
// Group tool messages together for Gemini
for i := 0; i < len(req.Messages); i++ {
msg := req.Messages[i]
role := "user" role := "user"
if msg.Role == "assistant" { if msg.Role == "assistant" {
role = "model" role = "model"
} else if msg.Role == "tool" { } else if msg.Role == "tool" {
role = "function" // Function results use 'function' role in Gemini contents role = "function"
} }
var parts []GeminiPart var parts []GeminiPart
// Handle tool responses
if msg.Role == "tool" { if msg.Role == "tool" {
// Check if we can group this with previous tool message
// Actually, it's easier to just collect all current and subsequent tool messages
for j := i; j < len(req.Messages) && req.Messages[j].Role == "tool"; j++ {
m := req.Messages[j]
text := "" text := ""
if len(msg.Content) > 0 { if len(m.Content) > 0 {
text = msg.Content[0].Text text = m.Content[0].Text
} }
// Gemini expects functionResponse to be an object
name := "unknown_function" name := "unknown_function"
if msg.Name != nil { if m.Name != nil {
name = *msg.Name name = *m.Name
} }
// Try to parse text as JSON if it looks like it, Gemini expects an object
var responseObj interface{} var responseObj interface{}
if err := json.Unmarshal([]byte(text), &responseObj); err != nil { if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
// If not valid JSON, wrap it in an object
responseObj = map[string]interface{}{"result": text} responseObj = map[string]interface{}{"result": text}
} }
respBytes, _ := json.Marshal(responseObj) respBytes, _ := json.Marshal(responseObj)
parts = append(parts, GeminiPart{ parts = append(parts, GeminiPart{
@@ -115,6 +117,8 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
Response: json.RawMessage(respBytes), Response: json.RawMessage(respBytes),
}, },
}) })
i = j // Advance outer loop
}
} else { } else {
for _, cp := range msg.Content { for _, cp := range msg.Content {
if cp.Type == "text" { if cp.Type == "text" {
@@ -208,6 +212,10 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
} }
if !resp.IsSuccess() { if !resp.IsSuccess() {
fmt.Printf("[Gemini] API Error %d: %s\n", resp.StatusCode(), resp.String())
// Also log the request body for debugging (careful with API keys if logged elsewhere)
reqJSON, _ := json.Marshal(body)
fmt.Printf("[Gemini] Request Body: %s\n", string(reqJSON))
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String()) return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
} }
@@ -292,7 +300,8 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) { func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
// Simplified Gemini mapping // Simplified Gemini mapping
var contents []GeminiContent var contents []GeminiContent
for _, msg := range req.Messages { for i := 0; i < len(req.Messages); i++ {
msg := req.Messages[i]
role := "user" role := "user"
if msg.Role == "assistant" { if msg.Role == "assistant" {
role = "model" role = "model"
@@ -303,13 +312,15 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
var parts []GeminiPart var parts []GeminiPart
if msg.Role == "tool" { if msg.Role == "tool" {
for j := i; j < len(req.Messages) && req.Messages[j].Role == "tool"; j++ {
m := req.Messages[j]
text := "" text := ""
if len(msg.Content) > 0 { if len(m.Content) > 0 {
text = msg.Content[0].Text text = m.Content[0].Text
} }
name := "unknown" name := "unknown_function"
if msg.Name != nil { if m.Name != nil {
name = *msg.Name name = *m.Name
} }
var responseObj interface{} var responseObj interface{}
@@ -324,9 +335,21 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
Response: json.RawMessage(respBytes), Response: json.RawMessage(respBytes),
}, },
}) })
i = j
}
} else { } else {
for _, p := range msg.Content { for _, p := range msg.Content {
if p.Type == "text" {
parts = append(parts, GeminiPart{Text: p.Text}) parts = append(parts, GeminiPart{Text: p.Text})
} else if p.Image != nil {
base64Data, mimeType, _ := p.Image.ToBase64()
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
},
})
}
} }
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 { if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
for _, tc := range msg.ToolCalls { for _, tc := range msg.ToolCalls {