Files
hobokenchicken 73a82e6175
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
feat: implement advanced condition-based heuristic model routing
Upgrades the routing engine to support tag, token limit, multimodal, reasoning, and tool calling conditions. Adds unit tests for the new routing features.
2026-06-05 15:05:13 +00:00

140 lines
3.9 KiB
Go

package router
import (
"context"
"encoding/json"
"fmt"
"strings"
"gophergate/internal/db"
)
// Decision holds the result of a routing decision.
type Decision struct {
SelectedModel string `json:"selected_model"`
Strategy string `json:"strategy"` // "heuristic" or "classifier"
Reason string `json:"reason"`
}
// RouteContext holds metadata of the request to evaluate condition rules.
type RouteContext struct {
UserMessage string `json:"user_message"`
InputTokens int `json:"input_tokens"`
HasMultimodalInput bool `json:"has_multimodal_input"`
RequiresToolCalling bool `json:"requires_tool_calling"`
RequiresReasoning bool `json:"requires_reasoning"`
Tags []string `json:"tags"`
}
// ClassifierFunc is the callback for classifier-based routing.
type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error)
// Router resolves model groups to concrete models.
type Router struct {
groups map[string]db.ModelGroup
classify ClassifierFunc
}
// New creates a Router. classify may be nil if no classifier groups exist.
func New(groups []db.ModelGroup, classify ClassifierFunc) *Router {
r := &Router{
groups: make(map[string]db.ModelGroup),
classify: classify,
}
for _, g := range groups {
r.groups[g.ID] = g
}
return r
}
// Groups returns all registered model group IDs.
func (r *Router) Groups() []string {
ids := make([]string, 0, len(r.groups))
for id := range r.groups {
ids = append(ids, id)
}
return ids
}
// IsGroup returns true if the model name is a group ID.
func (r *Router) IsGroup(modelID string) bool {
_, ok := r.groups[modelID]
return ok
}
// Route resolves a group to a concrete model.
func (r *Router) Route(ctx context.Context, groupID string, routeCtx *RouteContext) (*Decision, error) {
group, ok := r.groups[groupID]
if !ok {
return nil, fmt.Errorf("unknown model group: %s", groupID)
}
var targets []string
if err := json.Unmarshal([]byte(group.Targets), &targets); err != nil || len(targets) == 0 {
return nil, fmt.Errorf("invalid or empty targets for group %s", groupID)
}
switch group.Strategy {
case "heuristic":
return routeHeuristic(group, targets, routeCtx)
case "classifier":
if r.classify == nil {
return routeHeuristic(group, targets, routeCtx)
}
return routeClassifier(ctx, r.classify, group, targets, routeCtx)
default:
return nil, fmt.Errorf("unknown strategy: %s", group.Strategy)
}
}
// RouteToConcrete resolves a model name to a concrete model, following group
// chains recursively until a non-group target is reached. Returns the original
// name unchanged if it is not a group.
func (r *Router) RouteToConcrete(ctx context.Context, modelID string, routeCtx *RouteContext) (*Decision, error) {
const maxDepth = 10
visited := make(map[string]bool)
current := modelID
var chain []*Decision
for depth := 0; depth < maxDepth; depth++ {
if !r.IsGroup(current) {
// Build a composite reason showing the chain traversed
reason := "direct"
if len(chain) > 0 {
parts := make([]string, len(chain))
for i, d := range chain {
parts[i] = d.SelectedModel + " (" + d.Reason + ")"
}
reason = strings.Join(parts, " -> ")
}
return &Decision{
SelectedModel: current,
Strategy: "hierarchical",
Reason: reason,
}, nil
}
if visited[current] {
return nil, fmt.Errorf("routing cycle detected: group %s already visited", current)
}
visited[current] = true
decision, err := r.Route(ctx, current, routeCtx)
if err != nil {
return nil, err
}
chain = append(chain, decision)
current = decision.SelectedModel
}
return nil, fmt.Errorf("routing depth exceeded: reached max depth of %d", maxDepth)
}
// Reload replaces the group definitions without recreating the router.
func (r *Router) Reload(groups []db.ModelGroup) {
r.groups = make(map[string]db.ModelGroup)
for _, g := range groups {
r.groups[g.ID] = g
}
}