Files
GopherGate/internal/server/server.go
T
hobokenchicken 73a82e6175
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
feat: implement advanced condition-based heuristic model routing
Upgrades the routing engine to support tag, token limit, multimodal, reasoning, and tool calling conditions. Adds unit tests for the new routing features.
2026-06-05 15:05:13 +00:00

1105 lines
33 KiB
Go

package server
import (
"encoding/json"
"context"
"fmt"
"io"
"log"
"net/http"
"strings"
"sync"
"time"
"gophergate/internal/config"
"gophergate/internal/db"
"gophergate/internal/middleware"
"gophergate/internal/models"
"gophergate/internal/providers"
"gophergate/internal/router"
"gophergate/internal/utils"
"github.com/gin-gonic/gin"
)
type Server struct {
router *gin.Engine
cfg *config.Config
database *db.DB
providers map[string]providers.Provider
sessions *SessionManager
hub *Hub
logger *RequestLogger
registry *models.ModelRegistry
registryMu sync.RWMutex
modelRouter *router.Router
}
func NewServer(cfg *config.Config, database *db.DB) *Server {
router := gin.Default()
hub := NewHub()
s := &Server{
router: router,
cfg: cfg,
database: database,
providers: make(map[string]providers.Provider),
sessions: NewSessionManager(cfg.KeyBytes, 24*time.Hour),
hub: hub,
logger: NewRequestLogger(database, hub),
registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)},
}
s.sessions.StartCleanup()
// Fetch registry in background
go func() {
registry, err := utils.FetchRegistry()
if err != nil {
fmt.Printf("Warning: Failed to fetch initial model registry: %v\n", err)
} else {
s.registry = registry
}
}()
// Initialize providers from DB and Config
if err := s.RefreshProviders(); err != nil {
fmt.Printf("Warning: Failed to initial refresh providers: %v\n", err)
}
s.setupRoutes()
// Initialize model group router
s.refreshRouter()
return s
}
func (s *Server) RefreshProviders() error {
var dbConfigs []db.ProviderConfig
err := s.database.Select(&dbConfigs, "SELECT * FROM provider_configs")
if err != nil {
return fmt.Errorf("failed to fetch provider configs from db: %w", err)
}
dbMap := make(map[string]db.ProviderConfig)
for _, cfg := range dbConfigs {
dbMap[cfg.ID] = cfg
}
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama", "xiaomi"}
for _, id := range providerIDs {
// Default values from config
enabled := false
baseURL := ""
apiKey := ""
switch id {
case "openai":
enabled = s.cfg.Providers.OpenAI.Enabled
baseURL = s.cfg.Providers.OpenAI.BaseURL
apiKey, _ = s.cfg.GetAPIKey("openai")
case "gemini":
enabled = s.cfg.Providers.Gemini.Enabled
baseURL = s.cfg.Providers.Gemini.BaseURL
apiKey, _ = s.cfg.GetAPIKey("gemini")
case "deepseek":
enabled = s.cfg.Providers.DeepSeek.Enabled
baseURL = s.cfg.Providers.DeepSeek.BaseURL
apiKey, _ = s.cfg.GetAPIKey("deepseek")
case "moonshot":
enabled = s.cfg.Providers.Moonshot.Enabled
baseURL = s.cfg.Providers.Moonshot.BaseURL
apiKey, _ = s.cfg.GetAPIKey("moonshot")
case "grok":
enabled = s.cfg.Providers.Grok.Enabled
baseURL = s.cfg.Providers.Grok.BaseURL
apiKey, _ = s.cfg.GetAPIKey("grok")
case "xiaomi":
enabled = s.cfg.Providers.Xiaomi.Enabled
baseURL = s.cfg.Providers.Xiaomi.BaseURL
apiKey, _ = s.cfg.GetAPIKey("xiaomi")
}
// Overrides from DB
if dbCfg, ok := dbMap[id]; ok {
enabled = dbCfg.Enabled
if dbCfg.BaseURL != nil && *dbCfg.BaseURL != "" {
baseURL = *dbCfg.BaseURL
}
if dbCfg.APIKey != nil && *dbCfg.APIKey != "" {
key := *dbCfg.APIKey
if dbCfg.APIKeyEncrypted {
decrypted, err := utils.Decrypt(key, s.cfg.KeyBytes)
if err == nil {
key = decrypted
} else {
fmt.Printf("Warning: Failed to decrypt API key for %s: %v\n", id, err)
}
}
apiKey = key
}
}
if !enabled {
delete(s.providers, id)
continue
}
// Initialize provider
var p providers.Provider
switch id {
case "openai":
cfg := s.cfg.Providers.OpenAI
cfg.BaseURL = baseURL
p = providers.NewOpenAIProvider(cfg, apiKey)
case "gemini":
cfg := s.cfg.Providers.Gemini
cfg.BaseURL = baseURL
p = providers.NewGeminiProvider(cfg, apiKey)
case "deepseek":
cfg := s.cfg.Providers.DeepSeek
cfg.BaseURL = baseURL
p = providers.NewDeepSeekProvider(cfg, apiKey)
case "moonshot":
cfg := s.cfg.Providers.Moonshot
cfg.BaseURL = baseURL
p = providers.NewMoonshotProvider(cfg, apiKey)
case "grok":
cfg := s.cfg.Providers.Grok
cfg.BaseURL = baseURL
p = providers.NewGrokProvider(cfg, apiKey)
case "ollama":
cfg := s.cfg.Providers.Ollama
cfg.BaseURL = baseURL
p = providers.NewOllamaProvider(cfg)
case "xiaomi":
cfg := s.cfg.Providers.Xiaomi
cfg.BaseURL = baseURL
p = providers.NewXiaomiProvider(cfg, apiKey)
}
if p != nil {
s.providers[id] = providers.NewCircuitBreakerProvider(p)
}
}
s.refreshRouter()
return nil
}
func (s *Server) refreshRouter() {
var groups []db.ModelGroup
if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil {
fmt.Printf("Warning: Failed to load model groups: %v\n", err)
groups = nil
}
var classifyFn router.ClassifierFunc
classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) {
provider, _, err := s.selectProvider(selectorModel)
if err != nil {
return "", err
}
req := &models.UnifiedRequest{
Model: selectorModel,
Messages: []models.UnifiedMessage{
{Role: "system", Content: []models.UnifiedContentPart{{Type: "text", Text: systemPrompt}}},
{Role: "user", Content: []models.UnifiedContentPart{{Type: "text", Text: userMessage}}},
},
MaxTokens: uint32Ptr(5),
Stream: false,
}
resp, err := provider.ChatCompletion(ctx, req)
if err != nil {
return "", err
}
if len(resp.Choices) == 0 {
return "", fmt.Errorf("no choices in classifier response")
}
content, ok := resp.Choices[0].Message.Content.(string)
if !ok {
return "", fmt.Errorf("classifier response content is not a string")
}
return content, nil
}
if s.modelRouter == nil {
s.modelRouter = router.New(groups, classifyFn)
} else {
s.modelRouter.Reload(groups)
}
}
func (s *Server) setupRoutes() {
// Static files
s.router.StaticFile("/", "./static/index.html")
s.router.StaticFile("/favicon.ico", "./static/favicon.ico")
s.router.Static("/css", "./static/css")
s.router.Static("/js", "./static/js")
s.router.Static("/img", "./static/img")
// WebSocket
s.router.GET("/ws", s.handleWebSocket)
// API V1 (External LLM Access) - Secured with AuthMiddleware
v1 := s.router.Group("/v1")
v1.Use(middleware.AuthMiddleware(s.database, true))
{
v1.POST("/chat/completions", s.handleChatCompletions)
v1.POST("/images/generations", s.handleImageGenerations)
v1.GET("/models", s.handleListModels)
v1.POST("/responses", s.handleResponses)
}
// Dashboard API Group
api := s.router.Group("/api")
{
api.POST("/auth/login", s.handleLogin)
api.GET("/auth/status", s.handleAuthStatus)
api.POST("/auth/logout", s.handleLogout)
api.POST("/auth/change-password", s.handleChangePassword)
// Protected dashboard routes (need admin session)
admin := api.Group("/")
admin.Use(s.adminAuthMiddleware())
{
admin.GET("/usage/summary", s.handleUsageSummary)
admin.GET("/usage/time-series", s.handleTimeSeries)
admin.GET("/usage/providers", s.handleProvidersUsage)
admin.GET("/usage/clients", s.handleClientsUsage)
admin.GET("/usage/detailed", s.handleDetailedUsage)
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
admin.GET("/clients", s.handleGetClients)
admin.POST("/clients", s.handleCreateClient)
admin.GET("/clients/:id", s.handleGetClient)
admin.PUT("/clients/:id", s.handleUpdateClient)
admin.DELETE("/clients/:id", s.handleDeleteClient)
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
admin.GET("/providers", s.handleGetProviders)
admin.PUT("/providers/:name", s.handleUpdateProvider)
admin.POST("/providers/:name/test", s.handleTestProvider)
admin.GET("/models", s.handleGetModels)
admin.PUT("/models/:id", s.handleUpdateModel)
admin.GET("/model-groups", s.handleGetModelGroups)
admin.POST("/model-groups", s.handleCreateModelGroup)
admin.PUT("/model-groups/:id", s.handleUpdateModelGroup)
admin.DELETE("/model-groups/:id", s.handleDeleteModelGroup)
admin.GET("/users", s.handleGetUsers)
admin.POST("/users", s.handleCreateUser)
admin.PUT("/users/:id", s.handleUpdateUser)
admin.DELETE("/users/:id", s.handleDeleteUser)
admin.GET("/system/health", s.handleSystemHealth)
admin.GET("/system/metrics", s.handleSystemMetrics)
admin.GET("/system/settings", s.handleGetSettings)
admin.POST("/system/backup", s.handleCreateBackup)
admin.GET("/system/logs", s.handleGetLogs)
}
}
s.router.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
}
func (s *Server) handleResponses(c *gin.Context) {
startTime := time.Now()
var req models.ResponsesRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Strip common prefixes and resolve model groups to concrete models
// (same pattern as handleChatCompletions).
modelGroup := ""
modelID := req.Model
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/", "xiaomi/"}
for _, p := range prefixes {
if strings.HasPrefix(modelID, p) {
modelID = strings.TrimPrefix(modelID, p)
break
}
}
if s.modelRouter != nil {
routeCtx := s.buildRouteContextFromResponses(req)
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
return
}
if decision.SelectedModel != modelID {
modelGroup = modelID
}
modelID = decision.SelectedModel
}
// Select provider based on resolved model name
providerName := "openai" // default for Responses API
modelLower := strings.ToLower(modelID)
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
providerName = "gemini"
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
providerName = "deepseek"
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
providerName = "moonshot"
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
providerName = "grok"
} else if strings.HasPrefix(modelLower, "ollama/") ||
strings.Contains(modelLower, "glm-") ||
strings.Contains(modelLower, "qwen") ||
strings.Contains(modelLower, "gemma") ||
strings.Contains(modelLower, "llama") ||
strings.Contains(modelLower, "mistral") ||
strings.Contains(modelLower, "phi") ||
strings.Contains(modelLower, "yi") ||
strings.Contains(modelLower, "codellama") ||
strings.Contains(modelLower, "command-r") {
providerName = "ollama"
}
provider, ok := s.providers[providerName]
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
return
}
// Use resolved model for the actual API call
req.Model = modelID
clientID := "default"
if auth, ok := c.Get("auth"); ok {
if authInfo, ok := auth.(models.AuthInfo); ok {
clientID = authInfo.ClientID
}
}
stream := req.Stream != nil && *req.Stream
if stream {
ch, err := provider.ResponsesStream(c.Request.Context(), &req)
if err != nil {
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var lastUsage *models.ResponsesUsage
c.Stream(func(w io.Writer) bool {
chunk, ok := <-ch
if !ok {
fmt.Fprintf(w, "data: [DONE]\n\n")
if lastUsage != nil {
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage.ToUsage(), nil, false)
} else {
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false)
}
return false
}
// Capture usage from the response payload in streaming chunks
if chunk.Response != nil && chunk.Response.Usage != nil {
lastUsage = chunk.Response.Usage
}
data, err := json.Marshal(chunk)
if err != nil {
return false
}
fmt.Fprintf(w, "data: %s\n\n", data)
return true
})
return
}
resp, err := provider.Responses(c.Request.Context(), &req)
if err != nil {
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if resp.Usage != nil {
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage.ToUsage(), nil, false)
} else {
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false)
}
c.JSON(http.StatusOK, resp)
}
func (s *Server) handleListModels(c *gin.Context) {
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
modelMap := make(map[string]OpenAIModel)
allowedProviders := map[string]bool{
"openai": true,
"google": true, // Models from models.dev use 'google' ID for Gemini
"deepseek": true,
"moonshot": true,
"moonshotai": true, // Official moonshotai ID in models.dev
"moonshotai-cn": true, // Official moonshotai-cn ID in models.dev
"xai": true, // Models from models.dev use 'xai' ID for Grok
"llmgateway": true, // Catch-all for newer models
"ollama": true,
"xiaomi": true, // Xiaomi MiMo models
}
s.registryMu.RLock()
if s.registry != nil {
for pID, pInfo := range s.registry.Providers {
if !allowedProviders[pID] {
continue
}
for mID := range pInfo.Models {
if _, exists := modelMap[mID]; !exists {
modelMap[mID] = OpenAIModel{
ID: mID,
Object: "model",
Created: 1700000000,
OwnedBy: pID,
}
}
}
}
}
s.registryMu.RUnlock()
// Add configured Ollama models
if s.cfg.Providers.Ollama.Enabled {
for _, mID := range s.cfg.Providers.Ollama.Models {
if _, exists := modelMap[mID]; !exists {
modelMap[mID] = OpenAIModel{
ID: mID,
Object: "model",
Created: 1700000000,
OwnedBy: "ollama",
}
}
}
}
// Add model groups so clients can discover them
if s.modelRouter != nil {
for _, gid := range s.modelRouter.Groups() {
if _, exists := modelMap[gid]; !exists {
modelMap[gid] = OpenAIModel{
ID: gid,
Object: "model",
Created: 1700000000,
OwnedBy: "gophergate",
}
}
}
}
var data []OpenAIModel
for _, m := range modelMap {
data = append(data, m)
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": data,
})
}
func (s *Server) selectProvider(modelID string) (providers.Provider, string, error) {
providerName := "openai" // default
modelLower := strings.ToLower(modelID)
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
providerName = "gemini"
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
providerName = "deepseek"
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
providerName = "moonshot"
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
providerName = "grok"
} else if strings.HasPrefix(modelLower, "ollama/") ||
strings.Contains(modelLower, "glm-") ||
strings.Contains(modelLower, "qwen") ||
strings.Contains(modelLower, "gemma") ||
strings.Contains(modelLower, "llama") ||
strings.Contains(modelLower, "mistral") ||
strings.Contains(modelLower, "phi") ||
strings.Contains(modelLower, "yi") ||
strings.Contains(modelLower, "codellama") ||
strings.Contains(modelLower, "command-r") {
providerName = "ollama"
} else if strings.HasPrefix(modelLower, "xiaomi/") || strings.Contains(modelLower, "mimo") || strings.Contains(modelLower, "xiaomi") {
providerName = "xiaomi"
}
p, ok := s.providers[providerName]
if !ok {
return nil, "", fmt.Errorf("Provider %s not enabled or supported", providerName)
}
return p, providerName, nil
}
func (s *Server) handleChatCompletions(c *gin.Context) {
startTime := time.Now()
var req models.ChatCompletionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Strip common prefixes and prepare model ID
modelID := req.Model
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/", "xiaomi/"}
for _, p := range prefixes {
if strings.HasPrefix(modelID, p) {
modelID = strings.TrimPrefix(modelID, p)
break
}
}
// Resolve model groups to concrete models (hierarchical — groups can target groups)
modelGroup := ""
for i, m := range req.Messages {
log.Printf("[DEBUG] Incoming Msg[%d]: role=%s, hasToolCalls=%v, hasContent=%v", i, m.Role, len(m.ToolCalls) > 0, m.Content != nil)
}
if s.modelRouter != nil {
routeCtx := s.buildRouteContextFromChat(req)
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
return
}
if decision.SelectedModel != modelID {
modelGroup = modelID
}
modelID = decision.SelectedModel
log.Printf("[ROUTER] %s (%s: %s)", modelID, decision.Strategy, decision.Reason)
}
// Select provider based on the resolved model name
provider, providerName, err := s.selectProvider(modelID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Convert ChatCompletionRequest to UnifiedRequest
unifiedReq := &models.UnifiedRequest{
Model: modelID,
Messages: []models.UnifiedMessage{},
Temperature: req.Temperature,
TopP: req.TopP,
TopK: req.TopK,
N: req.N,
MaxTokens: req.MaxTokens,
PresencePenalty: req.PresencePenalty,
FrequencyPenalty: req.FrequencyPenalty,
Stream: req.Stream != nil && *req.Stream,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
}
// Inject or cap max_tokens from model registry.
s.registryMu.RLock()
meta := s.registry.FindModel(modelID)
s.registryMu.RUnlock()
if meta != nil && meta.Limit != nil && meta.Limit.Output > 0 {
if unifiedReq.MaxTokens == nil {
unifiedReq.MaxTokens = &meta.Limit.Output
log.Printf("[DEBUG] %s: injected registry max_tokens=%d", modelID, meta.Limit.Output)
} else if *unifiedReq.MaxTokens > meta.Limit.Output {
log.Printf("[DEBUG] %s: capping client max_tokens (%d) to registry limit (%d)", modelID, *unifiedReq.MaxTokens, meta.Limit.Output)
unifiedReq.MaxTokens = &meta.Limit.Output
} else {
log.Printf("[DEBUG] %s: using client max_tokens (%d)", modelID, *unifiedReq.MaxTokens)
}
} else {
if unifiedReq.MaxTokens == nil {
log.Printf("[DEBUG] %s: no registry limit found, leaving max_tokens nil", modelID)
} else {
log.Printf("[DEBUG] %s: using client max_tokens (%d), no registry limit to cap", modelID, *unifiedReq.MaxTokens)
}
}
// Handle Stop sequences
if req.Stop != nil {
var stop []string
if err := json.Unmarshal(req.Stop, &stop); err == nil {
unifiedReq.Stop = stop
} else {
var singleStop string
if err := json.Unmarshal(req.Stop, &singleStop); err == nil {
unifiedReq.Stop = []string{singleStop}
}
}
}
// Convert messages
for _, msg := range req.Messages {
unifiedMsg := models.UnifiedMessage{
Role: msg.Role,
Content: []models.UnifiedContentPart{},
ReasoningContent: msg.ReasoningContent,
ToolCalls: msg.ToolCalls,
Name: msg.Name,
ToolCallID: msg.ToolCallID,
}
// Handle multimodal content
if strContent, ok := msg.Content.(string); ok {
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
Type: "text",
Text: strContent,
})
} else if parts, ok := msg.Content.([]interface{}); ok {
for _, part := range parts {
if partMap, ok := part.(map[string]interface{}); ok {
partType, _ := partMap["type"].(string)
if partType == "text" {
text, _ := partMap["text"].(string)
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
Type: "text",
Text: text,
})
} else if partType == "image_url" {
if imgURLMap, ok := partMap["image_url"].(map[string]interface{}); ok {
url, _ := imgURLMap["url"].(string)
imageInput := &models.ImageInput{}
if strings.HasPrefix(url, "data:") {
mime, data, err := utils.ParseDataURL(url)
if err == nil {
imageInput.Base64 = data
imageInput.MimeType = mime
}
} else {
imageInput.URL = url
}
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
Type: "image",
Image: imageInput,
})
unifiedReq.HasImages = true
}
}
}
}
}
unifiedReq.Messages = append(unifiedReq.Messages, unifiedMsg)
}
clientID := "default"
if auth, ok := c.Get("auth"); ok {
if authInfo, ok := auth.(models.AuthInfo); ok {
unifiedReq.ClientID = authInfo.ClientID
clientID = authInfo.ClientID
}
} else {
unifiedReq.ClientID = clientID
}
if unifiedReq.Stream {
ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq)
if err != nil {
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, unifiedReq.HasImages)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var lastUsage *models.Usage
c.Stream(func(w io.Writer) bool {
chunk, ok := <-ch
if !ok {
fmt.Fprintf(w, "data: [DONE]\n\n")
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage, nil, unifiedReq.HasImages)
return false
}
if chunk.Usage != nil {
lastUsage = chunk.Usage
}
data, err := json.Marshal(chunk)
if err != nil {
return false
}
fmt.Fprintf(w, "data: %s\n\n", data)
return true
})
return
}
resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq)
if err != nil {
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, unifiedReq.HasImages)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage, nil, unifiedReq.HasImages)
c.JSON(http.StatusOK, resp)
}
func extractUserMessage(messages []models.ChatMessage) string {
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
switch c := messages[i].Content.(type) {
case string:
return c
default:
return ""
}
}
}
return ""
}
func (s *Server) handleImageGenerations(c *gin.Context) {
startTime := time.Now()
var req models.ImageGenerationRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Determine provider based on model name
providerName := "openai"
modelLower := strings.ToLower(req.Model)
switch {
case strings.Contains(modelLower, "imagen"), strings.Contains(modelLower, "gemini"):
providerName = "gemini"
case strings.Contains(modelLower, "dall"), strings.HasPrefix(modelLower, "openai/"):
providerName = "openai"
}
// Default model for each provider if not specified
if req.Model == "" {
if providerName == "openai" {
req.Model = "dall-e-3"
} else {
req.Model = "imagen-3.0-generate-001"
}
}
// Strip common prefixes
prefixes := []string{"openai/", "gemini/", "google/"}
for _, p := range prefixes {
if strings.HasPrefix(req.Model, p) {
req.Model = strings.TrimPrefix(req.Model, p)
break
}
}
provider, ok := s.providers[providerName]
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
return
}
clientID := "default"
if auth, ok := c.Get("auth"); ok {
if authInfo, ok := auth.(models.AuthInfo); ok {
clientID = authInfo.ClientID
}
}
resp, err := provider.ImageGeneration(c.Request.Context(), &req)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Estimate tokens from prompt text (~4 chars per token)
promptTokens := uint32(len(req.Prompt) / 4)
if promptTokens < 1 {
promptTokens = 1
}
// Calculate per-image cost (not per-token like chat)
cost := imageGenCost(providerName, req.Model, req.Size, uint32(len(resp.Data)))
s.logRequest(startTime, clientID, providerName, req.Model, "", &models.Usage{
PromptTokens: promptTokens,
CompletionTokens: uint32(len(resp.Data)),
TotalTokens: promptTokens + uint32(len(resp.Data)),
}, nil, false)
// Update cost in DB — image gen is per-image, not per-token
if cost > 0 {
s.database.Exec("UPDATE llm_requests SET cost = ? WHERE id = (SELECT MAX(id) FROM llm_requests)", cost)
}
c.JSON(http.StatusOK, resp)
}
// imageGenCost returns per-image pricing for known image generation models.
func imageGenCost(provider, model string, size *string, n uint32) float64 {
if n == 0 {
return 0
}
modelLower := strings.ToLower(model)
var perImage float64
switch {
case strings.Contains(modelLower, "dall-e-3"):
perImage = 0.040 // standard 1024x1024
if size != nil {
s := *size
if s == "1024x1792" || s == "1792x1024" {
perImage = 0.080
}
}
case strings.Contains(modelLower, "dall-e-2"):
perImage = 0.020
case strings.Contains(modelLower, "imagen"):
perImage = 0.040 // approximate
default:
return 0
}
return perImage * float64(n)
}
func (s *Server) logRequest(start time.Time, clientID, provider, model, modelGroup string, usage *models.Usage, err error, hasImages bool) {
entry := RequestLog{
Timestamp: start,
ClientID: clientID,
Provider: provider,
Model: model,
ModelGroup: modelGroup,
Status: "success",
DurationMS: time.Since(start).Milliseconds(),
HasImages: hasImages,
}
if err != nil {
entry.Status = "error"
entry.ErrorMessage = err.Error()
}
if usage != nil {
entry.PromptTokens = usage.PromptTokens
entry.CompletionTokens = usage.CompletionTokens
entry.TotalTokens = usage.TotalTokens
if usage.ReasoningTokens != nil {
entry.ReasoningTokens = *usage.ReasoningTokens
}
if usage.CacheReadTokens != nil {
entry.CacheReadTokens = *usage.CacheReadTokens
}
if usage.CacheWriteTokens != nil {
entry.CacheWriteTokens = *usage.CacheWriteTokens
}
// Calculate cost using registry; if the resolved model is unknown,
// fall back to the model group so group requests still get priced.
s.registryMu.RLock()
pricingModel := model
if s.registry != nil && s.registry.FindModel(pricingModel) == nil && modelGroup != "" {
pricingModel = modelGroup
}
entry.Cost = utils.CalculateCost(s.registry, pricingModel, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
s.registryMu.RUnlock()
}
s.logger.LogRequest(entry)
}
func (s *Server) Run() error {
go s.hub.Run()
s.logger.Start()
// Start registry refresher
go func() {
ticker := time.NewTicker(24 * time.Hour)
for range ticker.C {
newRegistry, err := utils.FetchRegistry()
if err == nil {
s.registryMu.Lock()
s.registry = newRegistry
s.registryMu.Unlock()
}
}
}()
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
return s.router.Run(addr)
}
func uint32Ptr(v uint32) *uint32 { return &v }
func (s *Server) buildRouteContextFromChat(req models.ChatCompletionRequest) *router.RouteContext {
userMessage := extractUserMessage(req.Messages)
requiresToolCalling := len(req.Tools) > 0
hasMultimodal := false
inputTokens := 0
for _, msg := range req.Messages {
if strContent, ok := msg.Content.(string); ok {
inputTokens += len(strContent) / 4
} else if parts, ok := msg.Content.([]interface{}); ok {
for _, part := range parts {
if partMap, ok := part.(map[string]interface{}); ok {
partType, _ := partMap["type"].(string)
if partType == "text" {
text, _ := partMap["text"].(string)
inputTokens += len(text) / 4
} else if partType == "image_url" {
hasMultimodal = true
inputTokens += 1000 // Approximate cost of an image in tokens
}
}
}
}
}
msgLower := strings.ToLower(userMessage)
requiresReasoning := strings.Contains(msgLower, "reason") ||
strings.Contains(msgLower, "think step by step") ||
strings.Contains(msgLower, "mathematics") ||
strings.Contains(msgLower, "architecture") ||
strings.Contains(msgLower, "explain in detail")
routeCtx := &router.RouteContext{
UserMessage: userMessage,
InputTokens: inputTokens,
HasMultimodalInput: hasMultimodal,
RequiresToolCalling: requiresToolCalling,
RequiresReasoning: requiresReasoning,
}
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
return routeCtx
}
func (s *Server) buildRouteContextFromResponses(req models.ResponsesRequest) *router.RouteContext {
var userMessage string
hasMultimodal := false
inputTokens := len(req.Instructions) / 4
requiresToolCalling := len(req.Tools) > 0 && string(req.Tools) != "null" && string(req.Tools) != ""
var strInput string
if err := json.Unmarshal(req.Input, &strInput); err == nil {
userMessage = strInput
inputTokens += len(userMessage) / 4
} else {
var msgs []models.ResponseInputMessage
if err := json.Unmarshal(req.Input, &msgs); err == nil {
for _, m := range msgs {
var contentStr string
if err := json.Unmarshal(m.Content, &contentStr); err == nil {
if m.Role == "user" {
userMessage = contentStr
}
inputTokens += len(contentStr) / 4
} else {
var parts []models.ContentPart
if err := json.Unmarshal(m.Content, &parts); err == nil {
for _, p := range parts {
if p.Type == "text" {
if m.Role == "user" {
userMessage = p.Text
}
inputTokens += len(p.Text) / 4
} else if p.Type == "image_url" {
hasMultimodal = true
inputTokens += 1000
}
}
}
}
}
}
}
msgLower := strings.ToLower(userMessage)
requiresReasoning := strings.Contains(msgLower, "reason") ||
strings.Contains(msgLower, "think step by step") ||
strings.Contains(msgLower, "mathematics") ||
strings.Contains(msgLower, "architecture") ||
strings.Contains(msgLower, "explain in detail")
routeCtx := &router.RouteContext{
UserMessage: userMessage,
InputTokens: inputTokens,
HasMultimodalInput: hasMultimodal,
RequiresToolCalling: requiresToolCalling,
RequiresReasoning: requiresReasoning,
}
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
return routeCtx
}
func (s *Server) getRouteCtxTags(routeCtx *router.RouteContext) []string {
var tags []string
msgLower := strings.ToLower(routeCtx.UserMessage)
// fast-flow keywords
fastFlowKeywords := []string{
"classify", "classification", "label", "tag", "route", "routing", "intent",
"json", "yaml", "csv", "schema", "parse", "extract", "transform", "format", "regex",
"short answer", "brief", "concise", "tl;dr", "one line", "simple",
"fix this", "small bug", "quick fix", "typo", "syntax error",
}
for _, kw := range fastFlowKeywords {
if strings.Contains(msgLower, kw) {
tags = append(tags, "fast-flow", "classification", "json-extraction", "basic-qa")
break
}
}
// standard-pro keywords
standardProKeywords := []string{
"explain", "summarize", "rewrite", "draft", "edit", "polish", "outline",
"long doc", "document", "email", "memo", "proposal", "report", "handout", "notes",
"compare", "choose", "recommend", "tradeoff", "pros and cons", "analysis",
"code review", "debug", "bug", "feature", "api", "endpoint", "implement",
"plan", "planning", "workflow", "integration",
}
for _, kw := range standardProKeywords {
if strings.Contains(msgLower, kw) {
tags = append(tags, "standard-pro", "long-doc")
break
}
}
if routeCtx.HasMultimodalInput {
tags = append(tags, "video-analysis", "multimodal-qa")
}
// heavy-logic keywords
heavyLogicKeywords := []string{
"agent", "agents", "tool use", "function calling", "multi-agent", "orchestrate",
"system design", "scaling", "performance", "architecture review", "distributed",
"hard bug", "race condition", "deadlock", "memory leak", "crash", "production outage",
"long context", "large codebase", "many files", "complex refactor", "migration",
"research", "deep dive", "literature", "paper", "scholarly", "thorough analysis",
"deep reasoning", "think step by step", "reason through", "careful analysis",
}
for _, kw := range heavyLogicKeywords {
if strings.Contains(msgLower, kw) {
tags = append(tags, "heavy-logic", "deep-reasoning", "architecture", "hard-debugging")
break
}
}
if routeCtx.RequiresToolCalling {
tags = append(tags, "tool-heavy", "multi-step-agent", "swe-bench")
}
return tags
}