feat: implement advanced condition-based heuristic model routing
Upgrades the routing engine to support tag, token limit, multimodal, reasoning, and tool calling conditions. Adds unit tests for the new routing features.
This commit is contained in:
+172
-19
@@ -2,48 +2,125 @@ package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
// HeuristicRule defines a pattern-based routing rule.
|
||||
// HeuristicRule defines a pattern-based routing rule (legacy format).
|
||||
type HeuristicRule struct {
|
||||
Pattern string `json:"pattern"`
|
||||
TargetIdx int `json:"target"`
|
||||
CaseSensitive bool `json:"case_sensitive,omitempty"`
|
||||
}
|
||||
|
||||
func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (*Decision, error) {
|
||||
// ConditionRule defines a condition-based routing rule (new format).
|
||||
type ConditionRule struct {
|
||||
RuleID string `json:"rule_id"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Conditions Conditions `json:"conditions"`
|
||||
PrimaryModel string `json:"primary_model"`
|
||||
FallbackModel string `json:"fallback_model,omitempty"`
|
||||
}
|
||||
|
||||
// Conditions defines the matching parameters for a rule.
|
||||
type Conditions struct {
|
||||
AnyOfTags []string `json:"any_of_tags,omitempty"`
|
||||
MaxInputTokensLt *int `json:"max_input_tokens_lt,omitempty"`
|
||||
RequiresReasoning *bool `json:"requires_reasoning,omitempty"`
|
||||
RequiresToolCalling *bool `json:"requires_tool_calling,omitempty"`
|
||||
HasMultimodalInput *bool `json:"has_multimodal_input,omitempty"`
|
||||
IsDefaultFallback *bool `json:"is_default_fallback,omitempty"`
|
||||
}
|
||||
|
||||
func routeHeuristic(group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
|
||||
if routeCtx == nil {
|
||||
routeCtx = &RouteContext{}
|
||||
}
|
||||
|
||||
selected := targets[0]
|
||||
reason := "default (first target)"
|
||||
|
||||
// If heuristic_rules is set, use them
|
||||
// If heuristic_rules is set, determine format and parse
|
||||
if group.HeuristicRules != nil && *group.HeuristicRules != "" {
|
||||
var rules []HeuristicRule
|
||||
if err := json.Unmarshal([]byte(*group.HeuristicRules), &rules); err == nil {
|
||||
searchMsg := userMessage
|
||||
for _, rule := range rules {
|
||||
pattern := rule.Pattern
|
||||
msg := searchMsg
|
||||
if !rule.CaseSensitive {
|
||||
pattern = strings.ToLower(pattern)
|
||||
msg = strings.ToLower(msg)
|
||||
rulesJSON := *group.HeuristicRules
|
||||
|
||||
if isConditionBasedRules(rulesJSON) {
|
||||
var condRules []ConditionRule
|
||||
if err := json.Unmarshal([]byte(rulesJSON), &condRules); err == nil {
|
||||
for _, rule := range condRules {
|
||||
if matchConditions(rule.Conditions, routeCtx) {
|
||||
// Resolve primary/fallback to concrete models in target list
|
||||
targetModel := ""
|
||||
if rule.PrimaryModel != "" {
|
||||
targetModel = getModelInTargets(rule.PrimaryModel, targets)
|
||||
}
|
||||
if targetModel == "" && rule.FallbackModel != "" {
|
||||
targetModel = getModelInTargets(rule.FallbackModel, targets)
|
||||
}
|
||||
|
||||
if targetModel != "" {
|
||||
selected = targetModel
|
||||
reason = "matched condition rule: " + rule.RuleID
|
||||
if rule.Description != "" {
|
||||
reason += " (" + rule.Description + ")"
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.Contains(msg, pattern) {
|
||||
if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) {
|
||||
selected = targets[rule.TargetIdx]
|
||||
reason = "matched heuristic rule: " + rule.Pattern
|
||||
break
|
||||
}
|
||||
} else {
|
||||
// Fallback to legacy pattern-based rules
|
||||
var legacyRules []HeuristicRule
|
||||
if err := json.Unmarshal([]byte(rulesJSON), &legacyRules); err == nil {
|
||||
searchMsg := routeCtx.UserMessage
|
||||
for _, rule := range legacyRules {
|
||||
pattern := rule.Pattern
|
||||
if pattern == "" {
|
||||
continue // Avoid infinite matches with empty patterns
|
||||
}
|
||||
msg := searchMsg
|
||||
if !rule.CaseSensitive {
|
||||
pattern = strings.ToLower(pattern)
|
||||
msg = strings.ToLower(msg)
|
||||
}
|
||||
|
||||
// Support both regex matching (if pattern is valid regex) and literal contains
|
||||
matched := false
|
||||
if strings.Contains(rule.Pattern, "(") || strings.Contains(rule.Pattern, "\\b") {
|
||||
var re *regexp.Regexp
|
||||
var err error
|
||||
if !rule.CaseSensitive {
|
||||
re, err = regexp.Compile("(?i)" + rule.Pattern)
|
||||
} else {
|
||||
re, err = regexp.Compile(rule.Pattern)
|
||||
}
|
||||
if err == nil {
|
||||
matched = re.MatchString(routeCtx.UserMessage)
|
||||
}
|
||||
}
|
||||
|
||||
if !matched && strings.Contains(msg, pattern) {
|
||||
matched = true
|
||||
}
|
||||
|
||||
if matched {
|
||||
if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) {
|
||||
selected = targets[rule.TargetIdx]
|
||||
reason = "matched heuristic rule: " + rule.Pattern
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Built-in fallback heuristics
|
||||
// Built-in fallback heuristics (if no custom rule matched)
|
||||
if reason == "default (first target)" && len(targets) > 1 {
|
||||
msgLower := strings.ToLower(userMessage)
|
||||
msgLower := strings.ToLower(routeCtx.UserMessage)
|
||||
complexIndicators := []string{
|
||||
"step by step", "explain in detail", "reason through",
|
||||
"think carefully", "analyze", "debug", "write code",
|
||||
@@ -64,3 +141,79 @@ func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (
|
||||
Reason: reason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// isConditionBasedRules returns true if the JSON represents condition-based rules.
|
||||
func isConditionBasedRules(rulesJSON string) bool {
|
||||
var rules []ConditionRule
|
||||
if err := json.Unmarshal([]byte(rulesJSON), &rules); err == nil && len(rules) > 0 {
|
||||
// If the rule has either conditions or primary_model/rule_id, treat it as condition-based
|
||||
return rules[0].PrimaryModel != "" || rules[0].RuleID != ""
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchConditions evaluates whether the given conditions match the RouteContext.
|
||||
func matchConditions(cond Conditions, routeCtx *RouteContext) bool {
|
||||
if cond.IsDefaultFallback != nil && *cond.IsDefaultFallback {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check tags: must match any_of_tags if specified
|
||||
if len(cond.AnyOfTags) > 0 {
|
||||
tagMatched := false
|
||||
for _, ruleTag := range cond.AnyOfTags {
|
||||
for _, ctxTag := range routeCtx.Tags {
|
||||
if strings.EqualFold(ruleTag, ctxTag) {
|
||||
tagMatched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if tagMatched {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !tagMatched {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check max input tokens
|
||||
if cond.MaxInputTokensLt != nil {
|
||||
if routeCtx.InputTokens >= *cond.MaxInputTokensLt {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check reasoning flag
|
||||
if cond.RequiresReasoning != nil {
|
||||
if routeCtx.RequiresReasoning != *cond.RequiresReasoning {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check tool calling flag
|
||||
if cond.RequiresToolCalling != nil {
|
||||
if routeCtx.RequiresToolCalling != *cond.RequiresToolCalling {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check multimodal flag
|
||||
if cond.HasMultimodalInput != nil {
|
||||
if routeCtx.HasMultimodalInput != *cond.HasMultimodalInput {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// getModelInTargets returns the model name if it exists in targets, or empty string.
|
||||
func getModelInTargets(modelName string, targets []string) string {
|
||||
for _, t := range targets {
|
||||
if strings.EqualFold(t, modelName) {
|
||||
return t
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user