aeffeb8c03
The 40-character truncation of tool call IDs in helper.go caused collisions when models (like deepseek-v4-flash) generated longer IDs, leading to "Duplicate value for 'tool_call_id'" errors. Removed the limit to allow full unique IDs. DeepSeek: updated reasoning_content injection to use an empty string instead of a space, better matching provider expectations for history. Improved API error reporting across all providers by capturing raw body content when response parsing fails or returns empty strings.
915 lines
27 KiB
Go
915 lines
27 KiB
Go
package server
|
|
|
|
import (
|
|
"encoding/json"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"gophergate/internal/config"
|
|
"gophergate/internal/db"
|
|
"gophergate/internal/middleware"
|
|
"gophergate/internal/models"
|
|
"gophergate/internal/providers"
|
|
"gophergate/internal/router"
|
|
"gophergate/internal/utils"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
type Server struct {
|
|
router *gin.Engine
|
|
cfg *config.Config
|
|
database *db.DB
|
|
providers map[string]providers.Provider
|
|
sessions *SessionManager
|
|
hub *Hub
|
|
logger *RequestLogger
|
|
registry *models.ModelRegistry
|
|
registryMu sync.RWMutex
|
|
modelRouter *router.Router
|
|
}
|
|
|
|
func NewServer(cfg *config.Config, database *db.DB) *Server {
|
|
router := gin.Default()
|
|
hub := NewHub()
|
|
|
|
s := &Server{
|
|
router: router,
|
|
cfg: cfg,
|
|
database: database,
|
|
providers: make(map[string]providers.Provider),
|
|
sessions: NewSessionManager(cfg.KeyBytes, 24*time.Hour),
|
|
hub: hub,
|
|
logger: NewRequestLogger(database, hub),
|
|
registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)},
|
|
}
|
|
|
|
s.sessions.StartCleanup()
|
|
// Fetch registry in background
|
|
go func() {
|
|
registry, err := utils.FetchRegistry()
|
|
if err != nil {
|
|
fmt.Printf("Warning: Failed to fetch initial model registry: %v\n", err)
|
|
} else {
|
|
s.registry = registry
|
|
}
|
|
}()
|
|
|
|
// Initialize providers from DB and Config
|
|
if err := s.RefreshProviders(); err != nil {
|
|
fmt.Printf("Warning: Failed to initial refresh providers: %v\n", err)
|
|
}
|
|
|
|
s.setupRoutes()
|
|
|
|
// Initialize model group router
|
|
s.refreshRouter()
|
|
return s
|
|
}
|
|
|
|
func (s *Server) RefreshProviders() error {
|
|
var dbConfigs []db.ProviderConfig
|
|
err := s.database.Select(&dbConfigs, "SELECT * FROM provider_configs")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to fetch provider configs from db: %w", err)
|
|
}
|
|
|
|
dbMap := make(map[string]db.ProviderConfig)
|
|
for _, cfg := range dbConfigs {
|
|
dbMap[cfg.ID] = cfg
|
|
}
|
|
|
|
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama"}
|
|
for _, id := range providerIDs {
|
|
// Default values from config
|
|
enabled := false
|
|
baseURL := ""
|
|
apiKey := ""
|
|
|
|
switch id {
|
|
case "openai":
|
|
enabled = s.cfg.Providers.OpenAI.Enabled
|
|
baseURL = s.cfg.Providers.OpenAI.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("openai")
|
|
case "gemini":
|
|
enabled = s.cfg.Providers.Gemini.Enabled
|
|
baseURL = s.cfg.Providers.Gemini.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("gemini")
|
|
case "deepseek":
|
|
enabled = s.cfg.Providers.DeepSeek.Enabled
|
|
baseURL = s.cfg.Providers.DeepSeek.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("deepseek")
|
|
case "moonshot":
|
|
enabled = s.cfg.Providers.Moonshot.Enabled
|
|
baseURL = s.cfg.Providers.Moonshot.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("moonshot")
|
|
case "grok":
|
|
enabled = s.cfg.Providers.Grok.Enabled
|
|
baseURL = s.cfg.Providers.Grok.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("grok")
|
|
}
|
|
|
|
// Overrides from DB
|
|
if dbCfg, ok := dbMap[id]; ok {
|
|
enabled = dbCfg.Enabled
|
|
if dbCfg.BaseURL != nil && *dbCfg.BaseURL != "" {
|
|
baseURL = *dbCfg.BaseURL
|
|
}
|
|
if dbCfg.APIKey != nil && *dbCfg.APIKey != "" {
|
|
key := *dbCfg.APIKey
|
|
if dbCfg.APIKeyEncrypted {
|
|
decrypted, err := utils.Decrypt(key, s.cfg.KeyBytes)
|
|
if err == nil {
|
|
key = decrypted
|
|
} else {
|
|
fmt.Printf("Warning: Failed to decrypt API key for %s: %v\n", id, err)
|
|
}
|
|
}
|
|
apiKey = key
|
|
}
|
|
}
|
|
|
|
if !enabled {
|
|
delete(s.providers, id)
|
|
continue
|
|
}
|
|
|
|
// Initialize provider
|
|
var p providers.Provider
|
|
switch id {
|
|
case "openai":
|
|
cfg := s.cfg.Providers.OpenAI
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewOpenAIProvider(cfg, apiKey)
|
|
case "gemini":
|
|
cfg := s.cfg.Providers.Gemini
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewGeminiProvider(cfg, apiKey)
|
|
case "deepseek":
|
|
cfg := s.cfg.Providers.DeepSeek
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewDeepSeekProvider(cfg, apiKey)
|
|
case "moonshot":
|
|
cfg := s.cfg.Providers.Moonshot
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewMoonshotProvider(cfg, apiKey)
|
|
case "grok":
|
|
cfg := s.cfg.Providers.Grok
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewGrokProvider(cfg, apiKey)
|
|
case "ollama":
|
|
cfg := s.cfg.Providers.Ollama
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewOllamaProvider(cfg)
|
|
}
|
|
|
|
if p != nil {
|
|
s.providers[id] = providers.NewCircuitBreakerProvider(p)
|
|
}
|
|
}
|
|
|
|
s.refreshRouter()
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) refreshRouter() {
|
|
var groups []db.ModelGroup
|
|
if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil {
|
|
fmt.Printf("Warning: Failed to load model groups: %v\n", err)
|
|
groups = nil
|
|
}
|
|
|
|
var classifyFn router.ClassifierFunc
|
|
classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) {
|
|
provider, _, err := s.selectProvider(selectorModel)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
req := &models.UnifiedRequest{
|
|
Model: selectorModel,
|
|
Messages: []models.UnifiedMessage{
|
|
{Role: "system", Content: []models.UnifiedContentPart{{Type: "text", Text: systemPrompt}}},
|
|
{Role: "user", Content: []models.UnifiedContentPart{{Type: "text", Text: userMessage}}},
|
|
},
|
|
MaxTokens: uint32Ptr(5),
|
|
Stream: false,
|
|
}
|
|
resp, err := provider.ChatCompletion(ctx, req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if len(resp.Choices) == 0 {
|
|
return "", fmt.Errorf("no choices in classifier response")
|
|
}
|
|
content, ok := resp.Choices[0].Message.Content.(string)
|
|
if !ok {
|
|
return "", fmt.Errorf("classifier response content is not a string")
|
|
}
|
|
return content, nil
|
|
}
|
|
|
|
if s.modelRouter == nil {
|
|
s.modelRouter = router.New(groups, classifyFn)
|
|
} else {
|
|
s.modelRouter.Reload(groups)
|
|
}
|
|
}
|
|
|
|
func (s *Server) setupRoutes() {
|
|
// Static files
|
|
s.router.StaticFile("/", "./static/index.html")
|
|
s.router.StaticFile("/favicon.ico", "./static/favicon.ico")
|
|
s.router.Static("/css", "./static/css")
|
|
s.router.Static("/js", "./static/js")
|
|
s.router.Static("/img", "./static/img")
|
|
|
|
// WebSocket
|
|
s.router.GET("/ws", s.handleWebSocket)
|
|
|
|
// API V1 (External LLM Access) - Secured with AuthMiddleware
|
|
v1 := s.router.Group("/v1")
|
|
v1.Use(middleware.AuthMiddleware(s.database, true))
|
|
{
|
|
v1.POST("/chat/completions", s.handleChatCompletions)
|
|
v1.POST("/images/generations", s.handleImageGenerations)
|
|
v1.GET("/models", s.handleListModels)
|
|
v1.POST("/responses", s.handleResponses)
|
|
}
|
|
|
|
// Dashboard API Group
|
|
api := s.router.Group("/api")
|
|
{
|
|
api.POST("/auth/login", s.handleLogin)
|
|
api.GET("/auth/status", s.handleAuthStatus)
|
|
api.POST("/auth/logout", s.handleLogout)
|
|
api.POST("/auth/change-password", s.handleChangePassword)
|
|
|
|
// Protected dashboard routes (need admin session)
|
|
admin := api.Group("/")
|
|
admin.Use(s.adminAuthMiddleware())
|
|
{
|
|
admin.GET("/usage/summary", s.handleUsageSummary)
|
|
admin.GET("/usage/time-series", s.handleTimeSeries)
|
|
admin.GET("/usage/providers", s.handleProvidersUsage)
|
|
admin.GET("/usage/clients", s.handleClientsUsage)
|
|
admin.GET("/usage/detailed", s.handleDetailedUsage)
|
|
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
|
|
|
|
admin.GET("/clients", s.handleGetClients)
|
|
admin.POST("/clients", s.handleCreateClient)
|
|
admin.GET("/clients/:id", s.handleGetClient)
|
|
admin.PUT("/clients/:id", s.handleUpdateClient)
|
|
admin.DELETE("/clients/:id", s.handleDeleteClient)
|
|
|
|
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
|
|
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
|
|
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
|
|
|
|
admin.GET("/providers", s.handleGetProviders)
|
|
admin.PUT("/providers/:name", s.handleUpdateProvider)
|
|
admin.POST("/providers/:name/test", s.handleTestProvider)
|
|
|
|
admin.GET("/models", s.handleGetModels)
|
|
admin.PUT("/models/:id", s.handleUpdateModel)
|
|
|
|
admin.GET("/model-groups", s.handleGetModelGroups)
|
|
admin.POST("/model-groups", s.handleCreateModelGroup)
|
|
admin.PUT("/model-groups/:id", s.handleUpdateModelGroup)
|
|
admin.DELETE("/model-groups/:id", s.handleDeleteModelGroup)
|
|
|
|
admin.GET("/users", s.handleGetUsers)
|
|
admin.POST("/users", s.handleCreateUser)
|
|
admin.PUT("/users/:id", s.handleUpdateUser)
|
|
admin.DELETE("/users/:id", s.handleDeleteUser)
|
|
|
|
admin.GET("/system/health", s.handleSystemHealth)
|
|
admin.GET("/system/metrics", s.handleSystemMetrics)
|
|
admin.GET("/system/settings", s.handleGetSettings)
|
|
admin.POST("/system/backup", s.handleCreateBackup)
|
|
admin.GET("/system/logs", s.handleGetLogs)
|
|
}
|
|
}
|
|
|
|
s.router.GET("/health", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleResponses(c *gin.Context) {
|
|
startTime := time.Now()
|
|
var req models.ResponsesRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Select provider based on model name
|
|
providerName := "openai" // default for Responses API
|
|
modelLower := strings.ToLower(req.Model)
|
|
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
|
|
providerName = "gemini"
|
|
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
|
|
providerName = "deepseek"
|
|
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
|
|
providerName = "moonshot"
|
|
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
|
|
providerName = "grok"
|
|
} else if strings.HasPrefix(modelLower, "ollama/") ||
|
|
strings.Contains(modelLower, "glm-") ||
|
|
strings.Contains(modelLower, "qwen") ||
|
|
strings.Contains(modelLower, "gemma") ||
|
|
strings.Contains(modelLower, "llama") ||
|
|
strings.Contains(modelLower, "mistral") ||
|
|
strings.Contains(modelLower, "phi") ||
|
|
strings.Contains(modelLower, "yi") ||
|
|
strings.Contains(modelLower, "codellama") ||
|
|
strings.Contains(modelLower, "command-r") {
|
|
providerName = "ollama"
|
|
}
|
|
|
|
provider, ok := s.providers[providerName]
|
|
if !ok {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
|
return
|
|
}
|
|
|
|
// Strip common prefixes from model name
|
|
modelID := req.Model
|
|
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"}
|
|
for _, p := range prefixes {
|
|
if strings.HasPrefix(modelID, p) {
|
|
modelID = strings.TrimPrefix(modelID, p)
|
|
break
|
|
}
|
|
}
|
|
|
|
// Use the stripped model name for the actual API call
|
|
req.Model = modelID
|
|
|
|
clientID := "default"
|
|
if auth, ok := c.Get("auth"); ok {
|
|
if authInfo, ok := auth.(models.AuthInfo); ok {
|
|
clientID = authInfo.ClientID
|
|
}
|
|
}
|
|
|
|
stream := req.Stream != nil && *req.Stream
|
|
|
|
if stream {
|
|
ch, err := provider.ResponsesStream(c.Request.Context(), &req)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
|
|
var lastUsage *models.ResponsesUsage
|
|
c.Stream(func(w io.Writer) bool {
|
|
chunk, ok := <-ch
|
|
if !ok {
|
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
if lastUsage != nil {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, "", lastUsage.ToUsage(), nil, false)
|
|
} else {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, nil, false)
|
|
}
|
|
return false
|
|
}
|
|
// Capture usage from the response payload in streaming chunks
|
|
if chunk.Response != nil && chunk.Response.Usage != nil {
|
|
lastUsage = chunk.Response.Usage
|
|
}
|
|
data, err := json.Marshal(chunk)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
return true
|
|
})
|
|
return
|
|
}
|
|
|
|
resp, err := provider.Responses(c.Request.Context(), &req)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if resp.Usage != nil {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, "", resp.Usage.ToUsage(), nil, false)
|
|
} else {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, nil, false)
|
|
}
|
|
c.JSON(http.StatusOK, resp)
|
|
}
|
|
|
|
func (s *Server) handleListModels(c *gin.Context) {
|
|
type OpenAIModel struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
OwnedBy string `json:"owned_by"`
|
|
}
|
|
|
|
modelMap := make(map[string]OpenAIModel)
|
|
allowedProviders := map[string]bool{
|
|
"openai": true,
|
|
"google": true, // Models from models.dev use 'google' ID for Gemini
|
|
"deepseek": true,
|
|
"moonshot": true,
|
|
"moonshotai": true, // Official moonshotai ID in models.dev
|
|
"moonshotai-cn": true, // Official moonshotai-cn ID in models.dev
|
|
"xai": true, // Models from models.dev use 'xai' ID for Grok
|
|
"llmgateway": true, // Catch-all for newer models
|
|
"ollama": true,
|
|
}
|
|
|
|
s.registryMu.RLock()
|
|
if s.registry != nil {
|
|
for pID, pInfo := range s.registry.Providers {
|
|
if !allowedProviders[pID] {
|
|
continue
|
|
}
|
|
for mID := range pInfo.Models {
|
|
if _, exists := modelMap[mID]; !exists {
|
|
modelMap[mID] = OpenAIModel{
|
|
ID: mID,
|
|
Object: "model",
|
|
Created: 1700000000,
|
|
OwnedBy: pID,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
s.registryMu.RUnlock()
|
|
|
|
// Add configured Ollama models
|
|
if s.cfg.Providers.Ollama.Enabled {
|
|
for _, mID := range s.cfg.Providers.Ollama.Models {
|
|
if _, exists := modelMap[mID]; !exists {
|
|
modelMap[mID] = OpenAIModel{
|
|
ID: mID,
|
|
Object: "model",
|
|
Created: 1700000000,
|
|
OwnedBy: "ollama",
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add model groups so clients can discover them
|
|
if s.modelRouter != nil {
|
|
for _, gid := range s.modelRouter.Groups() {
|
|
if _, exists := modelMap[gid]; !exists {
|
|
modelMap[gid] = OpenAIModel{
|
|
ID: gid,
|
|
Object: "model",
|
|
Created: 1700000000,
|
|
OwnedBy: "gophergate",
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
var data []OpenAIModel
|
|
for _, m := range modelMap {
|
|
data = append(data, m)
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"object": "list",
|
|
"data": data,
|
|
})
|
|
}
|
|
|
|
func (s *Server) selectProvider(modelID string) (providers.Provider, string, error) {
|
|
providerName := "openai" // default
|
|
modelLower := strings.ToLower(modelID)
|
|
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
|
|
providerName = "gemini"
|
|
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
|
|
providerName = "deepseek"
|
|
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
|
|
providerName = "moonshot"
|
|
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
|
|
providerName = "grok"
|
|
} else if strings.HasPrefix(modelLower, "ollama/") ||
|
|
strings.Contains(modelLower, "glm-") ||
|
|
strings.Contains(modelLower, "qwen") ||
|
|
strings.Contains(modelLower, "gemma") ||
|
|
strings.Contains(modelLower, "llama") ||
|
|
strings.Contains(modelLower, "mistral") ||
|
|
strings.Contains(modelLower, "phi") ||
|
|
strings.Contains(modelLower, "yi") ||
|
|
strings.Contains(modelLower, "codellama") ||
|
|
strings.Contains(modelLower, "command-r") {
|
|
providerName = "ollama"
|
|
}
|
|
|
|
p, ok := s.providers[providerName]
|
|
if !ok {
|
|
return nil, "", fmt.Errorf("Provider %s not enabled or supported", providerName)
|
|
}
|
|
return p, providerName, nil
|
|
}
|
|
|
|
func (s *Server) handleChatCompletions(c *gin.Context) {
|
|
startTime := time.Now()
|
|
var req models.ChatCompletionRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Strip common prefixes and prepare model ID
|
|
modelID := req.Model
|
|
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"}
|
|
for _, p := range prefixes {
|
|
if strings.HasPrefix(modelID, p) {
|
|
modelID = strings.TrimPrefix(modelID, p)
|
|
break
|
|
}
|
|
}
|
|
|
|
// Resolve model groups to concrete models (hierarchical — groups can target groups)
|
|
modelGroup := ""
|
|
for i, m := range req.Messages {
|
|
log.Printf("[DEBUG] Incoming Msg[%d]: role=%s, hasToolCalls=%v, hasContent=%v", i, m.Role, len(m.ToolCalls) > 0, m.Content != nil)
|
|
}
|
|
if s.modelRouter != nil {
|
|
userMessage := extractUserMessage(req.Messages)
|
|
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, userMessage)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
|
|
return
|
|
}
|
|
if decision.SelectedModel != modelID {
|
|
modelGroup = modelID
|
|
}
|
|
modelID = decision.SelectedModel
|
|
log.Printf("[ROUTER] %s (%s: %s)", modelID, decision.Strategy, decision.Reason)
|
|
}
|
|
|
|
// Select provider based on the resolved model name
|
|
provider, providerName, err := s.selectProvider(modelID)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Convert ChatCompletionRequest to UnifiedRequest
|
|
unifiedReq := &models.UnifiedRequest{
|
|
Model: modelID,
|
|
Messages: []models.UnifiedMessage{},
|
|
Temperature: req.Temperature,
|
|
TopP: req.TopP,
|
|
TopK: req.TopK,
|
|
N: req.N,
|
|
MaxTokens: req.MaxTokens,
|
|
PresencePenalty: req.PresencePenalty,
|
|
FrequencyPenalty: req.FrequencyPenalty,
|
|
Stream: req.Stream != nil && *req.Stream,
|
|
Tools: req.Tools,
|
|
ToolChoice: req.ToolChoice,
|
|
}
|
|
|
|
// Inject or cap max_tokens from model registry.
|
|
s.registryMu.RLock()
|
|
meta := s.registry.FindModel(modelID)
|
|
s.registryMu.RUnlock()
|
|
|
|
if meta != nil && meta.Limit != nil && meta.Limit.Output > 0 {
|
|
if unifiedReq.MaxTokens == nil {
|
|
unifiedReq.MaxTokens = &meta.Limit.Output
|
|
log.Printf("[DEBUG] %s: injected registry max_tokens=%d", modelID, meta.Limit.Output)
|
|
} else if *unifiedReq.MaxTokens > meta.Limit.Output {
|
|
log.Printf("[DEBUG] %s: capping client max_tokens (%d) to registry limit (%d)", modelID, *unifiedReq.MaxTokens, meta.Limit.Output)
|
|
unifiedReq.MaxTokens = &meta.Limit.Output
|
|
} else {
|
|
log.Printf("[DEBUG] %s: using client max_tokens (%d)", modelID, *unifiedReq.MaxTokens)
|
|
}
|
|
} else {
|
|
if unifiedReq.MaxTokens == nil {
|
|
log.Printf("[DEBUG] %s: no registry limit found, leaving max_tokens nil", modelID)
|
|
} else {
|
|
log.Printf("[DEBUG] %s: using client max_tokens (%d), no registry limit to cap", modelID, *unifiedReq.MaxTokens)
|
|
}
|
|
}
|
|
|
|
// Handle Stop sequences
|
|
if req.Stop != nil {
|
|
var stop []string
|
|
if err := json.Unmarshal(req.Stop, &stop); err == nil {
|
|
unifiedReq.Stop = stop
|
|
} else {
|
|
var singleStop string
|
|
if err := json.Unmarshal(req.Stop, &singleStop); err == nil {
|
|
unifiedReq.Stop = []string{singleStop}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert messages
|
|
for _, msg := range req.Messages {
|
|
unifiedMsg := models.UnifiedMessage{
|
|
Role: msg.Role,
|
|
Content: []models.UnifiedContentPart{},
|
|
ReasoningContent: msg.ReasoningContent,
|
|
ToolCalls: msg.ToolCalls,
|
|
Name: msg.Name,
|
|
ToolCallID: msg.ToolCallID,
|
|
}
|
|
|
|
// Handle multimodal content
|
|
if strContent, ok := msg.Content.(string); ok {
|
|
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
|
Type: "text",
|
|
Text: strContent,
|
|
})
|
|
} else if parts, ok := msg.Content.([]interface{}); ok {
|
|
for _, part := range parts {
|
|
if partMap, ok := part.(map[string]interface{}); ok {
|
|
partType, _ := partMap["type"].(string)
|
|
if partType == "text" {
|
|
text, _ := partMap["text"].(string)
|
|
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
|
Type: "text",
|
|
Text: text,
|
|
})
|
|
} else if partType == "image_url" {
|
|
if imgURLMap, ok := partMap["image_url"].(map[string]interface{}); ok {
|
|
url, _ := imgURLMap["url"].(string)
|
|
imageInput := &models.ImageInput{}
|
|
if strings.HasPrefix(url, "data:") {
|
|
mime, data, err := utils.ParseDataURL(url)
|
|
if err == nil {
|
|
imageInput.Base64 = data
|
|
imageInput.MimeType = mime
|
|
}
|
|
} else {
|
|
imageInput.URL = url
|
|
}
|
|
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
|
Type: "image",
|
|
Image: imageInput,
|
|
})
|
|
unifiedReq.HasImages = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
unifiedReq.Messages = append(unifiedReq.Messages, unifiedMsg)
|
|
}
|
|
|
|
clientID := "default"
|
|
if auth, ok := c.Get("auth"); ok {
|
|
if authInfo, ok := auth.(models.AuthInfo); ok {
|
|
unifiedReq.ClientID = authInfo.ClientID
|
|
clientID = authInfo.ClientID
|
|
}
|
|
} else {
|
|
unifiedReq.ClientID = clientID
|
|
}
|
|
|
|
if unifiedReq.Stream {
|
|
ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, unifiedReq.HasImages)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
|
|
var lastUsage *models.Usage
|
|
c.Stream(func(w io.Writer) bool {
|
|
chunk, ok := <-ch
|
|
if !ok {
|
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage, nil, unifiedReq.HasImages)
|
|
return false
|
|
}
|
|
if chunk.Usage != nil {
|
|
lastUsage = chunk.Usage
|
|
}
|
|
data, err := json.Marshal(chunk)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
return true
|
|
})
|
|
return
|
|
}
|
|
|
|
resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, unifiedReq.HasImages)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage, nil, unifiedReq.HasImages)
|
|
c.JSON(http.StatusOK, resp)
|
|
}
|
|
|
|
func extractUserMessage(messages []models.ChatMessage) string {
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
if messages[i].Role == "user" {
|
|
switch c := messages[i].Content.(type) {
|
|
case string:
|
|
return c
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (s *Server) handleImageGenerations(c *gin.Context) {
|
|
startTime := time.Now()
|
|
var req models.ImageGenerationRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Determine provider based on model name
|
|
providerName := "openai"
|
|
modelLower := strings.ToLower(req.Model)
|
|
switch {
|
|
case strings.Contains(modelLower, "imagen"), strings.Contains(modelLower, "gemini"):
|
|
providerName = "gemini"
|
|
case strings.Contains(modelLower, "dall"), strings.HasPrefix(modelLower, "openai/"):
|
|
providerName = "openai"
|
|
}
|
|
|
|
// Default model for each provider if not specified
|
|
if req.Model == "" {
|
|
if providerName == "openai" {
|
|
req.Model = "dall-e-3"
|
|
} else {
|
|
req.Model = "imagen-3.0-generate-001"
|
|
}
|
|
}
|
|
|
|
// Strip common prefixes
|
|
prefixes := []string{"openai/", "gemini/", "google/"}
|
|
for _, p := range prefixes {
|
|
if strings.HasPrefix(req.Model, p) {
|
|
req.Model = strings.TrimPrefix(req.Model, p)
|
|
break
|
|
}
|
|
}
|
|
|
|
provider, ok := s.providers[providerName]
|
|
if !ok {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
|
return
|
|
}
|
|
|
|
clientID := "default"
|
|
if auth, ok := c.Get("auth"); ok {
|
|
if authInfo, ok := auth.(models.AuthInfo); ok {
|
|
clientID = authInfo.ClientID
|
|
}
|
|
}
|
|
|
|
resp, err := provider.ImageGeneration(c.Request.Context(), &req)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Estimate tokens from prompt text (~4 chars per token)
|
|
promptTokens := uint32(len(req.Prompt) / 4)
|
|
if promptTokens < 1 {
|
|
promptTokens = 1
|
|
}
|
|
|
|
// Calculate per-image cost (not per-token like chat)
|
|
cost := imageGenCost(providerName, req.Model, req.Size, uint32(len(resp.Data)))
|
|
|
|
s.logRequest(startTime, clientID, providerName, req.Model, "", &models.Usage{
|
|
PromptTokens: promptTokens,
|
|
CompletionTokens: uint32(len(resp.Data)),
|
|
TotalTokens: promptTokens + uint32(len(resp.Data)),
|
|
}, nil, false)
|
|
|
|
// Update cost in DB — image gen is per-image, not per-token
|
|
if cost > 0 {
|
|
s.database.Exec("UPDATE llm_requests SET cost = ? WHERE id = (SELECT MAX(id) FROM llm_requests)", cost)
|
|
}
|
|
|
|
c.JSON(http.StatusOK, resp)
|
|
}
|
|
|
|
// imageGenCost returns per-image pricing for known image generation models.
|
|
func imageGenCost(provider, model string, size *string, n uint32) float64 {
|
|
if n == 0 {
|
|
return 0
|
|
}
|
|
modelLower := strings.ToLower(model)
|
|
var perImage float64
|
|
|
|
switch {
|
|
case strings.Contains(modelLower, "dall-e-3"):
|
|
perImage = 0.040 // standard 1024x1024
|
|
if size != nil {
|
|
s := *size
|
|
if s == "1024x1792" || s == "1792x1024" {
|
|
perImage = 0.080
|
|
}
|
|
}
|
|
case strings.Contains(modelLower, "dall-e-2"):
|
|
perImage = 0.020
|
|
case strings.Contains(modelLower, "imagen"):
|
|
perImage = 0.040 // approximate
|
|
default:
|
|
return 0
|
|
}
|
|
|
|
return perImage * float64(n)
|
|
}
|
|
|
|
func (s *Server) logRequest(start time.Time, clientID, provider, model, modelGroup string, usage *models.Usage, err error, hasImages bool) {
|
|
entry := RequestLog{
|
|
Timestamp: start,
|
|
ClientID: clientID,
|
|
Provider: provider,
|
|
Model: model,
|
|
ModelGroup: modelGroup,
|
|
Status: "success",
|
|
DurationMS: time.Since(start).Milliseconds(),
|
|
HasImages: hasImages,
|
|
}
|
|
|
|
if err != nil {
|
|
entry.Status = "error"
|
|
entry.ErrorMessage = err.Error()
|
|
}
|
|
|
|
if usage != nil {
|
|
entry.PromptTokens = usage.PromptTokens
|
|
entry.CompletionTokens = usage.CompletionTokens
|
|
entry.TotalTokens = usage.TotalTokens
|
|
if usage.ReasoningTokens != nil {
|
|
entry.ReasoningTokens = *usage.ReasoningTokens
|
|
}
|
|
if usage.CacheReadTokens != nil {
|
|
entry.CacheReadTokens = *usage.CacheReadTokens
|
|
}
|
|
if usage.CacheWriteTokens != nil {
|
|
entry.CacheWriteTokens = *usage.CacheWriteTokens
|
|
}
|
|
|
|
// Calculate cost using registry
|
|
s.registryMu.RLock()
|
|
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
|
|
s.registryMu.RUnlock()
|
|
}
|
|
|
|
s.logger.LogRequest(entry)
|
|
}
|
|
|
|
func (s *Server) Run() error {
|
|
go s.hub.Run()
|
|
s.logger.Start()
|
|
|
|
// Start registry refresher
|
|
go func() {
|
|
ticker := time.NewTicker(24 * time.Hour)
|
|
for range ticker.C {
|
|
newRegistry, err := utils.FetchRegistry()
|
|
if err == nil {
|
|
s.registryMu.Lock()
|
|
s.registry = newRegistry
|
|
s.registryMu.Unlock()
|
|
}
|
|
}
|
|
}()
|
|
|
|
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
|
|
return s.router.Run(addr)
|
|
}
|
|
|
|
func uint32Ptr(v uint32) *uint32 { return &v }
|