feat: implement advanced condition-based heuristic model routing
Upgrades the routing engine to support tag, token limit, multimodal, reasoning, and tool calling conditions. Adds unit tests for the new routing features.
This commit is contained in:
@@ -16,6 +16,16 @@ type Decision struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// RouteContext holds metadata of the request to evaluate condition rules.
|
||||
type RouteContext struct {
|
||||
UserMessage string `json:"user_message"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
HasMultimodalInput bool `json:"has_multimodal_input"`
|
||||
RequiresToolCalling bool `json:"requires_tool_calling"`
|
||||
RequiresReasoning bool `json:"requires_reasoning"`
|
||||
Tags []string `json:"tags"`
|
||||
}
|
||||
|
||||
// ClassifierFunc is the callback for classifier-based routing.
|
||||
type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error)
|
||||
|
||||
@@ -53,7 +63,7 @@ func (r *Router) IsGroup(modelID string) bool {
|
||||
}
|
||||
|
||||
// Route resolves a group to a concrete model.
|
||||
func (r *Router) Route(ctx context.Context, groupID string, userMessage string) (*Decision, error) {
|
||||
func (r *Router) Route(ctx context.Context, groupID string, routeCtx *RouteContext) (*Decision, error) {
|
||||
group, ok := r.groups[groupID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown model group: %s", groupID)
|
||||
@@ -66,12 +76,12 @@ func (r *Router) Route(ctx context.Context, groupID string, userMessage string)
|
||||
|
||||
switch group.Strategy {
|
||||
case "heuristic":
|
||||
return routeHeuristic(group, targets, userMessage)
|
||||
return routeHeuristic(group, targets, routeCtx)
|
||||
case "classifier":
|
||||
if r.classify == nil {
|
||||
return routeHeuristic(group, targets, userMessage)
|
||||
return routeHeuristic(group, targets, routeCtx)
|
||||
}
|
||||
return routeClassifier(ctx, r.classify, group, targets, userMessage)
|
||||
return routeClassifier(ctx, r.classify, group, targets, routeCtx)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown strategy: %s", group.Strategy)
|
||||
}
|
||||
@@ -80,7 +90,7 @@ func (r *Router) Route(ctx context.Context, groupID string, userMessage string)
|
||||
// RouteToConcrete resolves a model name to a concrete model, following group
|
||||
// chains recursively until a non-group target is reached. Returns the original
|
||||
// name unchanged if it is not a group.
|
||||
func (r *Router) RouteToConcrete(ctx context.Context, modelID string, userMessage string) (*Decision, error) {
|
||||
func (r *Router) RouteToConcrete(ctx context.Context, modelID string, routeCtx *RouteContext) (*Decision, error) {
|
||||
const maxDepth = 10
|
||||
visited := make(map[string]bool)
|
||||
current := modelID
|
||||
@@ -109,7 +119,7 @@ func (r *Router) RouteToConcrete(ctx context.Context, modelID string, userMessag
|
||||
}
|
||||
visited[current] = true
|
||||
|
||||
decision, err := r.Route(ctx, current, userMessage)
|
||||
decision, err := r.Route(ctx, current, routeCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user