fix(gemini): group adjacent tool messages and ensure correct role sequence
- Group consecutive 'tool' messages into a single Gemini content message with multiple 'functionResponse' parts - Ensure assistant tool calls are properly mapped and sent - Maintain v1beta for preview and newer models - Added debug logging for API errors
This commit is contained in:
@@ -77,44 +77,48 @@ type GeminiFunctionResponse struct {
|
|||||||
func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||||
// Gemini mapping
|
// Gemini mapping
|
||||||
var contents []GeminiContent
|
var contents []GeminiContent
|
||||||
for _, msg := range req.Messages {
|
|
||||||
|
// Group tool messages together for Gemini
|
||||||
|
for i := 0; i < len(req.Messages); i++ {
|
||||||
|
msg := req.Messages[i]
|
||||||
role := "user"
|
role := "user"
|
||||||
if msg.Role == "assistant" {
|
if msg.Role == "assistant" {
|
||||||
role = "model"
|
role = "model"
|
||||||
} else if msg.Role == "tool" {
|
} else if msg.Role == "tool" {
|
||||||
role = "function" // Function results use 'function' role in Gemini contents
|
role = "function"
|
||||||
}
|
}
|
||||||
|
|
||||||
var parts []GeminiPart
|
var parts []GeminiPart
|
||||||
|
|
||||||
// Handle tool responses
|
|
||||||
if msg.Role == "tool" {
|
if msg.Role == "tool" {
|
||||||
text := ""
|
// Check if we can group this with previous tool message
|
||||||
if len(msg.Content) > 0 {
|
// Actually, it's easier to just collect all current and subsequent tool messages
|
||||||
text = msg.Content[0].Text
|
for j := i; j < len(req.Messages) && req.Messages[j].Role == "tool"; j++ {
|
||||||
|
m := req.Messages[j]
|
||||||
|
text := ""
|
||||||
|
if len(m.Content) > 0 {
|
||||||
|
text = m.Content[0].Text
|
||||||
|
}
|
||||||
|
|
||||||
|
name := "unknown_function"
|
||||||
|
if m.Name != nil {
|
||||||
|
name = *m.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
var responseObj interface{}
|
||||||
|
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
|
||||||
|
responseObj = map[string]interface{}{"result": text}
|
||||||
|
}
|
||||||
|
respBytes, _ := json.Marshal(responseObj)
|
||||||
|
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
FunctionResponse: &GeminiFunctionResponse{
|
||||||
|
Name: name,
|
||||||
|
Response: json.RawMessage(respBytes),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
i = j // Advance outer loop
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gemini expects functionResponse to be an object
|
|
||||||
name := "unknown_function"
|
|
||||||
if msg.Name != nil {
|
|
||||||
name = *msg.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to parse text as JSON if it looks like it, Gemini expects an object
|
|
||||||
var responseObj interface{}
|
|
||||||
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
|
|
||||||
// If not valid JSON, wrap it in an object
|
|
||||||
responseObj = map[string]interface{}{"result": text}
|
|
||||||
}
|
|
||||||
|
|
||||||
respBytes, _ := json.Marshal(responseObj)
|
|
||||||
|
|
||||||
parts = append(parts, GeminiPart{
|
|
||||||
FunctionResponse: &GeminiFunctionResponse{
|
|
||||||
Name: name,
|
|
||||||
Response: json.RawMessage(respBytes),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
} else {
|
} else {
|
||||||
for _, cp := range msg.Content {
|
for _, cp := range msg.Content {
|
||||||
if cp.Type == "text" {
|
if cp.Type == "text" {
|
||||||
@@ -208,6 +212,10 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !resp.IsSuccess() {
|
if !resp.IsSuccess() {
|
||||||
|
fmt.Printf("[Gemini] API Error %d: %s\n", resp.StatusCode(), resp.String())
|
||||||
|
// Also log the request body for debugging (careful with API keys if logged elsewhere)
|
||||||
|
reqJSON, _ := json.Marshal(body)
|
||||||
|
fmt.Printf("[Gemini] Request Body: %s\n", string(reqJSON))
|
||||||
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -292,7 +300,8 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
|||||||
func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||||
// Simplified Gemini mapping
|
// Simplified Gemini mapping
|
||||||
var contents []GeminiContent
|
var contents []GeminiContent
|
||||||
for _, msg := range req.Messages {
|
for i := 0; i < len(req.Messages); i++ {
|
||||||
|
msg := req.Messages[i]
|
||||||
role := "user"
|
role := "user"
|
||||||
if msg.Role == "assistant" {
|
if msg.Role == "assistant" {
|
||||||
role = "model"
|
role = "model"
|
||||||
@@ -303,30 +312,44 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
|||||||
var parts []GeminiPart
|
var parts []GeminiPart
|
||||||
|
|
||||||
if msg.Role == "tool" {
|
if msg.Role == "tool" {
|
||||||
text := ""
|
for j := i; j < len(req.Messages) && req.Messages[j].Role == "tool"; j++ {
|
||||||
if len(msg.Content) > 0 {
|
m := req.Messages[j]
|
||||||
text = msg.Content[0].Text
|
text := ""
|
||||||
}
|
if len(m.Content) > 0 {
|
||||||
name := "unknown"
|
text = m.Content[0].Text
|
||||||
if msg.Name != nil {
|
}
|
||||||
name = *msg.Name
|
name := "unknown_function"
|
||||||
}
|
if m.Name != nil {
|
||||||
|
name = *m.Name
|
||||||
|
}
|
||||||
|
|
||||||
var responseObj interface{}
|
var responseObj interface{}
|
||||||
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
|
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
|
||||||
responseObj = map[string]interface{}{"result": text}
|
responseObj = map[string]interface{}{"result": text}
|
||||||
}
|
}
|
||||||
respBytes, _ := json.Marshal(responseObj)
|
respBytes, _ := json.Marshal(responseObj)
|
||||||
|
|
||||||
parts = append(parts, GeminiPart{
|
parts = append(parts, GeminiPart{
|
||||||
FunctionResponse: &GeminiFunctionResponse{
|
FunctionResponse: &GeminiFunctionResponse{
|
||||||
Name: name,
|
Name: name,
|
||||||
Response: json.RawMessage(respBytes),
|
Response: json.RawMessage(respBytes),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
i = j
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
for _, p := range msg.Content {
|
for _, p := range msg.Content {
|
||||||
parts = append(parts, GeminiPart{Text: p.Text})
|
if p.Type == "text" {
|
||||||
|
parts = append(parts, GeminiPart{Text: p.Text})
|
||||||
|
} else if p.Image != nil {
|
||||||
|
base64Data, mimeType, _ := p.Image.ToBase64()
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
InlineData: &GeminiInlineData{
|
||||||
|
MimeType: mimeType,
|
||||||
|
Data: base64Data,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||||
for _, tc := range msg.ToolCalls {
|
for _, tc := range msg.ToolCalls {
|
||||||
|
|||||||
Reference in New Issue
Block a user