feat: implement advanced condition-based heuristic model routing
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled

Upgrades the routing engine to support tag, token limit, multimodal, reasoning, and tool calling conditions. Adds unit tests for the new routing features.
This commit is contained in:
2026-06-05 15:05:13 +00:00
parent b3354a1bbc
commit 73a82e6175
12 changed files with 731 additions and 45 deletions
+3
View File
@@ -46,6 +46,9 @@ Implements HMAC-SHA256 signed tokens for dashboard authentication. Tokens secure
### 5. WebSocket Hub (`internal/server/websocket.go`)
A centralized hub for managing WebSocket connections, allowing real-time broadcast of system events, system metrics, and request logs to the dashboard.
### 6. Model Group Router (`internal/router/`)
Resolves model groups (e.g., `deepseek-auto`, `dustins_stack`) into concrete models. It supports a Classifier strategy (uses a cheap LLM to rate complexity) and an upgraded Heuristic strategy (evaluates custom condition rules like tags, token counts, multimodal inputs, reasoning, and tool calling flags or legacy keyword patterns).
## Concurrency Model
Go's goroutines and channels are used extensively:
+1 -1
View File
@@ -22,7 +22,7 @@ A unified, high-performance LLM proxy gateway built in Go. It provides OpenAI-co
- **Image Generation:** DALL-E 2/3 (OpenAI) and Imagen 3 (Gemini) via OpenAI-compatible `/v1/images/generations` endpoint.
- **Automatic Model Routing:**
- **Hierarchical Routing:** Groups can target other groups, cascading through multiple levels until a concrete model is reached. Cycle detection and depth limiting (max 10) prevent infinite loops.
- **Heuristic strategy:** Free, zero-latency keyword matching (e.g. "debug" or "step by step" routes to the reasoning model).
- **Heuristic strategy:** Free, zero-latency routing supporting both keyword matching (regex/substrings) and condition-based checks (evaluating tags, token limits, multimodal inputs, reasoning, and tool calling requirements).
- **Classifier strategy:** Uses a cheap LLM to rate task complexity on a configurable scale (1-10), then selects the appropriate model. Bucket mapping distributes ratings proportionally across targets.
- **Two-Level Dispatch:** A `dispatcher` group (classifier, threshold=10) auto-routes to tier groups by complexity score, which then apply their own internal strategies.
- **Metadata:** Groups support `logic_level` (1-10 complexity scale) and `primary_use` (description) fields for organizational clarity.
+8
View File
@@ -32,6 +32,14 @@ func Init(path string) (*DB, error) {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
// Enable Write-Ahead Logging (WAL) and set a busy timeout to handle concurrent access
if _, err := db.Exec("PRAGMA journal_mode=WAL;"); err != nil {
log.Printf("failed to enable WAL mode: %v", err)
}
if _, err := db.Exec("PRAGMA busy_timeout=5000;"); err != nil {
log.Printf("failed to set busy timeout: %v", err)
}
instance := &DB{db}
// Run migrations
+44 -5
View File
@@ -14,9 +14,21 @@ import (
func AuthMiddleware(database *db.DB, requireAuth bool) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
// Fallback to checking "Authentication" header in case the client library used the wrong name
authHeader = c.GetHeader("Authentication")
}
if authHeader == "" {
if requireAuth {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"})
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": "Missing Authorization or Authentication header.",
"type": "invalid_request_error",
"param": nil,
"code": "401",
},
})
return
}
c.Next()
@@ -25,23 +37,50 @@ func AuthMiddleware(database *db.DB, requireAuth bool) gin.HandlerFunc {
token := strings.TrimPrefix(authHeader, "Bearer ")
if token == authHeader { // No "Bearer " prefix
if requireAuth {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": "Invalid authorization header format. Bearer token required.",
"type": "invalid_request_error",
"param": nil,
"code": "401",
},
})
return
}
c.Next()
return
}
// Try to resolve client from database
// Try to resolve client from database with a read-only SELECT
var clientID string
err := database.Get(&clientID, "UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = 1 RETURNING client_id", token)
err := database.Get(&clientID, "SELECT client_id FROM client_tokens WHERE token = ? AND is_active = 1", token)
if err == nil {
c.Set("auth", models.AuthInfo{
Token: token,
ClientID: clientID,
})
// Update last_used_at asynchronously so that database locks or write delays
// do not block or fail the client's request authentication.
go func(t string) {
if _, updateErr := database.Exec("UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ?", t); updateErr != nil {
log.Printf("Warning: failed to update client token last_used_at: %v", updateErr)
}
}(token)
c.Next()
} else {
log.Printf("Token not found or inactive in DB: %s", token)
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid or inactive token"})
log.Printf("Token not found, inactive or error in DB: %s (err: %v)", token, err)
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": "Invalid or inactive client token.",
"type": "invalid_request_error",
"param": nil,
"code": "401",
},
})
}
}
}
+12 -4
View File
@@ -338,6 +338,7 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
}
// Map Tools
hasMappedTools := false
if len(req.Tools) > 0 {
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
for _, t := range req.Tools {
@@ -345,13 +346,16 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
}
}
if len(geminiTool.FunctionDeclarations) > 0 {
body.Tools = []GeminiTool{geminiTool}
hasMappedTools = true
}
}
baseURL := p.config.BaseURL
lowerModel := strings.ToLower(req.Model)
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
// Use v1beta for preview and newer models
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools {
// Use v1beta for preview, newer models, or when using tools
if !strings.Contains(baseURL, "v1beta") {
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
}
@@ -578,6 +582,7 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
GenerationConfig: genConfig,
}
hasMappedTools := false
if len(req.Tools) > 0 {
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
for _, t := range req.Tools {
@@ -585,13 +590,16 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
}
}
if len(geminiTool.FunctionDeclarations) > 0 {
body.Tools = []GeminiTool{geminiTool}
hasMappedTools = true
}
}
baseURL := p.config.BaseURL
lowerModel := strings.ToLower(req.Model)
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
// Use v1beta for preview and newer models
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools {
// Use v1beta for preview, newer models, or when using tools
if !strings.Contains(baseURL, "v1beta") {
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
}
+33 -2
View File
@@ -10,6 +10,22 @@ import (
"gophergate/internal/models"
)
func sanitizeFunctionName(name string) string {
var sb strings.Builder
for _, ch := range name {
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' {
sb.WriteRune(ch)
} else {
sb.WriteRune('_')
}
}
res := sb.String()
if res == "" {
return "function"
}
return res
}
// MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images.
func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) {
var result []interface{}
@@ -35,7 +51,7 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
msg["tool_call_id"] = id
if m.Name != nil {
msg["name"] = *m.Name
msg["name"] = sanitizeFunctionName(*m.Name)
}
result = append(result, msg)
continue
@@ -91,6 +107,7 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
if sanitizedCalls[i].Type == "" {
sanitizedCalls[i].Type = "function"
}
sanitizedCalls[i].Function.Name = sanitizeFunctionName(sanitizedCalls[i].Function.Name)
}
msg["tool_calls"] = sanitizedCalls
msg["content"] = "" // OpenAI requirement: content must be string if tool_calls present
@@ -124,11 +141,25 @@ func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{},
body["max_tokens"] = *request.MaxTokens
}
if len(request.Tools) > 0 {
body["tools"] = request.Tools
sanitizedTools := make([]models.Tool, len(request.Tools))
copy(sanitizedTools, request.Tools)
for i := range sanitizedTools {
if sanitizedTools[i].Type == "function" {
sanitizedTools[i].Function.Name = sanitizeFunctionName(sanitizedTools[i].Function.Name)
}
}
body["tools"] = sanitizedTools
}
if request.ToolChoice != nil {
var toolChoice interface{}
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
if tcMap, ok := toolChoice.(map[string]interface{}); ok {
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
if name, ok := funcMap["name"].(string); ok {
funcMap["name"] = sanitizeFunctionName(name)
}
}
}
body["tool_choice"] = toolChoice
}
}
+127
View File
@@ -0,0 +1,127 @@
package providers
import (
"encoding/json"
"testing"
"gophergate/internal/models"
)
func TestSanitizeFunctionName(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"google-search", "google-search"},
{"google.search", "google_search"},
{"google search", "google_search"},
{"web_search(query)", "web_search_query_"},
{"", "function"},
{"123_abc-XYZ", "123_abc-XYZ"},
{"invalid.name.with.dots", "invalid_name_with_dots"},
}
for _, tc := range tests {
actual := sanitizeFunctionName(tc.input)
if actual != tc.expected {
t.Errorf("sanitizeFunctionName(%q) = %q; expected %q", tc.input, actual, tc.expected)
}
}
}
func TestMessagesToOpenAIJSON_SanitizeToolCalls(t *testing.T) {
messages := []models.UnifiedMessage{
{
Role: "assistant",
Content: []models.UnifiedContentPart{
{Type: "text", Text: "I will use search."},
},
ToolCalls: []models.ToolCall{
{
ID: "call_1",
Type: "function",
Function: models.FunctionCall{
Name: "google.search",
Arguments: `{"query": "hello"}`,
},
},
},
},
{
Role: "tool",
Content: []models.UnifiedContentPart{
{Type: "text", Text: `{"result": "success"}`},
},
ToolCallID: stringPtr("call_1"),
Name: stringPtr("google.search"),
},
}
res, err := MessagesToOpenAIJSON(messages)
if err != nil {
t.Fatalf("MessagesToOpenAIJSON failed: %v", err)
}
if len(res) != 2 {
t.Fatalf("expected 2 messages, got %d", len(res))
}
// Verify assistant message
msg1 := res[0].(map[string]interface{})
if msg1["role"] != "assistant" {
t.Errorf("expected role assistant, got %v", msg1["role"])
}
calls := msg1["tool_calls"].([]models.ToolCall)
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
if calls[0].Function.Name != "google_search" {
t.Errorf("expected function name google_search, got %q", calls[0].Function.Name)
}
// Verify tool response message
msg2 := res[1].(map[string]interface{})
if msg2["role"] != "tool" {
t.Errorf("expected role tool, got %v", msg2["role"])
}
if msg2["name"] != "google_search" {
t.Errorf("expected tool name google_search, got %v", msg2["name"])
}
}
func TestBuildOpenAIBody_SanitizeToolsAndChoice(t *testing.T) {
req := &models.UnifiedRequest{
Model: "gpt-4o",
Tools: []models.Tool{
{
Type: "function",
Function: models.FunctionDef{
Name: "google.search",
},
},
},
ToolChoice: json.RawMessage(`{"type": "function", "function": {"name": "google.search"}}`),
}
body := BuildOpenAIBody(req, nil, false)
// Verify tools
tools := body["tools"].([]models.Tool)
if len(tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(tools))
}
if tools[0].Function.Name != "google_search" {
t.Errorf("expected tool function name google_search, got %q", tools[0].Function.Name)
}
// Verify tool_choice
toolChoice := body["tool_choice"].(map[string]interface{})
funcObj := toolChoice["function"].(map[string]interface{})
if funcObj["name"] != "google_search" {
t.Errorf("expected tool_choice function name google_search, got %q", funcObj["name"])
}
}
func stringPtr(s string) *string {
return &s
}
+7 -3
View File
@@ -15,7 +15,7 @@ const classifierSystemPrompt = `You are a task complexity classifier. Rate the f
Reply with ONLY the number. No explanation.`
func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, userMessage string) (*Decision, error) {
func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
// Determine the rating scale
maxRating := len(targets)
if maxRating < 2 {
@@ -30,10 +30,14 @@ func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.Mode
}
prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating)
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMessage)
userMsg := ""
if routeCtx != nil {
userMsg = routeCtx.UserMessage
}
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMsg)
if err != nil {
// Classifier failed — fall back to heuristic
return routeHeuristic(group, targets, userMessage)
return routeHeuristic(group, targets, routeCtx)
}
rating, err := strconv.Atoi(strings.TrimSpace(ratingStr))
+163 -10
View File
@@ -2,35 +2,111 @@ package router
import (
"encoding/json"
"regexp"
"strings"
"gophergate/internal/db"
)
// HeuristicRule defines a pattern-based routing rule.
// HeuristicRule defines a pattern-based routing rule (legacy format).
type HeuristicRule struct {
Pattern string `json:"pattern"`
TargetIdx int `json:"target"`
CaseSensitive bool `json:"case_sensitive,omitempty"`
}
func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (*Decision, error) {
// ConditionRule defines a condition-based routing rule (new format).
type ConditionRule struct {
RuleID string `json:"rule_id"`
Description string `json:"description,omitempty"`
Conditions Conditions `json:"conditions"`
PrimaryModel string `json:"primary_model"`
FallbackModel string `json:"fallback_model,omitempty"`
}
// Conditions defines the matching parameters for a rule.
type Conditions struct {
AnyOfTags []string `json:"any_of_tags,omitempty"`
MaxInputTokensLt *int `json:"max_input_tokens_lt,omitempty"`
RequiresReasoning *bool `json:"requires_reasoning,omitempty"`
RequiresToolCalling *bool `json:"requires_tool_calling,omitempty"`
HasMultimodalInput *bool `json:"has_multimodal_input,omitempty"`
IsDefaultFallback *bool `json:"is_default_fallback,omitempty"`
}
func routeHeuristic(group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
if routeCtx == nil {
routeCtx = &RouteContext{}
}
selected := targets[0]
reason := "default (first target)"
// If heuristic_rules is set, use them
// If heuristic_rules is set, determine format and parse
if group.HeuristicRules != nil && *group.HeuristicRules != "" {
var rules []HeuristicRule
if err := json.Unmarshal([]byte(*group.HeuristicRules), &rules); err == nil {
searchMsg := userMessage
for _, rule := range rules {
rulesJSON := *group.HeuristicRules
if isConditionBasedRules(rulesJSON) {
var condRules []ConditionRule
if err := json.Unmarshal([]byte(rulesJSON), &condRules); err == nil {
for _, rule := range condRules {
if matchConditions(rule.Conditions, routeCtx) {
// Resolve primary/fallback to concrete models in target list
targetModel := ""
if rule.PrimaryModel != "" {
targetModel = getModelInTargets(rule.PrimaryModel, targets)
}
if targetModel == "" && rule.FallbackModel != "" {
targetModel = getModelInTargets(rule.FallbackModel, targets)
}
if targetModel != "" {
selected = targetModel
reason = "matched condition rule: " + rule.RuleID
if rule.Description != "" {
reason += " (" + rule.Description + ")"
}
break
}
}
}
}
} else {
// Fallback to legacy pattern-based rules
var legacyRules []HeuristicRule
if err := json.Unmarshal([]byte(rulesJSON), &legacyRules); err == nil {
searchMsg := routeCtx.UserMessage
for _, rule := range legacyRules {
pattern := rule.Pattern
if pattern == "" {
continue // Avoid infinite matches with empty patterns
}
msg := searchMsg
if !rule.CaseSensitive {
pattern = strings.ToLower(pattern)
msg = strings.ToLower(msg)
}
if strings.Contains(msg, pattern) {
// Support both regex matching (if pattern is valid regex) and literal contains
matched := false
if strings.Contains(rule.Pattern, "(") || strings.Contains(rule.Pattern, "\\b") {
var re *regexp.Regexp
var err error
if !rule.CaseSensitive {
re, err = regexp.Compile("(?i)" + rule.Pattern)
} else {
re, err = regexp.Compile(rule.Pattern)
}
if err == nil {
matched = re.MatchString(routeCtx.UserMessage)
}
}
if !matched && strings.Contains(msg, pattern) {
matched = true
}
if matched {
if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) {
selected = targets[rule.TargetIdx]
reason = "matched heuristic rule: " + rule.Pattern
@@ -40,10 +116,11 @@ func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (
}
}
}
}
// Built-in fallback heuristics
// Built-in fallback heuristics (if no custom rule matched)
if reason == "default (first target)" && len(targets) > 1 {
msgLower := strings.ToLower(userMessage)
msgLower := strings.ToLower(routeCtx.UserMessage)
complexIndicators := []string{
"step by step", "explain in detail", "reason through",
"think carefully", "analyze", "debug", "write code",
@@ -64,3 +141,79 @@ func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (
Reason: reason,
}, nil
}
// isConditionBasedRules returns true if the JSON represents condition-based rules.
func isConditionBasedRules(rulesJSON string) bool {
var rules []ConditionRule
if err := json.Unmarshal([]byte(rulesJSON), &rules); err == nil && len(rules) > 0 {
// If the rule has either conditions or primary_model/rule_id, treat it as condition-based
return rules[0].PrimaryModel != "" || rules[0].RuleID != ""
}
return false
}
// matchConditions evaluates whether the given conditions match the RouteContext.
func matchConditions(cond Conditions, routeCtx *RouteContext) bool {
if cond.IsDefaultFallback != nil && *cond.IsDefaultFallback {
return true
}
// Check tags: must match any_of_tags if specified
if len(cond.AnyOfTags) > 0 {
tagMatched := false
for _, ruleTag := range cond.AnyOfTags {
for _, ctxTag := range routeCtx.Tags {
if strings.EqualFold(ruleTag, ctxTag) {
tagMatched = true
break
}
}
if tagMatched {
break
}
}
if !tagMatched {
return false
}
}
// Check max input tokens
if cond.MaxInputTokensLt != nil {
if routeCtx.InputTokens >= *cond.MaxInputTokensLt {
return false
}
}
// Check reasoning flag
if cond.RequiresReasoning != nil {
if routeCtx.RequiresReasoning != *cond.RequiresReasoning {
return false
}
}
// Check tool calling flag
if cond.RequiresToolCalling != nil {
if routeCtx.RequiresToolCalling != *cond.RequiresToolCalling {
return false
}
}
// Check multimodal flag
if cond.HasMultimodalInput != nil {
if routeCtx.HasMultimodalInput != *cond.HasMultimodalInput {
return false
}
}
return true
}
// getModelInTargets returns the model name if it exists in targets, or empty string.
func getModelInTargets(modelName string, targets []string) string {
for _, t := range targets {
if strings.EqualFold(t, modelName) {
return t
}
}
return ""
}
+142
View File
@@ -0,0 +1,142 @@
package router
import (
"testing"
"gophergate/internal/db"
)
func TestRouteHeuristic_ConditionRules(t *testing.T) {
targets := []string{
"deepseek-v4-flash", // index 0
"gemini-3-flash", // index 1
"grok-build-0.1", // index 2
"kimi-k2.6", // index 3
"mimo-v2.5-pro", // index 4
"grok-4.3", // index 5
"deepseek-v4-pro", // index 6
}
rulesJSON := `[
{
"rule_id": "fast_flow_extraction",
"conditions": {
"any_of_tags": ["fast-flow", "classification"],
"max_input_tokens_lt": 8000,
"requires_reasoning": false
},
"primary_model": "deepseek-v4-flash",
"fallback_model": "grok-build-0.1"
},
{
"rule_id": "multimodal_long_context",
"conditions": {
"any_of_tags": ["standard-pro", "long-doc"],
"has_multimodal_input": true
},
"primary_model": "gemini-3-flash",
"fallback_model": "mimo-v2.5-pro"
},
{
"rule_id": "regional_fallback_general",
"conditions": {
"is_default_fallback": true
},
"primary_model": "kimi-k2.6"
}
]`
group := db.ModelGroup{
ID: "dustins_stack",
Strategy: "heuristic",
HeuristicRules: &rulesJSON,
}
// 1. Test Match Fast Flow (condition success)
ctx1 := &RouteContext{
UserMessage: "classify this JSON",
InputTokens: 500,
HasMultimodalInput: false,
RequiresReasoning: false,
Tags: []string{"fast-flow", "classification"},
}
dec1, err := routeHeuristic(group, targets, ctx1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec1.SelectedModel != "deepseek-v4-flash" {
t.Fatalf("expected deepseek-v4-flash, got %s", dec1.SelectedModel)
}
// 2. Test Multimodal Long Context (condition success)
ctx2 := &RouteContext{
UserMessage: "explain this video",
InputTokens: 15000,
HasMultimodalInput: true,
RequiresReasoning: false,
Tags: []string{"standard-pro", "video-analysis"},
}
dec2, err := routeHeuristic(group, targets, ctx2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec2.SelectedModel != "gemini-3-flash" {
t.Fatalf("expected gemini-3-flash, got %s", dec2.SelectedModel)
}
// 3. Test Fallback general rule
ctx3 := &RouteContext{
UserMessage: "hello there",
InputTokens: 100,
HasMultimodalInput: false,
RequiresReasoning: false,
Tags: []string{"general"},
}
dec3, err := routeHeuristic(group, targets, ctx3)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec3.SelectedModel != "kimi-k2.6" {
t.Fatalf("expected kimi-k2.6, got %s", dec3.SelectedModel)
}
}
func TestRouteHeuristic_LegacyRules(t *testing.T) {
targets := []string{"gpt-4o-mini", "deepseek-v4-pro", "kimi-k2.6"}
// Legacy pattern-based rule with regex
rulesJSON := `[
{"pattern": "\\b(agent|agents|tool use)\\b", "target": 1},
{"pattern": "summarize", "target": 2}
]`
group := db.ModelGroup{
ID: "heavy-logic",
Strategy: "heuristic",
HeuristicRules: &rulesJSON,
}
// 1. Test regex match
ctx1 := &RouteContext{
UserMessage: "We need an agent to do tool use",
}
dec1, err := routeHeuristic(group, targets, ctx1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec1.SelectedModel != "deepseek-v4-pro" {
t.Fatalf("expected deepseek-v4-pro, got %s", dec1.SelectedModel)
}
// 2. Test literal match
ctx2 := &RouteContext{
UserMessage: "Please summarize this text",
}
dec2, err := routeHeuristic(group, targets, ctx2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec2.SelectedModel != "kimi-k2.6" {
t.Fatalf("expected kimi-k2.6, got %s", dec2.SelectedModel)
}
}
+16 -6
View File
@@ -16,6 +16,16 @@ type Decision struct {
Reason string `json:"reason"`
}
// RouteContext holds metadata of the request to evaluate condition rules.
type RouteContext struct {
UserMessage string `json:"user_message"`
InputTokens int `json:"input_tokens"`
HasMultimodalInput bool `json:"has_multimodal_input"`
RequiresToolCalling bool `json:"requires_tool_calling"`
RequiresReasoning bool `json:"requires_reasoning"`
Tags []string `json:"tags"`
}
// ClassifierFunc is the callback for classifier-based routing.
type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error)
@@ -53,7 +63,7 @@ func (r *Router) IsGroup(modelID string) bool {
}
// Route resolves a group to a concrete model.
func (r *Router) Route(ctx context.Context, groupID string, userMessage string) (*Decision, error) {
func (r *Router) Route(ctx context.Context, groupID string, routeCtx *RouteContext) (*Decision, error) {
group, ok := r.groups[groupID]
if !ok {
return nil, fmt.Errorf("unknown model group: %s", groupID)
@@ -66,12 +76,12 @@ func (r *Router) Route(ctx context.Context, groupID string, userMessage string)
switch group.Strategy {
case "heuristic":
return routeHeuristic(group, targets, userMessage)
return routeHeuristic(group, targets, routeCtx)
case "classifier":
if r.classify == nil {
return routeHeuristic(group, targets, userMessage)
return routeHeuristic(group, targets, routeCtx)
}
return routeClassifier(ctx, r.classify, group, targets, userMessage)
return routeClassifier(ctx, r.classify, group, targets, routeCtx)
default:
return nil, fmt.Errorf("unknown strategy: %s", group.Strategy)
}
@@ -80,7 +90,7 @@ func (r *Router) Route(ctx context.Context, groupID string, userMessage string)
// RouteToConcrete resolves a model name to a concrete model, following group
// chains recursively until a non-group target is reached. Returns the original
// name unchanged if it is not a group.
func (r *Router) RouteToConcrete(ctx context.Context, modelID string, userMessage string) (*Decision, error) {
func (r *Router) RouteToConcrete(ctx context.Context, modelID string, routeCtx *RouteContext) (*Decision, error) {
const maxDepth = 10
visited := make(map[string]bool)
current := modelID
@@ -109,7 +119,7 @@ func (r *Router) RouteToConcrete(ctx context.Context, modelID string, userMessag
}
visited[current] = true
decision, err := r.Route(ctx, current, userMessage)
decision, err := r.Route(ctx, current, routeCtx)
if err != nil {
return nil, err
}
+164 -3
View File
@@ -329,7 +329,8 @@ func (s *Server) handleResponses(c *gin.Context) {
}
}
if s.modelRouter != nil {
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, "")
routeCtx := s.buildRouteContextFromResponses(req)
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
return
@@ -573,8 +574,8 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
log.Printf("[DEBUG] Incoming Msg[%d]: role=%s, hasToolCalls=%v, hasContent=%v", i, m.Role, len(m.ToolCalls) > 0, m.Content != nil)
}
if s.modelRouter != nil {
userMessage := extractUserMessage(req.Messages)
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, userMessage)
routeCtx := s.buildRouteContextFromChat(req)
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
return
@@ -941,3 +942,163 @@ func (s *Server) Run() error {
}
func uint32Ptr(v uint32) *uint32 { return &v }
func (s *Server) buildRouteContextFromChat(req models.ChatCompletionRequest) *router.RouteContext {
userMessage := extractUserMessage(req.Messages)
requiresToolCalling := len(req.Tools) > 0
hasMultimodal := false
inputTokens := 0
for _, msg := range req.Messages {
if strContent, ok := msg.Content.(string); ok {
inputTokens += len(strContent) / 4
} else if parts, ok := msg.Content.([]interface{}); ok {
for _, part := range parts {
if partMap, ok := part.(map[string]interface{}); ok {
partType, _ := partMap["type"].(string)
if partType == "text" {
text, _ := partMap["text"].(string)
inputTokens += len(text) / 4
} else if partType == "image_url" {
hasMultimodal = true
inputTokens += 1000 // Approximate cost of an image in tokens
}
}
}
}
}
msgLower := strings.ToLower(userMessage)
requiresReasoning := strings.Contains(msgLower, "reason") ||
strings.Contains(msgLower, "think step by step") ||
strings.Contains(msgLower, "mathematics") ||
strings.Contains(msgLower, "architecture") ||
strings.Contains(msgLower, "explain in detail")
routeCtx := &router.RouteContext{
UserMessage: userMessage,
InputTokens: inputTokens,
HasMultimodalInput: hasMultimodal,
RequiresToolCalling: requiresToolCalling,
RequiresReasoning: requiresReasoning,
}
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
return routeCtx
}
func (s *Server) buildRouteContextFromResponses(req models.ResponsesRequest) *router.RouteContext {
var userMessage string
hasMultimodal := false
inputTokens := len(req.Instructions) / 4
requiresToolCalling := len(req.Tools) > 0 && string(req.Tools) != "null" && string(req.Tools) != ""
var strInput string
if err := json.Unmarshal(req.Input, &strInput); err == nil {
userMessage = strInput
inputTokens += len(userMessage) / 4
} else {
var msgs []models.ResponseInputMessage
if err := json.Unmarshal(req.Input, &msgs); err == nil {
for _, m := range msgs {
var contentStr string
if err := json.Unmarshal(m.Content, &contentStr); err == nil {
if m.Role == "user" {
userMessage = contentStr
}
inputTokens += len(contentStr) / 4
} else {
var parts []models.ContentPart
if err := json.Unmarshal(m.Content, &parts); err == nil {
for _, p := range parts {
if p.Type == "text" {
if m.Role == "user" {
userMessage = p.Text
}
inputTokens += len(p.Text) / 4
} else if p.Type == "image_url" {
hasMultimodal = true
inputTokens += 1000
}
}
}
}
}
}
}
msgLower := strings.ToLower(userMessage)
requiresReasoning := strings.Contains(msgLower, "reason") ||
strings.Contains(msgLower, "think step by step") ||
strings.Contains(msgLower, "mathematics") ||
strings.Contains(msgLower, "architecture") ||
strings.Contains(msgLower, "explain in detail")
routeCtx := &router.RouteContext{
UserMessage: userMessage,
InputTokens: inputTokens,
HasMultimodalInput: hasMultimodal,
RequiresToolCalling: requiresToolCalling,
RequiresReasoning: requiresReasoning,
}
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
return routeCtx
}
func (s *Server) getRouteCtxTags(routeCtx *router.RouteContext) []string {
var tags []string
msgLower := strings.ToLower(routeCtx.UserMessage)
// fast-flow keywords
fastFlowKeywords := []string{
"classify", "classification", "label", "tag", "route", "routing", "intent",
"json", "yaml", "csv", "schema", "parse", "extract", "transform", "format", "regex",
"short answer", "brief", "concise", "tl;dr", "one line", "simple",
"fix this", "small bug", "quick fix", "typo", "syntax error",
}
for _, kw := range fastFlowKeywords {
if strings.Contains(msgLower, kw) {
tags = append(tags, "fast-flow", "classification", "json-extraction", "basic-qa")
break
}
}
// standard-pro keywords
standardProKeywords := []string{
"explain", "summarize", "rewrite", "draft", "edit", "polish", "outline",
"long doc", "document", "email", "memo", "proposal", "report", "handout", "notes",
"compare", "choose", "recommend", "tradeoff", "pros and cons", "analysis",
"code review", "debug", "bug", "feature", "api", "endpoint", "implement",
"plan", "planning", "workflow", "integration",
}
for _, kw := range standardProKeywords {
if strings.Contains(msgLower, kw) {
tags = append(tags, "standard-pro", "long-doc")
break
}
}
if routeCtx.HasMultimodalInput {
tags = append(tags, "video-analysis", "multimodal-qa")
}
// heavy-logic keywords
heavyLogicKeywords := []string{
"agent", "agents", "tool use", "function calling", "multi-agent", "orchestrate",
"system design", "scaling", "performance", "architecture review", "distributed",
"hard bug", "race condition", "deadlock", "memory leak", "crash", "production outage",
"long context", "large codebase", "many files", "complex refactor", "migration",
"research", "deep dive", "literature", "paper", "scholarly", "thorough analysis",
"deep reasoning", "think step by step", "reason through", "careful analysis",
}
for _, kw := range heavyLogicKeywords {
if strings.Contains(msgLower, kw) {
tags = append(tags, "heavy-logic", "deep-reasoning", "architecture", "hard-debugging")
break
}
}
if routeCtx.RequiresToolCalling {
tags = append(tags, "tool-heavy", "multi-step-agent", "swe-bench")
}
return tags
}