feat: wire model group router into chat completions handler

This commit is contained in:
2026-05-05 10:47:32 -04:00
parent d345f8c41d
commit 10262c0e5a
+76
View File
@@ -2,6 +2,7 @@ package server
import (
"encoding/json"
"context"
"fmt"
"io"
"log"
@@ -15,6 +16,7 @@ import (
"gophergate/internal/middleware"
"gophergate/internal/models"
"gophergate/internal/providers"
"gophergate/internal/router"
"gophergate/internal/utils"
"github.com/gin-gonic/gin"
@@ -30,6 +32,7 @@ type Server struct {
logger *RequestLogger
registry *models.ModelRegistry
registryMu sync.RWMutex
modelRouter *router.Router
}
func NewServer(cfg *config.Config, database *db.DB) *Server {
@@ -64,6 +67,9 @@ func NewServer(cfg *config.Config, database *db.DB) *Server {
}
s.setupRoutes()
// Initialize model group router
s.refreshRouter()
return s
}
@@ -168,9 +174,51 @@ func (s *Server) RefreshProviders() error {
}
}
s.refreshRouter()
return nil
}
func (s *Server) refreshRouter() {
var groups []db.ModelGroup
if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil {
fmt.Printf("Warning: Failed to load model groups: %v\n", err)
groups = nil
}
var classifyFn router.ClassifierFunc
if openaiProvider, ok := s.providers["openai"]; ok {
classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) {
req := &models.UnifiedRequest{
Model: selectorModel,
Messages: []models.UnifiedMessage{
{Role: "system", Content: []models.UnifiedContentPart{{Type: "text", Text: systemPrompt}}},
{Role: "user", Content: []models.UnifiedContentPart{{Type: "text", Text: userMessage}}},
},
MaxTokens: uint32Ptr(5),
Stream: false,
}
resp, err := openaiProvider.ChatCompletion(ctx, req)
if err != nil {
return "", err
}
if len(resp.Choices) == 0 {
return "", fmt.Errorf("no choices in classifier response")
}
content, ok := resp.Choices[0].Message.Content.(string)
if !ok {
return "", fmt.Errorf("classifier response content is not a string")
}
return content, nil
}
}
if s.modelRouter == nil {
s.modelRouter = router.New(groups, classifyFn)
} else {
s.modelRouter.Reload(groups)
}
}
func (s *Server) setupRoutes() {
// Static files
s.router.StaticFile("/", "./static/index.html")
@@ -474,6 +522,18 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
}
}
// Check if model is a group and route to a concrete model
if s.modelRouter != nil && s.modelRouter.IsGroup(modelID) {
userMessage := extractUserMessage(req.Messages)
decision, err := s.modelRouter.Route(c.Request.Context(), modelID, userMessage)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
return
}
modelID = decision.SelectedModel
log.Printf("[ROUTER] %s -> %s (%s: %s)", req.Model, modelID, decision.Strategy, decision.Reason)
}
// Convert ChatCompletionRequest to UnifiedRequest
unifiedReq := &models.UnifiedRequest{
Model: modelID,
@@ -633,6 +693,20 @@ if unifiedReq.MaxTokens == nil {
c.JSON(http.StatusOK, resp)
}
func extractUserMessage(messages []models.ChatMessage) string {
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
switch c := messages[i].Content.(type) {
case string:
return c
default:
return ""
}
}
}
return ""
}
func (s *Server) handleImageGenerations(c *gin.Context) {
startTime := time.Now()
var req models.ImageGenerationRequest
@@ -799,3 +873,5 @@ func (s *Server) Run() error {
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
return s.router.Run(addr)
}
func uint32Ptr(v uint32) *uint32 { return &v }