e5ef39f327
Add full Responses API endpoint alongside existing Chat Completions, with identical logging/tracking/cost pipeline. New: - internal/models/responses.go — request/response/stream types + ToUsage() bridge - internal/providers/openai_responses.go — OpenAI Responses/ResponsesStream Modified: - provider.go — Responses()+ResponsesStream() added to Provider interface - helpers.go — BuildOpenAIResponsesBody, parsers, SSE stream reader - circuit_breaker.go — CB wraps Responses, passthrough for stream - server.go — POST /v1/responses route + handleResponses handler - all non-OpenAI providers — stub methods with clear error messages Logging: ResponsesUsage.ToUsage() bridges to models.Usage, feeding same logRequest() -> DB insert -> dashboard WS -> client stats -> cost calc pipeline. No schema or logger changes needed.
802 lines
23 KiB
Go
802 lines
23 KiB
Go
package server
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"gophergate/internal/config"
|
|
"gophergate/internal/db"
|
|
"gophergate/internal/middleware"
|
|
"gophergate/internal/models"
|
|
"gophergate/internal/providers"
|
|
"gophergate/internal/utils"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
type Server struct {
|
|
router *gin.Engine
|
|
cfg *config.Config
|
|
database *db.DB
|
|
providers map[string]providers.Provider
|
|
sessions *SessionManager
|
|
hub *Hub
|
|
logger *RequestLogger
|
|
registry *models.ModelRegistry
|
|
registryMu sync.RWMutex
|
|
}
|
|
|
|
func NewServer(cfg *config.Config, database *db.DB) *Server {
|
|
router := gin.Default()
|
|
hub := NewHub()
|
|
|
|
s := &Server{
|
|
router: router,
|
|
cfg: cfg,
|
|
database: database,
|
|
providers: make(map[string]providers.Provider),
|
|
sessions: NewSessionManager(cfg.KeyBytes, 24*time.Hour),
|
|
hub: hub,
|
|
logger: NewRequestLogger(database, hub),
|
|
registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)},
|
|
}
|
|
|
|
s.sessions.StartCleanup()
|
|
// Fetch registry in background
|
|
go func() {
|
|
registry, err := utils.FetchRegistry()
|
|
if err != nil {
|
|
fmt.Printf("Warning: Failed to fetch initial model registry: %v\n", err)
|
|
} else {
|
|
s.registry = registry
|
|
}
|
|
}()
|
|
|
|
// Initialize providers from DB and Config
|
|
if err := s.RefreshProviders(); err != nil {
|
|
fmt.Printf("Warning: Failed to initial refresh providers: %v\n", err)
|
|
}
|
|
|
|
s.setupRoutes()
|
|
return s
|
|
}
|
|
|
|
func (s *Server) RefreshProviders() error {
|
|
var dbConfigs []db.ProviderConfig
|
|
err := s.database.Select(&dbConfigs, "SELECT * FROM provider_configs")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to fetch provider configs from db: %w", err)
|
|
}
|
|
|
|
dbMap := make(map[string]db.ProviderConfig)
|
|
for _, cfg := range dbConfigs {
|
|
dbMap[cfg.ID] = cfg
|
|
}
|
|
|
|
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama"}
|
|
for _, id := range providerIDs {
|
|
// Default values from config
|
|
enabled := false
|
|
baseURL := ""
|
|
apiKey := ""
|
|
|
|
switch id {
|
|
case "openai":
|
|
enabled = s.cfg.Providers.OpenAI.Enabled
|
|
baseURL = s.cfg.Providers.OpenAI.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("openai")
|
|
case "gemini":
|
|
enabled = s.cfg.Providers.Gemini.Enabled
|
|
baseURL = s.cfg.Providers.Gemini.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("gemini")
|
|
case "deepseek":
|
|
enabled = s.cfg.Providers.DeepSeek.Enabled
|
|
baseURL = s.cfg.Providers.DeepSeek.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("deepseek")
|
|
case "moonshot":
|
|
enabled = s.cfg.Providers.Moonshot.Enabled
|
|
baseURL = s.cfg.Providers.Moonshot.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("moonshot")
|
|
case "grok":
|
|
enabled = s.cfg.Providers.Grok.Enabled
|
|
baseURL = s.cfg.Providers.Grok.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("grok")
|
|
}
|
|
|
|
// Overrides from DB
|
|
if dbCfg, ok := dbMap[id]; ok {
|
|
enabled = dbCfg.Enabled
|
|
if dbCfg.BaseURL != nil && *dbCfg.BaseURL != "" {
|
|
baseURL = *dbCfg.BaseURL
|
|
}
|
|
if dbCfg.APIKey != nil && *dbCfg.APIKey != "" {
|
|
key := *dbCfg.APIKey
|
|
if dbCfg.APIKeyEncrypted {
|
|
decrypted, err := utils.Decrypt(key, s.cfg.KeyBytes)
|
|
if err == nil {
|
|
key = decrypted
|
|
} else {
|
|
fmt.Printf("Warning: Failed to decrypt API key for %s: %v\n", id, err)
|
|
}
|
|
}
|
|
apiKey = key
|
|
}
|
|
}
|
|
|
|
if !enabled {
|
|
delete(s.providers, id)
|
|
continue
|
|
}
|
|
|
|
// Initialize provider
|
|
var p providers.Provider
|
|
switch id {
|
|
case "openai":
|
|
cfg := s.cfg.Providers.OpenAI
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewOpenAIProvider(cfg, apiKey)
|
|
case "gemini":
|
|
cfg := s.cfg.Providers.Gemini
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewGeminiProvider(cfg, apiKey)
|
|
case "deepseek":
|
|
cfg := s.cfg.Providers.DeepSeek
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewDeepSeekProvider(cfg, apiKey)
|
|
case "moonshot":
|
|
cfg := s.cfg.Providers.Moonshot
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewMoonshotProvider(cfg, apiKey)
|
|
case "grok":
|
|
cfg := s.cfg.Providers.Grok
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewGrokProvider(cfg, apiKey)
|
|
case "ollama":
|
|
cfg := s.cfg.Providers.Ollama
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewOllamaProvider(cfg)
|
|
}
|
|
|
|
if p != nil {
|
|
s.providers[id] = providers.NewCircuitBreakerProvider(p)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) setupRoutes() {
|
|
// Static files
|
|
s.router.StaticFile("/", "./static/index.html")
|
|
s.router.StaticFile("/favicon.ico", "./static/favicon.ico")
|
|
s.router.Static("/css", "./static/css")
|
|
s.router.Static("/js", "./static/js")
|
|
s.router.Static("/img", "./static/img")
|
|
|
|
// WebSocket
|
|
s.router.GET("/ws", s.handleWebSocket)
|
|
|
|
// API V1 (External LLM Access) - Secured with AuthMiddleware
|
|
v1 := s.router.Group("/v1")
|
|
v1.Use(middleware.AuthMiddleware(s.database, true))
|
|
{
|
|
v1.POST("/chat/completions", s.handleChatCompletions)
|
|
v1.POST("/images/generations", s.handleImageGenerations)
|
|
v1.GET("/models", s.handleListModels)
|
|
v1.POST("/responses", s.handleResponses)
|
|
}
|
|
|
|
// Dashboard API Group
|
|
api := s.router.Group("/api")
|
|
{
|
|
api.POST("/auth/login", s.handleLogin)
|
|
api.GET("/auth/status", s.handleAuthStatus)
|
|
api.POST("/auth/logout", s.handleLogout)
|
|
api.POST("/auth/change-password", s.handleChangePassword)
|
|
|
|
// Protected dashboard routes (need admin session)
|
|
admin := api.Group("/")
|
|
admin.Use(s.adminAuthMiddleware())
|
|
{
|
|
admin.GET("/usage/summary", s.handleUsageSummary)
|
|
admin.GET("/usage/time-series", s.handleTimeSeries)
|
|
admin.GET("/usage/providers", s.handleProvidersUsage)
|
|
admin.GET("/usage/clients", s.handleClientsUsage)
|
|
admin.GET("/usage/detailed", s.handleDetailedUsage)
|
|
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
|
|
|
|
admin.GET("/clients", s.handleGetClients)
|
|
admin.POST("/clients", s.handleCreateClient)
|
|
admin.GET("/clients/:id", s.handleGetClient)
|
|
admin.PUT("/clients/:id", s.handleUpdateClient)
|
|
admin.DELETE("/clients/:id", s.handleDeleteClient)
|
|
|
|
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
|
|
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
|
|
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
|
|
|
|
admin.GET("/providers", s.handleGetProviders)
|
|
admin.PUT("/providers/:name", s.handleUpdateProvider)
|
|
admin.POST("/providers/:name/test", s.handleTestProvider)
|
|
|
|
admin.GET("/models", s.handleGetModels)
|
|
admin.PUT("/models/:id", s.handleUpdateModel)
|
|
|
|
admin.GET("/users", s.handleGetUsers)
|
|
admin.POST("/users", s.handleCreateUser)
|
|
admin.PUT("/users/:id", s.handleUpdateUser)
|
|
admin.DELETE("/users/:id", s.handleDeleteUser)
|
|
|
|
admin.GET("/system/health", s.handleSystemHealth)
|
|
admin.GET("/system/metrics", s.handleSystemMetrics)
|
|
admin.GET("/system/settings", s.handleGetSettings)
|
|
admin.POST("/system/backup", s.handleCreateBackup)
|
|
admin.GET("/system/logs", s.handleGetLogs)
|
|
}
|
|
}
|
|
|
|
s.router.GET("/health", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleResponses(c *gin.Context) {
|
|
startTime := time.Now()
|
|
var req models.ResponsesRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Select provider based on model name
|
|
providerName := "openai" // default for Responses API
|
|
modelLower := strings.ToLower(req.Model)
|
|
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
|
|
providerName = "gemini"
|
|
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
|
|
providerName = "deepseek"
|
|
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
|
|
providerName = "moonshot"
|
|
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
|
|
providerName = "grok"
|
|
} else if strings.HasPrefix(modelLower, "ollama/") ||
|
|
strings.Contains(modelLower, "glm-") ||
|
|
strings.Contains(modelLower, "qwen") ||
|
|
strings.Contains(modelLower, "gemma") ||
|
|
strings.Contains(modelLower, "llama") ||
|
|
strings.Contains(modelLower, "mistral") ||
|
|
strings.Contains(modelLower, "phi") ||
|
|
strings.Contains(modelLower, "yi") ||
|
|
strings.Contains(modelLower, "codellama") ||
|
|
strings.Contains(modelLower, "command-r") {
|
|
providerName = "ollama"
|
|
}
|
|
|
|
provider, ok := s.providers[providerName]
|
|
if !ok {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
|
return
|
|
}
|
|
|
|
// Strip common prefixes from model name
|
|
modelID := req.Model
|
|
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"}
|
|
for _, p := range prefixes {
|
|
if strings.HasPrefix(modelID, p) {
|
|
modelID = strings.TrimPrefix(modelID, p)
|
|
break
|
|
}
|
|
}
|
|
|
|
// Use the stripped model name for the actual API call
|
|
req.Model = modelID
|
|
|
|
clientID := "default"
|
|
if auth, ok := c.Get("auth"); ok {
|
|
if authInfo, ok := auth.(models.AuthInfo); ok {
|
|
clientID = authInfo.ClientID
|
|
}
|
|
}
|
|
|
|
stream := req.Stream != nil && *req.Stream
|
|
|
|
if stream {
|
|
ch, err := provider.ResponsesStream(c.Request.Context(), &req)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, false)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
|
|
var lastUsage *models.ResponsesUsage
|
|
c.Stream(func(w io.Writer) bool {
|
|
chunk, ok := <-ch
|
|
if !ok {
|
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
if lastUsage != nil {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, lastUsage.ToUsage(), nil, false)
|
|
} else {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, nil, nil, false)
|
|
}
|
|
return false
|
|
}
|
|
// Capture usage from the response payload in streaming chunks
|
|
if chunk.Response != nil && chunk.Response.Usage != nil {
|
|
lastUsage = chunk.Response.Usage
|
|
}
|
|
data, err := json.Marshal(chunk)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
return true
|
|
})
|
|
return
|
|
}
|
|
|
|
resp, err := provider.Responses(c.Request.Context(), &req)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, false)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if resp.Usage != nil {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage.ToUsage(), nil, false)
|
|
} else {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, nil, nil, false)
|
|
}
|
|
c.JSON(http.StatusOK, resp)
|
|
}
|
|
|
|
func (s *Server) handleListModels(c *gin.Context) {
|
|
type OpenAIModel struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
OwnedBy string `json:"owned_by"`
|
|
}
|
|
|
|
modelMap := make(map[string]OpenAIModel)
|
|
allowedProviders := map[string]bool{
|
|
"openai": true,
|
|
"google": true, // Models from models.dev use 'google' ID for Gemini
|
|
"deepseek": true,
|
|
"moonshot": true,
|
|
"moonshotai": true, // Official moonshotai ID in models.dev
|
|
"moonshotai-cn": true, // Official moonshotai-cn ID in models.dev
|
|
"xai": true, // Models from models.dev use 'xai' ID for Grok
|
|
"llmgateway": true, // Catch-all for newer models
|
|
"ollama": true,
|
|
}
|
|
|
|
s.registryMu.RLock()
|
|
if s.registry != nil {
|
|
for pID, pInfo := range s.registry.Providers {
|
|
if !allowedProviders[pID] {
|
|
continue
|
|
}
|
|
for mID := range pInfo.Models {
|
|
if _, exists := modelMap[mID]; !exists {
|
|
modelMap[mID] = OpenAIModel{
|
|
ID: mID,
|
|
Object: "model",
|
|
Created: 1700000000,
|
|
OwnedBy: pID,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
s.registryMu.RUnlock()
|
|
|
|
// Add configured Ollama models
|
|
if s.cfg.Providers.Ollama.Enabled {
|
|
for _, mID := range s.cfg.Providers.Ollama.Models {
|
|
if _, exists := modelMap[mID]; !exists {
|
|
modelMap[mID] = OpenAIModel{
|
|
ID: mID,
|
|
Object: "model",
|
|
Created: 1700000000,
|
|
OwnedBy: "ollama",
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
var data []OpenAIModel
|
|
for _, m := range modelMap {
|
|
data = append(data, m)
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"object": "list",
|
|
"data": data,
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleChatCompletions(c *gin.Context) {
|
|
startTime := time.Now()
|
|
var req models.ChatCompletionRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Select provider based on model name
|
|
providerName := "openai" // default
|
|
modelLower := strings.ToLower(req.Model)
|
|
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
|
|
providerName = "gemini"
|
|
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
|
|
// Only use deepseek provider if it's not explicitly tagged for ollama
|
|
providerName = "deepseek"
|
|
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
|
|
providerName = "moonshot"
|
|
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
|
|
providerName = "grok"
|
|
} else if strings.HasPrefix(modelLower, "ollama/") ||
|
|
strings.Contains(modelLower, "glm-") ||
|
|
strings.Contains(modelLower, "qwen") ||
|
|
strings.Contains(modelLower, "gemma") ||
|
|
strings.Contains(modelLower, "llama") ||
|
|
strings.Contains(modelLower, "mistral") ||
|
|
strings.Contains(modelLower, "phi") ||
|
|
strings.Contains(modelLower, "yi") ||
|
|
strings.Contains(modelLower, "codellama") ||
|
|
strings.Contains(modelLower, "command-r") {
|
|
providerName = "ollama"
|
|
}
|
|
|
|
provider, ok := s.providers[providerName]
|
|
if !ok {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
|
return
|
|
}
|
|
|
|
// Strip common prefixes
|
|
modelID := req.Model
|
|
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"}
|
|
for _, p := range prefixes {
|
|
if strings.HasPrefix(modelID, p) {
|
|
modelID = strings.TrimPrefix(modelID, p)
|
|
break
|
|
}
|
|
}
|
|
|
|
// Convert ChatCompletionRequest to UnifiedRequest
|
|
unifiedReq := &models.UnifiedRequest{
|
|
Model: modelID,
|
|
Messages: []models.UnifiedMessage{},
|
|
Temperature: req.Temperature,
|
|
TopP: req.TopP,
|
|
TopK: req.TopK,
|
|
N: req.N,
|
|
MaxTokens: req.MaxTokens,
|
|
PresencePenalty: req.PresencePenalty,
|
|
FrequencyPenalty: req.FrequencyPenalty,
|
|
Stream: req.Stream != nil && *req.Stream,
|
|
Tools: req.Tools,
|
|
ToolChoice: req.ToolChoice,
|
|
}
|
|
|
|
// Inject max_tokens from model registry when client doesn't specify one.
|
|
// Prevents providers from applying a low default output cap.
|
|
// DEBUG: Trace max_tokens through the proxy
|
|
clientMaxTokens := "nil"
|
|
if unifiedReq.MaxTokens != nil {
|
|
clientMaxTokens = fmt.Sprintf("%d", *unifiedReq.MaxTokens)
|
|
}
|
|
log.Printf("[DEBUG] %s: client max_tokens=%s", modelID, clientMaxTokens)
|
|
if unifiedReq.MaxTokens == nil {
|
|
s.registryMu.RLock()
|
|
meta := s.registry.FindModel(modelID)
|
|
s.registryMu.RUnlock()
|
|
if meta != nil && meta.Limit != nil && meta.Limit.Output > 0 {
|
|
unifiedReq.MaxTokens = &meta.Limit.Output
|
|
log.Printf("[DEBUG] %s: injected registry max_tokens=%d", modelID, meta.Limit.Output)
|
|
} else {
|
|
log.Printf("[DEBUG] %s: no registry limit found, leaving max_tokens nil (provider default)", modelID)
|
|
}
|
|
} else {
|
|
log.Printf("[DEBUG] %s: using client's max_tokens=%d", modelID, *unifiedReq.MaxTokens)
|
|
}
|
|
|
|
// Handle Stop sequences
|
|
if req.Stop != nil {
|
|
var stop []string
|
|
if err := json.Unmarshal(req.Stop, &stop); err == nil {
|
|
unifiedReq.Stop = stop
|
|
} else {
|
|
var singleStop string
|
|
if err := json.Unmarshal(req.Stop, &singleStop); err == nil {
|
|
unifiedReq.Stop = []string{singleStop}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert messages
|
|
for _, msg := range req.Messages {
|
|
unifiedMsg := models.UnifiedMessage{
|
|
Role: msg.Role,
|
|
Content: []models.UnifiedContentPart{},
|
|
ReasoningContent: msg.ReasoningContent,
|
|
ToolCalls: msg.ToolCalls,
|
|
Name: msg.Name,
|
|
ToolCallID: msg.ToolCallID,
|
|
}
|
|
|
|
// Handle multimodal content
|
|
if strContent, ok := msg.Content.(string); ok {
|
|
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
|
Type: "text",
|
|
Text: strContent,
|
|
})
|
|
} else if parts, ok := msg.Content.([]interface{}); ok {
|
|
for _, part := range parts {
|
|
if partMap, ok := part.(map[string]interface{}); ok {
|
|
partType, _ := partMap["type"].(string)
|
|
if partType == "text" {
|
|
text, _ := partMap["text"].(string)
|
|
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
|
Type: "text",
|
|
Text: text,
|
|
})
|
|
} else if partType == "image_url" {
|
|
if imgURLMap, ok := partMap["image_url"].(map[string]interface{}); ok {
|
|
url, _ := imgURLMap["url"].(string)
|
|
imageInput := &models.ImageInput{}
|
|
if strings.HasPrefix(url, "data:") {
|
|
mime, data, err := utils.ParseDataURL(url)
|
|
if err == nil {
|
|
imageInput.Base64 = data
|
|
imageInput.MimeType = mime
|
|
}
|
|
} else {
|
|
imageInput.URL = url
|
|
}
|
|
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
|
Type: "image",
|
|
Image: imageInput,
|
|
})
|
|
unifiedReq.HasImages = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
unifiedReq.Messages = append(unifiedReq.Messages, unifiedMsg)
|
|
}
|
|
|
|
clientID := "default"
|
|
if auth, ok := c.Get("auth"); ok {
|
|
if authInfo, ok := auth.(models.AuthInfo); ok {
|
|
unifiedReq.ClientID = authInfo.ClientID
|
|
clientID = authInfo.ClientID
|
|
}
|
|
} else {
|
|
unifiedReq.ClientID = clientID
|
|
}
|
|
|
|
if unifiedReq.Stream {
|
|
ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
|
|
var lastUsage *models.Usage
|
|
c.Stream(func(w io.Writer) bool {
|
|
chunk, ok := <-ch
|
|
if !ok {
|
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
s.logRequest(startTime, clientID, providerName, req.Model, lastUsage, nil, unifiedReq.HasImages)
|
|
return false
|
|
}
|
|
if chunk.Usage != nil {
|
|
lastUsage = chunk.Usage
|
|
}
|
|
data, err := json.Marshal(chunk)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
return true
|
|
})
|
|
return
|
|
}
|
|
|
|
resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage, nil, unifiedReq.HasImages)
|
|
c.JSON(http.StatusOK, resp)
|
|
}
|
|
|
|
func (s *Server) handleImageGenerations(c *gin.Context) {
|
|
startTime := time.Now()
|
|
var req models.ImageGenerationRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Determine provider based on model name
|
|
providerName := "openai"
|
|
modelLower := strings.ToLower(req.Model)
|
|
switch {
|
|
case strings.Contains(modelLower, "imagen"), strings.Contains(modelLower, "gemini"):
|
|
providerName = "gemini"
|
|
case strings.Contains(modelLower, "dall"), strings.HasPrefix(modelLower, "openai/"):
|
|
providerName = "openai"
|
|
}
|
|
|
|
// Default model for each provider if not specified
|
|
if req.Model == "" {
|
|
if providerName == "openai" {
|
|
req.Model = "dall-e-3"
|
|
} else {
|
|
req.Model = "imagen-3.0-generate-001"
|
|
}
|
|
}
|
|
|
|
// Strip common prefixes
|
|
prefixes := []string{"openai/", "gemini/", "google/"}
|
|
for _, p := range prefixes {
|
|
if strings.HasPrefix(req.Model, p) {
|
|
req.Model = strings.TrimPrefix(req.Model, p)
|
|
break
|
|
}
|
|
}
|
|
|
|
provider, ok := s.providers[providerName]
|
|
if !ok {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
|
return
|
|
}
|
|
|
|
clientID := "default"
|
|
if auth, ok := c.Get("auth"); ok {
|
|
if authInfo, ok := auth.(models.AuthInfo); ok {
|
|
clientID = authInfo.ClientID
|
|
}
|
|
}
|
|
|
|
resp, err := provider.ImageGeneration(c.Request.Context(), &req)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, false)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Estimate tokens from prompt text (~4 chars per token)
|
|
promptTokens := uint32(len(req.Prompt) / 4)
|
|
if promptTokens < 1 {
|
|
promptTokens = 1
|
|
}
|
|
|
|
// Calculate per-image cost (not per-token like chat)
|
|
cost := imageGenCost(providerName, req.Model, req.Size, uint32(len(resp.Data)))
|
|
|
|
s.logRequest(startTime, clientID, providerName, req.Model, &models.Usage{
|
|
PromptTokens: promptTokens,
|
|
CompletionTokens: uint32(len(resp.Data)),
|
|
TotalTokens: promptTokens + uint32(len(resp.Data)),
|
|
}, nil, false)
|
|
|
|
// Update cost in DB — image gen is per-image, not per-token
|
|
if cost > 0 {
|
|
s.database.Exec("UPDATE llm_requests SET cost = ? WHERE id = (SELECT MAX(id) FROM llm_requests)", cost)
|
|
}
|
|
|
|
c.JSON(http.StatusOK, resp)
|
|
}
|
|
|
|
// imageGenCost returns per-image pricing for known image generation models.
|
|
func imageGenCost(provider, model string, size *string, n uint32) float64 {
|
|
if n == 0 {
|
|
return 0
|
|
}
|
|
modelLower := strings.ToLower(model)
|
|
var perImage float64
|
|
|
|
switch {
|
|
case strings.Contains(modelLower, "dall-e-3"):
|
|
perImage = 0.040 // standard 1024x1024
|
|
if size != nil {
|
|
s := *size
|
|
if s == "1024x1792" || s == "1792x1024" {
|
|
perImage = 0.080
|
|
}
|
|
}
|
|
case strings.Contains(modelLower, "dall-e-2"):
|
|
perImage = 0.020
|
|
case strings.Contains(modelLower, "imagen"):
|
|
perImage = 0.040 // approximate
|
|
default:
|
|
return 0
|
|
}
|
|
|
|
return perImage * float64(n)
|
|
}
|
|
|
|
func (s *Server) logRequest(start time.Time, clientID, provider, model string, usage *models.Usage, err error, hasImages bool) {
|
|
entry := RequestLog{
|
|
Timestamp: start,
|
|
ClientID: clientID,
|
|
Provider: provider,
|
|
Model: model,
|
|
Status: "success",
|
|
DurationMS: time.Since(start).Milliseconds(),
|
|
HasImages: hasImages,
|
|
}
|
|
|
|
if err != nil {
|
|
entry.Status = "error"
|
|
entry.ErrorMessage = err.Error()
|
|
}
|
|
|
|
if usage != nil {
|
|
entry.PromptTokens = usage.PromptTokens
|
|
entry.CompletionTokens = usage.CompletionTokens
|
|
entry.TotalTokens = usage.TotalTokens
|
|
if usage.ReasoningTokens != nil {
|
|
entry.ReasoningTokens = *usage.ReasoningTokens
|
|
}
|
|
if usage.CacheReadTokens != nil {
|
|
entry.CacheReadTokens = *usage.CacheReadTokens
|
|
}
|
|
if usage.CacheWriteTokens != nil {
|
|
entry.CacheWriteTokens = *usage.CacheWriteTokens
|
|
}
|
|
|
|
// Calculate cost using registry
|
|
s.registryMu.RLock()
|
|
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
|
|
s.registryMu.RUnlock()
|
|
}
|
|
|
|
s.logger.LogRequest(entry)
|
|
}
|
|
|
|
func (s *Server) Run() error {
|
|
go s.hub.Run()
|
|
s.logger.Start()
|
|
|
|
// Start registry refresher
|
|
go func() {
|
|
ticker := time.NewTicker(24 * time.Hour)
|
|
for range ticker.C {
|
|
newRegistry, err := utils.FetchRegistry()
|
|
if err == nil {
|
|
s.registryMu.Lock()
|
|
s.registry = newRegistry
|
|
s.registryMu.Unlock()
|
|
}
|
|
}
|
|
}()
|
|
|
|
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
|
|
return s.router.Run(addr)
|
|
}
|