feat: implement advanced condition-based heuristic model routing
Upgrades the routing engine to support tag, token limit, multimodal, reasoning, and tool calling conditions. Adds unit tests for the new routing features.
This commit is contained in:
@@ -15,7 +15,7 @@ const classifierSystemPrompt = `You are a task complexity classifier. Rate the f
|
||||
|
||||
Reply with ONLY the number. No explanation.`
|
||||
|
||||
func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, userMessage string) (*Decision, error) {
|
||||
func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
|
||||
// Determine the rating scale
|
||||
maxRating := len(targets)
|
||||
if maxRating < 2 {
|
||||
@@ -30,10 +30,14 @@ func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.Mode
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating)
|
||||
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMessage)
|
||||
userMsg := ""
|
||||
if routeCtx != nil {
|
||||
userMsg = routeCtx.UserMessage
|
||||
}
|
||||
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMsg)
|
||||
if err != nil {
|
||||
// Classifier failed — fall back to heuristic
|
||||
return routeHeuristic(group, targets, userMessage)
|
||||
return routeHeuristic(group, targets, routeCtx)
|
||||
}
|
||||
|
||||
rating, err := strconv.Atoi(strings.TrimSpace(ratingStr))
|
||||
|
||||
+172
-19
@@ -2,48 +2,125 @@ package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
// HeuristicRule defines a pattern-based routing rule.
|
||||
// HeuristicRule defines a pattern-based routing rule (legacy format).
|
||||
type HeuristicRule struct {
|
||||
Pattern string `json:"pattern"`
|
||||
TargetIdx int `json:"target"`
|
||||
CaseSensitive bool `json:"case_sensitive,omitempty"`
|
||||
}
|
||||
|
||||
func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (*Decision, error) {
|
||||
// ConditionRule defines a condition-based routing rule (new format).
|
||||
type ConditionRule struct {
|
||||
RuleID string `json:"rule_id"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Conditions Conditions `json:"conditions"`
|
||||
PrimaryModel string `json:"primary_model"`
|
||||
FallbackModel string `json:"fallback_model,omitempty"`
|
||||
}
|
||||
|
||||
// Conditions defines the matching parameters for a rule.
|
||||
type Conditions struct {
|
||||
AnyOfTags []string `json:"any_of_tags,omitempty"`
|
||||
MaxInputTokensLt *int `json:"max_input_tokens_lt,omitempty"`
|
||||
RequiresReasoning *bool `json:"requires_reasoning,omitempty"`
|
||||
RequiresToolCalling *bool `json:"requires_tool_calling,omitempty"`
|
||||
HasMultimodalInput *bool `json:"has_multimodal_input,omitempty"`
|
||||
IsDefaultFallback *bool `json:"is_default_fallback,omitempty"`
|
||||
}
|
||||
|
||||
func routeHeuristic(group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
|
||||
if routeCtx == nil {
|
||||
routeCtx = &RouteContext{}
|
||||
}
|
||||
|
||||
selected := targets[0]
|
||||
reason := "default (first target)"
|
||||
|
||||
// If heuristic_rules is set, use them
|
||||
// If heuristic_rules is set, determine format and parse
|
||||
if group.HeuristicRules != nil && *group.HeuristicRules != "" {
|
||||
var rules []HeuristicRule
|
||||
if err := json.Unmarshal([]byte(*group.HeuristicRules), &rules); err == nil {
|
||||
searchMsg := userMessage
|
||||
for _, rule := range rules {
|
||||
pattern := rule.Pattern
|
||||
msg := searchMsg
|
||||
if !rule.CaseSensitive {
|
||||
pattern = strings.ToLower(pattern)
|
||||
msg = strings.ToLower(msg)
|
||||
rulesJSON := *group.HeuristicRules
|
||||
|
||||
if isConditionBasedRules(rulesJSON) {
|
||||
var condRules []ConditionRule
|
||||
if err := json.Unmarshal([]byte(rulesJSON), &condRules); err == nil {
|
||||
for _, rule := range condRules {
|
||||
if matchConditions(rule.Conditions, routeCtx) {
|
||||
// Resolve primary/fallback to concrete models in target list
|
||||
targetModel := ""
|
||||
if rule.PrimaryModel != "" {
|
||||
targetModel = getModelInTargets(rule.PrimaryModel, targets)
|
||||
}
|
||||
if targetModel == "" && rule.FallbackModel != "" {
|
||||
targetModel = getModelInTargets(rule.FallbackModel, targets)
|
||||
}
|
||||
|
||||
if targetModel != "" {
|
||||
selected = targetModel
|
||||
reason = "matched condition rule: " + rule.RuleID
|
||||
if rule.Description != "" {
|
||||
reason += " (" + rule.Description + ")"
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.Contains(msg, pattern) {
|
||||
if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) {
|
||||
selected = targets[rule.TargetIdx]
|
||||
reason = "matched heuristic rule: " + rule.Pattern
|
||||
break
|
||||
}
|
||||
} else {
|
||||
// Fallback to legacy pattern-based rules
|
||||
var legacyRules []HeuristicRule
|
||||
if err := json.Unmarshal([]byte(rulesJSON), &legacyRules); err == nil {
|
||||
searchMsg := routeCtx.UserMessage
|
||||
for _, rule := range legacyRules {
|
||||
pattern := rule.Pattern
|
||||
if pattern == "" {
|
||||
continue // Avoid infinite matches with empty patterns
|
||||
}
|
||||
msg := searchMsg
|
||||
if !rule.CaseSensitive {
|
||||
pattern = strings.ToLower(pattern)
|
||||
msg = strings.ToLower(msg)
|
||||
}
|
||||
|
||||
// Support both regex matching (if pattern is valid regex) and literal contains
|
||||
matched := false
|
||||
if strings.Contains(rule.Pattern, "(") || strings.Contains(rule.Pattern, "\\b") {
|
||||
var re *regexp.Regexp
|
||||
var err error
|
||||
if !rule.CaseSensitive {
|
||||
re, err = regexp.Compile("(?i)" + rule.Pattern)
|
||||
} else {
|
||||
re, err = regexp.Compile(rule.Pattern)
|
||||
}
|
||||
if err == nil {
|
||||
matched = re.MatchString(routeCtx.UserMessage)
|
||||
}
|
||||
}
|
||||
|
||||
if !matched && strings.Contains(msg, pattern) {
|
||||
matched = true
|
||||
}
|
||||
|
||||
if matched {
|
||||
if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) {
|
||||
selected = targets[rule.TargetIdx]
|
||||
reason = "matched heuristic rule: " + rule.Pattern
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Built-in fallback heuristics
|
||||
// Built-in fallback heuristics (if no custom rule matched)
|
||||
if reason == "default (first target)" && len(targets) > 1 {
|
||||
msgLower := strings.ToLower(userMessage)
|
||||
msgLower := strings.ToLower(routeCtx.UserMessage)
|
||||
complexIndicators := []string{
|
||||
"step by step", "explain in detail", "reason through",
|
||||
"think carefully", "analyze", "debug", "write code",
|
||||
@@ -64,3 +141,79 @@ func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (
|
||||
Reason: reason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// isConditionBasedRules returns true if the JSON represents condition-based rules.
|
||||
func isConditionBasedRules(rulesJSON string) bool {
|
||||
var rules []ConditionRule
|
||||
if err := json.Unmarshal([]byte(rulesJSON), &rules); err == nil && len(rules) > 0 {
|
||||
// If the rule has either conditions or primary_model/rule_id, treat it as condition-based
|
||||
return rules[0].PrimaryModel != "" || rules[0].RuleID != ""
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchConditions evaluates whether the given conditions match the RouteContext.
|
||||
func matchConditions(cond Conditions, routeCtx *RouteContext) bool {
|
||||
if cond.IsDefaultFallback != nil && *cond.IsDefaultFallback {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check tags: must match any_of_tags if specified
|
||||
if len(cond.AnyOfTags) > 0 {
|
||||
tagMatched := false
|
||||
for _, ruleTag := range cond.AnyOfTags {
|
||||
for _, ctxTag := range routeCtx.Tags {
|
||||
if strings.EqualFold(ruleTag, ctxTag) {
|
||||
tagMatched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if tagMatched {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !tagMatched {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check max input tokens
|
||||
if cond.MaxInputTokensLt != nil {
|
||||
if routeCtx.InputTokens >= *cond.MaxInputTokensLt {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check reasoning flag
|
||||
if cond.RequiresReasoning != nil {
|
||||
if routeCtx.RequiresReasoning != *cond.RequiresReasoning {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check tool calling flag
|
||||
if cond.RequiresToolCalling != nil {
|
||||
if routeCtx.RequiresToolCalling != *cond.RequiresToolCalling {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check multimodal flag
|
||||
if cond.HasMultimodalInput != nil {
|
||||
if routeCtx.HasMultimodalInput != *cond.HasMultimodalInput {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// getModelInTargets returns the model name if it exists in targets, or empty string.
|
||||
func getModelInTargets(modelName string, targets []string) string {
|
||||
for _, t := range targets {
|
||||
if strings.EqualFold(t, modelName) {
|
||||
return t
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
func TestRouteHeuristic_ConditionRules(t *testing.T) {
|
||||
targets := []string{
|
||||
"deepseek-v4-flash", // index 0
|
||||
"gemini-3-flash", // index 1
|
||||
"grok-build-0.1", // index 2
|
||||
"kimi-k2.6", // index 3
|
||||
"mimo-v2.5-pro", // index 4
|
||||
"grok-4.3", // index 5
|
||||
"deepseek-v4-pro", // index 6
|
||||
}
|
||||
|
||||
rulesJSON := `[
|
||||
{
|
||||
"rule_id": "fast_flow_extraction",
|
||||
"conditions": {
|
||||
"any_of_tags": ["fast-flow", "classification"],
|
||||
"max_input_tokens_lt": 8000,
|
||||
"requires_reasoning": false
|
||||
},
|
||||
"primary_model": "deepseek-v4-flash",
|
||||
"fallback_model": "grok-build-0.1"
|
||||
},
|
||||
{
|
||||
"rule_id": "multimodal_long_context",
|
||||
"conditions": {
|
||||
"any_of_tags": ["standard-pro", "long-doc"],
|
||||
"has_multimodal_input": true
|
||||
},
|
||||
"primary_model": "gemini-3-flash",
|
||||
"fallback_model": "mimo-v2.5-pro"
|
||||
},
|
||||
{
|
||||
"rule_id": "regional_fallback_general",
|
||||
"conditions": {
|
||||
"is_default_fallback": true
|
||||
},
|
||||
"primary_model": "kimi-k2.6"
|
||||
}
|
||||
]`
|
||||
|
||||
group := db.ModelGroup{
|
||||
ID: "dustins_stack",
|
||||
Strategy: "heuristic",
|
||||
HeuristicRules: &rulesJSON,
|
||||
}
|
||||
|
||||
// 1. Test Match Fast Flow (condition success)
|
||||
ctx1 := &RouteContext{
|
||||
UserMessage: "classify this JSON",
|
||||
InputTokens: 500,
|
||||
HasMultimodalInput: false,
|
||||
RequiresReasoning: false,
|
||||
Tags: []string{"fast-flow", "classification"},
|
||||
}
|
||||
dec1, err := routeHeuristic(group, targets, ctx1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec1.SelectedModel != "deepseek-v4-flash" {
|
||||
t.Fatalf("expected deepseek-v4-flash, got %s", dec1.SelectedModel)
|
||||
}
|
||||
|
||||
// 2. Test Multimodal Long Context (condition success)
|
||||
ctx2 := &RouteContext{
|
||||
UserMessage: "explain this video",
|
||||
InputTokens: 15000,
|
||||
HasMultimodalInput: true,
|
||||
RequiresReasoning: false,
|
||||
Tags: []string{"standard-pro", "video-analysis"},
|
||||
}
|
||||
dec2, err := routeHeuristic(group, targets, ctx2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec2.SelectedModel != "gemini-3-flash" {
|
||||
t.Fatalf("expected gemini-3-flash, got %s", dec2.SelectedModel)
|
||||
}
|
||||
|
||||
// 3. Test Fallback general rule
|
||||
ctx3 := &RouteContext{
|
||||
UserMessage: "hello there",
|
||||
InputTokens: 100,
|
||||
HasMultimodalInput: false,
|
||||
RequiresReasoning: false,
|
||||
Tags: []string{"general"},
|
||||
}
|
||||
dec3, err := routeHeuristic(group, targets, ctx3)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec3.SelectedModel != "kimi-k2.6" {
|
||||
t.Fatalf("expected kimi-k2.6, got %s", dec3.SelectedModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteHeuristic_LegacyRules(t *testing.T) {
|
||||
targets := []string{"gpt-4o-mini", "deepseek-v4-pro", "kimi-k2.6"}
|
||||
|
||||
// Legacy pattern-based rule with regex
|
||||
rulesJSON := `[
|
||||
{"pattern": "\\b(agent|agents|tool use)\\b", "target": 1},
|
||||
{"pattern": "summarize", "target": 2}
|
||||
]`
|
||||
|
||||
group := db.ModelGroup{
|
||||
ID: "heavy-logic",
|
||||
Strategy: "heuristic",
|
||||
HeuristicRules: &rulesJSON,
|
||||
}
|
||||
|
||||
// 1. Test regex match
|
||||
ctx1 := &RouteContext{
|
||||
UserMessage: "We need an agent to do tool use",
|
||||
}
|
||||
dec1, err := routeHeuristic(group, targets, ctx1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec1.SelectedModel != "deepseek-v4-pro" {
|
||||
t.Fatalf("expected deepseek-v4-pro, got %s", dec1.SelectedModel)
|
||||
}
|
||||
|
||||
// 2. Test literal match
|
||||
ctx2 := &RouteContext{
|
||||
UserMessage: "Please summarize this text",
|
||||
}
|
||||
dec2, err := routeHeuristic(group, targets, ctx2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec2.SelectedModel != "kimi-k2.6" {
|
||||
t.Fatalf("expected kimi-k2.6, got %s", dec2.SelectedModel)
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,16 @@ type Decision struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// RouteContext holds metadata of the request to evaluate condition rules.
|
||||
type RouteContext struct {
|
||||
UserMessage string `json:"user_message"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
HasMultimodalInput bool `json:"has_multimodal_input"`
|
||||
RequiresToolCalling bool `json:"requires_tool_calling"`
|
||||
RequiresReasoning bool `json:"requires_reasoning"`
|
||||
Tags []string `json:"tags"`
|
||||
}
|
||||
|
||||
// ClassifierFunc is the callback for classifier-based routing.
|
||||
type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error)
|
||||
|
||||
@@ -53,7 +63,7 @@ func (r *Router) IsGroup(modelID string) bool {
|
||||
}
|
||||
|
||||
// Route resolves a group to a concrete model.
|
||||
func (r *Router) Route(ctx context.Context, groupID string, userMessage string) (*Decision, error) {
|
||||
func (r *Router) Route(ctx context.Context, groupID string, routeCtx *RouteContext) (*Decision, error) {
|
||||
group, ok := r.groups[groupID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown model group: %s", groupID)
|
||||
@@ -66,12 +76,12 @@ func (r *Router) Route(ctx context.Context, groupID string, userMessage string)
|
||||
|
||||
switch group.Strategy {
|
||||
case "heuristic":
|
||||
return routeHeuristic(group, targets, userMessage)
|
||||
return routeHeuristic(group, targets, routeCtx)
|
||||
case "classifier":
|
||||
if r.classify == nil {
|
||||
return routeHeuristic(group, targets, userMessage)
|
||||
return routeHeuristic(group, targets, routeCtx)
|
||||
}
|
||||
return routeClassifier(ctx, r.classify, group, targets, userMessage)
|
||||
return routeClassifier(ctx, r.classify, group, targets, routeCtx)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown strategy: %s", group.Strategy)
|
||||
}
|
||||
@@ -80,7 +90,7 @@ func (r *Router) Route(ctx context.Context, groupID string, userMessage string)
|
||||
// RouteToConcrete resolves a model name to a concrete model, following group
|
||||
// chains recursively until a non-group target is reached. Returns the original
|
||||
// name unchanged if it is not a group.
|
||||
func (r *Router) RouteToConcrete(ctx context.Context, modelID string, userMessage string) (*Decision, error) {
|
||||
func (r *Router) RouteToConcrete(ctx context.Context, modelID string, routeCtx *RouteContext) (*Decision, error) {
|
||||
const maxDepth = 10
|
||||
visited := make(map[string]bool)
|
||||
current := modelID
|
||||
@@ -109,7 +119,7 @@ func (r *Router) RouteToConcrete(ctx context.Context, modelID string, userMessag
|
||||
}
|
||||
visited[current] = true
|
||||
|
||||
decision, err := r.Route(ctx, current, userMessage)
|
||||
decision, err := r.Route(ctx, current, routeCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user