feat: implement advanced condition-based heuristic model routing
Upgrades the routing engine to support tag, token limit, multimodal, reasoning, and tool calling conditions. Adds unit tests for the new routing features.
This commit is contained in:
+164
-3
@@ -329,7 +329,8 @@ func (s *Server) handleResponses(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
if s.modelRouter != nil {
|
||||
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, "")
|
||||
routeCtx := s.buildRouteContextFromResponses(req)
|
||||
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
|
||||
return
|
||||
@@ -573,8 +574,8 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||
log.Printf("[DEBUG] Incoming Msg[%d]: role=%s, hasToolCalls=%v, hasContent=%v", i, m.Role, len(m.ToolCalls) > 0, m.Content != nil)
|
||||
}
|
||||
if s.modelRouter != nil {
|
||||
userMessage := extractUserMessage(req.Messages)
|
||||
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, userMessage)
|
||||
routeCtx := s.buildRouteContextFromChat(req)
|
||||
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
|
||||
return
|
||||
@@ -941,3 +942,163 @@ func (s *Server) Run() error {
|
||||
}
|
||||
|
||||
func uint32Ptr(v uint32) *uint32 { return &v }
|
||||
|
||||
func (s *Server) buildRouteContextFromChat(req models.ChatCompletionRequest) *router.RouteContext {
|
||||
userMessage := extractUserMessage(req.Messages)
|
||||
requiresToolCalling := len(req.Tools) > 0
|
||||
hasMultimodal := false
|
||||
inputTokens := 0
|
||||
|
||||
for _, msg := range req.Messages {
|
||||
if strContent, ok := msg.Content.(string); ok {
|
||||
inputTokens += len(strContent) / 4
|
||||
} else if parts, ok := msg.Content.([]interface{}); ok {
|
||||
for _, part := range parts {
|
||||
if partMap, ok := part.(map[string]interface{}); ok {
|
||||
partType, _ := partMap["type"].(string)
|
||||
if partType == "text" {
|
||||
text, _ := partMap["text"].(string)
|
||||
inputTokens += len(text) / 4
|
||||
} else if partType == "image_url" {
|
||||
hasMultimodal = true
|
||||
inputTokens += 1000 // Approximate cost of an image in tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msgLower := strings.ToLower(userMessage)
|
||||
requiresReasoning := strings.Contains(msgLower, "reason") ||
|
||||
strings.Contains(msgLower, "think step by step") ||
|
||||
strings.Contains(msgLower, "mathematics") ||
|
||||
strings.Contains(msgLower, "architecture") ||
|
||||
strings.Contains(msgLower, "explain in detail")
|
||||
|
||||
routeCtx := &router.RouteContext{
|
||||
UserMessage: userMessage,
|
||||
InputTokens: inputTokens,
|
||||
HasMultimodalInput: hasMultimodal,
|
||||
RequiresToolCalling: requiresToolCalling,
|
||||
RequiresReasoning: requiresReasoning,
|
||||
}
|
||||
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
|
||||
return routeCtx
|
||||
}
|
||||
|
||||
func (s *Server) buildRouteContextFromResponses(req models.ResponsesRequest) *router.RouteContext {
|
||||
var userMessage string
|
||||
hasMultimodal := false
|
||||
inputTokens := len(req.Instructions) / 4
|
||||
requiresToolCalling := len(req.Tools) > 0 && string(req.Tools) != "null" && string(req.Tools) != ""
|
||||
|
||||
var strInput string
|
||||
if err := json.Unmarshal(req.Input, &strInput); err == nil {
|
||||
userMessage = strInput
|
||||
inputTokens += len(userMessage) / 4
|
||||
} else {
|
||||
var msgs []models.ResponseInputMessage
|
||||
if err := json.Unmarshal(req.Input, &msgs); err == nil {
|
||||
for _, m := range msgs {
|
||||
var contentStr string
|
||||
if err := json.Unmarshal(m.Content, &contentStr); err == nil {
|
||||
if m.Role == "user" {
|
||||
userMessage = contentStr
|
||||
}
|
||||
inputTokens += len(contentStr) / 4
|
||||
} else {
|
||||
var parts []models.ContentPart
|
||||
if err := json.Unmarshal(m.Content, &parts); err == nil {
|
||||
for _, p := range parts {
|
||||
if p.Type == "text" {
|
||||
if m.Role == "user" {
|
||||
userMessage = p.Text
|
||||
}
|
||||
inputTokens += len(p.Text) / 4
|
||||
} else if p.Type == "image_url" {
|
||||
hasMultimodal = true
|
||||
inputTokens += 1000
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msgLower := strings.ToLower(userMessage)
|
||||
requiresReasoning := strings.Contains(msgLower, "reason") ||
|
||||
strings.Contains(msgLower, "think step by step") ||
|
||||
strings.Contains(msgLower, "mathematics") ||
|
||||
strings.Contains(msgLower, "architecture") ||
|
||||
strings.Contains(msgLower, "explain in detail")
|
||||
|
||||
routeCtx := &router.RouteContext{
|
||||
UserMessage: userMessage,
|
||||
InputTokens: inputTokens,
|
||||
HasMultimodalInput: hasMultimodal,
|
||||
RequiresToolCalling: requiresToolCalling,
|
||||
RequiresReasoning: requiresReasoning,
|
||||
}
|
||||
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
|
||||
return routeCtx
|
||||
}
|
||||
|
||||
func (s *Server) getRouteCtxTags(routeCtx *router.RouteContext) []string {
|
||||
var tags []string
|
||||
msgLower := strings.ToLower(routeCtx.UserMessage)
|
||||
|
||||
// fast-flow keywords
|
||||
fastFlowKeywords := []string{
|
||||
"classify", "classification", "label", "tag", "route", "routing", "intent",
|
||||
"json", "yaml", "csv", "schema", "parse", "extract", "transform", "format", "regex",
|
||||
"short answer", "brief", "concise", "tl;dr", "one line", "simple",
|
||||
"fix this", "small bug", "quick fix", "typo", "syntax error",
|
||||
}
|
||||
for _, kw := range fastFlowKeywords {
|
||||
if strings.Contains(msgLower, kw) {
|
||||
tags = append(tags, "fast-flow", "classification", "json-extraction", "basic-qa")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// standard-pro keywords
|
||||
standardProKeywords := []string{
|
||||
"explain", "summarize", "rewrite", "draft", "edit", "polish", "outline",
|
||||
"long doc", "document", "email", "memo", "proposal", "report", "handout", "notes",
|
||||
"compare", "choose", "recommend", "tradeoff", "pros and cons", "analysis",
|
||||
"code review", "debug", "bug", "feature", "api", "endpoint", "implement",
|
||||
"plan", "planning", "workflow", "integration",
|
||||
}
|
||||
for _, kw := range standardProKeywords {
|
||||
if strings.Contains(msgLower, kw) {
|
||||
tags = append(tags, "standard-pro", "long-doc")
|
||||
break
|
||||
}
|
||||
}
|
||||
if routeCtx.HasMultimodalInput {
|
||||
tags = append(tags, "video-analysis", "multimodal-qa")
|
||||
}
|
||||
|
||||
// heavy-logic keywords
|
||||
heavyLogicKeywords := []string{
|
||||
"agent", "agents", "tool use", "function calling", "multi-agent", "orchestrate",
|
||||
"system design", "scaling", "performance", "architecture review", "distributed",
|
||||
"hard bug", "race condition", "deadlock", "memory leak", "crash", "production outage",
|
||||
"long context", "large codebase", "many files", "complex refactor", "migration",
|
||||
"research", "deep dive", "literature", "paper", "scholarly", "thorough analysis",
|
||||
"deep reasoning", "think step by step", "reason through", "careful analysis",
|
||||
}
|
||||
for _, kw := range heavyLogicKeywords {
|
||||
if strings.Contains(msgLower, kw) {
|
||||
tags = append(tags, "heavy-logic", "deep-reasoning", "architecture", "hard-debugging")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if routeCtx.RequiresToolCalling {
|
||||
tags = append(tags, "tool-heavy", "multi-step-agent", "swe-bench")
|
||||
}
|
||||
|
||||
return tags
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user