feat: add image generation for OpenAI DALL-E and Gemini Imagen
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled

New `/v1/images/generations` endpoint proxies DALL-E 2/3 (OpenAI)
and Imagen 3 (Gemini). Same auth/logging as chat completions.

- Add ImageGenerationRequest/Response models
- Extend Provider interface with ImageGeneration()
- OpenAI: forward to /v1/images/generations
- Gemini: call /v1beta/models/{model}:predict, map OpenAI params
- Circuit breaker wraps image gen like chat completions
- Model routing: dall-e* -> openai, imagen*/gemini* -> gemini
- Unsupported providers (deepseek/moonshot/grok/ollama) return error
- Fix pre-existing CachedContentTokenCount bug in StreamGemini
This commit is contained in:
2026-04-27 10:06:07 -04:00
parent 14e26a4323
commit 5ee539d95c
12 changed files with 330 additions and 21 deletions
+10
View File
@@ -54,3 +54,13 @@ func (cbp *CircuitBreakerProvider) ChatCompletionStream(ctx context.Context, req
// Future: Implement a way to track stream failures in the circuit breaker.
return cbp.provider.ChatCompletionStream(ctx, req)
}
func (cbp *CircuitBreakerProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
result, err := cbp.cb.Execute(func() (interface{}, error) {
return cbp.provider.ImageGeneration(ctx, req)
})
if err != nil {
return nil, err
}
return result.(*models.ImageGenerationResponse), nil
}
+14 -10
View File
@@ -3,15 +3,15 @@ package providers
import (
"bufio"
"context"
"time"
"encoding/json"
"fmt"
"io"
"strings"
"time"
"github.com/go-resty/resty/v2"
"gophergate/internal/config"
"gophergate/internal/models"
"github.com/go-resty/resty/v2"
)
type DeepSeekProvider struct {
@@ -33,11 +33,11 @@ func (p *DeepSeekProvider) Name() string {
}
type deepSeekUsage struct {
PromptTokens uint32 `json:"prompt_tokens"`
CompletionTokens uint32 `json:"completion_tokens"`
TotalTokens uint32 `json:"total_tokens"`
PromptCacheHitTokens uint32 `json:"prompt_cache_hit_tokens"`
PromptCacheMissTokens uint32 `json:"prompt_cache_miss_tokens"`
PromptTokens uint32 `json:"prompt_tokens"`
CompletionTokens uint32 `json:"completion_tokens"`
TotalTokens uint32 `json:"total_tokens"`
PromptCacheHitTokens uint32 `json:"prompt_cache_hit_tokens"`
PromptCacheMissTokens uint32 `json:"prompt_cache_miss_tokens"`
CompletionTokensDetails *struct {
ReasoningTokens uint32 `json:"reasoning_tokens"`
} `json:"completion_tokens_details"`
@@ -75,7 +75,7 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi
delete(body, "top_p")
delete(body, "presence_penalty")
delete(body, "frequency_penalty")
if msgs, ok := body["messages"].([]interface{}); ok {
for _, m := range msgs {
if msg, ok := m.(map[string]interface{}); ok {
@@ -142,7 +142,7 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models
delete(body, "top_p")
delete(body, "presence_penalty")
delete(body, "frequency_penalty")
if msgs, ok := body["messages"].([]interface{}); ok {
for _, m := range msgs {
if msg, ok := m.(map[string]interface{}); ok {
@@ -175,7 +175,7 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
// Custom scanner loop to handle DeepSeek specific usage in chunks
@@ -219,3 +219,7 @@ func StreamDeepSeek(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRes
}
return scanner.Err()
}
func (p *DeepSeekProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
return nil, fmt.Errorf("deepseek does not support image generation")
}
+110
View File
@@ -75,6 +75,116 @@ type GeminiFunctionResponse struct {
Response json.RawMessage `json:"response"`
}
func (p *GeminiProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
// Gemini Imagen API: POST https://generativelanguage.googleapis.com/v1beta/models/{model}:predict
// Map OpenAI-style params to Gemini Imagen params
n := uint32(1)
if req.N != nil && *req.N > 0 {
n = *req.N
}
aspectRatio := "1:1"
if req.Size != nil {
aspectRatio = sizeToGeminiAspectRatio(*req.Size)
}
// Build Imagen request
imagenReq := map[string]interface{}{
"instances": []map[string]interface{}{
{"prompt": req.Prompt},
},
"parameters": map[string]interface{}{
"sampleCount": n,
"aspectRatio": aspectRatio,
},
}
// Model defaults to imagen-3.0-generate-001 if empty
model := req.Model
if model == "" {
model = "imagen-3.0-generate-001"
}
// Use v1beta for Imagen
baseURL := p.config.BaseURL
if !strings.Contains(baseURL, "v1beta") {
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
}
url := fmt.Sprintf("%s/models/%s:predict?key=%s", baseURL, model, p.apiKey)
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Content-Type", "application/json").
SetBody(imagenReq).
Post(url)
if err != nil {
return nil, fmt.Errorf("gemini imagen request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Gemini Imagen API error (%d): %s", resp.StatusCode(), resp.String())
}
// Parse Imagen response
var imagenResp struct {
Predictions []struct {
MimeType string `json:"mimeType"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
} `json:"predictions"`
}
if err := json.Unmarshal(resp.Body(), &imagenResp); err != nil {
return nil, fmt.Errorf("failed to parse Imagen response: %w", err)
}
respFormat := "url"
if req.ResponseFormat != nil && *req.ResponseFormat == "b64_json" {
respFormat = "b64_json"
}
var data []models.ImageData
for _, pred := range imagenResp.Predictions {
imgData := models.ImageData{}
if respFormat == "b64_json" {
imgData.B64JSON = pred.BytesBase64Encoded
} else {
// Build a data URI since Gemini returns base64, not a URL
mime := pred.MimeType
if mime == "" {
mime = "image/png"
}
imgData.URL = fmt.Sprintf("data:%s;base64,%s", mime, pred.BytesBase64Encoded)
}
data = append(data, imgData)
}
result := &models.ImageGenerationResponse{
Created: time.Now().Unix(),
Data: data,
}
return result, nil
}
// sizeToGeminiAspectRatio converts OpenAI size format (e.g. "1024x1024") to Gemini aspect ratio (e.g. "1:1")
func sizeToGeminiAspectRatio(size string) string {
switch size {
case "1024x1024":
return "1:1"
case "1024x1792":
return "9:16"
case "1792x1024":
return "16:9"
case "256x256", "512x512":
return "1:1"
default:
return "1:1"
}
}
func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
// Gemini mapping
var contents []GeminiContent
+7 -3
View File
@@ -2,13 +2,13 @@ package providers
import (
"context"
"time"
"encoding/json"
"fmt"
"time"
"github.com/go-resty/resty/v2"
"gophergate/internal/config"
"gophergate/internal/models"
"github.com/go-resty/resty/v2"
)
type GrokProvider struct {
@@ -83,7 +83,7 @@ func (p *GrokProvider) ChatCompletionStream(ctx context.Context, req *models.Uni
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
err := StreamOpenAI(resp.RawBody(), ch)
@@ -94,3 +94,7 @@ func (p *GrokProvider) ChatCompletionStream(ctx context.Context, req *models.Uni
return ch, nil
}
func (p *GrokProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
return nil, fmt.Errorf("grok does not support image generation")
}
+4 -3
View File
@@ -256,9 +256,10 @@ func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
FinishReason string `json:"finishReason"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount uint32 `json:"promptTokenCount"`
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
TotalTokenCount uint32 `json:"totalTokenCount"`
PromptTokenCount uint32 `json:"promptTokenCount"`
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
TotalTokenCount uint32 `json:"totalTokenCount"`
CachedContentTokenCount uint32 `json:"cachedContentTokenCount"`
} `json:"usageMetadata"`
}
+6 -2
View File
@@ -2,14 +2,14 @@ package providers
import (
"context"
"time"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/go-resty/resty/v2"
"gophergate/internal/config"
"gophergate/internal/models"
"github.com/go-resty/resty/v2"
)
type MoonshotProvider struct {
@@ -113,3 +113,7 @@ func (p *MoonshotProvider) ChatCompletionStream(ctx context.Context, req *models
return ch, nil
}
func (p *MoonshotProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
return nil, fmt.Errorf("moonshot does not support image generation")
}
+4
View File
@@ -249,3 +249,7 @@ func StreamOllama(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
}
return scanner.Err()
}
func (p *OllamaProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
return nil, fmt.Errorf("ollama does not support image generation")
}
+47
View File
@@ -68,6 +68,53 @@ func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.Unified
return ParseOpenAIResponse(respJSON, req.Model)
}
func (p *OpenAIProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
body := map[string]interface{}{
"prompt": req.Prompt,
"model": req.Model,
}
if req.N != nil {
body["n"] = *req.N
}
if req.Quality != nil {
body["quality"] = *req.Quality
}
if req.ResponseFormat != nil {
body["response_format"] = *req.ResponseFormat
}
if req.Size != nil {
body["size"] = *req.Size
}
if req.Style != nil {
body["style"] = *req.Style
}
if req.User != nil {
body["user"] = *req.User
}
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetBody(body).
Post(fmt.Sprintf("%s/images/generations", p.config.BaseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("OpenAI image API error (%d): %s", resp.StatusCode(), resp.String())
}
var result models.ImageGenerationResponse
if err := json.Unmarshal(resp.Body(), &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return &result, nil
}
func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
+1
View File
@@ -10,4 +10,5 @@ type Provider interface {
Name() string
ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error)
ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error)
ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error)
}