Files
GopherGate/internal/router/classifier.go
T
hobokenchicken 73a82e6175
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
feat: implement advanced condition-based heuristic model routing
Upgrades the routing engine to support tag, token limit, multimodal, reasoning, and tool calling conditions. Adds unit tests for the new routing features.
2026-06-05 15:05:13 +00:00

77 lines
2.1 KiB
Go

package router
import (
"context"
"fmt"
"strconv"
"strings"
"gophergate/internal/db"
)
const classifierSystemPrompt = `You are a task complexity classifier. Rate the following user message on a scale of 1 to %d, where:
1 = trivial/simple (basic facts, greetings, simple math)
%d = highly complex (multi-step reasoning, code generation, architecture design)
Reply with ONLY the number. No explanation.`
func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
// Determine the rating scale
maxRating := len(targets)
if maxRating < 2 {
maxRating = 2
}
// When complexity_threshold is set, use it as a wider scale (e.g., 1-10)
// and map ratings proportionally to target buckets.
bucketMode := group.ComplexityThreshold != nil && *group.ComplexityThreshold > 0
if bucketMode {
maxRating = *group.ComplexityThreshold
}
prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating)
userMsg := ""
if routeCtx != nil {
userMsg = routeCtx.UserMessage
}
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMsg)
if err != nil {
// Classifier failed — fall back to heuristic
return routeHeuristic(group, targets, routeCtx)
}
rating, err := strconv.Atoi(strings.TrimSpace(ratingStr))
if err != nil || rating < 1 {
rating = 1
}
if rating > maxRating {
rating = maxRating
}
var idx int
if bucketMode {
// Proportional mapping: wider scale → N target buckets
// e.g., threshold=10, 3 targets: 1-3→0, 4-7→1, 8-10→2
idx = rating * len(targets) / (maxRating + 1)
if idx >= len(targets) {
idx = len(targets) - 1
}
} else {
idx = rating - 1 // 1:1 mapping
}
return &Decision{
SelectedModel: targets[idx],
Strategy: "classifier",
Reason: fmt.Sprintf("complexity rating: %d/%d", rating, maxRating),
}, nil
}
func getSelectorModel(group db.ModelGroup, targets []string) string {
if group.SelectorModel != nil && *group.SelectorModel != "" {
return *group.SelectorModel
}
// Default: use the first (cheapest) target model as the selector
return targets[0]
}