Files
GopherGate/internal/providers/gemini.go
T
hobokenchicken 5ee539d95c
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
feat: add image generation for OpenAI DALL-E and Gemini Imagen
New `/v1/images/generations` endpoint proxies DALL-E 2/3 (OpenAI)
and Imagen 3 (Gemini). Same auth/logging as chat completions.

- Add ImageGenerationRequest/Response models
- Extend Provider interface with ImageGeneration()
- OpenAI: forward to /v1/images/generations
- Gemini: call /v1beta/models/{model}:predict, map OpenAI params
- Circuit breaker wraps image gen like chat completions
- Model routing: dall-e* -> openai, imagen*/gemini* -> gemini
- Unsupported providers (deepseek/moonshot/grok/ollama) return error
- Fix pre-existing CachedContentTokenCount bug in StreamGemini
2026-04-27 10:06:07 -04:00

616 lines
16 KiB
Go

package providers
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/go-resty/resty/v2"
"gophergate/internal/config"
"gophergate/internal/models"
)
type GeminiProvider struct {
client *resty.Client
config config.GeminiConfig
apiKey string
}
func NewGeminiProvider(cfg config.GeminiConfig, apiKey string) *GeminiProvider {
return &GeminiProvider{
client: resty.New().SetTimeout(30 * time.Second),
config: cfg,
apiKey: apiKey,
}
}
func (p *GeminiProvider) Name() string {
return "gemini"
}
type GeminiRequest struct {
Contents []GeminiContent `json:"contents"`
Tools []GeminiTool `json:"tools,omitempty"`
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
}
type GeminiTool struct {
FunctionDeclarations []models.FunctionDef `json:"functionDeclarations"`
}
type GeminiGenerationConfig struct {
Temperature *float32 `json:"temperature,omitempty"`
TopP *float32 `json:"topP,omitempty"`
TopK *int `json:"topK,omitempty"`
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}
type GeminiContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
}
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type GeminiFunctionCall struct {
Name string `json:"name"`
Args json.RawMessage `json:"args"`
}
type GeminiFunctionResponse struct {
Name string `json:"name"`
Response json.RawMessage `json:"response"`
}
func (p *GeminiProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
// Gemini Imagen API: POST https://generativelanguage.googleapis.com/v1beta/models/{model}:predict
// Map OpenAI-style params to Gemini Imagen params
n := uint32(1)
if req.N != nil && *req.N > 0 {
n = *req.N
}
aspectRatio := "1:1"
if req.Size != nil {
aspectRatio = sizeToGeminiAspectRatio(*req.Size)
}
// Build Imagen request
imagenReq := map[string]interface{}{
"instances": []map[string]interface{}{
{"prompt": req.Prompt},
},
"parameters": map[string]interface{}{
"sampleCount": n,
"aspectRatio": aspectRatio,
},
}
// Model defaults to imagen-3.0-generate-001 if empty
model := req.Model
if model == "" {
model = "imagen-3.0-generate-001"
}
// Use v1beta for Imagen
baseURL := p.config.BaseURL
if !strings.Contains(baseURL, "v1beta") {
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
}
url := fmt.Sprintf("%s/models/%s:predict?key=%s", baseURL, model, p.apiKey)
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Content-Type", "application/json").
SetBody(imagenReq).
Post(url)
if err != nil {
return nil, fmt.Errorf("gemini imagen request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Gemini Imagen API error (%d): %s", resp.StatusCode(), resp.String())
}
// Parse Imagen response
var imagenResp struct {
Predictions []struct {
MimeType string `json:"mimeType"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
} `json:"predictions"`
}
if err := json.Unmarshal(resp.Body(), &imagenResp); err != nil {
return nil, fmt.Errorf("failed to parse Imagen response: %w", err)
}
respFormat := "url"
if req.ResponseFormat != nil && *req.ResponseFormat == "b64_json" {
respFormat = "b64_json"
}
var data []models.ImageData
for _, pred := range imagenResp.Predictions {
imgData := models.ImageData{}
if respFormat == "b64_json" {
imgData.B64JSON = pred.BytesBase64Encoded
} else {
// Build a data URI since Gemini returns base64, not a URL
mime := pred.MimeType
if mime == "" {
mime = "image/png"
}
imgData.URL = fmt.Sprintf("data:%s;base64,%s", mime, pred.BytesBase64Encoded)
}
data = append(data, imgData)
}
result := &models.ImageGenerationResponse{
Created: time.Now().Unix(),
Data: data,
}
return result, nil
}
// sizeToGeminiAspectRatio converts OpenAI size format (e.g. "1024x1024") to Gemini aspect ratio (e.g. "1:1")
func sizeToGeminiAspectRatio(size string) string {
switch size {
case "1024x1024":
return "1:1"
case "1024x1792":
return "9:16"
case "1792x1024":
return "16:9"
case "256x256", "512x512":
return "1:1"
default:
return "1:1"
}
}
func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
// Gemini mapping
var contents []GeminiContent
for i := 0; i < len(req.Messages); i++ {
msg := req.Messages[i]
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
// 1. Add the assistant (model) message with tool calls
parts := []GeminiPart{}
for _, cp := range msg.Content {
if cp.Type == "text" && cp.Text != "" {
parts = append(parts, GeminiPart{Text: cp.Text})
}
}
for _, tc := range msg.ToolCalls {
parts = append(parts, GeminiPart{
FunctionCall: &GeminiFunctionCall{
Name: tc.Function.Name,
Args: json.RawMessage(tc.Function.Arguments),
},
})
}
contents = append(contents, GeminiContent{Role: "model", Parts: parts})
// 2. The VERY NEXT message MUST be the "function" results for THESE EXACT calls.
// Look ahead for tool messages.
var functionParts []GeminiPart
toolCallIDs := make(map[string]bool)
for _, tc := range msg.ToolCalls {
toolCallIDs[tc.ID] = true
}
// We need to find tool messages that correspond to these calls.
// In many patterns, they follow immediately.
j := i + 1
foundAny := false
for j < len(req.Messages) && req.Messages[j].Role == "tool" {
m := req.Messages[j]
// Try to match by ID or just take them in order if IDs are missing/mismatched
// Gemini is strict: you must respond to EVERY call in the previous message.
text := ""
if len(m.Content) > 0 {
text = m.Content[0].Text
}
name := "unknown_function"
if m.Name != nil {
name = *m.Name
}
var responseObj interface{}
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
responseObj = map[string]interface{}{"result": text}
}
respBytes, _ := json.Marshal(responseObj)
functionParts = append(functionParts, GeminiPart{
FunctionResponse: &GeminiFunctionResponse{
Name: name,
Response: json.RawMessage(respBytes),
},
})
foundAny = true
j++
}
if foundAny {
contents = append(contents, GeminiContent{Role: "function", Parts: functionParts})
i = j - 1 // Advance outer loop past the tool messages we consumed
} else {
// If no tool results found but assistant made calls, Gemini WILL error.
// We should probably skip the calls or provide dummy results,
// but usually this means the conversation is incomplete.
// For now, don't add a "function" message if none found.
}
continue
}
// Standard message handling (System/User/Assistant without tools)
role := "user"
if msg.Role == "assistant" {
role = "model"
} else if msg.Role == "system" {
role = "user" // Gemini uses 'user' for system prompts in some versions, or handles it via systemInstruction
} else if msg.Role == "tool" {
// Orphaned tool message (not following an assistant call) - Gemini doesn't like this.
// Skip or map to user? Skipping is safer for API stability.
continue
}
var parts []GeminiPart
for _, cp := range msg.Content {
if cp.Type == "text" && cp.Text != "" {
parts = append(parts, GeminiPart{Text: cp.Text})
} else if cp.Image != nil {
base64Data, mimeType, _ := cp.Image.ToBase64()
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
},
})
}
}
if len(parts) > 0 {
contents = append(contents, GeminiContent{Role: role, Parts: parts})
}
}
genConfig := &GeminiGenerationConfig{}
if req.Temperature != nil {
t := float32(*req.Temperature)
genConfig.Temperature = &t
}
if req.TopP != nil {
tp := float32(*req.TopP)
genConfig.TopP = &tp
}
if req.TopK != nil {
tk := int(*req.TopK)
genConfig.TopK = &tk
}
if req.MaxTokens != nil {
mt := int(*req.MaxTokens)
genConfig.MaxOutputTokens = &mt
}
if len(req.Stop) > 0 {
genConfig.StopSequences = req.Stop
}
body := GeminiRequest{
Contents: contents,
GenerationConfig: genConfig,
}
// Map Tools
if len(req.Tools) > 0 {
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
for _, t := range req.Tools {
if t.Type == "function" {
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
}
}
body.Tools = []GeminiTool{geminiTool}
}
baseURL := p.config.BaseURL
lowerModel := strings.ToLower(req.Model)
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
// Use v1beta for preview and newer models
if !strings.Contains(baseURL, "v1beta") {
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
}
}
url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", baseURL, req.Model, p.apiKey)
fmt.Printf("[Gemini] POST %s\n", url)
resp, err := p.client.R().
SetContext(ctx).
SetBody(body).
Post(url)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
fmt.Printf("[Gemini] API Error %d: %s\n", resp.StatusCode(), resp.String())
// Also log the request body for debugging (careful with API keys if logged elsewhere)
reqJSON, _ := json.Marshal(body)
fmt.Printf("[Gemini] Request Body: %s\n", string(reqJSON))
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
}
// Parse Gemini response and convert to OpenAI format
var geminiResp struct {
Candidates []struct {
Content struct {
Role string `json:"role"`
Parts []struct {
Text string `json:"text"`
FunctionCall *GeminiFunctionCall `json:"functionCall"`
} `json:"parts"`
} `json:"content"`
FinishReason string `json:"finishReason"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount uint32 `json:"promptTokenCount"`
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
TotalTokenCount uint32 `json:"totalTokenCount"`
CachedContentTokenCount uint32 `json:"cachedContentTokenCount"`
} `json:"usageMetadata"`
}
if err := json.Unmarshal(resp.Body(), &geminiResp); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
if len(geminiResp.Candidates) == 0 {
return nil, fmt.Errorf("no candidates in Gemini response")
}
content := ""
var toolCalls []models.ToolCall
for _, part := range geminiResp.Candidates[0].Content.Parts {
if part.Text != "" {
content += part.Text
}
if part.FunctionCall != nil {
toolCalls = append(toolCalls, models.ToolCall{
ID: fmt.Sprintf("call_%s", part.FunctionCall.Name), // Gemini doesn't have call IDs
Type: "function",
Function: models.FunctionCall{
Name: part.FunctionCall.Name,
Arguments: string(part.FunctionCall.Args),
},
})
}
}
finishReason := strings.ToLower(geminiResp.Candidates[0].FinishReason)
if finishReason == "stop" {
finishReason = "stop"
} else if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
openAIResp := &models.ChatCompletionResponse{
ID: "gemini-" + req.Model,
Object: "chat.completion",
Created: 0,
Model: req.Model,
Choices: []models.ChatChoice{
{
Index: 0,
Message: models.ChatMessage{
Role: "assistant",
Content: content,
ToolCalls: toolCalls,
},
FinishReason: &finishReason,
},
},
Usage: &models.Usage{
PromptTokens: geminiResp.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiResp.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiResp.UsageMetadata.TotalTokenCount,
CacheReadTokens: uint32Ptr(geminiResp.UsageMetadata.CachedContentTokenCount),
},
}
return openAIResp, nil
}
func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
// Simplified Gemini mapping
var contents []GeminiContent
for i := 0; i < len(req.Messages); i++ {
msg := req.Messages[i]
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
parts := []GeminiPart{}
for _, cp := range msg.Content {
if cp.Type == "text" && cp.Text != "" {
parts = append(parts, GeminiPart{Text: cp.Text})
}
}
for _, tc := range msg.ToolCalls {
parts = append(parts, GeminiPart{
FunctionCall: &GeminiFunctionCall{
Name: tc.Function.Name,
Args: json.RawMessage(tc.Function.Arguments),
},
})
}
contents = append(contents, GeminiContent{Role: "model", Parts: parts})
var functionParts []GeminiPart
j := i + 1
foundAny := false
for j < len(req.Messages) && req.Messages[j].Role == "tool" {
m := req.Messages[j]
text := ""
if len(m.Content) > 0 {
text = m.Content[0].Text
}
name := "unknown_function"
if m.Name != nil {
name = *m.Name
}
var responseObj interface{}
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
responseObj = map[string]interface{}{"result": text}
}
respBytes, _ := json.Marshal(responseObj)
functionParts = append(functionParts, GeminiPart{
FunctionResponse: &GeminiFunctionResponse{
Name: name,
Response: json.RawMessage(respBytes),
},
})
foundAny = true
j++
}
if foundAny {
contents = append(contents, GeminiContent{Role: "function", Parts: functionParts})
i = j - 1
}
continue
}
role := "user"
if msg.Role == "assistant" {
role = "model"
} else if msg.Role == "system" {
role = "user"
} else if msg.Role == "tool" {
continue
}
var parts []GeminiPart
for _, cp := range msg.Content {
if cp.Type == "text" && cp.Text != "" {
parts = append(parts, GeminiPart{Text: cp.Text})
} else if cp.Image != nil {
base64Data, mimeType, _ := cp.Image.ToBase64()
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
},
})
}
}
if len(parts) > 0 {
contents = append(contents, GeminiContent{Role: role, Parts: parts})
}
}
genConfig := &GeminiGenerationConfig{}
if req.Temperature != nil {
t := float32(*req.Temperature)
genConfig.Temperature = &t
}
if req.TopP != nil {
tp := float32(*req.TopP)
genConfig.TopP = &tp
}
if req.TopK != nil {
tk := int(*req.TopK)
genConfig.TopK = &tk
}
if req.MaxTokens != nil {
mt := int(*req.MaxTokens)
genConfig.MaxOutputTokens = &mt
}
if len(req.Stop) > 0 {
genConfig.StopSequences = req.Stop
}
body := GeminiRequest{
Contents: contents,
GenerationConfig: genConfig,
}
if len(req.Tools) > 0 {
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
for _, t := range req.Tools {
if t.Type == "function" {
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
}
}
body.Tools = []GeminiTool{geminiTool}
}
baseURL := p.config.BaseURL
lowerModel := strings.ToLower(req.Model)
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
// Use v1beta for preview and newer models
if !strings.Contains(baseURL, "v1beta") {
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
}
}
// Use streamGenerateContent for streaming
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", baseURL, req.Model, p.apiKey)
fmt.Printf("[Gemini-Stream] POST %s\n", url)
resp, err := p.client.R().
SetContext(ctx).
SetBody(body).
SetDoNotParseResponse(true).
Post(url)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
err := StreamGemini(resp.RawBody(), ch, req.Model)
if err != nil {
fmt.Printf("Gemini Stream error: %v\n", err)
}
}()
return ch, nil
}
func uint32Ptr(v uint32) *uint32 {
if v > 0 {
return &v
}
return nil
}