73a82e6175
Upgrades the routing engine to support tag, token limit, multimodal, reasoning, and tool calling conditions. Adds unit tests for the new routing features.
1105 lines
33 KiB
Go
1105 lines
33 KiB
Go
package server
|
|
|
|
import (
|
|
"encoding/json"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"gophergate/internal/config"
|
|
"gophergate/internal/db"
|
|
"gophergate/internal/middleware"
|
|
"gophergate/internal/models"
|
|
"gophergate/internal/providers"
|
|
"gophergate/internal/router"
|
|
"gophergate/internal/utils"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
type Server struct {
|
|
router *gin.Engine
|
|
cfg *config.Config
|
|
database *db.DB
|
|
providers map[string]providers.Provider
|
|
sessions *SessionManager
|
|
hub *Hub
|
|
logger *RequestLogger
|
|
registry *models.ModelRegistry
|
|
registryMu sync.RWMutex
|
|
modelRouter *router.Router
|
|
}
|
|
|
|
func NewServer(cfg *config.Config, database *db.DB) *Server {
|
|
router := gin.Default()
|
|
hub := NewHub()
|
|
|
|
s := &Server{
|
|
router: router,
|
|
cfg: cfg,
|
|
database: database,
|
|
providers: make(map[string]providers.Provider),
|
|
sessions: NewSessionManager(cfg.KeyBytes, 24*time.Hour),
|
|
hub: hub,
|
|
logger: NewRequestLogger(database, hub),
|
|
registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)},
|
|
}
|
|
|
|
s.sessions.StartCleanup()
|
|
// Fetch registry in background
|
|
go func() {
|
|
registry, err := utils.FetchRegistry()
|
|
if err != nil {
|
|
fmt.Printf("Warning: Failed to fetch initial model registry: %v\n", err)
|
|
} else {
|
|
s.registry = registry
|
|
}
|
|
}()
|
|
|
|
// Initialize providers from DB and Config
|
|
if err := s.RefreshProviders(); err != nil {
|
|
fmt.Printf("Warning: Failed to initial refresh providers: %v\n", err)
|
|
}
|
|
|
|
s.setupRoutes()
|
|
|
|
// Initialize model group router
|
|
s.refreshRouter()
|
|
return s
|
|
}
|
|
|
|
func (s *Server) RefreshProviders() error {
|
|
var dbConfigs []db.ProviderConfig
|
|
err := s.database.Select(&dbConfigs, "SELECT * FROM provider_configs")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to fetch provider configs from db: %w", err)
|
|
}
|
|
|
|
dbMap := make(map[string]db.ProviderConfig)
|
|
for _, cfg := range dbConfigs {
|
|
dbMap[cfg.ID] = cfg
|
|
}
|
|
|
|
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama", "xiaomi"}
|
|
for _, id := range providerIDs {
|
|
// Default values from config
|
|
enabled := false
|
|
baseURL := ""
|
|
apiKey := ""
|
|
|
|
switch id {
|
|
case "openai":
|
|
enabled = s.cfg.Providers.OpenAI.Enabled
|
|
baseURL = s.cfg.Providers.OpenAI.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("openai")
|
|
case "gemini":
|
|
enabled = s.cfg.Providers.Gemini.Enabled
|
|
baseURL = s.cfg.Providers.Gemini.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("gemini")
|
|
case "deepseek":
|
|
enabled = s.cfg.Providers.DeepSeek.Enabled
|
|
baseURL = s.cfg.Providers.DeepSeek.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("deepseek")
|
|
case "moonshot":
|
|
enabled = s.cfg.Providers.Moonshot.Enabled
|
|
baseURL = s.cfg.Providers.Moonshot.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("moonshot")
|
|
case "grok":
|
|
enabled = s.cfg.Providers.Grok.Enabled
|
|
baseURL = s.cfg.Providers.Grok.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("grok")
|
|
case "xiaomi":
|
|
enabled = s.cfg.Providers.Xiaomi.Enabled
|
|
baseURL = s.cfg.Providers.Xiaomi.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("xiaomi")
|
|
}
|
|
|
|
// Overrides from DB
|
|
if dbCfg, ok := dbMap[id]; ok {
|
|
enabled = dbCfg.Enabled
|
|
if dbCfg.BaseURL != nil && *dbCfg.BaseURL != "" {
|
|
baseURL = *dbCfg.BaseURL
|
|
}
|
|
if dbCfg.APIKey != nil && *dbCfg.APIKey != "" {
|
|
key := *dbCfg.APIKey
|
|
if dbCfg.APIKeyEncrypted {
|
|
decrypted, err := utils.Decrypt(key, s.cfg.KeyBytes)
|
|
if err == nil {
|
|
key = decrypted
|
|
} else {
|
|
fmt.Printf("Warning: Failed to decrypt API key for %s: %v\n", id, err)
|
|
}
|
|
}
|
|
apiKey = key
|
|
}
|
|
}
|
|
|
|
if !enabled {
|
|
delete(s.providers, id)
|
|
continue
|
|
}
|
|
|
|
// Initialize provider
|
|
var p providers.Provider
|
|
switch id {
|
|
case "openai":
|
|
cfg := s.cfg.Providers.OpenAI
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewOpenAIProvider(cfg, apiKey)
|
|
case "gemini":
|
|
cfg := s.cfg.Providers.Gemini
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewGeminiProvider(cfg, apiKey)
|
|
case "deepseek":
|
|
cfg := s.cfg.Providers.DeepSeek
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewDeepSeekProvider(cfg, apiKey)
|
|
case "moonshot":
|
|
cfg := s.cfg.Providers.Moonshot
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewMoonshotProvider(cfg, apiKey)
|
|
case "grok":
|
|
cfg := s.cfg.Providers.Grok
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewGrokProvider(cfg, apiKey)
|
|
case "ollama":
|
|
cfg := s.cfg.Providers.Ollama
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewOllamaProvider(cfg)
|
|
case "xiaomi":
|
|
cfg := s.cfg.Providers.Xiaomi
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewXiaomiProvider(cfg, apiKey)
|
|
}
|
|
|
|
if p != nil {
|
|
s.providers[id] = providers.NewCircuitBreakerProvider(p)
|
|
}
|
|
}
|
|
|
|
s.refreshRouter()
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) refreshRouter() {
|
|
var groups []db.ModelGroup
|
|
if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil {
|
|
fmt.Printf("Warning: Failed to load model groups: %v\n", err)
|
|
groups = nil
|
|
}
|
|
|
|
var classifyFn router.ClassifierFunc
|
|
classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) {
|
|
provider, _, err := s.selectProvider(selectorModel)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
req := &models.UnifiedRequest{
|
|
Model: selectorModel,
|
|
Messages: []models.UnifiedMessage{
|
|
{Role: "system", Content: []models.UnifiedContentPart{{Type: "text", Text: systemPrompt}}},
|
|
{Role: "user", Content: []models.UnifiedContentPart{{Type: "text", Text: userMessage}}},
|
|
},
|
|
MaxTokens: uint32Ptr(5),
|
|
Stream: false,
|
|
}
|
|
resp, err := provider.ChatCompletion(ctx, req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if len(resp.Choices) == 0 {
|
|
return "", fmt.Errorf("no choices in classifier response")
|
|
}
|
|
content, ok := resp.Choices[0].Message.Content.(string)
|
|
if !ok {
|
|
return "", fmt.Errorf("classifier response content is not a string")
|
|
}
|
|
return content, nil
|
|
}
|
|
|
|
if s.modelRouter == nil {
|
|
s.modelRouter = router.New(groups, classifyFn)
|
|
} else {
|
|
s.modelRouter.Reload(groups)
|
|
}
|
|
}
|
|
|
|
func (s *Server) setupRoutes() {
|
|
// Static files
|
|
s.router.StaticFile("/", "./static/index.html")
|
|
s.router.StaticFile("/favicon.ico", "./static/favicon.ico")
|
|
s.router.Static("/css", "./static/css")
|
|
s.router.Static("/js", "./static/js")
|
|
s.router.Static("/img", "./static/img")
|
|
|
|
// WebSocket
|
|
s.router.GET("/ws", s.handleWebSocket)
|
|
|
|
// API V1 (External LLM Access) - Secured with AuthMiddleware
|
|
v1 := s.router.Group("/v1")
|
|
v1.Use(middleware.AuthMiddleware(s.database, true))
|
|
{
|
|
v1.POST("/chat/completions", s.handleChatCompletions)
|
|
v1.POST("/images/generations", s.handleImageGenerations)
|
|
v1.GET("/models", s.handleListModels)
|
|
v1.POST("/responses", s.handleResponses)
|
|
}
|
|
|
|
// Dashboard API Group
|
|
api := s.router.Group("/api")
|
|
{
|
|
api.POST("/auth/login", s.handleLogin)
|
|
api.GET("/auth/status", s.handleAuthStatus)
|
|
api.POST("/auth/logout", s.handleLogout)
|
|
api.POST("/auth/change-password", s.handleChangePassword)
|
|
|
|
// Protected dashboard routes (need admin session)
|
|
admin := api.Group("/")
|
|
admin.Use(s.adminAuthMiddleware())
|
|
{
|
|
admin.GET("/usage/summary", s.handleUsageSummary)
|
|
admin.GET("/usage/time-series", s.handleTimeSeries)
|
|
admin.GET("/usage/providers", s.handleProvidersUsage)
|
|
admin.GET("/usage/clients", s.handleClientsUsage)
|
|
admin.GET("/usage/detailed", s.handleDetailedUsage)
|
|
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
|
|
|
|
admin.GET("/clients", s.handleGetClients)
|
|
admin.POST("/clients", s.handleCreateClient)
|
|
admin.GET("/clients/:id", s.handleGetClient)
|
|
admin.PUT("/clients/:id", s.handleUpdateClient)
|
|
admin.DELETE("/clients/:id", s.handleDeleteClient)
|
|
|
|
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
|
|
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
|
|
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
|
|
|
|
admin.GET("/providers", s.handleGetProviders)
|
|
admin.PUT("/providers/:name", s.handleUpdateProvider)
|
|
admin.POST("/providers/:name/test", s.handleTestProvider)
|
|
|
|
admin.GET("/models", s.handleGetModels)
|
|
admin.PUT("/models/:id", s.handleUpdateModel)
|
|
|
|
admin.GET("/model-groups", s.handleGetModelGroups)
|
|
admin.POST("/model-groups", s.handleCreateModelGroup)
|
|
admin.PUT("/model-groups/:id", s.handleUpdateModelGroup)
|
|
admin.DELETE("/model-groups/:id", s.handleDeleteModelGroup)
|
|
|
|
admin.GET("/users", s.handleGetUsers)
|
|
admin.POST("/users", s.handleCreateUser)
|
|
admin.PUT("/users/:id", s.handleUpdateUser)
|
|
admin.DELETE("/users/:id", s.handleDeleteUser)
|
|
|
|
admin.GET("/system/health", s.handleSystemHealth)
|
|
admin.GET("/system/metrics", s.handleSystemMetrics)
|
|
admin.GET("/system/settings", s.handleGetSettings)
|
|
admin.POST("/system/backup", s.handleCreateBackup)
|
|
admin.GET("/system/logs", s.handleGetLogs)
|
|
}
|
|
}
|
|
|
|
s.router.GET("/health", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleResponses(c *gin.Context) {
|
|
startTime := time.Now()
|
|
var req models.ResponsesRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Strip common prefixes and resolve model groups to concrete models
|
|
// (same pattern as handleChatCompletions).
|
|
modelGroup := ""
|
|
modelID := req.Model
|
|
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/", "xiaomi/"}
|
|
for _, p := range prefixes {
|
|
if strings.HasPrefix(modelID, p) {
|
|
modelID = strings.TrimPrefix(modelID, p)
|
|
break
|
|
}
|
|
}
|
|
if s.modelRouter != nil {
|
|
routeCtx := s.buildRouteContextFromResponses(req)
|
|
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
|
|
return
|
|
}
|
|
if decision.SelectedModel != modelID {
|
|
modelGroup = modelID
|
|
}
|
|
modelID = decision.SelectedModel
|
|
}
|
|
|
|
// Select provider based on resolved model name
|
|
providerName := "openai" // default for Responses API
|
|
modelLower := strings.ToLower(modelID)
|
|
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
|
|
providerName = "gemini"
|
|
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
|
|
providerName = "deepseek"
|
|
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
|
|
providerName = "moonshot"
|
|
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
|
|
providerName = "grok"
|
|
} else if strings.HasPrefix(modelLower, "ollama/") ||
|
|
strings.Contains(modelLower, "glm-") ||
|
|
strings.Contains(modelLower, "qwen") ||
|
|
strings.Contains(modelLower, "gemma") ||
|
|
strings.Contains(modelLower, "llama") ||
|
|
strings.Contains(modelLower, "mistral") ||
|
|
strings.Contains(modelLower, "phi") ||
|
|
strings.Contains(modelLower, "yi") ||
|
|
strings.Contains(modelLower, "codellama") ||
|
|
strings.Contains(modelLower, "command-r") {
|
|
providerName = "ollama"
|
|
}
|
|
|
|
provider, ok := s.providers[providerName]
|
|
if !ok {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
|
return
|
|
}
|
|
|
|
// Use resolved model for the actual API call
|
|
req.Model = modelID
|
|
|
|
clientID := "default"
|
|
if auth, ok := c.Get("auth"); ok {
|
|
if authInfo, ok := auth.(models.AuthInfo); ok {
|
|
clientID = authInfo.ClientID
|
|
}
|
|
}
|
|
|
|
stream := req.Stream != nil && *req.Stream
|
|
|
|
if stream {
|
|
ch, err := provider.ResponsesStream(c.Request.Context(), &req)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
|
|
var lastUsage *models.ResponsesUsage
|
|
c.Stream(func(w io.Writer) bool {
|
|
chunk, ok := <-ch
|
|
if !ok {
|
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
if lastUsage != nil {
|
|
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage.ToUsage(), nil, false)
|
|
} else {
|
|
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false)
|
|
}
|
|
return false
|
|
}
|
|
// Capture usage from the response payload in streaming chunks
|
|
if chunk.Response != nil && chunk.Response.Usage != nil {
|
|
lastUsage = chunk.Response.Usage
|
|
}
|
|
data, err := json.Marshal(chunk)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
return true
|
|
})
|
|
return
|
|
}
|
|
|
|
resp, err := provider.Responses(c.Request.Context(), &req)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
if resp.Usage != nil {
|
|
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage.ToUsage(), nil, false)
|
|
} else {
|
|
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false)
|
|
}
|
|
c.JSON(http.StatusOK, resp)
|
|
}
|
|
|
|
func (s *Server) handleListModels(c *gin.Context) {
|
|
type OpenAIModel struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
OwnedBy string `json:"owned_by"`
|
|
}
|
|
|
|
modelMap := make(map[string]OpenAIModel)
|
|
allowedProviders := map[string]bool{
|
|
"openai": true,
|
|
"google": true, // Models from models.dev use 'google' ID for Gemini
|
|
"deepseek": true,
|
|
"moonshot": true,
|
|
"moonshotai": true, // Official moonshotai ID in models.dev
|
|
"moonshotai-cn": true, // Official moonshotai-cn ID in models.dev
|
|
"xai": true, // Models from models.dev use 'xai' ID for Grok
|
|
"llmgateway": true, // Catch-all for newer models
|
|
"ollama": true,
|
|
"xiaomi": true, // Xiaomi MiMo models
|
|
}
|
|
|
|
s.registryMu.RLock()
|
|
if s.registry != nil {
|
|
for pID, pInfo := range s.registry.Providers {
|
|
if !allowedProviders[pID] {
|
|
continue
|
|
}
|
|
for mID := range pInfo.Models {
|
|
if _, exists := modelMap[mID]; !exists {
|
|
modelMap[mID] = OpenAIModel{
|
|
ID: mID,
|
|
Object: "model",
|
|
Created: 1700000000,
|
|
OwnedBy: pID,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
s.registryMu.RUnlock()
|
|
|
|
// Add configured Ollama models
|
|
if s.cfg.Providers.Ollama.Enabled {
|
|
for _, mID := range s.cfg.Providers.Ollama.Models {
|
|
if _, exists := modelMap[mID]; !exists {
|
|
modelMap[mID] = OpenAIModel{
|
|
ID: mID,
|
|
Object: "model",
|
|
Created: 1700000000,
|
|
OwnedBy: "ollama",
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add model groups so clients can discover them
|
|
if s.modelRouter != nil {
|
|
for _, gid := range s.modelRouter.Groups() {
|
|
if _, exists := modelMap[gid]; !exists {
|
|
modelMap[gid] = OpenAIModel{
|
|
ID: gid,
|
|
Object: "model",
|
|
Created: 1700000000,
|
|
OwnedBy: "gophergate",
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
var data []OpenAIModel
|
|
for _, m := range modelMap {
|
|
data = append(data, m)
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"object": "list",
|
|
"data": data,
|
|
})
|
|
}
|
|
|
|
func (s *Server) selectProvider(modelID string) (providers.Provider, string, error) {
|
|
providerName := "openai" // default
|
|
modelLower := strings.ToLower(modelID)
|
|
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
|
|
providerName = "gemini"
|
|
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
|
|
providerName = "deepseek"
|
|
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
|
|
providerName = "moonshot"
|
|
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
|
|
providerName = "grok"
|
|
} else if strings.HasPrefix(modelLower, "ollama/") ||
|
|
strings.Contains(modelLower, "glm-") ||
|
|
strings.Contains(modelLower, "qwen") ||
|
|
strings.Contains(modelLower, "gemma") ||
|
|
strings.Contains(modelLower, "llama") ||
|
|
strings.Contains(modelLower, "mistral") ||
|
|
strings.Contains(modelLower, "phi") ||
|
|
strings.Contains(modelLower, "yi") ||
|
|
strings.Contains(modelLower, "codellama") ||
|
|
strings.Contains(modelLower, "command-r") {
|
|
providerName = "ollama"
|
|
} else if strings.HasPrefix(modelLower, "xiaomi/") || strings.Contains(modelLower, "mimo") || strings.Contains(modelLower, "xiaomi") {
|
|
providerName = "xiaomi"
|
|
}
|
|
|
|
p, ok := s.providers[providerName]
|
|
if !ok {
|
|
return nil, "", fmt.Errorf("Provider %s not enabled or supported", providerName)
|
|
}
|
|
return p, providerName, nil
|
|
}
|
|
|
|
func (s *Server) handleChatCompletions(c *gin.Context) {
|
|
startTime := time.Now()
|
|
var req models.ChatCompletionRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Strip common prefixes and prepare model ID
|
|
modelID := req.Model
|
|
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/", "xiaomi/"}
|
|
for _, p := range prefixes {
|
|
if strings.HasPrefix(modelID, p) {
|
|
modelID = strings.TrimPrefix(modelID, p)
|
|
break
|
|
}
|
|
}
|
|
|
|
// Resolve model groups to concrete models (hierarchical — groups can target groups)
|
|
modelGroup := ""
|
|
for i, m := range req.Messages {
|
|
log.Printf("[DEBUG] Incoming Msg[%d]: role=%s, hasToolCalls=%v, hasContent=%v", i, m.Role, len(m.ToolCalls) > 0, m.Content != nil)
|
|
}
|
|
if s.modelRouter != nil {
|
|
routeCtx := s.buildRouteContextFromChat(req)
|
|
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
|
|
return
|
|
}
|
|
if decision.SelectedModel != modelID {
|
|
modelGroup = modelID
|
|
}
|
|
modelID = decision.SelectedModel
|
|
log.Printf("[ROUTER] %s (%s: %s)", modelID, decision.Strategy, decision.Reason)
|
|
}
|
|
|
|
// Select provider based on the resolved model name
|
|
provider, providerName, err := s.selectProvider(modelID)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Convert ChatCompletionRequest to UnifiedRequest
|
|
unifiedReq := &models.UnifiedRequest{
|
|
Model: modelID,
|
|
Messages: []models.UnifiedMessage{},
|
|
Temperature: req.Temperature,
|
|
TopP: req.TopP,
|
|
TopK: req.TopK,
|
|
N: req.N,
|
|
MaxTokens: req.MaxTokens,
|
|
PresencePenalty: req.PresencePenalty,
|
|
FrequencyPenalty: req.FrequencyPenalty,
|
|
Stream: req.Stream != nil && *req.Stream,
|
|
Tools: req.Tools,
|
|
ToolChoice: req.ToolChoice,
|
|
}
|
|
|
|
// Inject or cap max_tokens from model registry.
|
|
s.registryMu.RLock()
|
|
meta := s.registry.FindModel(modelID)
|
|
s.registryMu.RUnlock()
|
|
|
|
if meta != nil && meta.Limit != nil && meta.Limit.Output > 0 {
|
|
if unifiedReq.MaxTokens == nil {
|
|
unifiedReq.MaxTokens = &meta.Limit.Output
|
|
log.Printf("[DEBUG] %s: injected registry max_tokens=%d", modelID, meta.Limit.Output)
|
|
} else if *unifiedReq.MaxTokens > meta.Limit.Output {
|
|
log.Printf("[DEBUG] %s: capping client max_tokens (%d) to registry limit (%d)", modelID, *unifiedReq.MaxTokens, meta.Limit.Output)
|
|
unifiedReq.MaxTokens = &meta.Limit.Output
|
|
} else {
|
|
log.Printf("[DEBUG] %s: using client max_tokens (%d)", modelID, *unifiedReq.MaxTokens)
|
|
}
|
|
} else {
|
|
if unifiedReq.MaxTokens == nil {
|
|
log.Printf("[DEBUG] %s: no registry limit found, leaving max_tokens nil", modelID)
|
|
} else {
|
|
log.Printf("[DEBUG] %s: using client max_tokens (%d), no registry limit to cap", modelID, *unifiedReq.MaxTokens)
|
|
}
|
|
}
|
|
|
|
// Handle Stop sequences
|
|
if req.Stop != nil {
|
|
var stop []string
|
|
if err := json.Unmarshal(req.Stop, &stop); err == nil {
|
|
unifiedReq.Stop = stop
|
|
} else {
|
|
var singleStop string
|
|
if err := json.Unmarshal(req.Stop, &singleStop); err == nil {
|
|
unifiedReq.Stop = []string{singleStop}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert messages
|
|
for _, msg := range req.Messages {
|
|
unifiedMsg := models.UnifiedMessage{
|
|
Role: msg.Role,
|
|
Content: []models.UnifiedContentPart{},
|
|
ReasoningContent: msg.ReasoningContent,
|
|
ToolCalls: msg.ToolCalls,
|
|
Name: msg.Name,
|
|
ToolCallID: msg.ToolCallID,
|
|
}
|
|
|
|
// Handle multimodal content
|
|
if strContent, ok := msg.Content.(string); ok {
|
|
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
|
Type: "text",
|
|
Text: strContent,
|
|
})
|
|
} else if parts, ok := msg.Content.([]interface{}); ok {
|
|
for _, part := range parts {
|
|
if partMap, ok := part.(map[string]interface{}); ok {
|
|
partType, _ := partMap["type"].(string)
|
|
if partType == "text" {
|
|
text, _ := partMap["text"].(string)
|
|
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
|
Type: "text",
|
|
Text: text,
|
|
})
|
|
} else if partType == "image_url" {
|
|
if imgURLMap, ok := partMap["image_url"].(map[string]interface{}); ok {
|
|
url, _ := imgURLMap["url"].(string)
|
|
imageInput := &models.ImageInput{}
|
|
if strings.HasPrefix(url, "data:") {
|
|
mime, data, err := utils.ParseDataURL(url)
|
|
if err == nil {
|
|
imageInput.Base64 = data
|
|
imageInput.MimeType = mime
|
|
}
|
|
} else {
|
|
imageInput.URL = url
|
|
}
|
|
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
|
Type: "image",
|
|
Image: imageInput,
|
|
})
|
|
unifiedReq.HasImages = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
unifiedReq.Messages = append(unifiedReq.Messages, unifiedMsg)
|
|
}
|
|
|
|
clientID := "default"
|
|
if auth, ok := c.Get("auth"); ok {
|
|
if authInfo, ok := auth.(models.AuthInfo); ok {
|
|
unifiedReq.ClientID = authInfo.ClientID
|
|
clientID = authInfo.ClientID
|
|
}
|
|
} else {
|
|
unifiedReq.ClientID = clientID
|
|
}
|
|
|
|
if unifiedReq.Stream {
|
|
ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, unifiedReq.HasImages)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
|
|
var lastUsage *models.Usage
|
|
c.Stream(func(w io.Writer) bool {
|
|
chunk, ok := <-ch
|
|
if !ok {
|
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage, nil, unifiedReq.HasImages)
|
|
return false
|
|
}
|
|
if chunk.Usage != nil {
|
|
lastUsage = chunk.Usage
|
|
}
|
|
data, err := json.Marshal(chunk)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
return true
|
|
})
|
|
return
|
|
}
|
|
|
|
resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, unifiedReq.HasImages)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage, nil, unifiedReq.HasImages)
|
|
c.JSON(http.StatusOK, resp)
|
|
}
|
|
|
|
func extractUserMessage(messages []models.ChatMessage) string {
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
if messages[i].Role == "user" {
|
|
switch c := messages[i].Content.(type) {
|
|
case string:
|
|
return c
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (s *Server) handleImageGenerations(c *gin.Context) {
|
|
startTime := time.Now()
|
|
var req models.ImageGenerationRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Determine provider based on model name
|
|
providerName := "openai"
|
|
modelLower := strings.ToLower(req.Model)
|
|
switch {
|
|
case strings.Contains(modelLower, "imagen"), strings.Contains(modelLower, "gemini"):
|
|
providerName = "gemini"
|
|
case strings.Contains(modelLower, "dall"), strings.HasPrefix(modelLower, "openai/"):
|
|
providerName = "openai"
|
|
}
|
|
|
|
// Default model for each provider if not specified
|
|
if req.Model == "" {
|
|
if providerName == "openai" {
|
|
req.Model = "dall-e-3"
|
|
} else {
|
|
req.Model = "imagen-3.0-generate-001"
|
|
}
|
|
}
|
|
|
|
// Strip common prefixes
|
|
prefixes := []string{"openai/", "gemini/", "google/"}
|
|
for _, p := range prefixes {
|
|
if strings.HasPrefix(req.Model, p) {
|
|
req.Model = strings.TrimPrefix(req.Model, p)
|
|
break
|
|
}
|
|
}
|
|
|
|
provider, ok := s.providers[providerName]
|
|
if !ok {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
|
return
|
|
}
|
|
|
|
clientID := "default"
|
|
if auth, ok := c.Get("auth"); ok {
|
|
if authInfo, ok := auth.(models.AuthInfo); ok {
|
|
clientID = authInfo.ClientID
|
|
}
|
|
}
|
|
|
|
resp, err := provider.ImageGeneration(c.Request.Context(), &req)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Estimate tokens from prompt text (~4 chars per token)
|
|
promptTokens := uint32(len(req.Prompt) / 4)
|
|
if promptTokens < 1 {
|
|
promptTokens = 1
|
|
}
|
|
|
|
// Calculate per-image cost (not per-token like chat)
|
|
cost := imageGenCost(providerName, req.Model, req.Size, uint32(len(resp.Data)))
|
|
|
|
s.logRequest(startTime, clientID, providerName, req.Model, "", &models.Usage{
|
|
PromptTokens: promptTokens,
|
|
CompletionTokens: uint32(len(resp.Data)),
|
|
TotalTokens: promptTokens + uint32(len(resp.Data)),
|
|
}, nil, false)
|
|
|
|
// Update cost in DB — image gen is per-image, not per-token
|
|
if cost > 0 {
|
|
s.database.Exec("UPDATE llm_requests SET cost = ? WHERE id = (SELECT MAX(id) FROM llm_requests)", cost)
|
|
}
|
|
|
|
c.JSON(http.StatusOK, resp)
|
|
}
|
|
|
|
// imageGenCost returns per-image pricing for known image generation models.
|
|
func imageGenCost(provider, model string, size *string, n uint32) float64 {
|
|
if n == 0 {
|
|
return 0
|
|
}
|
|
modelLower := strings.ToLower(model)
|
|
var perImage float64
|
|
|
|
switch {
|
|
case strings.Contains(modelLower, "dall-e-3"):
|
|
perImage = 0.040 // standard 1024x1024
|
|
if size != nil {
|
|
s := *size
|
|
if s == "1024x1792" || s == "1792x1024" {
|
|
perImage = 0.080
|
|
}
|
|
}
|
|
case strings.Contains(modelLower, "dall-e-2"):
|
|
perImage = 0.020
|
|
case strings.Contains(modelLower, "imagen"):
|
|
perImage = 0.040 // approximate
|
|
default:
|
|
return 0
|
|
}
|
|
|
|
return perImage * float64(n)
|
|
}
|
|
|
|
func (s *Server) logRequest(start time.Time, clientID, provider, model, modelGroup string, usage *models.Usage, err error, hasImages bool) {
|
|
entry := RequestLog{
|
|
Timestamp: start,
|
|
ClientID: clientID,
|
|
Provider: provider,
|
|
Model: model,
|
|
ModelGroup: modelGroup,
|
|
Status: "success",
|
|
DurationMS: time.Since(start).Milliseconds(),
|
|
HasImages: hasImages,
|
|
}
|
|
|
|
if err != nil {
|
|
entry.Status = "error"
|
|
entry.ErrorMessage = err.Error()
|
|
}
|
|
|
|
if usage != nil {
|
|
entry.PromptTokens = usage.PromptTokens
|
|
entry.CompletionTokens = usage.CompletionTokens
|
|
entry.TotalTokens = usage.TotalTokens
|
|
if usage.ReasoningTokens != nil {
|
|
entry.ReasoningTokens = *usage.ReasoningTokens
|
|
}
|
|
if usage.CacheReadTokens != nil {
|
|
entry.CacheReadTokens = *usage.CacheReadTokens
|
|
}
|
|
if usage.CacheWriteTokens != nil {
|
|
entry.CacheWriteTokens = *usage.CacheWriteTokens
|
|
}
|
|
|
|
// Calculate cost using registry; if the resolved model is unknown,
|
|
// fall back to the model group so group requests still get priced.
|
|
s.registryMu.RLock()
|
|
pricingModel := model
|
|
if s.registry != nil && s.registry.FindModel(pricingModel) == nil && modelGroup != "" {
|
|
pricingModel = modelGroup
|
|
}
|
|
entry.Cost = utils.CalculateCost(s.registry, pricingModel, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
|
|
s.registryMu.RUnlock()
|
|
}
|
|
|
|
s.logger.LogRequest(entry)
|
|
}
|
|
|
|
func (s *Server) Run() error {
|
|
go s.hub.Run()
|
|
s.logger.Start()
|
|
|
|
// Start registry refresher
|
|
go func() {
|
|
ticker := time.NewTicker(24 * time.Hour)
|
|
for range ticker.C {
|
|
newRegistry, err := utils.FetchRegistry()
|
|
if err == nil {
|
|
s.registryMu.Lock()
|
|
s.registry = newRegistry
|
|
s.registryMu.Unlock()
|
|
}
|
|
}
|
|
}()
|
|
|
|
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
|
|
return s.router.Run(addr)
|
|
}
|
|
|
|
func uint32Ptr(v uint32) *uint32 { return &v }
|
|
|
|
func (s *Server) buildRouteContextFromChat(req models.ChatCompletionRequest) *router.RouteContext {
|
|
userMessage := extractUserMessage(req.Messages)
|
|
requiresToolCalling := len(req.Tools) > 0
|
|
hasMultimodal := false
|
|
inputTokens := 0
|
|
|
|
for _, msg := range req.Messages {
|
|
if strContent, ok := msg.Content.(string); ok {
|
|
inputTokens += len(strContent) / 4
|
|
} else if parts, ok := msg.Content.([]interface{}); ok {
|
|
for _, part := range parts {
|
|
if partMap, ok := part.(map[string]interface{}); ok {
|
|
partType, _ := partMap["type"].(string)
|
|
if partType == "text" {
|
|
text, _ := partMap["text"].(string)
|
|
inputTokens += len(text) / 4
|
|
} else if partType == "image_url" {
|
|
hasMultimodal = true
|
|
inputTokens += 1000 // Approximate cost of an image in tokens
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
msgLower := strings.ToLower(userMessage)
|
|
requiresReasoning := strings.Contains(msgLower, "reason") ||
|
|
strings.Contains(msgLower, "think step by step") ||
|
|
strings.Contains(msgLower, "mathematics") ||
|
|
strings.Contains(msgLower, "architecture") ||
|
|
strings.Contains(msgLower, "explain in detail")
|
|
|
|
routeCtx := &router.RouteContext{
|
|
UserMessage: userMessage,
|
|
InputTokens: inputTokens,
|
|
HasMultimodalInput: hasMultimodal,
|
|
RequiresToolCalling: requiresToolCalling,
|
|
RequiresReasoning: requiresReasoning,
|
|
}
|
|
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
|
|
return routeCtx
|
|
}
|
|
|
|
func (s *Server) buildRouteContextFromResponses(req models.ResponsesRequest) *router.RouteContext {
|
|
var userMessage string
|
|
hasMultimodal := false
|
|
inputTokens := len(req.Instructions) / 4
|
|
requiresToolCalling := len(req.Tools) > 0 && string(req.Tools) != "null" && string(req.Tools) != ""
|
|
|
|
var strInput string
|
|
if err := json.Unmarshal(req.Input, &strInput); err == nil {
|
|
userMessage = strInput
|
|
inputTokens += len(userMessage) / 4
|
|
} else {
|
|
var msgs []models.ResponseInputMessage
|
|
if err := json.Unmarshal(req.Input, &msgs); err == nil {
|
|
for _, m := range msgs {
|
|
var contentStr string
|
|
if err := json.Unmarshal(m.Content, &contentStr); err == nil {
|
|
if m.Role == "user" {
|
|
userMessage = contentStr
|
|
}
|
|
inputTokens += len(contentStr) / 4
|
|
} else {
|
|
var parts []models.ContentPart
|
|
if err := json.Unmarshal(m.Content, &parts); err == nil {
|
|
for _, p := range parts {
|
|
if p.Type == "text" {
|
|
if m.Role == "user" {
|
|
userMessage = p.Text
|
|
}
|
|
inputTokens += len(p.Text) / 4
|
|
} else if p.Type == "image_url" {
|
|
hasMultimodal = true
|
|
inputTokens += 1000
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
msgLower := strings.ToLower(userMessage)
|
|
requiresReasoning := strings.Contains(msgLower, "reason") ||
|
|
strings.Contains(msgLower, "think step by step") ||
|
|
strings.Contains(msgLower, "mathematics") ||
|
|
strings.Contains(msgLower, "architecture") ||
|
|
strings.Contains(msgLower, "explain in detail")
|
|
|
|
routeCtx := &router.RouteContext{
|
|
UserMessage: userMessage,
|
|
InputTokens: inputTokens,
|
|
HasMultimodalInput: hasMultimodal,
|
|
RequiresToolCalling: requiresToolCalling,
|
|
RequiresReasoning: requiresReasoning,
|
|
}
|
|
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
|
|
return routeCtx
|
|
}
|
|
|
|
func (s *Server) getRouteCtxTags(routeCtx *router.RouteContext) []string {
|
|
var tags []string
|
|
msgLower := strings.ToLower(routeCtx.UserMessage)
|
|
|
|
// fast-flow keywords
|
|
fastFlowKeywords := []string{
|
|
"classify", "classification", "label", "tag", "route", "routing", "intent",
|
|
"json", "yaml", "csv", "schema", "parse", "extract", "transform", "format", "regex",
|
|
"short answer", "brief", "concise", "tl;dr", "one line", "simple",
|
|
"fix this", "small bug", "quick fix", "typo", "syntax error",
|
|
}
|
|
for _, kw := range fastFlowKeywords {
|
|
if strings.Contains(msgLower, kw) {
|
|
tags = append(tags, "fast-flow", "classification", "json-extraction", "basic-qa")
|
|
break
|
|
}
|
|
}
|
|
|
|
// standard-pro keywords
|
|
standardProKeywords := []string{
|
|
"explain", "summarize", "rewrite", "draft", "edit", "polish", "outline",
|
|
"long doc", "document", "email", "memo", "proposal", "report", "handout", "notes",
|
|
"compare", "choose", "recommend", "tradeoff", "pros and cons", "analysis",
|
|
"code review", "debug", "bug", "feature", "api", "endpoint", "implement",
|
|
"plan", "planning", "workflow", "integration",
|
|
}
|
|
for _, kw := range standardProKeywords {
|
|
if strings.Contains(msgLower, kw) {
|
|
tags = append(tags, "standard-pro", "long-doc")
|
|
break
|
|
}
|
|
}
|
|
if routeCtx.HasMultimodalInput {
|
|
tags = append(tags, "video-analysis", "multimodal-qa")
|
|
}
|
|
|
|
// heavy-logic keywords
|
|
heavyLogicKeywords := []string{
|
|
"agent", "agents", "tool use", "function calling", "multi-agent", "orchestrate",
|
|
"system design", "scaling", "performance", "architecture review", "distributed",
|
|
"hard bug", "race condition", "deadlock", "memory leak", "crash", "production outage",
|
|
"long context", "large codebase", "many files", "complex refactor", "migration",
|
|
"research", "deep dive", "literature", "paper", "scholarly", "thorough analysis",
|
|
"deep reasoning", "think step by step", "reason through", "careful analysis",
|
|
}
|
|
for _, kw := range heavyLogicKeywords {
|
|
if strings.Contains(msgLower, kw) {
|
|
tags = append(tags, "heavy-logic", "deep-reasoning", "architecture", "hard-debugging")
|
|
break
|
|
}
|
|
}
|
|
|
|
if routeCtx.RequiresToolCalling {
|
|
tags = append(tags, "tool-heavy", "multi-step-agent", "swe-bench")
|
|
}
|
|
|
|
return tags
|
|
}
|