Files
GopherGate/internal/server/server.go
T
hobokenchicken 0ae30036f0
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
fix: classifier selector model now routes to correct provider
Extracted selectProvider() method from handleChatCompletions' inline
logic. The classifier callback now calls selectProvider(selectorModel)
instead of hardcoding openaiProvider.

This fixes the 'circuit breaker is open' error when dispatcher tries
to use deepseek-v4-flash as its selector model.
2026-05-07 13:37:19 -04:00

935 lines
28 KiB
Go

package server
import (
"encoding/json"
"context"
"fmt"
"io"
"log"
"net/http"
"strings"
"sync"
"time"
"gophergate/internal/config"
"gophergate/internal/db"
"gophergate/internal/middleware"
"gophergate/internal/models"
"gophergate/internal/providers"
"gophergate/internal/router"
"gophergate/internal/utils"
"github.com/gin-gonic/gin"
)
type Server struct {
router *gin.Engine
cfg *config.Config
database *db.DB
providers map[string]providers.Provider
sessions *SessionManager
hub *Hub
logger *RequestLogger
registry *models.ModelRegistry
registryMu sync.RWMutex
modelRouter *router.Router
}
func NewServer(cfg *config.Config, database *db.DB) *Server {
router := gin.Default()
hub := NewHub()
s := &Server{
router: router,
cfg: cfg,
database: database,
providers: make(map[string]providers.Provider),
sessions: NewSessionManager(cfg.KeyBytes, 24*time.Hour),
hub: hub,
logger: NewRequestLogger(database, hub),
registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)},
}
s.sessions.StartCleanup()
// Fetch registry in background
go func() {
registry, err := utils.FetchRegistry()
if err != nil {
fmt.Printf("Warning: Failed to fetch initial model registry: %v\n", err)
} else {
s.registry = registry
}
}()
// Initialize providers from DB and Config
if err := s.RefreshProviders(); err != nil {
fmt.Printf("Warning: Failed to initial refresh providers: %v\n", err)
}
s.setupRoutes()
// Initialize model group router
s.refreshRouter()
return s
}
func (s *Server) RefreshProviders() error {
var dbConfigs []db.ProviderConfig
err := s.database.Select(&dbConfigs, "SELECT * FROM provider_configs")
if err != nil {
return fmt.Errorf("failed to fetch provider configs from db: %w", err)
}
dbMap := make(map[string]db.ProviderConfig)
for _, cfg := range dbConfigs {
dbMap[cfg.ID] = cfg
}
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama"}
for _, id := range providerIDs {
// Default values from config
enabled := false
baseURL := ""
apiKey := ""
switch id {
case "openai":
enabled = s.cfg.Providers.OpenAI.Enabled
baseURL = s.cfg.Providers.OpenAI.BaseURL
apiKey, _ = s.cfg.GetAPIKey("openai")
case "gemini":
enabled = s.cfg.Providers.Gemini.Enabled
baseURL = s.cfg.Providers.Gemini.BaseURL
apiKey, _ = s.cfg.GetAPIKey("gemini")
case "deepseek":
enabled = s.cfg.Providers.DeepSeek.Enabled
baseURL = s.cfg.Providers.DeepSeek.BaseURL
apiKey, _ = s.cfg.GetAPIKey("deepseek")
case "moonshot":
enabled = s.cfg.Providers.Moonshot.Enabled
baseURL = s.cfg.Providers.Moonshot.BaseURL
apiKey, _ = s.cfg.GetAPIKey("moonshot")
case "grok":
enabled = s.cfg.Providers.Grok.Enabled
baseURL = s.cfg.Providers.Grok.BaseURL
apiKey, _ = s.cfg.GetAPIKey("grok")
}
// Overrides from DB
if dbCfg, ok := dbMap[id]; ok {
enabled = dbCfg.Enabled
if dbCfg.BaseURL != nil && *dbCfg.BaseURL != "" {
baseURL = *dbCfg.BaseURL
}
if dbCfg.APIKey != nil && *dbCfg.APIKey != "" {
key := *dbCfg.APIKey
if dbCfg.APIKeyEncrypted {
decrypted, err := utils.Decrypt(key, s.cfg.KeyBytes)
if err == nil {
key = decrypted
} else {
fmt.Printf("Warning: Failed to decrypt API key for %s: %v\n", id, err)
}
}
apiKey = key
}
}
if !enabled {
delete(s.providers, id)
continue
}
// Initialize provider
var p providers.Provider
switch id {
case "openai":
cfg := s.cfg.Providers.OpenAI
cfg.BaseURL = baseURL
p = providers.NewOpenAIProvider(cfg, apiKey)
case "gemini":
cfg := s.cfg.Providers.Gemini
cfg.BaseURL = baseURL
p = providers.NewGeminiProvider(cfg, apiKey)
case "deepseek":
cfg := s.cfg.Providers.DeepSeek
cfg.BaseURL = baseURL
p = providers.NewDeepSeekProvider(cfg, apiKey)
case "moonshot":
cfg := s.cfg.Providers.Moonshot
cfg.BaseURL = baseURL
p = providers.NewMoonshotProvider(cfg, apiKey)
case "grok":
cfg := s.cfg.Providers.Grok
cfg.BaseURL = baseURL
p = providers.NewGrokProvider(cfg, apiKey)
case "ollama":
cfg := s.cfg.Providers.Ollama
cfg.BaseURL = baseURL
p = providers.NewOllamaProvider(cfg)
}
if p != nil {
s.providers[id] = providers.NewCircuitBreakerProvider(p)
}
}
s.refreshRouter()
return nil
}
func (s *Server) refreshRouter() {
var groups []db.ModelGroup
if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil {
fmt.Printf("Warning: Failed to load model groups: %v\n", err)
groups = nil
}
var classifyFn router.ClassifierFunc
classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) {
provider, _, err := s.selectProvider(selectorModel)
if err != nil {
return "", err
}
req := &models.UnifiedRequest{
Model: selectorModel,
Messages: []models.UnifiedMessage{
{Role: "system", Content: []models.UnifiedContentPart{{Type: "text", Text: systemPrompt}}},
{Role: "user", Content: []models.UnifiedContentPart{{Type: "text", Text: userMessage}}},
},
MaxTokens: uint32Ptr(5),
Stream: false,
}
resp, err := provider.ChatCompletion(ctx, req)
if err != nil {
return "", err
}
if len(resp.Choices) == 0 {
return "", fmt.Errorf("no choices in classifier response")
}
content, ok := resp.Choices[0].Message.Content.(string)
if !ok {
return "", fmt.Errorf("classifier response content is not a string")
}
return content, nil
}
if s.modelRouter == nil {
s.modelRouter = router.New(groups, classifyFn)
} else {
s.modelRouter.Reload(groups)
}
}
func (s *Server) setupRoutes() {
// Static files
s.router.StaticFile("/", "./static/index.html")
s.router.StaticFile("/favicon.ico", "./static/favicon.ico")
s.router.Static("/css", "./static/css")
s.router.Static("/js", "./static/js")
s.router.Static("/img", "./static/img")
// WebSocket
s.router.GET("/ws", s.handleWebSocket)
// API V1 (External LLM Access) - Secured with AuthMiddleware
v1 := s.router.Group("/v1")
v1.Use(middleware.AuthMiddleware(s.database, true))
{
v1.POST("/chat/completions", s.handleChatCompletions)
v1.POST("/images/generations", s.handleImageGenerations)
v1.GET("/models", s.handleListModels)
v1.POST("/responses", s.handleResponses)
}
// Dashboard API Group
api := s.router.Group("/api")
{
api.POST("/auth/login", s.handleLogin)
api.GET("/auth/status", s.handleAuthStatus)
api.POST("/auth/logout", s.handleLogout)
api.POST("/auth/change-password", s.handleChangePassword)
// Protected dashboard routes (need admin session)
admin := api.Group("/")
admin.Use(s.adminAuthMiddleware())
{
admin.GET("/usage/summary", s.handleUsageSummary)
admin.GET("/usage/time-series", s.handleTimeSeries)
admin.GET("/usage/providers", s.handleProvidersUsage)
admin.GET("/usage/clients", s.handleClientsUsage)
admin.GET("/usage/detailed", s.handleDetailedUsage)
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
admin.GET("/clients", s.handleGetClients)
admin.POST("/clients", s.handleCreateClient)
admin.GET("/clients/:id", s.handleGetClient)
admin.PUT("/clients/:id", s.handleUpdateClient)
admin.DELETE("/clients/:id", s.handleDeleteClient)
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
admin.GET("/providers", s.handleGetProviders)
admin.PUT("/providers/:name", s.handleUpdateProvider)
admin.POST("/providers/:name/test", s.handleTestProvider)
admin.GET("/models", s.handleGetModels)
admin.PUT("/models/:id", s.handleUpdateModel)
admin.GET("/model-groups", s.handleGetModelGroups)
admin.POST("/model-groups", s.handleCreateModelGroup)
admin.PUT("/model-groups/:id", s.handleUpdateModelGroup)
admin.DELETE("/model-groups/:id", s.handleDeleteModelGroup)
admin.GET("/users", s.handleGetUsers)
admin.POST("/users", s.handleCreateUser)
admin.PUT("/users/:id", s.handleUpdateUser)
admin.DELETE("/users/:id", s.handleDeleteUser)
admin.GET("/system/health", s.handleSystemHealth)
admin.GET("/system/metrics", s.handleSystemMetrics)
admin.GET("/system/settings", s.handleGetSettings)
admin.POST("/system/backup", s.handleCreateBackup)
admin.GET("/system/logs", s.handleGetLogs)
}
}
s.router.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
}
func (s *Server) handleResponses(c *gin.Context) {
startTime := time.Now()
var req models.ResponsesRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Select provider based on model name
providerName := "openai" // default for Responses API
modelLower := strings.ToLower(req.Model)
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
providerName = "gemini"
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
providerName = "deepseek"
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
providerName = "moonshot"
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
providerName = "grok"
} else if strings.HasPrefix(modelLower, "ollama/") ||
strings.Contains(modelLower, "glm-") ||
strings.Contains(modelLower, "qwen") ||
strings.Contains(modelLower, "gemma") ||
strings.Contains(modelLower, "llama") ||
strings.Contains(modelLower, "mistral") ||
strings.Contains(modelLower, "phi") ||
strings.Contains(modelLower, "yi") ||
strings.Contains(modelLower, "codellama") ||
strings.Contains(modelLower, "command-r") {
providerName = "ollama"
}
provider, ok := s.providers[providerName]
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
return
}
// Strip common prefixes from model name
modelID := req.Model
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"}
for _, p := range prefixes {
if strings.HasPrefix(modelID, p) {
modelID = strings.TrimPrefix(modelID, p)
break
}
}
// Use the stripped model name for the actual API call
req.Model = modelID
clientID := "default"
if auth, ok := c.Get("auth"); ok {
if authInfo, ok := auth.(models.AuthInfo); ok {
clientID = authInfo.ClientID
}
}
stream := req.Stream != nil && *req.Stream
if stream {
ch, err := provider.ResponsesStream(c.Request.Context(), &req)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var lastUsage *models.ResponsesUsage
c.Stream(func(w io.Writer) bool {
chunk, ok := <-ch
if !ok {
fmt.Fprintf(w, "data: [DONE]\n\n")
if lastUsage != nil {
s.logRequest(startTime, clientID, providerName, req.Model, "", lastUsage.ToUsage(), nil, false)
} else {
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, nil, false)
}
return false
}
// Capture usage from the response payload in streaming chunks
if chunk.Response != nil && chunk.Response.Usage != nil {
lastUsage = chunk.Response.Usage
}
data, err := json.Marshal(chunk)
if err != nil {
return false
}
fmt.Fprintf(w, "data: %s\n\n", data)
return true
})
return
}
resp, err := provider.Responses(c.Request.Context(), &req)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if resp.Usage != nil {
s.logRequest(startTime, clientID, providerName, req.Model, "", resp.Usage.ToUsage(), nil, false)
} else {
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, nil, false)
}
c.JSON(http.StatusOK, resp)
}
func (s *Server) handleListModels(c *gin.Context) {
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
modelMap := make(map[string]OpenAIModel)
allowedProviders := map[string]bool{
"openai": true,
"google": true, // Models from models.dev use 'google' ID for Gemini
"deepseek": true,
"moonshot": true,
"moonshotai": true, // Official moonshotai ID in models.dev
"moonshotai-cn": true, // Official moonshotai-cn ID in models.dev
"xai": true, // Models from models.dev use 'xai' ID for Grok
"llmgateway": true, // Catch-all for newer models
"ollama": true,
}
s.registryMu.RLock()
if s.registry != nil {
for pID, pInfo := range s.registry.Providers {
if !allowedProviders[pID] {
continue
}
for mID := range pInfo.Models {
if _, exists := modelMap[mID]; !exists {
modelMap[mID] = OpenAIModel{
ID: mID,
Object: "model",
Created: 1700000000,
OwnedBy: pID,
}
}
}
}
}
s.registryMu.RUnlock()
// Add configured Ollama models
if s.cfg.Providers.Ollama.Enabled {
for _, mID := range s.cfg.Providers.Ollama.Models {
if _, exists := modelMap[mID]; !exists {
modelMap[mID] = OpenAIModel{
ID: mID,
Object: "model",
Created: 1700000000,
OwnedBy: "ollama",
}
}
}
}
// Add model groups so clients can discover them
if s.modelRouter != nil {
for _, gid := range s.modelRouter.Groups() {
if _, exists := modelMap[gid]; !exists {
modelMap[gid] = OpenAIModel{
ID: gid,
Object: "model",
Created: 1700000000,
OwnedBy: "gophergate",
}
}
}
}
var data []OpenAIModel
for _, m := range modelMap {
data = append(data, m)
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": data,
})
}
func (s *Server) selectProvider(modelID string) (providers.Provider, string, error) {
providerName := "openai" // default
modelLower := strings.ToLower(modelID)
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
providerName = "gemini"
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
providerName = "deepseek"
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
providerName = "moonshot"
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
providerName = "grok"
} else if strings.HasPrefix(modelLower, "ollama/") ||
strings.Contains(modelLower, "glm-") ||
strings.Contains(modelLower, "qwen") ||
strings.Contains(modelLower, "gemma") ||
strings.Contains(modelLower, "llama") ||
strings.Contains(modelLower, "mistral") ||
strings.Contains(modelLower, "phi") ||
strings.Contains(modelLower, "yi") ||
strings.Contains(modelLower, "codellama") ||
strings.Contains(modelLower, "command-r") {
providerName = "ollama"
}
p, ok := s.providers[providerName]
if !ok {
return nil, "", fmt.Errorf("Provider %s not enabled or supported", providerName)
}
return p, providerName, nil
}
func (s *Server) handleChatCompletions(c *gin.Context) {
startTime := time.Now()
var req models.ChatCompletionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Select provider based on model name
providerName := "openai" // default
modelLower := strings.ToLower(req.Model)
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
providerName = "gemini"
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
// Only use deepseek provider if it's not explicitly tagged for ollama
providerName = "deepseek"
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
providerName = "moonshot"
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
providerName = "grok"
} else if strings.HasPrefix(modelLower, "ollama/") ||
strings.Contains(modelLower, "glm-") ||
strings.Contains(modelLower, "qwen") ||
strings.Contains(modelLower, "gemma") ||
strings.Contains(modelLower, "llama") ||
strings.Contains(modelLower, "mistral") ||
strings.Contains(modelLower, "phi") ||
strings.Contains(modelLower, "yi") ||
strings.Contains(modelLower, "codellama") ||
strings.Contains(modelLower, "command-r") {
providerName = "ollama"
}
provider, ok := s.providers[providerName]
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
return
}
// Strip common prefixes
modelID := req.Model
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"}
for _, p := range prefixes {
if strings.HasPrefix(modelID, p) {
modelID = strings.TrimPrefix(modelID, p)
break
}
}
// Resolve model groups to concrete models (hierarchical — groups can target groups)
modelGroup := ""
if s.modelRouter != nil {
userMessage := extractUserMessage(req.Messages)
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, userMessage)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
return
}
if decision.SelectedModel != modelID {
modelGroup = modelID
}
modelID = decision.SelectedModel
log.Printf("[ROUTER] %s (%s: %s)", modelID, decision.Strategy, decision.Reason)
}
// Convert ChatCompletionRequest to UnifiedRequest
unifiedReq := &models.UnifiedRequest{
Model: modelID,
Messages: []models.UnifiedMessage{},
Temperature: req.Temperature,
TopP: req.TopP,
TopK: req.TopK,
N: req.N,
MaxTokens: req.MaxTokens,
PresencePenalty: req.PresencePenalty,
FrequencyPenalty: req.FrequencyPenalty,
Stream: req.Stream != nil && *req.Stream,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
}
// Inject max_tokens from model registry when client doesn't specify one.
// Prevents providers from applying a low default output cap.
// DEBUG: Trace max_tokens through the proxy
clientMaxTokens := "nil"
if unifiedReq.MaxTokens != nil {
clientMaxTokens = fmt.Sprintf("%d", *unifiedReq.MaxTokens)
}
log.Printf("[DEBUG] %s: client max_tokens=%s", modelID, clientMaxTokens)
if unifiedReq.MaxTokens == nil {
s.registryMu.RLock()
meta := s.registry.FindModel(modelID)
s.registryMu.RUnlock()
if meta != nil && meta.Limit != nil && meta.Limit.Output > 0 {
unifiedReq.MaxTokens = &meta.Limit.Output
log.Printf("[DEBUG] %s: injected registry max_tokens=%d", modelID, meta.Limit.Output)
} else {
log.Printf("[DEBUG] %s: no registry limit found, leaving max_tokens nil (provider default)", modelID)
}
} else {
log.Printf("[DEBUG] %s: using client's max_tokens=%d", modelID, *unifiedReq.MaxTokens)
}
// Handle Stop sequences
if req.Stop != nil {
var stop []string
if err := json.Unmarshal(req.Stop, &stop); err == nil {
unifiedReq.Stop = stop
} else {
var singleStop string
if err := json.Unmarshal(req.Stop, &singleStop); err == nil {
unifiedReq.Stop = []string{singleStop}
}
}
}
// Convert messages
for _, msg := range req.Messages {
unifiedMsg := models.UnifiedMessage{
Role: msg.Role,
Content: []models.UnifiedContentPart{},
ReasoningContent: msg.ReasoningContent,
ToolCalls: msg.ToolCalls,
Name: msg.Name,
ToolCallID: msg.ToolCallID,
}
// Handle multimodal content
if strContent, ok := msg.Content.(string); ok {
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
Type: "text",
Text: strContent,
})
} else if parts, ok := msg.Content.([]interface{}); ok {
for _, part := range parts {
if partMap, ok := part.(map[string]interface{}); ok {
partType, _ := partMap["type"].(string)
if partType == "text" {
text, _ := partMap["text"].(string)
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
Type: "text",
Text: text,
})
} else if partType == "image_url" {
if imgURLMap, ok := partMap["image_url"].(map[string]interface{}); ok {
url, _ := imgURLMap["url"].(string)
imageInput := &models.ImageInput{}
if strings.HasPrefix(url, "data:") {
mime, data, err := utils.ParseDataURL(url)
if err == nil {
imageInput.Base64 = data
imageInput.MimeType = mime
}
} else {
imageInput.URL = url
}
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
Type: "image",
Image: imageInput,
})
unifiedReq.HasImages = true
}
}
}
}
}
unifiedReq.Messages = append(unifiedReq.Messages, unifiedMsg)
}
clientID := "default"
if auth, ok := c.Get("auth"); ok {
if authInfo, ok := auth.(models.AuthInfo); ok {
unifiedReq.ClientID = authInfo.ClientID
clientID = authInfo.ClientID
}
} else {
unifiedReq.ClientID = clientID
}
if unifiedReq.Stream {
ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq)
if err != nil {
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, unifiedReq.HasImages)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var lastUsage *models.Usage
c.Stream(func(w io.Writer) bool {
chunk, ok := <-ch
if !ok {
fmt.Fprintf(w, "data: [DONE]\n\n")
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage, nil, unifiedReq.HasImages)
return false
}
if chunk.Usage != nil {
lastUsage = chunk.Usage
}
data, err := json.Marshal(chunk)
if err != nil {
return false
}
fmt.Fprintf(w, "data: %s\n\n", data)
return true
})
return
}
resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq)
if err != nil {
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, unifiedReq.HasImages)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage, nil, unifiedReq.HasImages)
c.JSON(http.StatusOK, resp)
}
func extractUserMessage(messages []models.ChatMessage) string {
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
switch c := messages[i].Content.(type) {
case string:
return c
default:
return ""
}
}
}
return ""
}
func (s *Server) handleImageGenerations(c *gin.Context) {
startTime := time.Now()
var req models.ImageGenerationRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Determine provider based on model name
providerName := "openai"
modelLower := strings.ToLower(req.Model)
switch {
case strings.Contains(modelLower, "imagen"), strings.Contains(modelLower, "gemini"):
providerName = "gemini"
case strings.Contains(modelLower, "dall"), strings.HasPrefix(modelLower, "openai/"):
providerName = "openai"
}
// Default model for each provider if not specified
if req.Model == "" {
if providerName == "openai" {
req.Model = "dall-e-3"
} else {
req.Model = "imagen-3.0-generate-001"
}
}
// Strip common prefixes
prefixes := []string{"openai/", "gemini/", "google/"}
for _, p := range prefixes {
if strings.HasPrefix(req.Model, p) {
req.Model = strings.TrimPrefix(req.Model, p)
break
}
}
provider, ok := s.providers[providerName]
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
return
}
clientID := "default"
if auth, ok := c.Get("auth"); ok {
if authInfo, ok := auth.(models.AuthInfo); ok {
clientID = authInfo.ClientID
}
}
resp, err := provider.ImageGeneration(c.Request.Context(), &req)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Estimate tokens from prompt text (~4 chars per token)
promptTokens := uint32(len(req.Prompt) / 4)
if promptTokens < 1 {
promptTokens = 1
}
// Calculate per-image cost (not per-token like chat)
cost := imageGenCost(providerName, req.Model, req.Size, uint32(len(resp.Data)))
s.logRequest(startTime, clientID, providerName, req.Model, "", &models.Usage{
PromptTokens: promptTokens,
CompletionTokens: uint32(len(resp.Data)),
TotalTokens: promptTokens + uint32(len(resp.Data)),
}, nil, false)
// Update cost in DB — image gen is per-image, not per-token
if cost > 0 {
s.database.Exec("UPDATE llm_requests SET cost = ? WHERE id = (SELECT MAX(id) FROM llm_requests)", cost)
}
c.JSON(http.StatusOK, resp)
}
// imageGenCost returns per-image pricing for known image generation models.
func imageGenCost(provider, model string, size *string, n uint32) float64 {
if n == 0 {
return 0
}
modelLower := strings.ToLower(model)
var perImage float64
switch {
case strings.Contains(modelLower, "dall-e-3"):
perImage = 0.040 // standard 1024x1024
if size != nil {
s := *size
if s == "1024x1792" || s == "1792x1024" {
perImage = 0.080
}
}
case strings.Contains(modelLower, "dall-e-2"):
perImage = 0.020
case strings.Contains(modelLower, "imagen"):
perImage = 0.040 // approximate
default:
return 0
}
return perImage * float64(n)
}
func (s *Server) logRequest(start time.Time, clientID, provider, model, modelGroup string, usage *models.Usage, err error, hasImages bool) {
entry := RequestLog{
Timestamp: start,
ClientID: clientID,
Provider: provider,
Model: model,
ModelGroup: modelGroup,
Status: "success",
DurationMS: time.Since(start).Milliseconds(),
HasImages: hasImages,
}
if err != nil {
entry.Status = "error"
entry.ErrorMessage = err.Error()
}
if usage != nil {
entry.PromptTokens = usage.PromptTokens
entry.CompletionTokens = usage.CompletionTokens
entry.TotalTokens = usage.TotalTokens
if usage.ReasoningTokens != nil {
entry.ReasoningTokens = *usage.ReasoningTokens
}
if usage.CacheReadTokens != nil {
entry.CacheReadTokens = *usage.CacheReadTokens
}
if usage.CacheWriteTokens != nil {
entry.CacheWriteTokens = *usage.CacheWriteTokens
}
// Calculate cost using registry
s.registryMu.RLock()
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
s.registryMu.RUnlock()
}
s.logger.LogRequest(entry)
}
func (s *Server) Run() error {
go s.hub.Run()
s.logger.Start()
// Start registry refresher
go func() {
ticker := time.NewTicker(24 * time.Hour)
for range ticker.C {
newRegistry, err := utils.FetchRegistry()
if err == nil {
s.registryMu.Lock()
s.registry = newRegistry
s.registryMu.Unlock()
}
}
}()
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
return s.router.Run(addr)
}
func uint32Ptr(v uint32) *uint32 { return &v }