73a82e6175
Upgrades the routing engine to support tag, token limit, multimodal, reasoning, and tool calling conditions. Adds unit tests for the new routing features.
77 lines
2.1 KiB
Go
77 lines
2.1 KiB
Go
package router
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"gophergate/internal/db"
|
|
)
|
|
|
|
const classifierSystemPrompt = `You are a task complexity classifier. Rate the following user message on a scale of 1 to %d, where:
|
|
1 = trivial/simple (basic facts, greetings, simple math)
|
|
%d = highly complex (multi-step reasoning, code generation, architecture design)
|
|
|
|
Reply with ONLY the number. No explanation.`
|
|
|
|
func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
|
|
// Determine the rating scale
|
|
maxRating := len(targets)
|
|
if maxRating < 2 {
|
|
maxRating = 2
|
|
}
|
|
|
|
// When complexity_threshold is set, use it as a wider scale (e.g., 1-10)
|
|
// and map ratings proportionally to target buckets.
|
|
bucketMode := group.ComplexityThreshold != nil && *group.ComplexityThreshold > 0
|
|
if bucketMode {
|
|
maxRating = *group.ComplexityThreshold
|
|
}
|
|
|
|
prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating)
|
|
userMsg := ""
|
|
if routeCtx != nil {
|
|
userMsg = routeCtx.UserMessage
|
|
}
|
|
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMsg)
|
|
if err != nil {
|
|
// Classifier failed — fall back to heuristic
|
|
return routeHeuristic(group, targets, routeCtx)
|
|
}
|
|
|
|
rating, err := strconv.Atoi(strings.TrimSpace(ratingStr))
|
|
if err != nil || rating < 1 {
|
|
rating = 1
|
|
}
|
|
if rating > maxRating {
|
|
rating = maxRating
|
|
}
|
|
|
|
var idx int
|
|
if bucketMode {
|
|
// Proportional mapping: wider scale → N target buckets
|
|
// e.g., threshold=10, 3 targets: 1-3→0, 4-7→1, 8-10→2
|
|
idx = rating * len(targets) / (maxRating + 1)
|
|
if idx >= len(targets) {
|
|
idx = len(targets) - 1
|
|
}
|
|
} else {
|
|
idx = rating - 1 // 1:1 mapping
|
|
}
|
|
|
|
return &Decision{
|
|
SelectedModel: targets[idx],
|
|
Strategy: "classifier",
|
|
Reason: fmt.Sprintf("complexity rating: %d/%d", rating, maxRating),
|
|
}, nil
|
|
}
|
|
|
|
func getSelectorModel(group db.ModelGroup, targets []string) string {
|
|
if group.SelectorModel != nil && *group.SelectorModel != "" {
|
|
return *group.SelectorModel
|
|
}
|
|
// Default: use the first (cheapest) target model as the selector
|
|
return targets[0]
|
|
}
|