From 5ee539d95ca4c6f47af9c74f5cc9d7452451f7c6 Mon Sep 17 00:00:00 2001 From: hobokenchicken Date: Mon, 27 Apr 2026 10:06:07 -0400 Subject: [PATCH] feat: add image generation for OpenAI DALL-E and Gemini Imagen New `/v1/images/generations` endpoint proxies DALL-E 2/3 (OpenAI) and Imagen 3 (Gemini). Same auth/logging as chat completions. - Add ImageGenerationRequest/Response models - Extend Provider interface with ImageGeneration() - OpenAI: forward to /v1/images/generations - Gemini: call /v1beta/models/{model}:predict, map OpenAI params - Circuit breaker wraps image gen like chat completions - Model routing: dall-e* -> openai, imagen*/gemini* -> gemini - Unsupported providers (deepseek/moonshot/grok/ollama) return error - Fix pre-existing CachedContentTokenCount bug in StreamGemini --- README.md | 40 +++++++++- internal/models/models.go | 24 ++++++ internal/providers/circuit_breaker.go | 10 +++ internal/providers/deepseek.go | 24 +++--- internal/providers/gemini.go | 110 ++++++++++++++++++++++++++ internal/providers/grok.go | 10 ++- internal/providers/helpers.go | 7 +- internal/providers/moonshot.go | 8 +- internal/providers/ollama.go | 4 + internal/providers/openai.go | 47 +++++++++++ internal/providers/provider.go | 1 + internal/server/server.go | 66 ++++++++++++++++ 12 files changed, 330 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 64294a29..5d8fddfd 100644 --- a/README.md +++ b/README.md @@ -4,10 +4,10 @@ A unified, high-performance LLM proxy gateway built in Go. It provides a single ## Features -- **Unified API:** OpenAI-compatible `/v1/chat/completions` and `/v1/models` endpoints. +- **Unified API:** OpenAI-compatible `/v1/chat/completions`, `/v1/images/generations`, and `/v1/models` endpoints. - **Multi-Provider Support:** - - **OpenAI:** GPT-4o, GPT-4o Mini, o1, o3 reasoning models. - - **Google Gemini:** Gemini 2.0 Flash, Pro, and vision models (with native CoT support). + - **OpenAI:** GPT-4o, GPT-4o Mini, o1, o3 reasoning models, DALL-E 2/3 image generation. + - **Google Gemini:** Gemini 2.0 Flash, Pro, and vision models (with native CoT support), Imagen 3 image generation. - **DeepSeek:** DeepSeek Chat and Reasoner (R1) models. - **Moonshot:** Kimi K2.5 and other Kimi models. - **xAI Grok:** Grok-4 models. @@ -18,6 +18,7 @@ A unified, high-performance LLM proxy gateway built in Go. It provides a single - **Database Persistence:** Every request logged to SQLite for historical analysis and dashboard analytics. - **Streaming Support:** Full SSE (Server-Sent Events) support for all providers. - **Multimodal (Vision):** Image processing (Base64 and remote URLs) across compatible providers. +- **Image Generation:** DALL-E 2/3 (OpenAI) and Imagen 3 (Gemini) via OpenAI-compatible `/v1/images/generations` endpoint. - **Multi-User Access Control:** - **Admin Role:** Full access to all dashboard features, user management, and system configuration. - **Viewer Role:** Read-only access to usage analytics, costs, and monitoring. @@ -54,6 +55,7 @@ GopherGate is designed with security in mind: ### Quick Start 1. Clone and build: + ```bash git clone cd gophergate @@ -61,6 +63,7 @@ GopherGate is designed with security in mind: ``` 2. Configure environment: + ```bash cp .env.example .env # Edit .env and add your configuration: @@ -112,6 +115,7 @@ Access the dashboard at `http://localhost:8080`. **Forgot Password?** You can reset the admin password to default by running: + ```bash ./gophergate -reset-admin ``` @@ -129,6 +133,7 @@ endpoint after enabling Ollama in configuration and setting the base URL to your Ollama server (default: `http://localhost:11434/v1`). ### Python + ```python from openai import OpenAI @@ -141,6 +146,35 @@ response = client.chat.completions.create( model="gpt-4o", messages=[{"role": "user", "content": "Hello!"}] ) + +### Image Generation (DALL-E / Imagen) + +```python +from openai import OpenAI + +client = OpenAI( + base_url="http://localhost:8080/v1", + api_key="YOUR_CLIENT_API_KEY" +) + +# DALL-E 3 (OpenAI) +resp = client.images.generate( + model="dall-e-3", + prompt="A cute gopher wearing a top hat", + n=1, + size="1024x1024" +) +print(resp.data[0].url) + +# Imagen 3 (Gemini) — uses same endpoint +resp = client.images.generate( + model="imagen-3.0-generate-001", + prompt="A gopher coding in Go", + n=1, + size="1024x1024" +) +print(resp.data[0].url) # Returns data URI (Gemini returns base64) +``` ``` ## License diff --git a/internal/models/models.go b/internal/models/models.go index 5ac75ad2..29286989 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -210,6 +210,30 @@ func (i *ImageInput) ToBase64() (string, string, error) { return "", "", fmt.Errorf("empty image input") } +// Image Generation (DALL-E, Imagen) + +type ImageGenerationRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + N *uint32 `json:"n,omitempty"` + Quality *string `json:"quality,omitempty"` + ResponseFormat *string `json:"response_format,omitempty"` + Size *string `json:"size,omitempty"` + Style *string `json:"style,omitempty"` + User *string `json:"user,omitempty"` +} + +type ImageGenerationResponse struct { + Created int64 `json:"created"` + Data []ImageData `json:"data"` +} + +type ImageData struct { + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` +} + // AuthInfo for context type AuthInfo struct { Token string diff --git a/internal/providers/circuit_breaker.go b/internal/providers/circuit_breaker.go index a9d4099b..2bda1492 100644 --- a/internal/providers/circuit_breaker.go +++ b/internal/providers/circuit_breaker.go @@ -54,3 +54,13 @@ func (cbp *CircuitBreakerProvider) ChatCompletionStream(ctx context.Context, req // Future: Implement a way to track stream failures in the circuit breaker. return cbp.provider.ChatCompletionStream(ctx, req) } + +func (cbp *CircuitBreakerProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) { + result, err := cbp.cb.Execute(func() (interface{}, error) { + return cbp.provider.ImageGeneration(ctx, req) + }) + if err != nil { + return nil, err + } + return result.(*models.ImageGenerationResponse), nil +} diff --git a/internal/providers/deepseek.go b/internal/providers/deepseek.go index 98d7c810..191fb5e3 100644 --- a/internal/providers/deepseek.go +++ b/internal/providers/deepseek.go @@ -3,15 +3,15 @@ package providers import ( "bufio" "context" - "time" "encoding/json" "fmt" "io" "strings" + "time" + "github.com/go-resty/resty/v2" "gophergate/internal/config" "gophergate/internal/models" - "github.com/go-resty/resty/v2" ) type DeepSeekProvider struct { @@ -33,11 +33,11 @@ func (p *DeepSeekProvider) Name() string { } type deepSeekUsage struct { - PromptTokens uint32 `json:"prompt_tokens"` - CompletionTokens uint32 `json:"completion_tokens"` - TotalTokens uint32 `json:"total_tokens"` - PromptCacheHitTokens uint32 `json:"prompt_cache_hit_tokens"` - PromptCacheMissTokens uint32 `json:"prompt_cache_miss_tokens"` + PromptTokens uint32 `json:"prompt_tokens"` + CompletionTokens uint32 `json:"completion_tokens"` + TotalTokens uint32 `json:"total_tokens"` + PromptCacheHitTokens uint32 `json:"prompt_cache_hit_tokens"` + PromptCacheMissTokens uint32 `json:"prompt_cache_miss_tokens"` CompletionTokensDetails *struct { ReasoningTokens uint32 `json:"reasoning_tokens"` } `json:"completion_tokens_details"` @@ -75,7 +75,7 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi delete(body, "top_p") delete(body, "presence_penalty") delete(body, "frequency_penalty") - + if msgs, ok := body["messages"].([]interface{}); ok { for _, m := range msgs { if msg, ok := m.(map[string]interface{}); ok { @@ -142,7 +142,7 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models delete(body, "top_p") delete(body, "presence_penalty") delete(body, "frequency_penalty") - + if msgs, ok := body["messages"].([]interface{}); ok { for _, m := range msgs { if msg, ok := m.(map[string]interface{}); ok { @@ -175,7 +175,7 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models } ch := make(chan *models.ChatCompletionStreamResponse) - + go func() { defer close(ch) // Custom scanner loop to handle DeepSeek specific usage in chunks @@ -219,3 +219,7 @@ func StreamDeepSeek(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRes } return scanner.Err() } + +func (p *DeepSeekProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) { + return nil, fmt.Errorf("deepseek does not support image generation") +} diff --git a/internal/providers/gemini.go b/internal/providers/gemini.go index adf9728a..880bfed8 100644 --- a/internal/providers/gemini.go +++ b/internal/providers/gemini.go @@ -75,6 +75,116 @@ type GeminiFunctionResponse struct { Response json.RawMessage `json:"response"` } +func (p *GeminiProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) { + // Gemini Imagen API: POST https://generativelanguage.googleapis.com/v1beta/models/{model}:predict + // Map OpenAI-style params to Gemini Imagen params + + n := uint32(1) + if req.N != nil && *req.N > 0 { + n = *req.N + } + + aspectRatio := "1:1" + if req.Size != nil { + aspectRatio = sizeToGeminiAspectRatio(*req.Size) + } + + // Build Imagen request + imagenReq := map[string]interface{}{ + "instances": []map[string]interface{}{ + {"prompt": req.Prompt}, + }, + "parameters": map[string]interface{}{ + "sampleCount": n, + "aspectRatio": aspectRatio, + }, + } + + // Model defaults to imagen-3.0-generate-001 if empty + model := req.Model + if model == "" { + model = "imagen-3.0-generate-001" + } + + // Use v1beta for Imagen + baseURL := p.config.BaseURL + if !strings.Contains(baseURL, "v1beta") { + baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1) + } + + url := fmt.Sprintf("%s/models/%s:predict?key=%s", baseURL, model, p.apiKey) + + resp, err := p.client.R(). + SetContext(ctx). + SetHeader("Content-Type", "application/json"). + SetBody(imagenReq). + Post(url) + + if err != nil { + return nil, fmt.Errorf("gemini imagen request failed: %w", err) + } + + if !resp.IsSuccess() { + return nil, fmt.Errorf("Gemini Imagen API error (%d): %s", resp.StatusCode(), resp.String()) + } + + // Parse Imagen response + var imagenResp struct { + Predictions []struct { + MimeType string `json:"mimeType"` + BytesBase64Encoded string `json:"bytesBase64Encoded"` + } `json:"predictions"` + } + + if err := json.Unmarshal(resp.Body(), &imagenResp); err != nil { + return nil, fmt.Errorf("failed to parse Imagen response: %w", err) + } + + respFormat := "url" + if req.ResponseFormat != nil && *req.ResponseFormat == "b64_json" { + respFormat = "b64_json" + } + + var data []models.ImageData + for _, pred := range imagenResp.Predictions { + imgData := models.ImageData{} + if respFormat == "b64_json" { + imgData.B64JSON = pred.BytesBase64Encoded + } else { + // Build a data URI since Gemini returns base64, not a URL + mime := pred.MimeType + if mime == "" { + mime = "image/png" + } + imgData.URL = fmt.Sprintf("data:%s;base64,%s", mime, pred.BytesBase64Encoded) + } + data = append(data, imgData) + } + + result := &models.ImageGenerationResponse{ + Created: time.Now().Unix(), + Data: data, + } + + return result, nil +} + +// sizeToGeminiAspectRatio converts OpenAI size format (e.g. "1024x1024") to Gemini aspect ratio (e.g. "1:1") +func sizeToGeminiAspectRatio(size string) string { + switch size { + case "1024x1024": + return "1:1" + case "1024x1792": + return "9:16" + case "1792x1024": + return "16:9" + case "256x256", "512x512": + return "1:1" + default: + return "1:1" + } +} + func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) { // Gemini mapping var contents []GeminiContent diff --git a/internal/providers/grok.go b/internal/providers/grok.go index 1482bb83..5c5106f0 100644 --- a/internal/providers/grok.go +++ b/internal/providers/grok.go @@ -2,13 +2,13 @@ package providers import ( "context" - "time" "encoding/json" "fmt" + "time" + "github.com/go-resty/resty/v2" "gophergate/internal/config" "gophergate/internal/models" - "github.com/go-resty/resty/v2" ) type GrokProvider struct { @@ -83,7 +83,7 @@ func (p *GrokProvider) ChatCompletionStream(ctx context.Context, req *models.Uni } ch := make(chan *models.ChatCompletionStreamResponse) - + go func() { defer close(ch) err := StreamOpenAI(resp.RawBody(), ch) @@ -94,3 +94,7 @@ func (p *GrokProvider) ChatCompletionStream(ctx context.Context, req *models.Uni return ch, nil } + +func (p *GrokProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) { + return nil, fmt.Errorf("grok does not support image generation") +} diff --git a/internal/providers/helpers.go b/internal/providers/helpers.go index a66f86f4..a0369aea 100644 --- a/internal/providers/helpers.go +++ b/internal/providers/helpers.go @@ -256,9 +256,10 @@ func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo FinishReason string `json:"finishReason"` } `json:"candidates"` UsageMetadata struct { - PromptTokenCount uint32 `json:"promptTokenCount"` - CandidatesTokenCount uint32 `json:"candidatesTokenCount"` - TotalTokenCount uint32 `json:"totalTokenCount"` + PromptTokenCount uint32 `json:"promptTokenCount"` + CandidatesTokenCount uint32 `json:"candidatesTokenCount"` + TotalTokenCount uint32 `json:"totalTokenCount"` + CachedContentTokenCount uint32 `json:"cachedContentTokenCount"` } `json:"usageMetadata"` } diff --git a/internal/providers/moonshot.go b/internal/providers/moonshot.go index 0d40533c..f2a01c3b 100644 --- a/internal/providers/moonshot.go +++ b/internal/providers/moonshot.go @@ -2,14 +2,14 @@ package providers import ( "context" - "time" "encoding/json" "fmt" "strings" + "time" + "github.com/go-resty/resty/v2" "gophergate/internal/config" "gophergate/internal/models" - "github.com/go-resty/resty/v2" ) type MoonshotProvider struct { @@ -113,3 +113,7 @@ func (p *MoonshotProvider) ChatCompletionStream(ctx context.Context, req *models return ch, nil } + +func (p *MoonshotProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) { + return nil, fmt.Errorf("moonshot does not support image generation") +} diff --git a/internal/providers/ollama.go b/internal/providers/ollama.go index 08175a17..528b8581 100644 --- a/internal/providers/ollama.go +++ b/internal/providers/ollama.go @@ -249,3 +249,7 @@ func StreamOllama(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo } return scanner.Err() } + +func (p *OllamaProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) { + return nil, fmt.Errorf("ollama does not support image generation") +} diff --git a/internal/providers/openai.go b/internal/providers/openai.go index 840502b8..25baf2ad 100644 --- a/internal/providers/openai.go +++ b/internal/providers/openai.go @@ -68,6 +68,53 @@ func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.Unified return ParseOpenAIResponse(respJSON, req.Model) } +func (p *OpenAIProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) { + body := map[string]interface{}{ + "prompt": req.Prompt, + "model": req.Model, + } + + if req.N != nil { + body["n"] = *req.N + } + if req.Quality != nil { + body["quality"] = *req.Quality + } + if req.ResponseFormat != nil { + body["response_format"] = *req.ResponseFormat + } + if req.Size != nil { + body["size"] = *req.Size + } + if req.Style != nil { + body["style"] = *req.Style + } + if req.User != nil { + body["user"] = *req.User + } + + resp, err := p.client.R(). + SetContext(ctx). + SetHeader("Authorization", "Bearer "+p.apiKey). + SetBody(body). + Post(fmt.Sprintf("%s/images/generations", p.config.BaseURL)) + + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if !resp.IsSuccess() { + return nil, fmt.Errorf("OpenAI image API error (%d): %s", resp.StatusCode(), resp.String()) + } + + var result models.ImageGenerationResponse + if err := json.Unmarshal(resp.Body(), &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &result, nil +} + func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) { messagesJSON, err := MessagesToOpenAIJSON(req.Messages) if err != nil { diff --git a/internal/providers/provider.go b/internal/providers/provider.go index 175b261b..a43b1bd4 100644 --- a/internal/providers/provider.go +++ b/internal/providers/provider.go @@ -10,4 +10,5 @@ type Provider interface { Name() string ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) + ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) } diff --git a/internal/server/server.go b/internal/server/server.go index 9ae3a8f5..bf30d41a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -186,6 +186,7 @@ func (s *Server) setupRoutes() { v1.Use(middleware.AuthMiddleware(s.database, true)) { v1.POST("/chat/completions", s.handleChatCompletions) + v1.POST("/images/generations", s.handleImageGenerations) v1.GET("/models", s.handleListModels) v1.GET("/responses", s.handleListResponses) } @@ -501,6 +502,71 @@ func (s *Server) handleChatCompletions(c *gin.Context) { c.JSON(http.StatusOK, resp) } +func (s *Server) handleImageGenerations(c *gin.Context) { + startTime := time.Now() + var req models.ImageGenerationRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Determine provider based on model name + providerName := "openai" + modelLower := strings.ToLower(req.Model) + switch { + case strings.Contains(modelLower, "imagen"), strings.Contains(modelLower, "gemini"): + providerName = "gemini" + case strings.Contains(modelLower, "dall"), strings.HasPrefix(modelLower, "openai/"): + providerName = "openai" + } + + // Default model for each provider if not specified + if req.Model == "" { + if providerName == "openai" { + req.Model = "dall-e-3" + } else { + req.Model = "imagen-3.0-generate-001" + } + } + + // Strip common prefixes + prefixes := []string{"openai/", "gemini/", "google/"} + for _, p := range prefixes { + if strings.HasPrefix(req.Model, p) { + req.Model = strings.TrimPrefix(req.Model, p) + break + } + } + + provider, ok := s.providers[providerName] + if !ok { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)}) + return + } + + clientID := "default" + if auth, ok := c.Get("auth"); ok { + if authInfo, ok := auth.(models.AuthInfo); ok { + clientID = authInfo.ClientID + } + } + + resp, err := provider.ImageGeneration(c.Request.Context(), &req) + if err != nil { + s.logRequest(startTime, clientID, providerName, req.Model, nil, err, false) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + s.logRequest(startTime, clientID, providerName, req.Model, &models.Usage{ + PromptTokens: 1, + CompletionTokens: uint32(len(resp.Data)), + TotalTokens: 1 + uint32(len(resp.Data)), + }, nil, false) + + c.JSON(http.StatusOK, resp) +} + func (s *Server) logRequest(start time.Time, clientID, provider, model string, usage *models.Usage, err error, hasImages bool) { entry := RequestLog{ Timestamp: start,