From d345f8c41dffc33413bac95d1a2ff099abf6a890 Mon Sep 17 00:00:00 2001 From: hobokenchicken Date: Tue, 5 May 2026 10:40:26 -0400 Subject: [PATCH] feat: add classifier routing strategy with LLM complexity rating --- internal/router/classifier.go | 53 +++++++++++++++++++++++++++++++++++ internal/router/heuristic.go | 7 ----- 2 files changed, 53 insertions(+), 7 deletions(-) create mode 100644 internal/router/classifier.go diff --git a/internal/router/classifier.go b/internal/router/classifier.go new file mode 100644 index 00000000..49d54f97 --- /dev/null +++ b/internal/router/classifier.go @@ -0,0 +1,53 @@ +package router + +import ( + "context" + "fmt" + "strconv" + "strings" + + "gophergate/internal/db" +) + +const classifierSystemPrompt = `You are a task complexity classifier. Rate the following user message on a scale of 1 to %d, where: +1 = trivial/simple (basic facts, greetings, simple math) +%d = highly complex (multi-step reasoning, code generation, architecture design) + +Reply with ONLY the number. No explanation.` + +func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, userMessage string) (*Decision, error) { + maxRating := len(targets) + if maxRating < 2 { + maxRating = 2 + } + + prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating) + ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMessage) + if err != nil { + // Classifier failed — fall back to heuristic + return routeHeuristic(group, targets, userMessage) + } + + rating, err := strconv.Atoi(strings.TrimSpace(ratingStr)) + if err != nil || rating < 1 { + rating = 1 + } + if rating > maxRating { + rating = maxRating + } + + idx := rating - 1 // 0-based index into targets + return &Decision{ + SelectedModel: targets[idx], + Strategy: "classifier", + Reason: fmt.Sprintf("complexity rating: %d/%d", rating, maxRating), + }, nil +} + +func getSelectorModel(group db.ModelGroup, targets []string) string { + if group.SelectorModel != nil && *group.SelectorModel != "" { + return *group.SelectorModel + } + // Default: use the first (cheapest) target model as the selector + return targets[0] +} diff --git a/internal/router/heuristic.go b/internal/router/heuristic.go index f02cc75f..4c9f7ec6 100644 --- a/internal/router/heuristic.go +++ b/internal/router/heuristic.go @@ -1,7 +1,6 @@ package router import ( - "context" "encoding/json" "strings" @@ -65,9 +64,3 @@ func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) ( Reason: reason, }, nil } - -// routeClassifier is a stub — real implementation in classifier.go (Task 3). -// Falls back to heuristic routing for now. -func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, userMessage string) (*Decision, error) { - return routeHeuristic(group, targets, userMessage) -}