feat: wire model group router into chat completions handler
This commit is contained in:
@@ -2,6 +2,7 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
@@ -15,6 +16,7 @@ import (
|
|||||||
"gophergate/internal/middleware"
|
"gophergate/internal/middleware"
|
||||||
"gophergate/internal/models"
|
"gophergate/internal/models"
|
||||||
"gophergate/internal/providers"
|
"gophergate/internal/providers"
|
||||||
|
"gophergate/internal/router"
|
||||||
"gophergate/internal/utils"
|
"gophergate/internal/utils"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -30,6 +32,7 @@ type Server struct {
|
|||||||
logger *RequestLogger
|
logger *RequestLogger
|
||||||
registry *models.ModelRegistry
|
registry *models.ModelRegistry
|
||||||
registryMu sync.RWMutex
|
registryMu sync.RWMutex
|
||||||
|
modelRouter *router.Router
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(cfg *config.Config, database *db.DB) *Server {
|
func NewServer(cfg *config.Config, database *db.DB) *Server {
|
||||||
@@ -64,6 +67,9 @@ func NewServer(cfg *config.Config, database *db.DB) *Server {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.setupRoutes()
|
s.setupRoutes()
|
||||||
|
|
||||||
|
// Initialize model group router
|
||||||
|
s.refreshRouter()
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,9 +174,51 @@ func (s *Server) RefreshProviders() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.refreshRouter()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) refreshRouter() {
|
||||||
|
var groups []db.ModelGroup
|
||||||
|
if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil {
|
||||||
|
fmt.Printf("Warning: Failed to load model groups: %v\n", err)
|
||||||
|
groups = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var classifyFn router.ClassifierFunc
|
||||||
|
if openaiProvider, ok := s.providers["openai"]; ok {
|
||||||
|
classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) {
|
||||||
|
req := &models.UnifiedRequest{
|
||||||
|
Model: selectorModel,
|
||||||
|
Messages: []models.UnifiedMessage{
|
||||||
|
{Role: "system", Content: []models.UnifiedContentPart{{Type: "text", Text: systemPrompt}}},
|
||||||
|
{Role: "user", Content: []models.UnifiedContentPart{{Type: "text", Text: userMessage}}},
|
||||||
|
},
|
||||||
|
MaxTokens: uint32Ptr(5),
|
||||||
|
Stream: false,
|
||||||
|
}
|
||||||
|
resp, err := openaiProvider.ChatCompletion(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if len(resp.Choices) == 0 {
|
||||||
|
return "", fmt.Errorf("no choices in classifier response")
|
||||||
|
}
|
||||||
|
content, ok := resp.Choices[0].Message.Content.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("classifier response content is not a string")
|
||||||
|
}
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.modelRouter == nil {
|
||||||
|
s.modelRouter = router.New(groups, classifyFn)
|
||||||
|
} else {
|
||||||
|
s.modelRouter.Reload(groups)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) setupRoutes() {
|
func (s *Server) setupRoutes() {
|
||||||
// Static files
|
// Static files
|
||||||
s.router.StaticFile("/", "./static/index.html")
|
s.router.StaticFile("/", "./static/index.html")
|
||||||
@@ -474,6 +522,18 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if model is a group and route to a concrete model
|
||||||
|
if s.modelRouter != nil && s.modelRouter.IsGroup(modelID) {
|
||||||
|
userMessage := extractUserMessage(req.Messages)
|
||||||
|
decision, err := s.modelRouter.Route(c.Request.Context(), modelID, userMessage)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
modelID = decision.SelectedModel
|
||||||
|
log.Printf("[ROUTER] %s -> %s (%s: %s)", req.Model, modelID, decision.Strategy, decision.Reason)
|
||||||
|
}
|
||||||
|
|
||||||
// Convert ChatCompletionRequest to UnifiedRequest
|
// Convert ChatCompletionRequest to UnifiedRequest
|
||||||
unifiedReq := &models.UnifiedRequest{
|
unifiedReq := &models.UnifiedRequest{
|
||||||
Model: modelID,
|
Model: modelID,
|
||||||
@@ -633,6 +693,20 @@ if unifiedReq.MaxTokens == nil {
|
|||||||
c.JSON(http.StatusOK, resp)
|
c.JSON(http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extractUserMessage(messages []models.ChatMessage) string {
|
||||||
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
|
if messages[i].Role == "user" {
|
||||||
|
switch c := messages[i].Content.(type) {
|
||||||
|
case string:
|
||||||
|
return c
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) handleImageGenerations(c *gin.Context) {
|
func (s *Server) handleImageGenerations(c *gin.Context) {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
var req models.ImageGenerationRequest
|
var req models.ImageGenerationRequest
|
||||||
@@ -799,3 +873,5 @@ func (s *Server) Run() error {
|
|||||||
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
|
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
|
||||||
return s.router.Run(addr)
|
return s.router.Run(addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func uint32Ptr(v uint32) *uint32 { return &v }
|
||||||
|
|||||||
Reference in New Issue
Block a user