feat: implement advanced condition-based heuristic model routing
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled

Upgrades the routing engine to support tag, token limit, multimodal, reasoning, and tool calling conditions. Adds unit tests for the new routing features.
This commit is contained in:
2026-06-05 15:05:13 +00:00
parent b3354a1bbc
commit 73a82e6175
12 changed files with 731 additions and 45 deletions
+3
View File
@@ -46,6 +46,9 @@ Implements HMAC-SHA256 signed tokens for dashboard authentication. Tokens secure
### 5. WebSocket Hub (`internal/server/websocket.go`) ### 5. WebSocket Hub (`internal/server/websocket.go`)
A centralized hub for managing WebSocket connections, allowing real-time broadcast of system events, system metrics, and request logs to the dashboard. A centralized hub for managing WebSocket connections, allowing real-time broadcast of system events, system metrics, and request logs to the dashboard.
### 6. Model Group Router (`internal/router/`)
Resolves model groups (e.g., `deepseek-auto`, `dustins_stack`) into concrete models. It supports a Classifier strategy (uses a cheap LLM to rate complexity) and an upgraded Heuristic strategy (evaluates custom condition rules like tags, token counts, multimodal inputs, reasoning, and tool calling flags or legacy keyword patterns).
## Concurrency Model ## Concurrency Model
Go's goroutines and channels are used extensively: Go's goroutines and channels are used extensively:
+1 -1
View File
@@ -22,7 +22,7 @@ A unified, high-performance LLM proxy gateway built in Go. It provides OpenAI-co
- **Image Generation:** DALL-E 2/3 (OpenAI) and Imagen 3 (Gemini) via OpenAI-compatible `/v1/images/generations` endpoint. - **Image Generation:** DALL-E 2/3 (OpenAI) and Imagen 3 (Gemini) via OpenAI-compatible `/v1/images/generations` endpoint.
- **Automatic Model Routing:** - **Automatic Model Routing:**
- **Hierarchical Routing:** Groups can target other groups, cascading through multiple levels until a concrete model is reached. Cycle detection and depth limiting (max 10) prevent infinite loops. - **Hierarchical Routing:** Groups can target other groups, cascading through multiple levels until a concrete model is reached. Cycle detection and depth limiting (max 10) prevent infinite loops.
- **Heuristic strategy:** Free, zero-latency keyword matching (e.g. "debug" or "step by step" routes to the reasoning model). - **Heuristic strategy:** Free, zero-latency routing supporting both keyword matching (regex/substrings) and condition-based checks (evaluating tags, token limits, multimodal inputs, reasoning, and tool calling requirements).
- **Classifier strategy:** Uses a cheap LLM to rate task complexity on a configurable scale (1-10), then selects the appropriate model. Bucket mapping distributes ratings proportionally across targets. - **Classifier strategy:** Uses a cheap LLM to rate task complexity on a configurable scale (1-10), then selects the appropriate model. Bucket mapping distributes ratings proportionally across targets.
- **Two-Level Dispatch:** A `dispatcher` group (classifier, threshold=10) auto-routes to tier groups by complexity score, which then apply their own internal strategies. - **Two-Level Dispatch:** A `dispatcher` group (classifier, threshold=10) auto-routes to tier groups by complexity score, which then apply their own internal strategies.
- **Metadata:** Groups support `logic_level` (1-10 complexity scale) and `primary_use` (description) fields for organizational clarity. - **Metadata:** Groups support `logic_level` (1-10 complexity scale) and `primary_use` (description) fields for organizational clarity.
+8
View File
@@ -32,6 +32,14 @@ func Init(path string) (*DB, error) {
return nil, fmt.Errorf("failed to connect to database: %w", err) return nil, fmt.Errorf("failed to connect to database: %w", err)
} }
// Enable Write-Ahead Logging (WAL) and set a busy timeout to handle concurrent access
if _, err := db.Exec("PRAGMA journal_mode=WAL;"); err != nil {
log.Printf("failed to enable WAL mode: %v", err)
}
if _, err := db.Exec("PRAGMA busy_timeout=5000;"); err != nil {
log.Printf("failed to set busy timeout: %v", err)
}
instance := &DB{db} instance := &DB{db}
// Run migrations // Run migrations
+44 -5
View File
@@ -14,9 +14,21 @@ import (
func AuthMiddleware(database *db.DB, requireAuth bool) gin.HandlerFunc { func AuthMiddleware(database *db.DB, requireAuth bool) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
if authHeader == "" {
// Fallback to checking "Authentication" header in case the client library used the wrong name
authHeader = c.GetHeader("Authentication")
}
if authHeader == "" { if authHeader == "" {
if requireAuth { if requireAuth {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"}) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": "Missing Authorization or Authentication header.",
"type": "invalid_request_error",
"param": nil,
"code": "401",
},
})
return return
} }
c.Next() c.Next()
@@ -25,23 +37,50 @@ func AuthMiddleware(database *db.DB, requireAuth bool) gin.HandlerFunc {
token := strings.TrimPrefix(authHeader, "Bearer ") token := strings.TrimPrefix(authHeader, "Bearer ")
if token == authHeader { // No "Bearer " prefix if token == authHeader { // No "Bearer " prefix
if requireAuth {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": "Invalid authorization header format. Bearer token required.",
"type": "invalid_request_error",
"param": nil,
"code": "401",
},
})
return
}
c.Next() c.Next()
return return
} }
// Try to resolve client from database // Try to resolve client from database with a read-only SELECT
var clientID string var clientID string
err := database.Get(&clientID, "UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = 1 RETURNING client_id", token) err := database.Get(&clientID, "SELECT client_id FROM client_tokens WHERE token = ? AND is_active = 1", token)
if err == nil { if err == nil {
c.Set("auth", models.AuthInfo{ c.Set("auth", models.AuthInfo{
Token: token, Token: token,
ClientID: clientID, ClientID: clientID,
}) })
// Update last_used_at asynchronously so that database locks or write delays
// do not block or fail the client's request authentication.
go func(t string) {
if _, updateErr := database.Exec("UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ?", t); updateErr != nil {
log.Printf("Warning: failed to update client token last_used_at: %v", updateErr)
}
}(token)
c.Next() c.Next()
} else { } else {
log.Printf("Token not found or inactive in DB: %s", token) log.Printf("Token not found, inactive or error in DB: %s (err: %v)", token, err)
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid or inactive token"}) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": "Invalid or inactive client token.",
"type": "invalid_request_error",
"param": nil,
"code": "401",
},
})
} }
} }
} }
+14 -6
View File
@@ -338,6 +338,7 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
} }
// Map Tools // Map Tools
hasMappedTools := false
if len(req.Tools) > 0 { if len(req.Tools) > 0 {
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}} geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
for _, t := range req.Tools { for _, t := range req.Tools {
@@ -345,13 +346,16 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function) geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
} }
} }
body.Tools = []GeminiTool{geminiTool} if len(geminiTool.FunctionDeclarations) > 0 {
body.Tools = []GeminiTool{geminiTool}
hasMappedTools = true
}
} }
baseURL := p.config.BaseURL baseURL := p.config.BaseURL
lowerModel := strings.ToLower(req.Model) lowerModel := strings.ToLower(req.Model)
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") { if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools {
// Use v1beta for preview and newer models // Use v1beta for preview, newer models, or when using tools
if !strings.Contains(baseURL, "v1beta") { if !strings.Contains(baseURL, "v1beta") {
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1) baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
} }
@@ -578,6 +582,7 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
GenerationConfig: genConfig, GenerationConfig: genConfig,
} }
hasMappedTools := false
if len(req.Tools) > 0 { if len(req.Tools) > 0 {
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}} geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
for _, t := range req.Tools { for _, t := range req.Tools {
@@ -585,13 +590,16 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function) geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
} }
} }
body.Tools = []GeminiTool{geminiTool} if len(geminiTool.FunctionDeclarations) > 0 {
body.Tools = []GeminiTool{geminiTool}
hasMappedTools = true
}
} }
baseURL := p.config.BaseURL baseURL := p.config.BaseURL
lowerModel := strings.ToLower(req.Model) lowerModel := strings.ToLower(req.Model)
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") { if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools {
// Use v1beta for preview and newer models // Use v1beta for preview, newer models, or when using tools
if !strings.Contains(baseURL, "v1beta") { if !strings.Contains(baseURL, "v1beta") {
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1) baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
} }
+33 -2
View File
@@ -10,6 +10,22 @@ import (
"gophergate/internal/models" "gophergate/internal/models"
) )
func sanitizeFunctionName(name string) string {
var sb strings.Builder
for _, ch := range name {
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' {
sb.WriteRune(ch)
} else {
sb.WriteRune('_')
}
}
res := sb.String()
if res == "" {
return "function"
}
return res
}
// MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images. // MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images.
func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) { func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) {
var result []interface{} var result []interface{}
@@ -35,7 +51,7 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
msg["tool_call_id"] = id msg["tool_call_id"] = id
if m.Name != nil { if m.Name != nil {
msg["name"] = *m.Name msg["name"] = sanitizeFunctionName(*m.Name)
} }
result = append(result, msg) result = append(result, msg)
continue continue
@@ -91,6 +107,7 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
if sanitizedCalls[i].Type == "" { if sanitizedCalls[i].Type == "" {
sanitizedCalls[i].Type = "function" sanitizedCalls[i].Type = "function"
} }
sanitizedCalls[i].Function.Name = sanitizeFunctionName(sanitizedCalls[i].Function.Name)
} }
msg["tool_calls"] = sanitizedCalls msg["tool_calls"] = sanitizedCalls
msg["content"] = "" // OpenAI requirement: content must be string if tool_calls present msg["content"] = "" // OpenAI requirement: content must be string if tool_calls present
@@ -124,11 +141,25 @@ func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{},
body["max_tokens"] = *request.MaxTokens body["max_tokens"] = *request.MaxTokens
} }
if len(request.Tools) > 0 { if len(request.Tools) > 0 {
body["tools"] = request.Tools sanitizedTools := make([]models.Tool, len(request.Tools))
copy(sanitizedTools, request.Tools)
for i := range sanitizedTools {
if sanitizedTools[i].Type == "function" {
sanitizedTools[i].Function.Name = sanitizeFunctionName(sanitizedTools[i].Function.Name)
}
}
body["tools"] = sanitizedTools
} }
if request.ToolChoice != nil { if request.ToolChoice != nil {
var toolChoice interface{} var toolChoice interface{}
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil { if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
if tcMap, ok := toolChoice.(map[string]interface{}); ok {
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
if name, ok := funcMap["name"].(string); ok {
funcMap["name"] = sanitizeFunctionName(name)
}
}
}
body["tool_choice"] = toolChoice body["tool_choice"] = toolChoice
} }
} }
+127
View File
@@ -0,0 +1,127 @@
package providers
import (
"encoding/json"
"testing"
"gophergate/internal/models"
)
func TestSanitizeFunctionName(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"google-search", "google-search"},
{"google.search", "google_search"},
{"google search", "google_search"},
{"web_search(query)", "web_search_query_"},
{"", "function"},
{"123_abc-XYZ", "123_abc-XYZ"},
{"invalid.name.with.dots", "invalid_name_with_dots"},
}
for _, tc := range tests {
actual := sanitizeFunctionName(tc.input)
if actual != tc.expected {
t.Errorf("sanitizeFunctionName(%q) = %q; expected %q", tc.input, actual, tc.expected)
}
}
}
func TestMessagesToOpenAIJSON_SanitizeToolCalls(t *testing.T) {
messages := []models.UnifiedMessage{
{
Role: "assistant",
Content: []models.UnifiedContentPart{
{Type: "text", Text: "I will use search."},
},
ToolCalls: []models.ToolCall{
{
ID: "call_1",
Type: "function",
Function: models.FunctionCall{
Name: "google.search",
Arguments: `{"query": "hello"}`,
},
},
},
},
{
Role: "tool",
Content: []models.UnifiedContentPart{
{Type: "text", Text: `{"result": "success"}`},
},
ToolCallID: stringPtr("call_1"),
Name: stringPtr("google.search"),
},
}
res, err := MessagesToOpenAIJSON(messages)
if err != nil {
t.Fatalf("MessagesToOpenAIJSON failed: %v", err)
}
if len(res) != 2 {
t.Fatalf("expected 2 messages, got %d", len(res))
}
// Verify assistant message
msg1 := res[0].(map[string]interface{})
if msg1["role"] != "assistant" {
t.Errorf("expected role assistant, got %v", msg1["role"])
}
calls := msg1["tool_calls"].([]models.ToolCall)
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
if calls[0].Function.Name != "google_search" {
t.Errorf("expected function name google_search, got %q", calls[0].Function.Name)
}
// Verify tool response message
msg2 := res[1].(map[string]interface{})
if msg2["role"] != "tool" {
t.Errorf("expected role tool, got %v", msg2["role"])
}
if msg2["name"] != "google_search" {
t.Errorf("expected tool name google_search, got %v", msg2["name"])
}
}
func TestBuildOpenAIBody_SanitizeToolsAndChoice(t *testing.T) {
req := &models.UnifiedRequest{
Model: "gpt-4o",
Tools: []models.Tool{
{
Type: "function",
Function: models.FunctionDef{
Name: "google.search",
},
},
},
ToolChoice: json.RawMessage(`{"type": "function", "function": {"name": "google.search"}}`),
}
body := BuildOpenAIBody(req, nil, false)
// Verify tools
tools := body["tools"].([]models.Tool)
if len(tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(tools))
}
if tools[0].Function.Name != "google_search" {
t.Errorf("expected tool function name google_search, got %q", tools[0].Function.Name)
}
// Verify tool_choice
toolChoice := body["tool_choice"].(map[string]interface{})
funcObj := toolChoice["function"].(map[string]interface{})
if funcObj["name"] != "google_search" {
t.Errorf("expected tool_choice function name google_search, got %q", funcObj["name"])
}
}
func stringPtr(s string) *string {
return &s
}
+7 -3
View File
@@ -15,7 +15,7 @@ const classifierSystemPrompt = `You are a task complexity classifier. Rate the f
Reply with ONLY the number. No explanation.` Reply with ONLY the number. No explanation.`
func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, userMessage string) (*Decision, error) { func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
// Determine the rating scale // Determine the rating scale
maxRating := len(targets) maxRating := len(targets)
if maxRating < 2 { if maxRating < 2 {
@@ -30,10 +30,14 @@ func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.Mode
} }
prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating) prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating)
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMessage) userMsg := ""
if routeCtx != nil {
userMsg = routeCtx.UserMessage
}
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMsg)
if err != nil { if err != nil {
// Classifier failed — fall back to heuristic // Classifier failed — fall back to heuristic
return routeHeuristic(group, targets, userMessage) return routeHeuristic(group, targets, routeCtx)
} }
rating, err := strconv.Atoi(strings.TrimSpace(ratingStr)) rating, err := strconv.Atoi(strings.TrimSpace(ratingStr))
+172 -19
View File
@@ -2,48 +2,125 @@ package router
import ( import (
"encoding/json" "encoding/json"
"regexp"
"strings" "strings"
"gophergate/internal/db" "gophergate/internal/db"
) )
// HeuristicRule defines a pattern-based routing rule. // HeuristicRule defines a pattern-based routing rule (legacy format).
type HeuristicRule struct { type HeuristicRule struct {
Pattern string `json:"pattern"` Pattern string `json:"pattern"`
TargetIdx int `json:"target"` TargetIdx int `json:"target"`
CaseSensitive bool `json:"case_sensitive,omitempty"` CaseSensitive bool `json:"case_sensitive,omitempty"`
} }
func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (*Decision, error) { // ConditionRule defines a condition-based routing rule (new format).
type ConditionRule struct {
RuleID string `json:"rule_id"`
Description string `json:"description,omitempty"`
Conditions Conditions `json:"conditions"`
PrimaryModel string `json:"primary_model"`
FallbackModel string `json:"fallback_model,omitempty"`
}
// Conditions defines the matching parameters for a rule.
type Conditions struct {
AnyOfTags []string `json:"any_of_tags,omitempty"`
MaxInputTokensLt *int `json:"max_input_tokens_lt,omitempty"`
RequiresReasoning *bool `json:"requires_reasoning,omitempty"`
RequiresToolCalling *bool `json:"requires_tool_calling,omitempty"`
HasMultimodalInput *bool `json:"has_multimodal_input,omitempty"`
IsDefaultFallback *bool `json:"is_default_fallback,omitempty"`
}
func routeHeuristic(group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
if routeCtx == nil {
routeCtx = &RouteContext{}
}
selected := targets[0] selected := targets[0]
reason := "default (first target)" reason := "default (first target)"
// If heuristic_rules is set, use them // If heuristic_rules is set, determine format and parse
if group.HeuristicRules != nil && *group.HeuristicRules != "" { if group.HeuristicRules != nil && *group.HeuristicRules != "" {
var rules []HeuristicRule rulesJSON := *group.HeuristicRules
if err := json.Unmarshal([]byte(*group.HeuristicRules), &rules); err == nil {
searchMsg := userMessage if isConditionBasedRules(rulesJSON) {
for _, rule := range rules { var condRules []ConditionRule
pattern := rule.Pattern if err := json.Unmarshal([]byte(rulesJSON), &condRules); err == nil {
msg := searchMsg for _, rule := range condRules {
if !rule.CaseSensitive { if matchConditions(rule.Conditions, routeCtx) {
pattern = strings.ToLower(pattern) // Resolve primary/fallback to concrete models in target list
msg = strings.ToLower(msg) targetModel := ""
if rule.PrimaryModel != "" {
targetModel = getModelInTargets(rule.PrimaryModel, targets)
}
if targetModel == "" && rule.FallbackModel != "" {
targetModel = getModelInTargets(rule.FallbackModel, targets)
}
if targetModel != "" {
selected = targetModel
reason = "matched condition rule: " + rule.RuleID
if rule.Description != "" {
reason += " (" + rule.Description + ")"
}
break
}
}
} }
if strings.Contains(msg, pattern) { }
if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) { } else {
selected = targets[rule.TargetIdx] // Fallback to legacy pattern-based rules
reason = "matched heuristic rule: " + rule.Pattern var legacyRules []HeuristicRule
break if err := json.Unmarshal([]byte(rulesJSON), &legacyRules); err == nil {
searchMsg := routeCtx.UserMessage
for _, rule := range legacyRules {
pattern := rule.Pattern
if pattern == "" {
continue // Avoid infinite matches with empty patterns
}
msg := searchMsg
if !rule.CaseSensitive {
pattern = strings.ToLower(pattern)
msg = strings.ToLower(msg)
}
// Support both regex matching (if pattern is valid regex) and literal contains
matched := false
if strings.Contains(rule.Pattern, "(") || strings.Contains(rule.Pattern, "\\b") {
var re *regexp.Regexp
var err error
if !rule.CaseSensitive {
re, err = regexp.Compile("(?i)" + rule.Pattern)
} else {
re, err = regexp.Compile(rule.Pattern)
}
if err == nil {
matched = re.MatchString(routeCtx.UserMessage)
}
}
if !matched && strings.Contains(msg, pattern) {
matched = true
}
if matched {
if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) {
selected = targets[rule.TargetIdx]
reason = "matched heuristic rule: " + rule.Pattern
break
}
} }
} }
} }
} }
} }
// Built-in fallback heuristics // Built-in fallback heuristics (if no custom rule matched)
if reason == "default (first target)" && len(targets) > 1 { if reason == "default (first target)" && len(targets) > 1 {
msgLower := strings.ToLower(userMessage) msgLower := strings.ToLower(routeCtx.UserMessage)
complexIndicators := []string{ complexIndicators := []string{
"step by step", "explain in detail", "reason through", "step by step", "explain in detail", "reason through",
"think carefully", "analyze", "debug", "write code", "think carefully", "analyze", "debug", "write code",
@@ -64,3 +141,79 @@ func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (
Reason: reason, Reason: reason,
}, nil }, nil
} }
// isConditionBasedRules returns true if the JSON represents condition-based rules.
func isConditionBasedRules(rulesJSON string) bool {
var rules []ConditionRule
if err := json.Unmarshal([]byte(rulesJSON), &rules); err == nil && len(rules) > 0 {
// If the rule has either conditions or primary_model/rule_id, treat it as condition-based
return rules[0].PrimaryModel != "" || rules[0].RuleID != ""
}
return false
}
// matchConditions evaluates whether the given conditions match the RouteContext.
func matchConditions(cond Conditions, routeCtx *RouteContext) bool {
if cond.IsDefaultFallback != nil && *cond.IsDefaultFallback {
return true
}
// Check tags: must match any_of_tags if specified
if len(cond.AnyOfTags) > 0 {
tagMatched := false
for _, ruleTag := range cond.AnyOfTags {
for _, ctxTag := range routeCtx.Tags {
if strings.EqualFold(ruleTag, ctxTag) {
tagMatched = true
break
}
}
if tagMatched {
break
}
}
if !tagMatched {
return false
}
}
// Check max input tokens
if cond.MaxInputTokensLt != nil {
if routeCtx.InputTokens >= *cond.MaxInputTokensLt {
return false
}
}
// Check reasoning flag
if cond.RequiresReasoning != nil {
if routeCtx.RequiresReasoning != *cond.RequiresReasoning {
return false
}
}
// Check tool calling flag
if cond.RequiresToolCalling != nil {
if routeCtx.RequiresToolCalling != *cond.RequiresToolCalling {
return false
}
}
// Check multimodal flag
if cond.HasMultimodalInput != nil {
if routeCtx.HasMultimodalInput != *cond.HasMultimodalInput {
return false
}
}
return true
}
// getModelInTargets returns the model name if it exists in targets, or empty string.
func getModelInTargets(modelName string, targets []string) string {
for _, t := range targets {
if strings.EqualFold(t, modelName) {
return t
}
}
return ""
}
+142
View File
@@ -0,0 +1,142 @@
package router
import (
"testing"
"gophergate/internal/db"
)
func TestRouteHeuristic_ConditionRules(t *testing.T) {
targets := []string{
"deepseek-v4-flash", // index 0
"gemini-3-flash", // index 1
"grok-build-0.1", // index 2
"kimi-k2.6", // index 3
"mimo-v2.5-pro", // index 4
"grok-4.3", // index 5
"deepseek-v4-pro", // index 6
}
rulesJSON := `[
{
"rule_id": "fast_flow_extraction",
"conditions": {
"any_of_tags": ["fast-flow", "classification"],
"max_input_tokens_lt": 8000,
"requires_reasoning": false
},
"primary_model": "deepseek-v4-flash",
"fallback_model": "grok-build-0.1"
},
{
"rule_id": "multimodal_long_context",
"conditions": {
"any_of_tags": ["standard-pro", "long-doc"],
"has_multimodal_input": true
},
"primary_model": "gemini-3-flash",
"fallback_model": "mimo-v2.5-pro"
},
{
"rule_id": "regional_fallback_general",
"conditions": {
"is_default_fallback": true
},
"primary_model": "kimi-k2.6"
}
]`
group := db.ModelGroup{
ID: "dustins_stack",
Strategy: "heuristic",
HeuristicRules: &rulesJSON,
}
// 1. Test Match Fast Flow (condition success)
ctx1 := &RouteContext{
UserMessage: "classify this JSON",
InputTokens: 500,
HasMultimodalInput: false,
RequiresReasoning: false,
Tags: []string{"fast-flow", "classification"},
}
dec1, err := routeHeuristic(group, targets, ctx1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec1.SelectedModel != "deepseek-v4-flash" {
t.Fatalf("expected deepseek-v4-flash, got %s", dec1.SelectedModel)
}
// 2. Test Multimodal Long Context (condition success)
ctx2 := &RouteContext{
UserMessage: "explain this video",
InputTokens: 15000,
HasMultimodalInput: true,
RequiresReasoning: false,
Tags: []string{"standard-pro", "video-analysis"},
}
dec2, err := routeHeuristic(group, targets, ctx2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec2.SelectedModel != "gemini-3-flash" {
t.Fatalf("expected gemini-3-flash, got %s", dec2.SelectedModel)
}
// 3. Test Fallback general rule
ctx3 := &RouteContext{
UserMessage: "hello there",
InputTokens: 100,
HasMultimodalInput: false,
RequiresReasoning: false,
Tags: []string{"general"},
}
dec3, err := routeHeuristic(group, targets, ctx3)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec3.SelectedModel != "kimi-k2.6" {
t.Fatalf("expected kimi-k2.6, got %s", dec3.SelectedModel)
}
}
func TestRouteHeuristic_LegacyRules(t *testing.T) {
targets := []string{"gpt-4o-mini", "deepseek-v4-pro", "kimi-k2.6"}
// Legacy pattern-based rule with regex
rulesJSON := `[
{"pattern": "\\b(agent|agents|tool use)\\b", "target": 1},
{"pattern": "summarize", "target": 2}
]`
group := db.ModelGroup{
ID: "heavy-logic",
Strategy: "heuristic",
HeuristicRules: &rulesJSON,
}
// 1. Test regex match
ctx1 := &RouteContext{
UserMessage: "We need an agent to do tool use",
}
dec1, err := routeHeuristic(group, targets, ctx1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec1.SelectedModel != "deepseek-v4-pro" {
t.Fatalf("expected deepseek-v4-pro, got %s", dec1.SelectedModel)
}
// 2. Test literal match
ctx2 := &RouteContext{
UserMessage: "Please summarize this text",
}
dec2, err := routeHeuristic(group, targets, ctx2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec2.SelectedModel != "kimi-k2.6" {
t.Fatalf("expected kimi-k2.6, got %s", dec2.SelectedModel)
}
}
+16 -6
View File
@@ -16,6 +16,16 @@ type Decision struct {
Reason string `json:"reason"` Reason string `json:"reason"`
} }
// RouteContext holds metadata of the request to evaluate condition rules.
type RouteContext struct {
UserMessage string `json:"user_message"`
InputTokens int `json:"input_tokens"`
HasMultimodalInput bool `json:"has_multimodal_input"`
RequiresToolCalling bool `json:"requires_tool_calling"`
RequiresReasoning bool `json:"requires_reasoning"`
Tags []string `json:"tags"`
}
// ClassifierFunc is the callback for classifier-based routing. // ClassifierFunc is the callback for classifier-based routing.
type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error)
@@ -53,7 +63,7 @@ func (r *Router) IsGroup(modelID string) bool {
} }
// Route resolves a group to a concrete model. // Route resolves a group to a concrete model.
func (r *Router) Route(ctx context.Context, groupID string, userMessage string) (*Decision, error) { func (r *Router) Route(ctx context.Context, groupID string, routeCtx *RouteContext) (*Decision, error) {
group, ok := r.groups[groupID] group, ok := r.groups[groupID]
if !ok { if !ok {
return nil, fmt.Errorf("unknown model group: %s", groupID) return nil, fmt.Errorf("unknown model group: %s", groupID)
@@ -66,12 +76,12 @@ func (r *Router) Route(ctx context.Context, groupID string, userMessage string)
switch group.Strategy { switch group.Strategy {
case "heuristic": case "heuristic":
return routeHeuristic(group, targets, userMessage) return routeHeuristic(group, targets, routeCtx)
case "classifier": case "classifier":
if r.classify == nil { if r.classify == nil {
return routeHeuristic(group, targets, userMessage) return routeHeuristic(group, targets, routeCtx)
} }
return routeClassifier(ctx, r.classify, group, targets, userMessage) return routeClassifier(ctx, r.classify, group, targets, routeCtx)
default: default:
return nil, fmt.Errorf("unknown strategy: %s", group.Strategy) return nil, fmt.Errorf("unknown strategy: %s", group.Strategy)
} }
@@ -80,7 +90,7 @@ func (r *Router) Route(ctx context.Context, groupID string, userMessage string)
// RouteToConcrete resolves a model name to a concrete model, following group // RouteToConcrete resolves a model name to a concrete model, following group
// chains recursively until a non-group target is reached. Returns the original // chains recursively until a non-group target is reached. Returns the original
// name unchanged if it is not a group. // name unchanged if it is not a group.
func (r *Router) RouteToConcrete(ctx context.Context, modelID string, userMessage string) (*Decision, error) { func (r *Router) RouteToConcrete(ctx context.Context, modelID string, routeCtx *RouteContext) (*Decision, error) {
const maxDepth = 10 const maxDepth = 10
visited := make(map[string]bool) visited := make(map[string]bool)
current := modelID current := modelID
@@ -109,7 +119,7 @@ func (r *Router) RouteToConcrete(ctx context.Context, modelID string, userMessag
} }
visited[current] = true visited[current] = true
decision, err := r.Route(ctx, current, userMessage) decision, err := r.Route(ctx, current, routeCtx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
+164 -3
View File
@@ -329,7 +329,8 @@ func (s *Server) handleResponses(c *gin.Context) {
} }
} }
if s.modelRouter != nil { if s.modelRouter != nil {
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, "") routeCtx := s.buildRouteContextFromResponses(req)
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)}) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
return return
@@ -573,8 +574,8 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
log.Printf("[DEBUG] Incoming Msg[%d]: role=%s, hasToolCalls=%v, hasContent=%v", i, m.Role, len(m.ToolCalls) > 0, m.Content != nil) log.Printf("[DEBUG] Incoming Msg[%d]: role=%s, hasToolCalls=%v, hasContent=%v", i, m.Role, len(m.ToolCalls) > 0, m.Content != nil)
} }
if s.modelRouter != nil { if s.modelRouter != nil {
userMessage := extractUserMessage(req.Messages) routeCtx := s.buildRouteContextFromChat(req)
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, userMessage) decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)}) c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
return return
@@ -941,3 +942,163 @@ func (s *Server) Run() error {
} }
func uint32Ptr(v uint32) *uint32 { return &v } func uint32Ptr(v uint32) *uint32 { return &v }
func (s *Server) buildRouteContextFromChat(req models.ChatCompletionRequest) *router.RouteContext {
userMessage := extractUserMessage(req.Messages)
requiresToolCalling := len(req.Tools) > 0
hasMultimodal := false
inputTokens := 0
for _, msg := range req.Messages {
if strContent, ok := msg.Content.(string); ok {
inputTokens += len(strContent) / 4
} else if parts, ok := msg.Content.([]interface{}); ok {
for _, part := range parts {
if partMap, ok := part.(map[string]interface{}); ok {
partType, _ := partMap["type"].(string)
if partType == "text" {
text, _ := partMap["text"].(string)
inputTokens += len(text) / 4
} else if partType == "image_url" {
hasMultimodal = true
inputTokens += 1000 // Approximate cost of an image in tokens
}
}
}
}
}
msgLower := strings.ToLower(userMessage)
requiresReasoning := strings.Contains(msgLower, "reason") ||
strings.Contains(msgLower, "think step by step") ||
strings.Contains(msgLower, "mathematics") ||
strings.Contains(msgLower, "architecture") ||
strings.Contains(msgLower, "explain in detail")
routeCtx := &router.RouteContext{
UserMessage: userMessage,
InputTokens: inputTokens,
HasMultimodalInput: hasMultimodal,
RequiresToolCalling: requiresToolCalling,
RequiresReasoning: requiresReasoning,
}
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
return routeCtx
}
func (s *Server) buildRouteContextFromResponses(req models.ResponsesRequest) *router.RouteContext {
var userMessage string
hasMultimodal := false
inputTokens := len(req.Instructions) / 4
requiresToolCalling := len(req.Tools) > 0 && string(req.Tools) != "null" && string(req.Tools) != ""
var strInput string
if err := json.Unmarshal(req.Input, &strInput); err == nil {
userMessage = strInput
inputTokens += len(userMessage) / 4
} else {
var msgs []models.ResponseInputMessage
if err := json.Unmarshal(req.Input, &msgs); err == nil {
for _, m := range msgs {
var contentStr string
if err := json.Unmarshal(m.Content, &contentStr); err == nil {
if m.Role == "user" {
userMessage = contentStr
}
inputTokens += len(contentStr) / 4
} else {
var parts []models.ContentPart
if err := json.Unmarshal(m.Content, &parts); err == nil {
for _, p := range parts {
if p.Type == "text" {
if m.Role == "user" {
userMessage = p.Text
}
inputTokens += len(p.Text) / 4
} else if p.Type == "image_url" {
hasMultimodal = true
inputTokens += 1000
}
}
}
}
}
}
}
msgLower := strings.ToLower(userMessage)
requiresReasoning := strings.Contains(msgLower, "reason") ||
strings.Contains(msgLower, "think step by step") ||
strings.Contains(msgLower, "mathematics") ||
strings.Contains(msgLower, "architecture") ||
strings.Contains(msgLower, "explain in detail")
routeCtx := &router.RouteContext{
UserMessage: userMessage,
InputTokens: inputTokens,
HasMultimodalInput: hasMultimodal,
RequiresToolCalling: requiresToolCalling,
RequiresReasoning: requiresReasoning,
}
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
return routeCtx
}
func (s *Server) getRouteCtxTags(routeCtx *router.RouteContext) []string {
var tags []string
msgLower := strings.ToLower(routeCtx.UserMessage)
// fast-flow keywords
fastFlowKeywords := []string{
"classify", "classification", "label", "tag", "route", "routing", "intent",
"json", "yaml", "csv", "schema", "parse", "extract", "transform", "format", "regex",
"short answer", "brief", "concise", "tl;dr", "one line", "simple",
"fix this", "small bug", "quick fix", "typo", "syntax error",
}
for _, kw := range fastFlowKeywords {
if strings.Contains(msgLower, kw) {
tags = append(tags, "fast-flow", "classification", "json-extraction", "basic-qa")
break
}
}
// standard-pro keywords
standardProKeywords := []string{
"explain", "summarize", "rewrite", "draft", "edit", "polish", "outline",
"long doc", "document", "email", "memo", "proposal", "report", "handout", "notes",
"compare", "choose", "recommend", "tradeoff", "pros and cons", "analysis",
"code review", "debug", "bug", "feature", "api", "endpoint", "implement",
"plan", "planning", "workflow", "integration",
}
for _, kw := range standardProKeywords {
if strings.Contains(msgLower, kw) {
tags = append(tags, "standard-pro", "long-doc")
break
}
}
if routeCtx.HasMultimodalInput {
tags = append(tags, "video-analysis", "multimodal-qa")
}
// heavy-logic keywords
heavyLogicKeywords := []string{
"agent", "agents", "tool use", "function calling", "multi-agent", "orchestrate",
"system design", "scaling", "performance", "architecture review", "distributed",
"hard bug", "race condition", "deadlock", "memory leak", "crash", "production outage",
"long context", "large codebase", "many files", "complex refactor", "migration",
"research", "deep dive", "literature", "paper", "scholarly", "thorough analysis",
"deep reasoning", "think step by step", "reason through", "careful analysis",
}
for _, kw := range heavyLogicKeywords {
if strings.Contains(msgLower, kw) {
tags = append(tags, "heavy-logic", "deep-reasoning", "architecture", "hard-debugging")
break
}
}
if routeCtx.RequiresToolCalling {
tags = append(tags, "tool-heavy", "multi-step-agent", "swe-bench")
}
return tags
}