feat: wire model group router into chat completions handler
This commit is contained in:
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"gophergate/internal/middleware"
|
||||
"gophergate/internal/models"
|
||||
"gophergate/internal/providers"
|
||||
"gophergate/internal/router"
|
||||
"gophergate/internal/utils"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -30,6 +32,7 @@ type Server struct {
|
||||
logger *RequestLogger
|
||||
registry *models.ModelRegistry
|
||||
registryMu sync.RWMutex
|
||||
modelRouter *router.Router
|
||||
}
|
||||
|
||||
func NewServer(cfg *config.Config, database *db.DB) *Server {
|
||||
@@ -64,6 +67,9 @@ func NewServer(cfg *config.Config, database *db.DB) *Server {
|
||||
}
|
||||
|
||||
s.setupRoutes()
|
||||
|
||||
// Initialize model group router
|
||||
s.refreshRouter()
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -168,9 +174,51 @@ func (s *Server) RefreshProviders() error {
|
||||
}
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) refreshRouter() {
|
||||
var groups []db.ModelGroup
|
||||
if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil {
|
||||
fmt.Printf("Warning: Failed to load model groups: %v\n", err)
|
||||
groups = nil
|
||||
}
|
||||
|
||||
var classifyFn router.ClassifierFunc
|
||||
if openaiProvider, ok := s.providers["openai"]; ok {
|
||||
classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) {
|
||||
req := &models.UnifiedRequest{
|
||||
Model: selectorModel,
|
||||
Messages: []models.UnifiedMessage{
|
||||
{Role: "system", Content: []models.UnifiedContentPart{{Type: "text", Text: systemPrompt}}},
|
||||
{Role: "user", Content: []models.UnifiedContentPart{{Type: "text", Text: userMessage}}},
|
||||
},
|
||||
MaxTokens: uint32Ptr(5),
|
||||
Stream: false,
|
||||
}
|
||||
resp, err := openaiProvider.ChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", fmt.Errorf("no choices in classifier response")
|
||||
}
|
||||
content, ok := resp.Choices[0].Message.Content.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("classifier response content is not a string")
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
}
|
||||
|
||||
if s.modelRouter == nil {
|
||||
s.modelRouter = router.New(groups, classifyFn)
|
||||
} else {
|
||||
s.modelRouter.Reload(groups)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) setupRoutes() {
|
||||
// Static files
|
||||
s.router.StaticFile("/", "./static/index.html")
|
||||
@@ -474,6 +522,18 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Check if model is a group and route to a concrete model
|
||||
if s.modelRouter != nil && s.modelRouter.IsGroup(modelID) {
|
||||
userMessage := extractUserMessage(req.Messages)
|
||||
decision, err := s.modelRouter.Route(c.Request.Context(), modelID, userMessage)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
|
||||
return
|
||||
}
|
||||
modelID = decision.SelectedModel
|
||||
log.Printf("[ROUTER] %s -> %s (%s: %s)", req.Model, modelID, decision.Strategy, decision.Reason)
|
||||
}
|
||||
|
||||
// Convert ChatCompletionRequest to UnifiedRequest
|
||||
unifiedReq := &models.UnifiedRequest{
|
||||
Model: modelID,
|
||||
@@ -633,6 +693,20 @@ if unifiedReq.MaxTokens == nil {
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func extractUserMessage(messages []models.ChatMessage) string {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == "user" {
|
||||
switch c := messages[i].Content.(type) {
|
||||
case string:
|
||||
return c
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *Server) handleImageGenerations(c *gin.Context) {
|
||||
startTime := time.Now()
|
||||
var req models.ImageGenerationRequest
|
||||
@@ -799,3 +873,5 @@ func (s *Server) Run() error {
|
||||
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
|
||||
return s.router.Run(addr)
|
||||
}
|
||||
|
||||
func uint32Ptr(v uint32) *uint32 { return &v }
|
||||
|
||||
Reference in New Issue
Block a user