73a82e6175
Upgrades the routing engine to support tag, token limit, multimodal, reasoning, and tool calling conditions. Adds unit tests for the new routing features.
140 lines
3.9 KiB
Go
140 lines
3.9 KiB
Go
package router
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"gophergate/internal/db"
|
|
)
|
|
|
|
// Decision holds the result of a routing decision.
|
|
type Decision struct {
|
|
SelectedModel string `json:"selected_model"`
|
|
Strategy string `json:"strategy"` // "heuristic" or "classifier"
|
|
Reason string `json:"reason"`
|
|
}
|
|
|
|
// RouteContext holds metadata of the request to evaluate condition rules.
|
|
type RouteContext struct {
|
|
UserMessage string `json:"user_message"`
|
|
InputTokens int `json:"input_tokens"`
|
|
HasMultimodalInput bool `json:"has_multimodal_input"`
|
|
RequiresToolCalling bool `json:"requires_tool_calling"`
|
|
RequiresReasoning bool `json:"requires_reasoning"`
|
|
Tags []string `json:"tags"`
|
|
}
|
|
|
|
// ClassifierFunc is the callback for classifier-based routing.
|
|
type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error)
|
|
|
|
// Router resolves model groups to concrete models.
|
|
type Router struct {
|
|
groups map[string]db.ModelGroup
|
|
classify ClassifierFunc
|
|
}
|
|
|
|
// New creates a Router. classify may be nil if no classifier groups exist.
|
|
func New(groups []db.ModelGroup, classify ClassifierFunc) *Router {
|
|
r := &Router{
|
|
groups: make(map[string]db.ModelGroup),
|
|
classify: classify,
|
|
}
|
|
for _, g := range groups {
|
|
r.groups[g.ID] = g
|
|
}
|
|
return r
|
|
}
|
|
|
|
// Groups returns all registered model group IDs.
|
|
func (r *Router) Groups() []string {
|
|
ids := make([]string, 0, len(r.groups))
|
|
for id := range r.groups {
|
|
ids = append(ids, id)
|
|
}
|
|
return ids
|
|
}
|
|
|
|
// IsGroup returns true if the model name is a group ID.
|
|
func (r *Router) IsGroup(modelID string) bool {
|
|
_, ok := r.groups[modelID]
|
|
return ok
|
|
}
|
|
|
|
// Route resolves a group to a concrete model.
|
|
func (r *Router) Route(ctx context.Context, groupID string, routeCtx *RouteContext) (*Decision, error) {
|
|
group, ok := r.groups[groupID]
|
|
if !ok {
|
|
return nil, fmt.Errorf("unknown model group: %s", groupID)
|
|
}
|
|
|
|
var targets []string
|
|
if err := json.Unmarshal([]byte(group.Targets), &targets); err != nil || len(targets) == 0 {
|
|
return nil, fmt.Errorf("invalid or empty targets for group %s", groupID)
|
|
}
|
|
|
|
switch group.Strategy {
|
|
case "heuristic":
|
|
return routeHeuristic(group, targets, routeCtx)
|
|
case "classifier":
|
|
if r.classify == nil {
|
|
return routeHeuristic(group, targets, routeCtx)
|
|
}
|
|
return routeClassifier(ctx, r.classify, group, targets, routeCtx)
|
|
default:
|
|
return nil, fmt.Errorf("unknown strategy: %s", group.Strategy)
|
|
}
|
|
}
|
|
|
|
// RouteToConcrete resolves a model name to a concrete model, following group
|
|
// chains recursively until a non-group target is reached. Returns the original
|
|
// name unchanged if it is not a group.
|
|
func (r *Router) RouteToConcrete(ctx context.Context, modelID string, routeCtx *RouteContext) (*Decision, error) {
|
|
const maxDepth = 10
|
|
visited := make(map[string]bool)
|
|
current := modelID
|
|
var chain []*Decision
|
|
|
|
for depth := 0; depth < maxDepth; depth++ {
|
|
if !r.IsGroup(current) {
|
|
// Build a composite reason showing the chain traversed
|
|
reason := "direct"
|
|
if len(chain) > 0 {
|
|
parts := make([]string, len(chain))
|
|
for i, d := range chain {
|
|
parts[i] = d.SelectedModel + " (" + d.Reason + ")"
|
|
}
|
|
reason = strings.Join(parts, " -> ")
|
|
}
|
|
return &Decision{
|
|
SelectedModel: current,
|
|
Strategy: "hierarchical",
|
|
Reason: reason,
|
|
}, nil
|
|
}
|
|
|
|
if visited[current] {
|
|
return nil, fmt.Errorf("routing cycle detected: group %s already visited", current)
|
|
}
|
|
visited[current] = true
|
|
|
|
decision, err := r.Route(ctx, current, routeCtx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
chain = append(chain, decision)
|
|
current = decision.SelectedModel
|
|
}
|
|
|
|
return nil, fmt.Errorf("routing depth exceeded: reached max depth of %d", maxDepth)
|
|
}
|
|
|
|
// Reload replaces the group definitions without recreating the router.
|
|
func (r *Router) Reload(groups []db.ModelGroup) {
|
|
r.groups = make(map[string]db.ModelGroup)
|
|
for _, g := range groups {
|
|
r.groups[g.ID] = g
|
|
}
|
|
}
|