fix(gemini): group adjacent tool messages and ensure correct role sequence
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled

- Group consecutive 'tool' messages into a single Gemini content message with multiple 'functionResponse' parts
- Ensure assistant tool calls are properly mapped and sent
- Maintain v1beta for preview and newer models
- Added debug logging for API errors
This commit is contained in:
2026-04-07 18:50:48 +00:00
parent e67aafdac1
commit be4ec3482a
+73 -50
View File
@@ -77,44 +77,48 @@ type GeminiFunctionResponse struct {
func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) { func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
// Gemini mapping // Gemini mapping
var contents []GeminiContent var contents []GeminiContent
for _, msg := range req.Messages {
// Group tool messages together for Gemini
for i := 0; i < len(req.Messages); i++ {
msg := req.Messages[i]
role := "user" role := "user"
if msg.Role == "assistant" { if msg.Role == "assistant" {
role = "model" role = "model"
} else if msg.Role == "tool" { } else if msg.Role == "tool" {
role = "function" // Function results use 'function' role in Gemini contents role = "function"
} }
var parts []GeminiPart var parts []GeminiPart
// Handle tool responses
if msg.Role == "tool" { if msg.Role == "tool" {
text := "" // Check if we can group this with previous tool message
if len(msg.Content) > 0 { // Actually, it's easier to just collect all current and subsequent tool messages
text = msg.Content[0].Text for j := i; j < len(req.Messages) && req.Messages[j].Role == "tool"; j++ {
} m := req.Messages[j]
text := ""
// Gemini expects functionResponse to be an object if len(m.Content) > 0 {
name := "unknown_function" text = m.Content[0].Text
if msg.Name != nil { }
name = *msg.Name
} name := "unknown_function"
if m.Name != nil {
// Try to parse text as JSON if it looks like it, Gemini expects an object name = *m.Name
var responseObj interface{} }
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
// If not valid JSON, wrap it in an object var responseObj interface{}
responseObj = map[string]interface{}{"result": text} if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
} responseObj = map[string]interface{}{"result": text}
}
respBytes, _ := json.Marshal(responseObj)
respBytes, _ := json.Marshal(responseObj) parts = append(parts, GeminiPart{
FunctionResponse: &GeminiFunctionResponse{
parts = append(parts, GeminiPart{ Name: name,
FunctionResponse: &GeminiFunctionResponse{ Response: json.RawMessage(respBytes),
Name: name, },
Response: json.RawMessage(respBytes), })
}, i = j // Advance outer loop
}) }
} else { } else {
for _, cp := range msg.Content { for _, cp := range msg.Content {
if cp.Type == "text" { if cp.Type == "text" {
@@ -208,6 +212,10 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
} }
if !resp.IsSuccess() { if !resp.IsSuccess() {
fmt.Printf("[Gemini] API Error %d: %s\n", resp.StatusCode(), resp.String())
// Also log the request body for debugging (careful with API keys if logged elsewhere)
reqJSON, _ := json.Marshal(body)
fmt.Printf("[Gemini] Request Body: %s\n", string(reqJSON))
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String()) return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
} }
@@ -292,7 +300,8 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) { func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
// Simplified Gemini mapping // Simplified Gemini mapping
var contents []GeminiContent var contents []GeminiContent
for _, msg := range req.Messages { for i := 0; i < len(req.Messages); i++ {
msg := req.Messages[i]
role := "user" role := "user"
if msg.Role == "assistant" { if msg.Role == "assistant" {
role = "model" role = "model"
@@ -303,30 +312,44 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
var parts []GeminiPart var parts []GeminiPart
if msg.Role == "tool" { if msg.Role == "tool" {
text := "" for j := i; j < len(req.Messages) && req.Messages[j].Role == "tool"; j++ {
if len(msg.Content) > 0 { m := req.Messages[j]
text = msg.Content[0].Text text := ""
} if len(m.Content) > 0 {
name := "unknown" text = m.Content[0].Text
if msg.Name != nil { }
name = *msg.Name name := "unknown_function"
} if m.Name != nil {
name = *m.Name
var responseObj interface{} }
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
responseObj = map[string]interface{}{"result": text} var responseObj interface{}
} if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
respBytes, _ := json.Marshal(responseObj) responseObj = map[string]interface{}{"result": text}
}
respBytes, _ := json.Marshal(responseObj)
parts = append(parts, GeminiPart{ parts = append(parts, GeminiPart{
FunctionResponse: &GeminiFunctionResponse{ FunctionResponse: &GeminiFunctionResponse{
Name: name, Name: name,
Response: json.RawMessage(respBytes), Response: json.RawMessage(respBytes),
}, },
}) })
i = j
}
} else { } else {
for _, p := range msg.Content { for _, p := range msg.Content {
parts = append(parts, GeminiPart{Text: p.Text}) if p.Type == "text" {
parts = append(parts, GeminiPart{Text: p.Text})
} else if p.Image != nil {
base64Data, mimeType, _ := p.Image.ToBase64()
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
},
})
}
} }
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 { if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
for _, tc := range msg.ToolCalls { for _, tc := range msg.ToolCalls {