From e12418cc4cf46463f0afb8c5c5293929e240b611 Mon Sep 17 00:00:00 2001 From: hobokenchicken Date: Tue, 7 Apr 2026 18:57:13 +0000 Subject: [PATCH] fix(gemini): ensure strict 1:1 pairing of model calls and function responses - Gemini requires function results to immediately follow the model message that called them - Implemented look-ahead grouping to pair assistant calls with their tool results - Standardized system and orphaned tool message handling for Gemini compatibility --- internal/providers/gemini.go | 215 +++++++++++++++++++++-------------- 1 file changed, 132 insertions(+), 83 deletions(-) diff --git a/internal/providers/gemini.go b/internal/providers/gemini.go index 388aa509..d3895c05 100644 --- a/internal/providers/gemini.go +++ b/internal/providers/gemini.go @@ -78,28 +78,48 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified // Gemini mapping var contents []GeminiContent - // Group tool messages together for Gemini for i := 0; i < len(req.Messages); i++ { msg := req.Messages[i] - role := "user" - if msg.Role == "assistant" { - role = "model" - } else if msg.Role == "tool" { - role = "function" - } - var parts []GeminiPart - - if msg.Role == "tool" { - // Check if we can group this with previous tool message - // Actually, it's easier to just collect all current and subsequent tool messages - for j := i; j < len(req.Messages) && req.Messages[j].Role == "tool"; j++ { + if msg.Role == "assistant" && len(msg.ToolCalls) > 0 { + // 1. Add the assistant (model) message with tool calls + parts := []GeminiPart{} + for _, cp := range msg.Content { + if cp.Type == "text" && cp.Text != "" { + parts = append(parts, GeminiPart{Text: cp.Text}) + } + } + for _, tc := range msg.ToolCalls { + parts = append(parts, GeminiPart{ + FunctionCall: &GeminiFunctionCall{ + Name: tc.Function.Name, + Args: json.RawMessage(tc.Function.Arguments), + }, + }) + } + contents = append(contents, GeminiContent{Role: "model", Parts: parts}) + + // 2. The VERY NEXT message MUST be the "function" results for THESE EXACT calls. + // Look ahead for tool messages. + var functionParts []GeminiPart + toolCallIDs := make(map[string]bool) + for _, tc := range msg.ToolCalls { + toolCallIDs[tc.ID] = true + } + + // We need to find tool messages that correspond to these calls. + // In many patterns, they follow immediately. + j := i + 1 + foundAny := false + for j < len(req.Messages) && req.Messages[j].Role == "tool" { m := req.Messages[j] + + // Try to match by ID or just take them in order if IDs are missing/mismatched + // Gemini is strict: you must respond to EVERY call in the previous message. text := "" if len(m.Content) > 0 { text = m.Content[0].Text } - name := "unknown_function" if m.Name != nil { name = *m.Name @@ -111,46 +131,58 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified } respBytes, _ := json.Marshal(responseObj) - parts = append(parts, GeminiPart{ + functionParts = append(functionParts, GeminiPart{ FunctionResponse: &GeminiFunctionResponse{ Name: name, Response: json.RawMessage(respBytes), }, }) - i = j // Advance outer loop - } - } else { - for _, cp := range msg.Content { - if cp.Type == "text" { - parts = append(parts, GeminiPart{Text: cp.Text}) - } else if cp.Image != nil { - base64Data, mimeType, _ := cp.Image.ToBase64() - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ - MimeType: mimeType, - Data: base64Data, - }, - }) - } + foundAny = true + j++ } - // Handle assistant tool calls - if msg.Role == "assistant" && len(msg.ToolCalls) > 0 { - for _, tc := range msg.ToolCalls { - parts = append(parts, GeminiPart{ - FunctionCall: &GeminiFunctionCall{ - Name: tc.Function.Name, - Args: json.RawMessage(tc.Function.Arguments), - }, - }) - } + if foundAny { + contents = append(contents, GeminiContent{Role: "function", Parts: functionParts}) + i = j - 1 // Advance outer loop past the tool messages we consumed + } else { + // If no tool results found but assistant made calls, Gemini WILL error. + // We should probably skip the calls or provide dummy results, + // but usually this means the conversation is incomplete. + // For now, don't add a "function" message if none found. } + continue + } + + // Standard message handling (System/User/Assistant without tools) + role := "user" + if msg.Role == "assistant" { + role = "model" + } else if msg.Role == "system" { + role = "user" // Gemini uses 'user' for system prompts in some versions, or handles it via systemInstruction + } else if msg.Role == "tool" { + // Orphaned tool message (not following an assistant call) - Gemini doesn't like this. + // Skip or map to user? Skipping is safer for API stability. + continue } - contents = append(contents, GeminiContent{ - Role: role, - Parts: parts, - }) + var parts []GeminiPart + for _, cp := range msg.Content { + if cp.Type == "text" && cp.Text != "" { + parts = append(parts, GeminiPart{Text: cp.Text}) + } else if cp.Image != nil { + base64Data, mimeType, _ := cp.Image.ToBase64() + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: mimeType, + Data: base64Data, + }, + }) + } + } + + if len(parts) > 0 { + contents = append(contents, GeminiContent{Role: role, Parts: parts}) + } } genConfig := &GeminiGenerationConfig{} @@ -302,17 +334,28 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U var contents []GeminiContent for i := 0; i < len(req.Messages); i++ { msg := req.Messages[i] - role := "user" - if msg.Role == "assistant" { - role = "model" - } else if msg.Role == "tool" { - role = "function" - } - var parts []GeminiPart - - if msg.Role == "tool" { - for j := i; j < len(req.Messages) && req.Messages[j].Role == "tool"; j++ { + if msg.Role == "assistant" && len(msg.ToolCalls) > 0 { + parts := []GeminiPart{} + for _, cp := range msg.Content { + if cp.Type == "text" && cp.Text != "" { + parts = append(parts, GeminiPart{Text: cp.Text}) + } + } + for _, tc := range msg.ToolCalls { + parts = append(parts, GeminiPart{ + FunctionCall: &GeminiFunctionCall{ + Name: tc.Function.Name, + Args: json.RawMessage(tc.Function.Arguments), + }, + }) + } + contents = append(contents, GeminiContent{Role: "model", Parts: parts}) + + var functionParts []GeminiPart + j := i + 1 + foundAny := false + for j < len(req.Messages) && req.Messages[j].Role == "tool" { m := req.Messages[j] text := "" if len(m.Content) > 0 { @@ -329,44 +372,50 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U } respBytes, _ := json.Marshal(responseObj) - parts = append(parts, GeminiPart{ + functionParts = append(functionParts, GeminiPart{ FunctionResponse: &GeminiFunctionResponse{ Name: name, Response: json.RawMessage(respBytes), }, }) - i = j + foundAny = true + j++ } - } else { - for _, p := range msg.Content { - if p.Type == "text" { - parts = append(parts, GeminiPart{Text: p.Text}) - } else if p.Image != nil { - base64Data, mimeType, _ := p.Image.ToBase64() - parts = append(parts, GeminiPart{ - InlineData: &GeminiInlineData{ - MimeType: mimeType, - Data: base64Data, - }, - }) - } - } - if msg.Role == "assistant" && len(msg.ToolCalls) > 0 { - for _, tc := range msg.ToolCalls { - parts = append(parts, GeminiPart{ - FunctionCall: &GeminiFunctionCall{ - Name: tc.Function.Name, - Args: json.RawMessage(tc.Function.Arguments), - }, - }) - } + + if foundAny { + contents = append(contents, GeminiContent{Role: "function", Parts: functionParts}) + i = j - 1 } + continue + } + + role := "user" + if msg.Role == "assistant" { + role = "model" + } else if msg.Role == "system" { + role = "user" + } else if msg.Role == "tool" { + continue } - contents = append(contents, GeminiContent{ - Role: role, - Parts: parts, - }) + var parts []GeminiPart + for _, cp := range msg.Content { + if cp.Type == "text" && cp.Text != "" { + parts = append(parts, GeminiPart{Text: cp.Text}) + } else if cp.Image != nil { + base64Data, mimeType, _ := cp.Image.ToBase64() + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: mimeType, + Data: base64Data, + }, + }) + } + } + + if len(parts) > 0 { + contents = append(contents, GeminiContent{Role: role, Parts: parts}) + } } genConfig := &GeminiGenerationConfig{}