e12418cc4c
- Gemini requires function results to immediately follow the model message that called them - Implemented look-ahead grouping to pair assistant calls with their tool results - Standardized system and orphaned tool message handling for Gemini compatibility
496 lines
14 KiB
Go
496 lines
14 KiB
Go
package providers
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"gophergate/internal/config"
|
|
"gophergate/internal/models"
|
|
"github.com/go-resty/resty/v2"
|
|
)
|
|
|
|
type GeminiProvider struct {
|
|
client *resty.Client
|
|
config config.GeminiConfig
|
|
apiKey string
|
|
}
|
|
|
|
func NewGeminiProvider(cfg config.GeminiConfig, apiKey string) *GeminiProvider {
|
|
return &GeminiProvider{
|
|
client: resty.New(),
|
|
config: cfg,
|
|
apiKey: apiKey,
|
|
}
|
|
}
|
|
|
|
func (p *GeminiProvider) Name() string {
|
|
return "gemini"
|
|
}
|
|
|
|
type GeminiRequest struct {
|
|
Contents []GeminiContent `json:"contents"`
|
|
Tools []GeminiTool `json:"tools,omitempty"`
|
|
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
|
|
}
|
|
|
|
type GeminiTool struct {
|
|
FunctionDeclarations []models.FunctionDef `json:"functionDeclarations"`
|
|
}
|
|
|
|
type GeminiGenerationConfig struct {
|
|
Temperature *float32 `json:"temperature,omitempty"`
|
|
TopP *float32 `json:"topP,omitempty"`
|
|
TopK *int `json:"topK,omitempty"`
|
|
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
|
StopSequences []string `json:"stopSequences,omitempty"`
|
|
}
|
|
|
|
type GeminiContent struct {
|
|
Role string `json:"role,omitempty"`
|
|
Parts []GeminiPart `json:"parts"`
|
|
}
|
|
|
|
type GeminiPart struct {
|
|
Text string `json:"text,omitempty"`
|
|
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
|
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
|
|
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
|
}
|
|
|
|
type GeminiInlineData struct {
|
|
MimeType string `json:"mimeType"`
|
|
Data string `json:"data"`
|
|
}
|
|
|
|
type GeminiFunctionCall struct {
|
|
Name string `json:"name"`
|
|
Args json.RawMessage `json:"args"`
|
|
}
|
|
|
|
type GeminiFunctionResponse struct {
|
|
Name string `json:"name"`
|
|
Response json.RawMessage `json:"response"`
|
|
}
|
|
|
|
func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
|
// Gemini mapping
|
|
var contents []GeminiContent
|
|
|
|
for i := 0; i < len(req.Messages); i++ {
|
|
msg := req.Messages[i]
|
|
|
|
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
|
// 1. Add the assistant (model) message with tool calls
|
|
parts := []GeminiPart{}
|
|
for _, cp := range msg.Content {
|
|
if cp.Type == "text" && cp.Text != "" {
|
|
parts = append(parts, GeminiPart{Text: cp.Text})
|
|
}
|
|
}
|
|
for _, tc := range msg.ToolCalls {
|
|
parts = append(parts, GeminiPart{
|
|
FunctionCall: &GeminiFunctionCall{
|
|
Name: tc.Function.Name,
|
|
Args: json.RawMessage(tc.Function.Arguments),
|
|
},
|
|
})
|
|
}
|
|
contents = append(contents, GeminiContent{Role: "model", Parts: parts})
|
|
|
|
// 2. The VERY NEXT message MUST be the "function" results for THESE EXACT calls.
|
|
// Look ahead for tool messages.
|
|
var functionParts []GeminiPart
|
|
toolCallIDs := make(map[string]bool)
|
|
for _, tc := range msg.ToolCalls {
|
|
toolCallIDs[tc.ID] = true
|
|
}
|
|
|
|
// We need to find tool messages that correspond to these calls.
|
|
// In many patterns, they follow immediately.
|
|
j := i + 1
|
|
foundAny := false
|
|
for j < len(req.Messages) && req.Messages[j].Role == "tool" {
|
|
m := req.Messages[j]
|
|
|
|
// Try to match by ID or just take them in order if IDs are missing/mismatched
|
|
// Gemini is strict: you must respond to EVERY call in the previous message.
|
|
text := ""
|
|
if len(m.Content) > 0 {
|
|
text = m.Content[0].Text
|
|
}
|
|
name := "unknown_function"
|
|
if m.Name != nil {
|
|
name = *m.Name
|
|
}
|
|
|
|
var responseObj interface{}
|
|
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
|
|
responseObj = map[string]interface{}{"result": text}
|
|
}
|
|
respBytes, _ := json.Marshal(responseObj)
|
|
|
|
functionParts = append(functionParts, GeminiPart{
|
|
FunctionResponse: &GeminiFunctionResponse{
|
|
Name: name,
|
|
Response: json.RawMessage(respBytes),
|
|
},
|
|
})
|
|
foundAny = true
|
|
j++
|
|
}
|
|
|
|
if foundAny {
|
|
contents = append(contents, GeminiContent{Role: "function", Parts: functionParts})
|
|
i = j - 1 // Advance outer loop past the tool messages we consumed
|
|
} else {
|
|
// If no tool results found but assistant made calls, Gemini WILL error.
|
|
// We should probably skip the calls or provide dummy results,
|
|
// but usually this means the conversation is incomplete.
|
|
// For now, don't add a "function" message if none found.
|
|
}
|
|
continue
|
|
}
|
|
|
|
// Standard message handling (System/User/Assistant without tools)
|
|
role := "user"
|
|
if msg.Role == "assistant" {
|
|
role = "model"
|
|
} else if msg.Role == "system" {
|
|
role = "user" // Gemini uses 'user' for system prompts in some versions, or handles it via systemInstruction
|
|
} else if msg.Role == "tool" {
|
|
// Orphaned tool message (not following an assistant call) - Gemini doesn't like this.
|
|
// Skip or map to user? Skipping is safer for API stability.
|
|
continue
|
|
}
|
|
|
|
var parts []GeminiPart
|
|
for _, cp := range msg.Content {
|
|
if cp.Type == "text" && cp.Text != "" {
|
|
parts = append(parts, GeminiPart{Text: cp.Text})
|
|
} else if cp.Image != nil {
|
|
base64Data, mimeType, _ := cp.Image.ToBase64()
|
|
parts = append(parts, GeminiPart{
|
|
InlineData: &GeminiInlineData{
|
|
MimeType: mimeType,
|
|
Data: base64Data,
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
if len(parts) > 0 {
|
|
contents = append(contents, GeminiContent{Role: role, Parts: parts})
|
|
}
|
|
}
|
|
|
|
genConfig := &GeminiGenerationConfig{}
|
|
if req.Temperature != nil {
|
|
t := float32(*req.Temperature)
|
|
genConfig.Temperature = &t
|
|
}
|
|
if req.TopP != nil {
|
|
tp := float32(*req.TopP)
|
|
genConfig.TopP = &tp
|
|
}
|
|
if req.TopK != nil {
|
|
tk := int(*req.TopK)
|
|
genConfig.TopK = &tk
|
|
}
|
|
if req.MaxTokens != nil {
|
|
mt := int(*req.MaxTokens)
|
|
genConfig.MaxOutputTokens = &mt
|
|
}
|
|
if len(req.Stop) > 0 {
|
|
genConfig.StopSequences = req.Stop
|
|
}
|
|
|
|
body := GeminiRequest{
|
|
Contents: contents,
|
|
GenerationConfig: genConfig,
|
|
}
|
|
|
|
// Map Tools
|
|
if len(req.Tools) > 0 {
|
|
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
|
|
for _, t := range req.Tools {
|
|
if t.Type == "function" {
|
|
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
|
|
}
|
|
}
|
|
body.Tools = []GeminiTool{geminiTool}
|
|
}
|
|
|
|
baseURL := p.config.BaseURL
|
|
lowerModel := strings.ToLower(req.Model)
|
|
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
|
|
// Use v1beta for preview and newer models
|
|
if !strings.Contains(baseURL, "v1beta") {
|
|
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
|
|
}
|
|
}
|
|
|
|
url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", baseURL, req.Model, p.apiKey)
|
|
fmt.Printf("[Gemini] POST %s\n", url)
|
|
|
|
resp, err := p.client.R().
|
|
SetContext(ctx).
|
|
SetBody(body).
|
|
Post(url)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("request failed: %w", err)
|
|
}
|
|
|
|
if !resp.IsSuccess() {
|
|
fmt.Printf("[Gemini] API Error %d: %s\n", resp.StatusCode(), resp.String())
|
|
// Also log the request body for debugging (careful with API keys if logged elsewhere)
|
|
reqJSON, _ := json.Marshal(body)
|
|
fmt.Printf("[Gemini] Request Body: %s\n", string(reqJSON))
|
|
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
|
}
|
|
|
|
// Parse Gemini response and convert to OpenAI format
|
|
var geminiResp struct {
|
|
Candidates []struct {
|
|
Content struct {
|
|
Role string `json:"role"`
|
|
Parts []struct {
|
|
Text string `json:"text"`
|
|
FunctionCall *GeminiFunctionCall `json:"functionCall"`
|
|
} `json:"parts"`
|
|
} `json:"content"`
|
|
FinishReason string `json:"finishReason"`
|
|
} `json:"candidates"`
|
|
UsageMetadata struct {
|
|
PromptTokenCount uint32 `json:"promptTokenCount"`
|
|
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
|
TotalTokenCount uint32 `json:"totalTokenCount"`
|
|
} `json:"usageMetadata"`
|
|
}
|
|
|
|
if err := json.Unmarshal(resp.Body(), &geminiResp); err != nil {
|
|
return nil, fmt.Errorf("failed to parse response: %w", err)
|
|
}
|
|
|
|
if len(geminiResp.Candidates) == 0 {
|
|
return nil, fmt.Errorf("no candidates in Gemini response")
|
|
}
|
|
|
|
content := ""
|
|
var toolCalls []models.ToolCall
|
|
for _, part := range geminiResp.Candidates[0].Content.Parts {
|
|
if part.Text != "" {
|
|
content += part.Text
|
|
}
|
|
if part.FunctionCall != nil {
|
|
toolCalls = append(toolCalls, models.ToolCall{
|
|
ID: fmt.Sprintf("call_%s", part.FunctionCall.Name), // Gemini doesn't have call IDs
|
|
Type: "function",
|
|
Function: models.FunctionCall{
|
|
Name: part.FunctionCall.Name,
|
|
Arguments: string(part.FunctionCall.Args),
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
finishReason := strings.ToLower(geminiResp.Candidates[0].FinishReason)
|
|
if finishReason == "stop" {
|
|
finishReason = "stop"
|
|
} else if len(toolCalls) > 0 {
|
|
finishReason = "tool_calls"
|
|
}
|
|
|
|
openAIResp := &models.ChatCompletionResponse{
|
|
ID: "gemini-" + req.Model,
|
|
Object: "chat.completion",
|
|
Created: 0,
|
|
Model: req.Model,
|
|
Choices: []models.ChatChoice{
|
|
{
|
|
Index: 0,
|
|
Message: models.ChatMessage{
|
|
Role: "assistant",
|
|
Content: content,
|
|
ToolCalls: toolCalls,
|
|
},
|
|
FinishReason: &finishReason,
|
|
},
|
|
},
|
|
Usage: &models.Usage{
|
|
PromptTokens: geminiResp.UsageMetadata.PromptTokenCount,
|
|
CompletionTokens: geminiResp.UsageMetadata.CandidatesTokenCount,
|
|
TotalTokens: geminiResp.UsageMetadata.TotalTokenCount,
|
|
},
|
|
}
|
|
|
|
return openAIResp, nil
|
|
}
|
|
|
|
func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
|
// Simplified Gemini mapping
|
|
var contents []GeminiContent
|
|
for i := 0; i < len(req.Messages); i++ {
|
|
msg := req.Messages[i]
|
|
|
|
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
|
parts := []GeminiPart{}
|
|
for _, cp := range msg.Content {
|
|
if cp.Type == "text" && cp.Text != "" {
|
|
parts = append(parts, GeminiPart{Text: cp.Text})
|
|
}
|
|
}
|
|
for _, tc := range msg.ToolCalls {
|
|
parts = append(parts, GeminiPart{
|
|
FunctionCall: &GeminiFunctionCall{
|
|
Name: tc.Function.Name,
|
|
Args: json.RawMessage(tc.Function.Arguments),
|
|
},
|
|
})
|
|
}
|
|
contents = append(contents, GeminiContent{Role: "model", Parts: parts})
|
|
|
|
var functionParts []GeminiPart
|
|
j := i + 1
|
|
foundAny := false
|
|
for j < len(req.Messages) && req.Messages[j].Role == "tool" {
|
|
m := req.Messages[j]
|
|
text := ""
|
|
if len(m.Content) > 0 {
|
|
text = m.Content[0].Text
|
|
}
|
|
name := "unknown_function"
|
|
if m.Name != nil {
|
|
name = *m.Name
|
|
}
|
|
|
|
var responseObj interface{}
|
|
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
|
|
responseObj = map[string]interface{}{"result": text}
|
|
}
|
|
respBytes, _ := json.Marshal(responseObj)
|
|
|
|
functionParts = append(functionParts, GeminiPart{
|
|
FunctionResponse: &GeminiFunctionResponse{
|
|
Name: name,
|
|
Response: json.RawMessage(respBytes),
|
|
},
|
|
})
|
|
foundAny = true
|
|
j++
|
|
}
|
|
|
|
if foundAny {
|
|
contents = append(contents, GeminiContent{Role: "function", Parts: functionParts})
|
|
i = j - 1
|
|
}
|
|
continue
|
|
}
|
|
|
|
role := "user"
|
|
if msg.Role == "assistant" {
|
|
role = "model"
|
|
} else if msg.Role == "system" {
|
|
role = "user"
|
|
} else if msg.Role == "tool" {
|
|
continue
|
|
}
|
|
|
|
var parts []GeminiPart
|
|
for _, cp := range msg.Content {
|
|
if cp.Type == "text" && cp.Text != "" {
|
|
parts = append(parts, GeminiPart{Text: cp.Text})
|
|
} else if cp.Image != nil {
|
|
base64Data, mimeType, _ := cp.Image.ToBase64()
|
|
parts = append(parts, GeminiPart{
|
|
InlineData: &GeminiInlineData{
|
|
MimeType: mimeType,
|
|
Data: base64Data,
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
if len(parts) > 0 {
|
|
contents = append(contents, GeminiContent{Role: role, Parts: parts})
|
|
}
|
|
}
|
|
|
|
genConfig := &GeminiGenerationConfig{}
|
|
if req.Temperature != nil {
|
|
t := float32(*req.Temperature)
|
|
genConfig.Temperature = &t
|
|
}
|
|
if req.TopP != nil {
|
|
tp := float32(*req.TopP)
|
|
genConfig.TopP = &tp
|
|
}
|
|
if req.TopK != nil {
|
|
tk := int(*req.TopK)
|
|
genConfig.TopK = &tk
|
|
}
|
|
if req.MaxTokens != nil {
|
|
mt := int(*req.MaxTokens)
|
|
genConfig.MaxOutputTokens = &mt
|
|
}
|
|
if len(req.Stop) > 0 {
|
|
genConfig.StopSequences = req.Stop
|
|
}
|
|
|
|
body := GeminiRequest{
|
|
Contents: contents,
|
|
GenerationConfig: genConfig,
|
|
}
|
|
|
|
if len(req.Tools) > 0 {
|
|
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
|
|
for _, t := range req.Tools {
|
|
if t.Type == "function" {
|
|
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
|
|
}
|
|
}
|
|
body.Tools = []GeminiTool{geminiTool}
|
|
}
|
|
|
|
baseURL := p.config.BaseURL
|
|
lowerModel := strings.ToLower(req.Model)
|
|
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
|
|
// Use v1beta for preview and newer models
|
|
if !strings.Contains(baseURL, "v1beta") {
|
|
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
|
|
}
|
|
}
|
|
|
|
// Use streamGenerateContent for streaming
|
|
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", baseURL, req.Model, p.apiKey)
|
|
fmt.Printf("[Gemini-Stream] POST %s\n", url)
|
|
|
|
resp, err := p.client.R().
|
|
SetContext(ctx).
|
|
SetBody(body).
|
|
SetDoNotParseResponse(true).
|
|
Post(url)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("request failed: %w", err)
|
|
}
|
|
|
|
if !resp.IsSuccess() {
|
|
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
|
}
|
|
|
|
ch := make(chan *models.ChatCompletionStreamResponse)
|
|
|
|
go func() {
|
|
defer close(ch)
|
|
err := StreamGemini(resp.RawBody(), ch, req.Model)
|
|
if err != nil {
|
|
fmt.Printf("Gemini Stream error: %v\n", err)
|
|
}
|
|
}()
|
|
|
|
return ch, nil
|
|
}
|