feat: implement advanced condition-based heuristic model routing
Upgrades the routing engine to support tag, token limit, multimodal, reasoning, and tool calling conditions. Adds unit tests for the new routing features.
This commit is contained in:
@@ -46,6 +46,9 @@ Implements HMAC-SHA256 signed tokens for dashboard authentication. Tokens secure
|
||||
### 5. WebSocket Hub (`internal/server/websocket.go`)
|
||||
A centralized hub for managing WebSocket connections, allowing real-time broadcast of system events, system metrics, and request logs to the dashboard.
|
||||
|
||||
### 6. Model Group Router (`internal/router/`)
|
||||
Resolves model groups (e.g., `deepseek-auto`, `dustins_stack`) into concrete models. It supports a Classifier strategy (uses a cheap LLM to rate complexity) and an upgraded Heuristic strategy (evaluates custom condition rules like tags, token counts, multimodal inputs, reasoning, and tool calling flags or legacy keyword patterns).
|
||||
|
||||
## Concurrency Model
|
||||
|
||||
Go's goroutines and channels are used extensively:
|
||||
|
||||
@@ -22,7 +22,7 @@ A unified, high-performance LLM proxy gateway built in Go. It provides OpenAI-co
|
||||
- **Image Generation:** DALL-E 2/3 (OpenAI) and Imagen 3 (Gemini) via OpenAI-compatible `/v1/images/generations` endpoint.
|
||||
- **Automatic Model Routing:**
|
||||
- **Hierarchical Routing:** Groups can target other groups, cascading through multiple levels until a concrete model is reached. Cycle detection and depth limiting (max 10) prevent infinite loops.
|
||||
- **Heuristic strategy:** Free, zero-latency keyword matching (e.g. "debug" or "step by step" routes to the reasoning model).
|
||||
- **Heuristic strategy:** Free, zero-latency routing supporting both keyword matching (regex/substrings) and condition-based checks (evaluating tags, token limits, multimodal inputs, reasoning, and tool calling requirements).
|
||||
- **Classifier strategy:** Uses a cheap LLM to rate task complexity on a configurable scale (1-10), then selects the appropriate model. Bucket mapping distributes ratings proportionally across targets.
|
||||
- **Two-Level Dispatch:** A `dispatcher` group (classifier, threshold=10) auto-routes to tier groups by complexity score, which then apply their own internal strategies.
|
||||
- **Metadata:** Groups support `logic_level` (1-10 complexity scale) and `primary_use` (description) fields for organizational clarity.
|
||||
|
||||
@@ -32,6 +32,14 @@ func Init(path string) (*DB, error) {
|
||||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
// Enable Write-Ahead Logging (WAL) and set a busy timeout to handle concurrent access
|
||||
if _, err := db.Exec("PRAGMA journal_mode=WAL;"); err != nil {
|
||||
log.Printf("failed to enable WAL mode: %v", err)
|
||||
}
|
||||
if _, err := db.Exec("PRAGMA busy_timeout=5000;"); err != nil {
|
||||
log.Printf("failed to set busy timeout: %v", err)
|
||||
}
|
||||
|
||||
instance := &DB{db}
|
||||
|
||||
// Run migrations
|
||||
|
||||
@@ -14,9 +14,21 @@ import (
|
||||
func AuthMiddleware(database *db.DB, requireAuth bool) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
// Fallback to checking "Authentication" header in case the client library used the wrong name
|
||||
authHeader = c.GetHeader("Authentication")
|
||||
}
|
||||
|
||||
if authHeader == "" {
|
||||
if requireAuth {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"})
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Missing Authorization or Authentication header.",
|
||||
"type": "invalid_request_error",
|
||||
"param": nil,
|
||||
"code": "401",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
@@ -25,23 +37,50 @@ func AuthMiddleware(database *db.DB, requireAuth bool) gin.HandlerFunc {
|
||||
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if token == authHeader { // No "Bearer " prefix
|
||||
if requireAuth {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Invalid authorization header format. Bearer token required.",
|
||||
"type": "invalid_request_error",
|
||||
"param": nil,
|
||||
"code": "401",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Try to resolve client from database
|
||||
// Try to resolve client from database with a read-only SELECT
|
||||
var clientID string
|
||||
err := database.Get(&clientID, "UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = 1 RETURNING client_id", token)
|
||||
err := database.Get(&clientID, "SELECT client_id FROM client_tokens WHERE token = ? AND is_active = 1", token)
|
||||
|
||||
if err == nil {
|
||||
c.Set("auth", models.AuthInfo{
|
||||
Token: token,
|
||||
ClientID: clientID,
|
||||
})
|
||||
|
||||
// Update last_used_at asynchronously so that database locks or write delays
|
||||
// do not block or fail the client's request authentication.
|
||||
go func(t string) {
|
||||
if _, updateErr := database.Exec("UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ?", t); updateErr != nil {
|
||||
log.Printf("Warning: failed to update client token last_used_at: %v", updateErr)
|
||||
}
|
||||
}(token)
|
||||
|
||||
c.Next()
|
||||
} else {
|
||||
log.Printf("Token not found or inactive in DB: %s", token)
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid or inactive token"})
|
||||
log.Printf("Token not found, inactive or error in DB: %s (err: %v)", token, err)
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Invalid or inactive client token.",
|
||||
"type": "invalid_request_error",
|
||||
"param": nil,
|
||||
"code": "401",
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -338,6 +338,7 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
}
|
||||
|
||||
// Map Tools
|
||||
hasMappedTools := false
|
||||
if len(req.Tools) > 0 {
|
||||
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
|
||||
for _, t := range req.Tools {
|
||||
@@ -345,13 +346,16 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
|
||||
}
|
||||
}
|
||||
body.Tools = []GeminiTool{geminiTool}
|
||||
if len(geminiTool.FunctionDeclarations) > 0 {
|
||||
body.Tools = []GeminiTool{geminiTool}
|
||||
hasMappedTools = true
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := p.config.BaseURL
|
||||
lowerModel := strings.ToLower(req.Model)
|
||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
|
||||
// Use v1beta for preview and newer models
|
||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools {
|
||||
// Use v1beta for preview, newer models, or when using tools
|
||||
if !strings.Contains(baseURL, "v1beta") {
|
||||
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
|
||||
}
|
||||
@@ -578,6 +582,7 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
||||
GenerationConfig: genConfig,
|
||||
}
|
||||
|
||||
hasMappedTools := false
|
||||
if len(req.Tools) > 0 {
|
||||
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
|
||||
for _, t := range req.Tools {
|
||||
@@ -585,13 +590,16 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
||||
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
|
||||
}
|
||||
}
|
||||
body.Tools = []GeminiTool{geminiTool}
|
||||
if len(geminiTool.FunctionDeclarations) > 0 {
|
||||
body.Tools = []GeminiTool{geminiTool}
|
||||
hasMappedTools = true
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := p.config.BaseURL
|
||||
lowerModel := strings.ToLower(req.Model)
|
||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
|
||||
// Use v1beta for preview and newer models
|
||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools {
|
||||
// Use v1beta for preview, newer models, or when using tools
|
||||
if !strings.Contains(baseURL, "v1beta") {
|
||||
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
|
||||
}
|
||||
|
||||
@@ -10,6 +10,22 @@ import (
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
func sanitizeFunctionName(name string) string {
|
||||
var sb strings.Builder
|
||||
for _, ch := range name {
|
||||
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' {
|
||||
sb.WriteRune(ch)
|
||||
} else {
|
||||
sb.WriteRune('_')
|
||||
}
|
||||
}
|
||||
res := sb.String()
|
||||
if res == "" {
|
||||
return "function"
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images.
|
||||
func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) {
|
||||
var result []interface{}
|
||||
@@ -35,7 +51,7 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
|
||||
msg["tool_call_id"] = id
|
||||
|
||||
if m.Name != nil {
|
||||
msg["name"] = *m.Name
|
||||
msg["name"] = sanitizeFunctionName(*m.Name)
|
||||
}
|
||||
result = append(result, msg)
|
||||
continue
|
||||
@@ -91,6 +107,7 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
|
||||
if sanitizedCalls[i].Type == "" {
|
||||
sanitizedCalls[i].Type = "function"
|
||||
}
|
||||
sanitizedCalls[i].Function.Name = sanitizeFunctionName(sanitizedCalls[i].Function.Name)
|
||||
}
|
||||
msg["tool_calls"] = sanitizedCalls
|
||||
msg["content"] = "" // OpenAI requirement: content must be string if tool_calls present
|
||||
@@ -124,11 +141,25 @@ func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{},
|
||||
body["max_tokens"] = *request.MaxTokens
|
||||
}
|
||||
if len(request.Tools) > 0 {
|
||||
body["tools"] = request.Tools
|
||||
sanitizedTools := make([]models.Tool, len(request.Tools))
|
||||
copy(sanitizedTools, request.Tools)
|
||||
for i := range sanitizedTools {
|
||||
if sanitizedTools[i].Type == "function" {
|
||||
sanitizedTools[i].Function.Name = sanitizeFunctionName(sanitizedTools[i].Function.Name)
|
||||
}
|
||||
}
|
||||
body["tools"] = sanitizedTools
|
||||
}
|
||||
if request.ToolChoice != nil {
|
||||
var toolChoice interface{}
|
||||
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
|
||||
if tcMap, ok := toolChoice.(map[string]interface{}); ok {
|
||||
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
|
||||
if name, ok := funcMap["name"].(string); ok {
|
||||
funcMap["name"] = sanitizeFunctionName(name)
|
||||
}
|
||||
}
|
||||
}
|
||||
body["tool_choice"] = toolChoice
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
func TestSanitizeFunctionName(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"google-search", "google-search"},
|
||||
{"google.search", "google_search"},
|
||||
{"google search", "google_search"},
|
||||
{"web_search(query)", "web_search_query_"},
|
||||
{"", "function"},
|
||||
{"123_abc-XYZ", "123_abc-XYZ"},
|
||||
{"invalid.name.with.dots", "invalid_name_with_dots"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
actual := sanitizeFunctionName(tc.input)
|
||||
if actual != tc.expected {
|
||||
t.Errorf("sanitizeFunctionName(%q) = %q; expected %q", tc.input, actual, tc.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesToOpenAIJSON_SanitizeToolCalls(t *testing.T) {
|
||||
messages := []models.UnifiedMessage{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []models.UnifiedContentPart{
|
||||
{Type: "text", Text: "I will use search."},
|
||||
},
|
||||
ToolCalls: []models.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: models.FunctionCall{
|
||||
Name: "google.search",
|
||||
Arguments: `{"query": "hello"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: []models.UnifiedContentPart{
|
||||
{Type: "text", Text: `{"result": "success"}`},
|
||||
},
|
||||
ToolCallID: stringPtr("call_1"),
|
||||
Name: stringPtr("google.search"),
|
||||
},
|
||||
}
|
||||
|
||||
res, err := MessagesToOpenAIJSON(messages)
|
||||
if err != nil {
|
||||
t.Fatalf("MessagesToOpenAIJSON failed: %v", err)
|
||||
}
|
||||
|
||||
if len(res) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(res))
|
||||
}
|
||||
|
||||
// Verify assistant message
|
||||
msg1 := res[0].(map[string]interface{})
|
||||
if msg1["role"] != "assistant" {
|
||||
t.Errorf("expected role assistant, got %v", msg1["role"])
|
||||
}
|
||||
calls := msg1["tool_calls"].([]models.ToolCall)
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "google_search" {
|
||||
t.Errorf("expected function name google_search, got %q", calls[0].Function.Name)
|
||||
}
|
||||
|
||||
// Verify tool response message
|
||||
msg2 := res[1].(map[string]interface{})
|
||||
if msg2["role"] != "tool" {
|
||||
t.Errorf("expected role tool, got %v", msg2["role"])
|
||||
}
|
||||
if msg2["name"] != "google_search" {
|
||||
t.Errorf("expected tool name google_search, got %v", msg2["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIBody_SanitizeToolsAndChoice(t *testing.T) {
|
||||
req := &models.UnifiedRequest{
|
||||
Model: "gpt-4o",
|
||||
Tools: []models.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: models.FunctionDef{
|
||||
Name: "google.search",
|
||||
},
|
||||
},
|
||||
},
|
||||
ToolChoice: json.RawMessage(`{"type": "function", "function": {"name": "google.search"}}`),
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, nil, false)
|
||||
|
||||
// Verify tools
|
||||
tools := body["tools"].([]models.Tool)
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(tools))
|
||||
}
|
||||
if tools[0].Function.Name != "google_search" {
|
||||
t.Errorf("expected tool function name google_search, got %q", tools[0].Function.Name)
|
||||
}
|
||||
|
||||
// Verify tool_choice
|
||||
toolChoice := body["tool_choice"].(map[string]interface{})
|
||||
funcObj := toolChoice["function"].(map[string]interface{})
|
||||
if funcObj["name"] != "google_search" {
|
||||
t.Errorf("expected tool_choice function name google_search, got %q", funcObj["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
@@ -15,7 +15,7 @@ const classifierSystemPrompt = `You are a task complexity classifier. Rate the f
|
||||
|
||||
Reply with ONLY the number. No explanation.`
|
||||
|
||||
func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, userMessage string) (*Decision, error) {
|
||||
func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
|
||||
// Determine the rating scale
|
||||
maxRating := len(targets)
|
||||
if maxRating < 2 {
|
||||
@@ -30,10 +30,14 @@ func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.Mode
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating)
|
||||
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMessage)
|
||||
userMsg := ""
|
||||
if routeCtx != nil {
|
||||
userMsg = routeCtx.UserMessage
|
||||
}
|
||||
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMsg)
|
||||
if err != nil {
|
||||
// Classifier failed — fall back to heuristic
|
||||
return routeHeuristic(group, targets, userMessage)
|
||||
return routeHeuristic(group, targets, routeCtx)
|
||||
}
|
||||
|
||||
rating, err := strconv.Atoi(strings.TrimSpace(ratingStr))
|
||||
|
||||
+172
-19
@@ -2,48 +2,125 @@ package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
// HeuristicRule defines a pattern-based routing rule.
|
||||
// HeuristicRule defines a pattern-based routing rule (legacy format).
|
||||
type HeuristicRule struct {
|
||||
Pattern string `json:"pattern"`
|
||||
TargetIdx int `json:"target"`
|
||||
CaseSensitive bool `json:"case_sensitive,omitempty"`
|
||||
}
|
||||
|
||||
func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (*Decision, error) {
|
||||
// ConditionRule defines a condition-based routing rule (new format).
|
||||
type ConditionRule struct {
|
||||
RuleID string `json:"rule_id"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Conditions Conditions `json:"conditions"`
|
||||
PrimaryModel string `json:"primary_model"`
|
||||
FallbackModel string `json:"fallback_model,omitempty"`
|
||||
}
|
||||
|
||||
// Conditions defines the matching parameters for a rule.
|
||||
type Conditions struct {
|
||||
AnyOfTags []string `json:"any_of_tags,omitempty"`
|
||||
MaxInputTokensLt *int `json:"max_input_tokens_lt,omitempty"`
|
||||
RequiresReasoning *bool `json:"requires_reasoning,omitempty"`
|
||||
RequiresToolCalling *bool `json:"requires_tool_calling,omitempty"`
|
||||
HasMultimodalInput *bool `json:"has_multimodal_input,omitempty"`
|
||||
IsDefaultFallback *bool `json:"is_default_fallback,omitempty"`
|
||||
}
|
||||
|
||||
func routeHeuristic(group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
|
||||
if routeCtx == nil {
|
||||
routeCtx = &RouteContext{}
|
||||
}
|
||||
|
||||
selected := targets[0]
|
||||
reason := "default (first target)"
|
||||
|
||||
// If heuristic_rules is set, use them
|
||||
// If heuristic_rules is set, determine format and parse
|
||||
if group.HeuristicRules != nil && *group.HeuristicRules != "" {
|
||||
var rules []HeuristicRule
|
||||
if err := json.Unmarshal([]byte(*group.HeuristicRules), &rules); err == nil {
|
||||
searchMsg := userMessage
|
||||
for _, rule := range rules {
|
||||
pattern := rule.Pattern
|
||||
msg := searchMsg
|
||||
if !rule.CaseSensitive {
|
||||
pattern = strings.ToLower(pattern)
|
||||
msg = strings.ToLower(msg)
|
||||
rulesJSON := *group.HeuristicRules
|
||||
|
||||
if isConditionBasedRules(rulesJSON) {
|
||||
var condRules []ConditionRule
|
||||
if err := json.Unmarshal([]byte(rulesJSON), &condRules); err == nil {
|
||||
for _, rule := range condRules {
|
||||
if matchConditions(rule.Conditions, routeCtx) {
|
||||
// Resolve primary/fallback to concrete models in target list
|
||||
targetModel := ""
|
||||
if rule.PrimaryModel != "" {
|
||||
targetModel = getModelInTargets(rule.PrimaryModel, targets)
|
||||
}
|
||||
if targetModel == "" && rule.FallbackModel != "" {
|
||||
targetModel = getModelInTargets(rule.FallbackModel, targets)
|
||||
}
|
||||
|
||||
if targetModel != "" {
|
||||
selected = targetModel
|
||||
reason = "matched condition rule: " + rule.RuleID
|
||||
if rule.Description != "" {
|
||||
reason += " (" + rule.Description + ")"
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.Contains(msg, pattern) {
|
||||
if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) {
|
||||
selected = targets[rule.TargetIdx]
|
||||
reason = "matched heuristic rule: " + rule.Pattern
|
||||
break
|
||||
}
|
||||
} else {
|
||||
// Fallback to legacy pattern-based rules
|
||||
var legacyRules []HeuristicRule
|
||||
if err := json.Unmarshal([]byte(rulesJSON), &legacyRules); err == nil {
|
||||
searchMsg := routeCtx.UserMessage
|
||||
for _, rule := range legacyRules {
|
||||
pattern := rule.Pattern
|
||||
if pattern == "" {
|
||||
continue // Avoid infinite matches with empty patterns
|
||||
}
|
||||
msg := searchMsg
|
||||
if !rule.CaseSensitive {
|
||||
pattern = strings.ToLower(pattern)
|
||||
msg = strings.ToLower(msg)
|
||||
}
|
||||
|
||||
// Support both regex matching (if pattern is valid regex) and literal contains
|
||||
matched := false
|
||||
if strings.Contains(rule.Pattern, "(") || strings.Contains(rule.Pattern, "\\b") {
|
||||
var re *regexp.Regexp
|
||||
var err error
|
||||
if !rule.CaseSensitive {
|
||||
re, err = regexp.Compile("(?i)" + rule.Pattern)
|
||||
} else {
|
||||
re, err = regexp.Compile(rule.Pattern)
|
||||
}
|
||||
if err == nil {
|
||||
matched = re.MatchString(routeCtx.UserMessage)
|
||||
}
|
||||
}
|
||||
|
||||
if !matched && strings.Contains(msg, pattern) {
|
||||
matched = true
|
||||
}
|
||||
|
||||
if matched {
|
||||
if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) {
|
||||
selected = targets[rule.TargetIdx]
|
||||
reason = "matched heuristic rule: " + rule.Pattern
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Built-in fallback heuristics
|
||||
// Built-in fallback heuristics (if no custom rule matched)
|
||||
if reason == "default (first target)" && len(targets) > 1 {
|
||||
msgLower := strings.ToLower(userMessage)
|
||||
msgLower := strings.ToLower(routeCtx.UserMessage)
|
||||
complexIndicators := []string{
|
||||
"step by step", "explain in detail", "reason through",
|
||||
"think carefully", "analyze", "debug", "write code",
|
||||
@@ -64,3 +141,79 @@ func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (
|
||||
Reason: reason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// isConditionBasedRules returns true if the JSON represents condition-based rules.
|
||||
func isConditionBasedRules(rulesJSON string) bool {
|
||||
var rules []ConditionRule
|
||||
if err := json.Unmarshal([]byte(rulesJSON), &rules); err == nil && len(rules) > 0 {
|
||||
// If the rule has either conditions or primary_model/rule_id, treat it as condition-based
|
||||
return rules[0].PrimaryModel != "" || rules[0].RuleID != ""
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchConditions evaluates whether the given conditions match the RouteContext.
|
||||
func matchConditions(cond Conditions, routeCtx *RouteContext) bool {
|
||||
if cond.IsDefaultFallback != nil && *cond.IsDefaultFallback {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check tags: must match any_of_tags if specified
|
||||
if len(cond.AnyOfTags) > 0 {
|
||||
tagMatched := false
|
||||
for _, ruleTag := range cond.AnyOfTags {
|
||||
for _, ctxTag := range routeCtx.Tags {
|
||||
if strings.EqualFold(ruleTag, ctxTag) {
|
||||
tagMatched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if tagMatched {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !tagMatched {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check max input tokens
|
||||
if cond.MaxInputTokensLt != nil {
|
||||
if routeCtx.InputTokens >= *cond.MaxInputTokensLt {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check reasoning flag
|
||||
if cond.RequiresReasoning != nil {
|
||||
if routeCtx.RequiresReasoning != *cond.RequiresReasoning {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check tool calling flag
|
||||
if cond.RequiresToolCalling != nil {
|
||||
if routeCtx.RequiresToolCalling != *cond.RequiresToolCalling {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check multimodal flag
|
||||
if cond.HasMultimodalInput != nil {
|
||||
if routeCtx.HasMultimodalInput != *cond.HasMultimodalInput {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// getModelInTargets returns the model name if it exists in targets, or empty string.
|
||||
func getModelInTargets(modelName string, targets []string) string {
|
||||
for _, t := range targets {
|
||||
if strings.EqualFold(t, modelName) {
|
||||
return t
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
func TestRouteHeuristic_ConditionRules(t *testing.T) {
|
||||
targets := []string{
|
||||
"deepseek-v4-flash", // index 0
|
||||
"gemini-3-flash", // index 1
|
||||
"grok-build-0.1", // index 2
|
||||
"kimi-k2.6", // index 3
|
||||
"mimo-v2.5-pro", // index 4
|
||||
"grok-4.3", // index 5
|
||||
"deepseek-v4-pro", // index 6
|
||||
}
|
||||
|
||||
rulesJSON := `[
|
||||
{
|
||||
"rule_id": "fast_flow_extraction",
|
||||
"conditions": {
|
||||
"any_of_tags": ["fast-flow", "classification"],
|
||||
"max_input_tokens_lt": 8000,
|
||||
"requires_reasoning": false
|
||||
},
|
||||
"primary_model": "deepseek-v4-flash",
|
||||
"fallback_model": "grok-build-0.1"
|
||||
},
|
||||
{
|
||||
"rule_id": "multimodal_long_context",
|
||||
"conditions": {
|
||||
"any_of_tags": ["standard-pro", "long-doc"],
|
||||
"has_multimodal_input": true
|
||||
},
|
||||
"primary_model": "gemini-3-flash",
|
||||
"fallback_model": "mimo-v2.5-pro"
|
||||
},
|
||||
{
|
||||
"rule_id": "regional_fallback_general",
|
||||
"conditions": {
|
||||
"is_default_fallback": true
|
||||
},
|
||||
"primary_model": "kimi-k2.6"
|
||||
}
|
||||
]`
|
||||
|
||||
group := db.ModelGroup{
|
||||
ID: "dustins_stack",
|
||||
Strategy: "heuristic",
|
||||
HeuristicRules: &rulesJSON,
|
||||
}
|
||||
|
||||
// 1. Test Match Fast Flow (condition success)
|
||||
ctx1 := &RouteContext{
|
||||
UserMessage: "classify this JSON",
|
||||
InputTokens: 500,
|
||||
HasMultimodalInput: false,
|
||||
RequiresReasoning: false,
|
||||
Tags: []string{"fast-flow", "classification"},
|
||||
}
|
||||
dec1, err := routeHeuristic(group, targets, ctx1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec1.SelectedModel != "deepseek-v4-flash" {
|
||||
t.Fatalf("expected deepseek-v4-flash, got %s", dec1.SelectedModel)
|
||||
}
|
||||
|
||||
// 2. Test Multimodal Long Context (condition success)
|
||||
ctx2 := &RouteContext{
|
||||
UserMessage: "explain this video",
|
||||
InputTokens: 15000,
|
||||
HasMultimodalInput: true,
|
||||
RequiresReasoning: false,
|
||||
Tags: []string{"standard-pro", "video-analysis"},
|
||||
}
|
||||
dec2, err := routeHeuristic(group, targets, ctx2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec2.SelectedModel != "gemini-3-flash" {
|
||||
t.Fatalf("expected gemini-3-flash, got %s", dec2.SelectedModel)
|
||||
}
|
||||
|
||||
// 3. Test Fallback general rule
|
||||
ctx3 := &RouteContext{
|
||||
UserMessage: "hello there",
|
||||
InputTokens: 100,
|
||||
HasMultimodalInput: false,
|
||||
RequiresReasoning: false,
|
||||
Tags: []string{"general"},
|
||||
}
|
||||
dec3, err := routeHeuristic(group, targets, ctx3)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec3.SelectedModel != "kimi-k2.6" {
|
||||
t.Fatalf("expected kimi-k2.6, got %s", dec3.SelectedModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteHeuristic_LegacyRules(t *testing.T) {
|
||||
targets := []string{"gpt-4o-mini", "deepseek-v4-pro", "kimi-k2.6"}
|
||||
|
||||
// Legacy pattern-based rule with regex
|
||||
rulesJSON := `[
|
||||
{"pattern": "\\b(agent|agents|tool use)\\b", "target": 1},
|
||||
{"pattern": "summarize", "target": 2}
|
||||
]`
|
||||
|
||||
group := db.ModelGroup{
|
||||
ID: "heavy-logic",
|
||||
Strategy: "heuristic",
|
||||
HeuristicRules: &rulesJSON,
|
||||
}
|
||||
|
||||
// 1. Test regex match
|
||||
ctx1 := &RouteContext{
|
||||
UserMessage: "We need an agent to do tool use",
|
||||
}
|
||||
dec1, err := routeHeuristic(group, targets, ctx1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec1.SelectedModel != "deepseek-v4-pro" {
|
||||
t.Fatalf("expected deepseek-v4-pro, got %s", dec1.SelectedModel)
|
||||
}
|
||||
|
||||
// 2. Test literal match
|
||||
ctx2 := &RouteContext{
|
||||
UserMessage: "Please summarize this text",
|
||||
}
|
||||
dec2, err := routeHeuristic(group, targets, ctx2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec2.SelectedModel != "kimi-k2.6" {
|
||||
t.Fatalf("expected kimi-k2.6, got %s", dec2.SelectedModel)
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,16 @@ type Decision struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// RouteContext holds metadata of the request to evaluate condition rules.
|
||||
type RouteContext struct {
|
||||
UserMessage string `json:"user_message"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
HasMultimodalInput bool `json:"has_multimodal_input"`
|
||||
RequiresToolCalling bool `json:"requires_tool_calling"`
|
||||
RequiresReasoning bool `json:"requires_reasoning"`
|
||||
Tags []string `json:"tags"`
|
||||
}
|
||||
|
||||
// ClassifierFunc is the callback for classifier-based routing.
|
||||
type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error)
|
||||
|
||||
@@ -53,7 +63,7 @@ func (r *Router) IsGroup(modelID string) bool {
|
||||
}
|
||||
|
||||
// Route resolves a group to a concrete model.
|
||||
func (r *Router) Route(ctx context.Context, groupID string, userMessage string) (*Decision, error) {
|
||||
func (r *Router) Route(ctx context.Context, groupID string, routeCtx *RouteContext) (*Decision, error) {
|
||||
group, ok := r.groups[groupID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown model group: %s", groupID)
|
||||
@@ -66,12 +76,12 @@ func (r *Router) Route(ctx context.Context, groupID string, userMessage string)
|
||||
|
||||
switch group.Strategy {
|
||||
case "heuristic":
|
||||
return routeHeuristic(group, targets, userMessage)
|
||||
return routeHeuristic(group, targets, routeCtx)
|
||||
case "classifier":
|
||||
if r.classify == nil {
|
||||
return routeHeuristic(group, targets, userMessage)
|
||||
return routeHeuristic(group, targets, routeCtx)
|
||||
}
|
||||
return routeClassifier(ctx, r.classify, group, targets, userMessage)
|
||||
return routeClassifier(ctx, r.classify, group, targets, routeCtx)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown strategy: %s", group.Strategy)
|
||||
}
|
||||
@@ -80,7 +90,7 @@ func (r *Router) Route(ctx context.Context, groupID string, userMessage string)
|
||||
// RouteToConcrete resolves a model name to a concrete model, following group
|
||||
// chains recursively until a non-group target is reached. Returns the original
|
||||
// name unchanged if it is not a group.
|
||||
func (r *Router) RouteToConcrete(ctx context.Context, modelID string, userMessage string) (*Decision, error) {
|
||||
func (r *Router) RouteToConcrete(ctx context.Context, modelID string, routeCtx *RouteContext) (*Decision, error) {
|
||||
const maxDepth = 10
|
||||
visited := make(map[string]bool)
|
||||
current := modelID
|
||||
@@ -109,7 +119,7 @@ func (r *Router) RouteToConcrete(ctx context.Context, modelID string, userMessag
|
||||
}
|
||||
visited[current] = true
|
||||
|
||||
decision, err := r.Route(ctx, current, userMessage)
|
||||
decision, err := r.Route(ctx, current, routeCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
+164
-3
@@ -329,7 +329,8 @@ func (s *Server) handleResponses(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
if s.modelRouter != nil {
|
||||
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, "")
|
||||
routeCtx := s.buildRouteContextFromResponses(req)
|
||||
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
|
||||
return
|
||||
@@ -573,8 +574,8 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||
log.Printf("[DEBUG] Incoming Msg[%d]: role=%s, hasToolCalls=%v, hasContent=%v", i, m.Role, len(m.ToolCalls) > 0, m.Content != nil)
|
||||
}
|
||||
if s.modelRouter != nil {
|
||||
userMessage := extractUserMessage(req.Messages)
|
||||
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, userMessage)
|
||||
routeCtx := s.buildRouteContextFromChat(req)
|
||||
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
|
||||
return
|
||||
@@ -941,3 +942,163 @@ func (s *Server) Run() error {
|
||||
}
|
||||
|
||||
func uint32Ptr(v uint32) *uint32 { return &v }
|
||||
|
||||
func (s *Server) buildRouteContextFromChat(req models.ChatCompletionRequest) *router.RouteContext {
|
||||
userMessage := extractUserMessage(req.Messages)
|
||||
requiresToolCalling := len(req.Tools) > 0
|
||||
hasMultimodal := false
|
||||
inputTokens := 0
|
||||
|
||||
for _, msg := range req.Messages {
|
||||
if strContent, ok := msg.Content.(string); ok {
|
||||
inputTokens += len(strContent) / 4
|
||||
} else if parts, ok := msg.Content.([]interface{}); ok {
|
||||
for _, part := range parts {
|
||||
if partMap, ok := part.(map[string]interface{}); ok {
|
||||
partType, _ := partMap["type"].(string)
|
||||
if partType == "text" {
|
||||
text, _ := partMap["text"].(string)
|
||||
inputTokens += len(text) / 4
|
||||
} else if partType == "image_url" {
|
||||
hasMultimodal = true
|
||||
inputTokens += 1000 // Approximate cost of an image in tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msgLower := strings.ToLower(userMessage)
|
||||
requiresReasoning := strings.Contains(msgLower, "reason") ||
|
||||
strings.Contains(msgLower, "think step by step") ||
|
||||
strings.Contains(msgLower, "mathematics") ||
|
||||
strings.Contains(msgLower, "architecture") ||
|
||||
strings.Contains(msgLower, "explain in detail")
|
||||
|
||||
routeCtx := &router.RouteContext{
|
||||
UserMessage: userMessage,
|
||||
InputTokens: inputTokens,
|
||||
HasMultimodalInput: hasMultimodal,
|
||||
RequiresToolCalling: requiresToolCalling,
|
||||
RequiresReasoning: requiresReasoning,
|
||||
}
|
||||
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
|
||||
return routeCtx
|
||||
}
|
||||
|
||||
func (s *Server) buildRouteContextFromResponses(req models.ResponsesRequest) *router.RouteContext {
|
||||
var userMessage string
|
||||
hasMultimodal := false
|
||||
inputTokens := len(req.Instructions) / 4
|
||||
requiresToolCalling := len(req.Tools) > 0 && string(req.Tools) != "null" && string(req.Tools) != ""
|
||||
|
||||
var strInput string
|
||||
if err := json.Unmarshal(req.Input, &strInput); err == nil {
|
||||
userMessage = strInput
|
||||
inputTokens += len(userMessage) / 4
|
||||
} else {
|
||||
var msgs []models.ResponseInputMessage
|
||||
if err := json.Unmarshal(req.Input, &msgs); err == nil {
|
||||
for _, m := range msgs {
|
||||
var contentStr string
|
||||
if err := json.Unmarshal(m.Content, &contentStr); err == nil {
|
||||
if m.Role == "user" {
|
||||
userMessage = contentStr
|
||||
}
|
||||
inputTokens += len(contentStr) / 4
|
||||
} else {
|
||||
var parts []models.ContentPart
|
||||
if err := json.Unmarshal(m.Content, &parts); err == nil {
|
||||
for _, p := range parts {
|
||||
if p.Type == "text" {
|
||||
if m.Role == "user" {
|
||||
userMessage = p.Text
|
||||
}
|
||||
inputTokens += len(p.Text) / 4
|
||||
} else if p.Type == "image_url" {
|
||||
hasMultimodal = true
|
||||
inputTokens += 1000
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msgLower := strings.ToLower(userMessage)
|
||||
requiresReasoning := strings.Contains(msgLower, "reason") ||
|
||||
strings.Contains(msgLower, "think step by step") ||
|
||||
strings.Contains(msgLower, "mathematics") ||
|
||||
strings.Contains(msgLower, "architecture") ||
|
||||
strings.Contains(msgLower, "explain in detail")
|
||||
|
||||
routeCtx := &router.RouteContext{
|
||||
UserMessage: userMessage,
|
||||
InputTokens: inputTokens,
|
||||
HasMultimodalInput: hasMultimodal,
|
||||
RequiresToolCalling: requiresToolCalling,
|
||||
RequiresReasoning: requiresReasoning,
|
||||
}
|
||||
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
|
||||
return routeCtx
|
||||
}
|
||||
|
||||
func (s *Server) getRouteCtxTags(routeCtx *router.RouteContext) []string {
|
||||
var tags []string
|
||||
msgLower := strings.ToLower(routeCtx.UserMessage)
|
||||
|
||||
// fast-flow keywords
|
||||
fastFlowKeywords := []string{
|
||||
"classify", "classification", "label", "tag", "route", "routing", "intent",
|
||||
"json", "yaml", "csv", "schema", "parse", "extract", "transform", "format", "regex",
|
||||
"short answer", "brief", "concise", "tl;dr", "one line", "simple",
|
||||
"fix this", "small bug", "quick fix", "typo", "syntax error",
|
||||
}
|
||||
for _, kw := range fastFlowKeywords {
|
||||
if strings.Contains(msgLower, kw) {
|
||||
tags = append(tags, "fast-flow", "classification", "json-extraction", "basic-qa")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// standard-pro keywords
|
||||
standardProKeywords := []string{
|
||||
"explain", "summarize", "rewrite", "draft", "edit", "polish", "outline",
|
||||
"long doc", "document", "email", "memo", "proposal", "report", "handout", "notes",
|
||||
"compare", "choose", "recommend", "tradeoff", "pros and cons", "analysis",
|
||||
"code review", "debug", "bug", "feature", "api", "endpoint", "implement",
|
||||
"plan", "planning", "workflow", "integration",
|
||||
}
|
||||
for _, kw := range standardProKeywords {
|
||||
if strings.Contains(msgLower, kw) {
|
||||
tags = append(tags, "standard-pro", "long-doc")
|
||||
break
|
||||
}
|
||||
}
|
||||
if routeCtx.HasMultimodalInput {
|
||||
tags = append(tags, "video-analysis", "multimodal-qa")
|
||||
}
|
||||
|
||||
// heavy-logic keywords
|
||||
heavyLogicKeywords := []string{
|
||||
"agent", "agents", "tool use", "function calling", "multi-agent", "orchestrate",
|
||||
"system design", "scaling", "performance", "architecture review", "distributed",
|
||||
"hard bug", "race condition", "deadlock", "memory leak", "crash", "production outage",
|
||||
"long context", "large codebase", "many files", "complex refactor", "migration",
|
||||
"research", "deep dive", "literature", "paper", "scholarly", "thorough analysis",
|
||||
"deep reasoning", "think step by step", "reason through", "careful analysis",
|
||||
}
|
||||
for _, kw := range heavyLogicKeywords {
|
||||
if strings.Contains(msgLower, kw) {
|
||||
tags = append(tags, "heavy-logic", "deep-reasoning", "architecture", "hard-debugging")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if routeCtx.RequiresToolCalling {
|
||||
tags = append(tags, "tool-heavy", "multi-step-agent", "swe-bench")
|
||||
}
|
||||
|
||||
return tags
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user