diff --git a/BACKEND_ARCHITECTURE.md b/BACKEND_ARCHITECTURE.md index cf432628..5ed04184 100644 --- a/BACKEND_ARCHITECTURE.md +++ b/BACKEND_ARCHITECTURE.md @@ -46,6 +46,9 @@ Implements HMAC-SHA256 signed tokens for dashboard authentication. Tokens secure ### 5. WebSocket Hub (`internal/server/websocket.go`) A centralized hub for managing WebSocket connections, allowing real-time broadcast of system events, system metrics, and request logs to the dashboard. +### 6. Model Group Router (`internal/router/`) +Resolves model groups (e.g., `deepseek-auto`, `dustins_stack`) into concrete models. It supports a Classifier strategy (uses a cheap LLM to rate complexity) and an upgraded Heuristic strategy (evaluates custom condition rules like tags, token counts, multimodal inputs, reasoning, and tool calling flags or legacy keyword patterns). + ## Concurrency Model Go's goroutines and channels are used extensively: diff --git a/README.md b/README.md index 3b37846c..4db8b052 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ A unified, high-performance LLM proxy gateway built in Go. It provides OpenAI-co - **Image Generation:** DALL-E 2/3 (OpenAI) and Imagen 3 (Gemini) via OpenAI-compatible `/v1/images/generations` endpoint. - **Automatic Model Routing:** - **Hierarchical Routing:** Groups can target other groups, cascading through multiple levels until a concrete model is reached. Cycle detection and depth limiting (max 10) prevent infinite loops. - - **Heuristic strategy:** Free, zero-latency keyword matching (e.g. "debug" or "step by step" routes to the reasoning model). + - **Heuristic strategy:** Free, zero-latency routing supporting both keyword matching (regex/substrings) and condition-based checks (evaluating tags, token limits, multimodal inputs, reasoning, and tool calling requirements). - **Classifier strategy:** Uses a cheap LLM to rate task complexity on a configurable scale (1-10), then selects the appropriate model. Bucket mapping distributes ratings proportionally across targets. - **Two-Level Dispatch:** A `dispatcher` group (classifier, threshold=10) auto-routes to tier groups by complexity score, which then apply their own internal strategies. - **Metadata:** Groups support `logic_level` (1-10 complexity scale) and `primary_use` (description) fields for organizational clarity. diff --git a/internal/db/db.go b/internal/db/db.go index 1a7d05d6..18d40668 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -32,6 +32,14 @@ func Init(path string) (*DB, error) { return nil, fmt.Errorf("failed to connect to database: %w", err) } + // Enable Write-Ahead Logging (WAL) and set a busy timeout to handle concurrent access + if _, err := db.Exec("PRAGMA journal_mode=WAL;"); err != nil { + log.Printf("failed to enable WAL mode: %v", err) + } + if _, err := db.Exec("PRAGMA busy_timeout=5000;"); err != nil { + log.Printf("failed to set busy timeout: %v", err) + } + instance := &DB{db} // Run migrations diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index 727d7767..3e9934bc 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -14,9 +14,21 @@ import ( func AuthMiddleware(database *db.DB, requireAuth bool) gin.HandlerFunc { return func(c *gin.Context) { authHeader := c.GetHeader("Authorization") + if authHeader == "" { + // Fallback to checking "Authentication" header in case the client library used the wrong name + authHeader = c.GetHeader("Authentication") + } + if authHeader == "" { if requireAuth { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"}) + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "error": gin.H{ + "message": "Missing Authorization or Authentication header.", + "type": "invalid_request_error", + "param": nil, + "code": "401", + }, + }) return } c.Next() @@ -25,23 +37,50 @@ func AuthMiddleware(database *db.DB, requireAuth bool) gin.HandlerFunc { token := strings.TrimPrefix(authHeader, "Bearer ") if token == authHeader { // No "Bearer " prefix + if requireAuth { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "error": gin.H{ + "message": "Invalid authorization header format. Bearer token required.", + "type": "invalid_request_error", + "param": nil, + "code": "401", + }, + }) + return + } c.Next() return } - // Try to resolve client from database + // Try to resolve client from database with a read-only SELECT var clientID string - err := database.Get(&clientID, "UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = 1 RETURNING client_id", token) + err := database.Get(&clientID, "SELECT client_id FROM client_tokens WHERE token = ? AND is_active = 1", token) if err == nil { c.Set("auth", models.AuthInfo{ Token: token, ClientID: clientID, }) + + // Update last_used_at asynchronously so that database locks or write delays + // do not block or fail the client's request authentication. + go func(t string) { + if _, updateErr := database.Exec("UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ?", t); updateErr != nil { + log.Printf("Warning: failed to update client token last_used_at: %v", updateErr) + } + }(token) + c.Next() } else { - log.Printf("Token not found or inactive in DB: %s", token) - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid or inactive token"}) + log.Printf("Token not found, inactive or error in DB: %s (err: %v)", token, err) + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "error": gin.H{ + "message": "Invalid or inactive client token.", + "type": "invalid_request_error", + "param": nil, + "code": "401", + }, + }) } } } diff --git a/internal/providers/gemini.go b/internal/providers/gemini.go index 1f77c3f0..b72f6df2 100644 --- a/internal/providers/gemini.go +++ b/internal/providers/gemini.go @@ -338,6 +338,7 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified } // Map Tools + hasMappedTools := false if len(req.Tools) > 0 { geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}} for _, t := range req.Tools { @@ -345,13 +346,16 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function) } } - body.Tools = []GeminiTool{geminiTool} + if len(geminiTool.FunctionDeclarations) > 0 { + body.Tools = []GeminiTool{geminiTool} + hasMappedTools = true + } } baseURL := p.config.BaseURL lowerModel := strings.ToLower(req.Model) - if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") { - // Use v1beta for preview and newer models + if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools { + // Use v1beta for preview, newer models, or when using tools if !strings.Contains(baseURL, "v1beta") { baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1) } @@ -578,6 +582,7 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U GenerationConfig: genConfig, } + hasMappedTools := false if len(req.Tools) > 0 { geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}} for _, t := range req.Tools { @@ -585,13 +590,16 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function) } } - body.Tools = []GeminiTool{geminiTool} + if len(geminiTool.FunctionDeclarations) > 0 { + body.Tools = []GeminiTool{geminiTool} + hasMappedTools = true + } } baseURL := p.config.BaseURL lowerModel := strings.ToLower(req.Model) - if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") { - // Use v1beta for preview and newer models + if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools { + // Use v1beta for preview, newer models, or when using tools if !strings.Contains(baseURL, "v1beta") { baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1) } diff --git a/internal/providers/helpers.go b/internal/providers/helpers.go index 0bd988ea..add460b5 100644 --- a/internal/providers/helpers.go +++ b/internal/providers/helpers.go @@ -10,6 +10,22 @@ import ( "gophergate/internal/models" ) +func sanitizeFunctionName(name string) string { + var sb strings.Builder + for _, ch := range name { + if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' { + sb.WriteRune(ch) + } else { + sb.WriteRune('_') + } + } + res := sb.String() + if res == "" { + return "function" + } + return res +} + // MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images. func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) { var result []interface{} @@ -35,7 +51,7 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro msg["tool_call_id"] = id if m.Name != nil { - msg["name"] = *m.Name + msg["name"] = sanitizeFunctionName(*m.Name) } result = append(result, msg) continue @@ -91,6 +107,7 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro if sanitizedCalls[i].Type == "" { sanitizedCalls[i].Type = "function" } + sanitizedCalls[i].Function.Name = sanitizeFunctionName(sanitizedCalls[i].Function.Name) } msg["tool_calls"] = sanitizedCalls msg["content"] = "" // OpenAI requirement: content must be string if tool_calls present @@ -124,11 +141,25 @@ func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{}, body["max_tokens"] = *request.MaxTokens } if len(request.Tools) > 0 { - body["tools"] = request.Tools + sanitizedTools := make([]models.Tool, len(request.Tools)) + copy(sanitizedTools, request.Tools) + for i := range sanitizedTools { + if sanitizedTools[i].Type == "function" { + sanitizedTools[i].Function.Name = sanitizeFunctionName(sanitizedTools[i].Function.Name) + } + } + body["tools"] = sanitizedTools } if request.ToolChoice != nil { var toolChoice interface{} if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil { + if tcMap, ok := toolChoice.(map[string]interface{}); ok { + if funcMap, ok := tcMap["function"].(map[string]interface{}); ok { + if name, ok := funcMap["name"].(string); ok { + funcMap["name"] = sanitizeFunctionName(name) + } + } + } body["tool_choice"] = toolChoice } } diff --git a/internal/providers/helpers_test.go b/internal/providers/helpers_test.go new file mode 100644 index 00000000..0778c5ba --- /dev/null +++ b/internal/providers/helpers_test.go @@ -0,0 +1,127 @@ +package providers + +import ( + "encoding/json" + "testing" + + "gophergate/internal/models" +) + +func TestSanitizeFunctionName(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"google-search", "google-search"}, + {"google.search", "google_search"}, + {"google search", "google_search"}, + {"web_search(query)", "web_search_query_"}, + {"", "function"}, + {"123_abc-XYZ", "123_abc-XYZ"}, + {"invalid.name.with.dots", "invalid_name_with_dots"}, + } + + for _, tc := range tests { + actual := sanitizeFunctionName(tc.input) + if actual != tc.expected { + t.Errorf("sanitizeFunctionName(%q) = %q; expected %q", tc.input, actual, tc.expected) + } + } +} + +func TestMessagesToOpenAIJSON_SanitizeToolCalls(t *testing.T) { + messages := []models.UnifiedMessage{ + { + Role: "assistant", + Content: []models.UnifiedContentPart{ + {Type: "text", Text: "I will use search."}, + }, + ToolCalls: []models.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: models.FunctionCall{ + Name: "google.search", + Arguments: `{"query": "hello"}`, + }, + }, + }, + }, + { + Role: "tool", + Content: []models.UnifiedContentPart{ + {Type: "text", Text: `{"result": "success"}`}, + }, + ToolCallID: stringPtr("call_1"), + Name: stringPtr("google.search"), + }, + } + + res, err := MessagesToOpenAIJSON(messages) + if err != nil { + t.Fatalf("MessagesToOpenAIJSON failed: %v", err) + } + + if len(res) != 2 { + t.Fatalf("expected 2 messages, got %d", len(res)) + } + + // Verify assistant message + msg1 := res[0].(map[string]interface{}) + if msg1["role"] != "assistant" { + t.Errorf("expected role assistant, got %v", msg1["role"]) + } + calls := msg1["tool_calls"].([]models.ToolCall) + if len(calls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(calls)) + } + if calls[0].Function.Name != "google_search" { + t.Errorf("expected function name google_search, got %q", calls[0].Function.Name) + } + + // Verify tool response message + msg2 := res[1].(map[string]interface{}) + if msg2["role"] != "tool" { + t.Errorf("expected role tool, got %v", msg2["role"]) + } + if msg2["name"] != "google_search" { + t.Errorf("expected tool name google_search, got %v", msg2["name"]) + } +} + +func TestBuildOpenAIBody_SanitizeToolsAndChoice(t *testing.T) { + req := &models.UnifiedRequest{ + Model: "gpt-4o", + Tools: []models.Tool{ + { + Type: "function", + Function: models.FunctionDef{ + Name: "google.search", + }, + }, + }, + ToolChoice: json.RawMessage(`{"type": "function", "function": {"name": "google.search"}}`), + } + + body := BuildOpenAIBody(req, nil, false) + + // Verify tools + tools := body["tools"].([]models.Tool) + if len(tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(tools)) + } + if tools[0].Function.Name != "google_search" { + t.Errorf("expected tool function name google_search, got %q", tools[0].Function.Name) + } + + // Verify tool_choice + toolChoice := body["tool_choice"].(map[string]interface{}) + funcObj := toolChoice["function"].(map[string]interface{}) + if funcObj["name"] != "google_search" { + t.Errorf("expected tool_choice function name google_search, got %q", funcObj["name"]) + } +} + +func stringPtr(s string) *string { + return &s +} diff --git a/internal/router/classifier.go b/internal/router/classifier.go index 17b18923..b672213f 100644 --- a/internal/router/classifier.go +++ b/internal/router/classifier.go @@ -15,7 +15,7 @@ const classifierSystemPrompt = `You are a task complexity classifier. Rate the f Reply with ONLY the number. No explanation.` -func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, userMessage string) (*Decision, error) { +func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) { // Determine the rating scale maxRating := len(targets) if maxRating < 2 { @@ -30,10 +30,14 @@ func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.Mode } prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating) - ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMessage) + userMsg := "" + if routeCtx != nil { + userMsg = routeCtx.UserMessage + } + ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMsg) if err != nil { // Classifier failed — fall back to heuristic - return routeHeuristic(group, targets, userMessage) + return routeHeuristic(group, targets, routeCtx) } rating, err := strconv.Atoi(strings.TrimSpace(ratingStr)) diff --git a/internal/router/heuristic.go b/internal/router/heuristic.go index 4c9f7ec6..e10586b9 100644 --- a/internal/router/heuristic.go +++ b/internal/router/heuristic.go @@ -2,48 +2,125 @@ package router import ( "encoding/json" + "regexp" "strings" "gophergate/internal/db" ) -// HeuristicRule defines a pattern-based routing rule. +// HeuristicRule defines a pattern-based routing rule (legacy format). type HeuristicRule struct { Pattern string `json:"pattern"` TargetIdx int `json:"target"` CaseSensitive bool `json:"case_sensitive,omitempty"` } -func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (*Decision, error) { +// ConditionRule defines a condition-based routing rule (new format). +type ConditionRule struct { + RuleID string `json:"rule_id"` + Description string `json:"description,omitempty"` + Conditions Conditions `json:"conditions"` + PrimaryModel string `json:"primary_model"` + FallbackModel string `json:"fallback_model,omitempty"` +} + +// Conditions defines the matching parameters for a rule. +type Conditions struct { + AnyOfTags []string `json:"any_of_tags,omitempty"` + MaxInputTokensLt *int `json:"max_input_tokens_lt,omitempty"` + RequiresReasoning *bool `json:"requires_reasoning,omitempty"` + RequiresToolCalling *bool `json:"requires_tool_calling,omitempty"` + HasMultimodalInput *bool `json:"has_multimodal_input,omitempty"` + IsDefaultFallback *bool `json:"is_default_fallback,omitempty"` +} + +func routeHeuristic(group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) { + if routeCtx == nil { + routeCtx = &RouteContext{} + } + selected := targets[0] reason := "default (first target)" - // If heuristic_rules is set, use them + // If heuristic_rules is set, determine format and parse if group.HeuristicRules != nil && *group.HeuristicRules != "" { - var rules []HeuristicRule - if err := json.Unmarshal([]byte(*group.HeuristicRules), &rules); err == nil { - searchMsg := userMessage - for _, rule := range rules { - pattern := rule.Pattern - msg := searchMsg - if !rule.CaseSensitive { - pattern = strings.ToLower(pattern) - msg = strings.ToLower(msg) + rulesJSON := *group.HeuristicRules + + if isConditionBasedRules(rulesJSON) { + var condRules []ConditionRule + if err := json.Unmarshal([]byte(rulesJSON), &condRules); err == nil { + for _, rule := range condRules { + if matchConditions(rule.Conditions, routeCtx) { + // Resolve primary/fallback to concrete models in target list + targetModel := "" + if rule.PrimaryModel != "" { + targetModel = getModelInTargets(rule.PrimaryModel, targets) + } + if targetModel == "" && rule.FallbackModel != "" { + targetModel = getModelInTargets(rule.FallbackModel, targets) + } + + if targetModel != "" { + selected = targetModel + reason = "matched condition rule: " + rule.RuleID + if rule.Description != "" { + reason += " (" + rule.Description + ")" + } + break + } + } } - if strings.Contains(msg, pattern) { - if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) { - selected = targets[rule.TargetIdx] - reason = "matched heuristic rule: " + rule.Pattern - break + } + } else { + // Fallback to legacy pattern-based rules + var legacyRules []HeuristicRule + if err := json.Unmarshal([]byte(rulesJSON), &legacyRules); err == nil { + searchMsg := routeCtx.UserMessage + for _, rule := range legacyRules { + pattern := rule.Pattern + if pattern == "" { + continue // Avoid infinite matches with empty patterns + } + msg := searchMsg + if !rule.CaseSensitive { + pattern = strings.ToLower(pattern) + msg = strings.ToLower(msg) + } + + // Support both regex matching (if pattern is valid regex) and literal contains + matched := false + if strings.Contains(rule.Pattern, "(") || strings.Contains(rule.Pattern, "\\b") { + var re *regexp.Regexp + var err error + if !rule.CaseSensitive { + re, err = regexp.Compile("(?i)" + rule.Pattern) + } else { + re, err = regexp.Compile(rule.Pattern) + } + if err == nil { + matched = re.MatchString(routeCtx.UserMessage) + } + } + + if !matched && strings.Contains(msg, pattern) { + matched = true + } + + if matched { + if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) { + selected = targets[rule.TargetIdx] + reason = "matched heuristic rule: " + rule.Pattern + break + } } } } } } - // Built-in fallback heuristics + // Built-in fallback heuristics (if no custom rule matched) if reason == "default (first target)" && len(targets) > 1 { - msgLower := strings.ToLower(userMessage) + msgLower := strings.ToLower(routeCtx.UserMessage) complexIndicators := []string{ "step by step", "explain in detail", "reason through", "think carefully", "analyze", "debug", "write code", @@ -64,3 +141,79 @@ func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) ( Reason: reason, }, nil } + +// isConditionBasedRules returns true if the JSON represents condition-based rules. +func isConditionBasedRules(rulesJSON string) bool { + var rules []ConditionRule + if err := json.Unmarshal([]byte(rulesJSON), &rules); err == nil && len(rules) > 0 { + // If the rule has either conditions or primary_model/rule_id, treat it as condition-based + return rules[0].PrimaryModel != "" || rules[0].RuleID != "" + } + return false +} + +// matchConditions evaluates whether the given conditions match the RouteContext. +func matchConditions(cond Conditions, routeCtx *RouteContext) bool { + if cond.IsDefaultFallback != nil && *cond.IsDefaultFallback { + return true + } + + // Check tags: must match any_of_tags if specified + if len(cond.AnyOfTags) > 0 { + tagMatched := false + for _, ruleTag := range cond.AnyOfTags { + for _, ctxTag := range routeCtx.Tags { + if strings.EqualFold(ruleTag, ctxTag) { + tagMatched = true + break + } + } + if tagMatched { + break + } + } + if !tagMatched { + return false + } + } + + // Check max input tokens + if cond.MaxInputTokensLt != nil { + if routeCtx.InputTokens >= *cond.MaxInputTokensLt { + return false + } + } + + // Check reasoning flag + if cond.RequiresReasoning != nil { + if routeCtx.RequiresReasoning != *cond.RequiresReasoning { + return false + } + } + + // Check tool calling flag + if cond.RequiresToolCalling != nil { + if routeCtx.RequiresToolCalling != *cond.RequiresToolCalling { + return false + } + } + + // Check multimodal flag + if cond.HasMultimodalInput != nil { + if routeCtx.HasMultimodalInput != *cond.HasMultimodalInput { + return false + } + } + + return true +} + +// getModelInTargets returns the model name if it exists in targets, or empty string. +func getModelInTargets(modelName string, targets []string) string { + for _, t := range targets { + if strings.EqualFold(t, modelName) { + return t + } + } + return "" +} diff --git a/internal/router/heuristic_test.go b/internal/router/heuristic_test.go new file mode 100644 index 00000000..4029e9cd --- /dev/null +++ b/internal/router/heuristic_test.go @@ -0,0 +1,142 @@ +package router + +import ( + "testing" + + "gophergate/internal/db" +) + +func TestRouteHeuristic_ConditionRules(t *testing.T) { + targets := []string{ + "deepseek-v4-flash", // index 0 + "gemini-3-flash", // index 1 + "grok-build-0.1", // index 2 + "kimi-k2.6", // index 3 + "mimo-v2.5-pro", // index 4 + "grok-4.3", // index 5 + "deepseek-v4-pro", // index 6 + } + + rulesJSON := `[ + { + "rule_id": "fast_flow_extraction", + "conditions": { + "any_of_tags": ["fast-flow", "classification"], + "max_input_tokens_lt": 8000, + "requires_reasoning": false + }, + "primary_model": "deepseek-v4-flash", + "fallback_model": "grok-build-0.1" + }, + { + "rule_id": "multimodal_long_context", + "conditions": { + "any_of_tags": ["standard-pro", "long-doc"], + "has_multimodal_input": true + }, + "primary_model": "gemini-3-flash", + "fallback_model": "mimo-v2.5-pro" + }, + { + "rule_id": "regional_fallback_general", + "conditions": { + "is_default_fallback": true + }, + "primary_model": "kimi-k2.6" + } + ]` + + group := db.ModelGroup{ + ID: "dustins_stack", + Strategy: "heuristic", + HeuristicRules: &rulesJSON, + } + + // 1. Test Match Fast Flow (condition success) + ctx1 := &RouteContext{ + UserMessage: "classify this JSON", + InputTokens: 500, + HasMultimodalInput: false, + RequiresReasoning: false, + Tags: []string{"fast-flow", "classification"}, + } + dec1, err := routeHeuristic(group, targets, ctx1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dec1.SelectedModel != "deepseek-v4-flash" { + t.Fatalf("expected deepseek-v4-flash, got %s", dec1.SelectedModel) + } + + // 2. Test Multimodal Long Context (condition success) + ctx2 := &RouteContext{ + UserMessage: "explain this video", + InputTokens: 15000, + HasMultimodalInput: true, + RequiresReasoning: false, + Tags: []string{"standard-pro", "video-analysis"}, + } + dec2, err := routeHeuristic(group, targets, ctx2) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dec2.SelectedModel != "gemini-3-flash" { + t.Fatalf("expected gemini-3-flash, got %s", dec2.SelectedModel) + } + + // 3. Test Fallback general rule + ctx3 := &RouteContext{ + UserMessage: "hello there", + InputTokens: 100, + HasMultimodalInput: false, + RequiresReasoning: false, + Tags: []string{"general"}, + } + dec3, err := routeHeuristic(group, targets, ctx3) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dec3.SelectedModel != "kimi-k2.6" { + t.Fatalf("expected kimi-k2.6, got %s", dec3.SelectedModel) + } +} + +func TestRouteHeuristic_LegacyRules(t *testing.T) { + targets := []string{"gpt-4o-mini", "deepseek-v4-pro", "kimi-k2.6"} + + // Legacy pattern-based rule with regex + rulesJSON := `[ + {"pattern": "\\b(agent|agents|tool use)\\b", "target": 1}, + {"pattern": "summarize", "target": 2} + ]` + + group := db.ModelGroup{ + ID: "heavy-logic", + Strategy: "heuristic", + HeuristicRules: &rulesJSON, + } + + // 1. Test regex match + ctx1 := &RouteContext{ + UserMessage: "We need an agent to do tool use", + } + dec1, err := routeHeuristic(group, targets, ctx1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dec1.SelectedModel != "deepseek-v4-pro" { + t.Fatalf("expected deepseek-v4-pro, got %s", dec1.SelectedModel) + } + + // 2. Test literal match + ctx2 := &RouteContext{ + UserMessage: "Please summarize this text", + } + dec2, err := routeHeuristic(group, targets, ctx2) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if dec2.SelectedModel != "kimi-k2.6" { + t.Fatalf("expected kimi-k2.6, got %s", dec2.SelectedModel) + } +} diff --git a/internal/router/router.go b/internal/router/router.go index c0bc92cb..8007392e 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -16,6 +16,16 @@ type Decision struct { Reason string `json:"reason"` } +// RouteContext holds metadata of the request to evaluate condition rules. +type RouteContext struct { + UserMessage string `json:"user_message"` + InputTokens int `json:"input_tokens"` + HasMultimodalInput bool `json:"has_multimodal_input"` + RequiresToolCalling bool `json:"requires_tool_calling"` + RequiresReasoning bool `json:"requires_reasoning"` + Tags []string `json:"tags"` +} + // ClassifierFunc is the callback for classifier-based routing. type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) @@ -53,7 +63,7 @@ func (r *Router) IsGroup(modelID string) bool { } // Route resolves a group to a concrete model. -func (r *Router) Route(ctx context.Context, groupID string, userMessage string) (*Decision, error) { +func (r *Router) Route(ctx context.Context, groupID string, routeCtx *RouteContext) (*Decision, error) { group, ok := r.groups[groupID] if !ok { return nil, fmt.Errorf("unknown model group: %s", groupID) @@ -66,12 +76,12 @@ func (r *Router) Route(ctx context.Context, groupID string, userMessage string) switch group.Strategy { case "heuristic": - return routeHeuristic(group, targets, userMessage) + return routeHeuristic(group, targets, routeCtx) case "classifier": if r.classify == nil { - return routeHeuristic(group, targets, userMessage) + return routeHeuristic(group, targets, routeCtx) } - return routeClassifier(ctx, r.classify, group, targets, userMessage) + return routeClassifier(ctx, r.classify, group, targets, routeCtx) default: return nil, fmt.Errorf("unknown strategy: %s", group.Strategy) } @@ -80,7 +90,7 @@ func (r *Router) Route(ctx context.Context, groupID string, userMessage string) // RouteToConcrete resolves a model name to a concrete model, following group // chains recursively until a non-group target is reached. Returns the original // name unchanged if it is not a group. -func (r *Router) RouteToConcrete(ctx context.Context, modelID string, userMessage string) (*Decision, error) { +func (r *Router) RouteToConcrete(ctx context.Context, modelID string, routeCtx *RouteContext) (*Decision, error) { const maxDepth = 10 visited := make(map[string]bool) current := modelID @@ -109,7 +119,7 @@ func (r *Router) RouteToConcrete(ctx context.Context, modelID string, userMessag } visited[current] = true - decision, err := r.Route(ctx, current, userMessage) + decision, err := r.Route(ctx, current, routeCtx) if err != nil { return nil, err } diff --git a/internal/server/server.go b/internal/server/server.go index 989589f6..668e9209 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -329,7 +329,8 @@ func (s *Server) handleResponses(c *gin.Context) { } } if s.modelRouter != nil { - decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, "") + routeCtx := s.buildRouteContextFromResponses(req) + decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)}) return @@ -573,8 +574,8 @@ func (s *Server) handleChatCompletions(c *gin.Context) { log.Printf("[DEBUG] Incoming Msg[%d]: role=%s, hasToolCalls=%v, hasContent=%v", i, m.Role, len(m.ToolCalls) > 0, m.Content != nil) } if s.modelRouter != nil { - userMessage := extractUserMessage(req.Messages) - decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, userMessage) + routeCtx := s.buildRouteContextFromChat(req) + decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)}) return @@ -941,3 +942,163 @@ func (s *Server) Run() error { } func uint32Ptr(v uint32) *uint32 { return &v } + +func (s *Server) buildRouteContextFromChat(req models.ChatCompletionRequest) *router.RouteContext { + userMessage := extractUserMessage(req.Messages) + requiresToolCalling := len(req.Tools) > 0 + hasMultimodal := false + inputTokens := 0 + + for _, msg := range req.Messages { + if strContent, ok := msg.Content.(string); ok { + inputTokens += len(strContent) / 4 + } else if parts, ok := msg.Content.([]interface{}); ok { + for _, part := range parts { + if partMap, ok := part.(map[string]interface{}); ok { + partType, _ := partMap["type"].(string) + if partType == "text" { + text, _ := partMap["text"].(string) + inputTokens += len(text) / 4 + } else if partType == "image_url" { + hasMultimodal = true + inputTokens += 1000 // Approximate cost of an image in tokens + } + } + } + } + } + + msgLower := strings.ToLower(userMessage) + requiresReasoning := strings.Contains(msgLower, "reason") || + strings.Contains(msgLower, "think step by step") || + strings.Contains(msgLower, "mathematics") || + strings.Contains(msgLower, "architecture") || + strings.Contains(msgLower, "explain in detail") + + routeCtx := &router.RouteContext{ + UserMessage: userMessage, + InputTokens: inputTokens, + HasMultimodalInput: hasMultimodal, + RequiresToolCalling: requiresToolCalling, + RequiresReasoning: requiresReasoning, + } + routeCtx.Tags = s.getRouteCtxTags(routeCtx) + return routeCtx +} + +func (s *Server) buildRouteContextFromResponses(req models.ResponsesRequest) *router.RouteContext { + var userMessage string + hasMultimodal := false + inputTokens := len(req.Instructions) / 4 + requiresToolCalling := len(req.Tools) > 0 && string(req.Tools) != "null" && string(req.Tools) != "" + + var strInput string + if err := json.Unmarshal(req.Input, &strInput); err == nil { + userMessage = strInput + inputTokens += len(userMessage) / 4 + } else { + var msgs []models.ResponseInputMessage + if err := json.Unmarshal(req.Input, &msgs); err == nil { + for _, m := range msgs { + var contentStr string + if err := json.Unmarshal(m.Content, &contentStr); err == nil { + if m.Role == "user" { + userMessage = contentStr + } + inputTokens += len(contentStr) / 4 + } else { + var parts []models.ContentPart + if err := json.Unmarshal(m.Content, &parts); err == nil { + for _, p := range parts { + if p.Type == "text" { + if m.Role == "user" { + userMessage = p.Text + } + inputTokens += len(p.Text) / 4 + } else if p.Type == "image_url" { + hasMultimodal = true + inputTokens += 1000 + } + } + } + } + } + } + } + + msgLower := strings.ToLower(userMessage) + requiresReasoning := strings.Contains(msgLower, "reason") || + strings.Contains(msgLower, "think step by step") || + strings.Contains(msgLower, "mathematics") || + strings.Contains(msgLower, "architecture") || + strings.Contains(msgLower, "explain in detail") + + routeCtx := &router.RouteContext{ + UserMessage: userMessage, + InputTokens: inputTokens, + HasMultimodalInput: hasMultimodal, + RequiresToolCalling: requiresToolCalling, + RequiresReasoning: requiresReasoning, + } + routeCtx.Tags = s.getRouteCtxTags(routeCtx) + return routeCtx +} + +func (s *Server) getRouteCtxTags(routeCtx *router.RouteContext) []string { + var tags []string + msgLower := strings.ToLower(routeCtx.UserMessage) + + // fast-flow keywords + fastFlowKeywords := []string{ + "classify", "classification", "label", "tag", "route", "routing", "intent", + "json", "yaml", "csv", "schema", "parse", "extract", "transform", "format", "regex", + "short answer", "brief", "concise", "tl;dr", "one line", "simple", + "fix this", "small bug", "quick fix", "typo", "syntax error", + } + for _, kw := range fastFlowKeywords { + if strings.Contains(msgLower, kw) { + tags = append(tags, "fast-flow", "classification", "json-extraction", "basic-qa") + break + } + } + + // standard-pro keywords + standardProKeywords := []string{ + "explain", "summarize", "rewrite", "draft", "edit", "polish", "outline", + "long doc", "document", "email", "memo", "proposal", "report", "handout", "notes", + "compare", "choose", "recommend", "tradeoff", "pros and cons", "analysis", + "code review", "debug", "bug", "feature", "api", "endpoint", "implement", + "plan", "planning", "workflow", "integration", + } + for _, kw := range standardProKeywords { + if strings.Contains(msgLower, kw) { + tags = append(tags, "standard-pro", "long-doc") + break + } + } + if routeCtx.HasMultimodalInput { + tags = append(tags, "video-analysis", "multimodal-qa") + } + + // heavy-logic keywords + heavyLogicKeywords := []string{ + "agent", "agents", "tool use", "function calling", "multi-agent", "orchestrate", + "system design", "scaling", "performance", "architecture review", "distributed", + "hard bug", "race condition", "deadlock", "memory leak", "crash", "production outage", + "long context", "large codebase", "many files", "complex refactor", "migration", + "research", "deep dive", "literature", "paper", "scholarly", "thorough analysis", + "deep reasoning", "think step by step", "reason through", "careful analysis", + } + for _, kw := range heavyLogicKeywords { + if strings.Contains(msgLower, kw) { + tags = append(tags, "heavy-logic", "deep-reasoning", "architecture", "hard-debugging") + break + } + } + + if routeCtx.RequiresToolCalling { + tags = append(tags, "tool-heavy", "multi-step-agent", "swe-bench") + } + + return tags +}