From 10262c0e5adc28089c4704138a42491d567e4676 Mon Sep 17 00:00:00 2001 From: hobokenchicken Date: Tue, 5 May 2026 10:47:32 -0400 Subject: [PATCH] feat: wire model group router into chat completions handler --- internal/server/server.go | 76 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/internal/server/server.go b/internal/server/server.go index ec4c1453..95e297a9 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,6 +2,7 @@ package server import ( "encoding/json" + "context" "fmt" "io" "log" @@ -15,6 +16,7 @@ import ( "gophergate/internal/middleware" "gophergate/internal/models" "gophergate/internal/providers" + "gophergate/internal/router" "gophergate/internal/utils" "github.com/gin-gonic/gin" @@ -30,6 +32,7 @@ type Server struct { logger *RequestLogger registry *models.ModelRegistry registryMu sync.RWMutex + modelRouter *router.Router } func NewServer(cfg *config.Config, database *db.DB) *Server { @@ -64,6 +67,9 @@ func NewServer(cfg *config.Config, database *db.DB) *Server { } s.setupRoutes() + + // Initialize model group router + s.refreshRouter() return s } @@ -168,9 +174,51 @@ func (s *Server) RefreshProviders() error { } } + s.refreshRouter() return nil } +func (s *Server) refreshRouter() { + var groups []db.ModelGroup + if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil { + fmt.Printf("Warning: Failed to load model groups: %v\n", err) + groups = nil + } + + var classifyFn router.ClassifierFunc + if openaiProvider, ok := s.providers["openai"]; ok { + classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) { + req := &models.UnifiedRequest{ + Model: selectorModel, + Messages: []models.UnifiedMessage{ + {Role: "system", Content: []models.UnifiedContentPart{{Type: "text", Text: systemPrompt}}}, + {Role: "user", Content: []models.UnifiedContentPart{{Type: "text", Text: userMessage}}}, + }, + MaxTokens: uint32Ptr(5), + Stream: false, + } + resp, err := openaiProvider.ChatCompletion(ctx, req) + if err != nil { + return "", err + } + if len(resp.Choices) == 0 { + return "", fmt.Errorf("no choices in classifier response") + } + content, ok := resp.Choices[0].Message.Content.(string) + if !ok { + return "", fmt.Errorf("classifier response content is not a string") + } + return content, nil + } + } + + if s.modelRouter == nil { + s.modelRouter = router.New(groups, classifyFn) + } else { + s.modelRouter.Reload(groups) + } +} + func (s *Server) setupRoutes() { // Static files s.router.StaticFile("/", "./static/index.html") @@ -474,6 +522,18 @@ func (s *Server) handleChatCompletions(c *gin.Context) { } } + // Check if model is a group and route to a concrete model + if s.modelRouter != nil && s.modelRouter.IsGroup(modelID) { + userMessage := extractUserMessage(req.Messages) + decision, err := s.modelRouter.Route(c.Request.Context(), modelID, userMessage) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)}) + return + } + modelID = decision.SelectedModel + log.Printf("[ROUTER] %s -> %s (%s: %s)", req.Model, modelID, decision.Strategy, decision.Reason) + } + // Convert ChatCompletionRequest to UnifiedRequest unifiedReq := &models.UnifiedRequest{ Model: modelID, @@ -633,6 +693,20 @@ if unifiedReq.MaxTokens == nil { c.JSON(http.StatusOK, resp) } +func extractUserMessage(messages []models.ChatMessage) string { + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == "user" { + switch c := messages[i].Content.(type) { + case string: + return c + default: + return "" + } + } + } + return "" +} + func (s *Server) handleImageGenerations(c *gin.Context) { startTime := time.Now() var req models.ImageGenerationRequest @@ -799,3 +873,5 @@ func (s *Server) Run() error { addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port) return s.router.Run(addr) } + +func uint32Ptr(v uint32) *uint32 { return &v }