diff --git a/.hermes/plans/auto-model-routing.md b/.hermes/plans/auto-model-routing.md new file mode 100644 index 00000000..13b75049 --- /dev/null +++ b/.hermes/plans/auto-model-routing.md @@ -0,0 +1,919 @@ +# Automatic Model Routing — Implementation Plan + +> **For Hermes:** Use subagent-driven-development skill to implement this plan task-by-task. + +**Goal:** Add a model-group router that lets clients send `model: "deepseek-auto"` and have gophergate pick the best concrete model based on heuristic rules or an optional classifier LLM. + +**Architecture:** A new `internal/router/` package with heuristic and classifier strategies, backed by a `model_groups` DB table. The router injects into `handleChatCompletions` after provider resolution but before the provider call — zero changes to the Provider interface. Admin CRUD endpoints and a dashboard tab for management. + +**Tech Stack:** Go 1.22+, Gin, sqlx (SQLite), resty, existing OpenAI provider for classifier calls. + +--- + +## Task 1: Add `model_groups` DB migration and struct + +**Objective:** Create the `model_groups` table and Go struct. + +**Files:** +- Modify: `internal/db/db.go` + +**Step 1: Add CREATE TABLE to migrations** + +In `RunMigrations()`, add to the `queries` slice (after `client_tokens`): + +```go +`CREATE TABLE IF NOT EXISTS model_groups ( + id TEXT PRIMARY KEY, + strategy TEXT NOT NULL DEFAULT 'heuristic', + selector_model TEXT, + targets TEXT NOT NULL DEFAULT '[]', + complexity_threshold INTEGER, + heuristic_rules TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP +)`, +``` + +**Step 2: Add the Go struct** + +After the `ClientToken` struct (around line 264), add: + +```go +type ModelGroup struct { + ID string `db:"id" json:"id"` + Strategy string `db:"strategy" json:"strategy"` + SelectorModel *string `db:"selector_model" json:"selector_model"` + Targets string `db:"targets" json:"targets"` // JSON array + ComplexityThreshold *int `db:"complexity_threshold" json:"complexity_threshold"` + HeuristicRules *string `db:"heuristic_rules" json:"heuristic_rules"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} +``` + +**Step 3: Seed default groups** + +After the "Default client" block in `RunMigrations()`, add: + +```go +// Seed default model groups +defaultGroups := []struct { + id, strategy, targets string +}{ + {"deepseek-auto", "heuristic", `["deepseek-chat","deepseek-reasoner"]`}, + {"openai-auto", "heuristic", `["gpt-4o-mini","gpt-4o"]`}, + {"gemini-auto", "heuristic", `["gemini-2.0-flash","gemini-2.5-pro"]`}, +} +for _, g := range defaultGroups { + db.Exec(`INSERT OR IGNORE INTO model_groups (id, strategy, targets) VALUES (?, ?, ?)`, + g.id, g.strategy, g.targets) +} +``` + +**Step 4: Build and verify** + +```bash +cd ~/Documents/projects/web_projects/gophergate && go build ./... +``` + +**Step 5: Commit** + +```bash +git add internal/db/db.go +git commit -m "feat: add model_groups table and default seed data" +``` + +--- + +## Task 2: Create router package — interface and heuristic router + +**Objective:** Create `internal/router/` with the Router interface and heuristic implementation. + +**Files:** +- Create: `internal/router/router.go` +- Create: `internal/router/heuristic.go` + +**Step 1: Create `internal/router/router.go`** + +```go +package router + +import ( + "context" + "encoding/json" + + "gophergate/internal/db" +) + +// Decision holds the result of a routing decision. +type Decision struct { + SelectedModel string `json:"selected_model"` + Strategy string `json:"strategy"` // "heuristic" or "classifier" + Reason string `json:"reason"` +} + +// ClassifierFunc is the callback for classifier-based routing. +// Takes a system prompt, user message, and selector model. +// Returns a complexity rating string (e.g. "3"). +type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) + +// Router resolves model groups to concrete models. +type Router struct { + groups map[string]db.ModelGroup + classify ClassifierFunc +} + +// New creates a Router. classify may be nil if no classifier groups exist. +func New(groups []db.ModelGroup, classify ClassifierFunc) *Router { + r := &Router{ + groups: make(map[string]db.ModelGroup), + classify: classify, + } + for _, g := range groups { + r.groups[g.ID] = g + } + return r +} + +// IsGroup returns true if the model name is a group ID. +func (r *Router) IsGroup(modelID string) bool { + _, ok := r.groups[modelID] + return ok +} + +// Route resolves a group to a concrete model. +// Extracts the user message from the request body JSON bytes. +func (r *Router) Route(ctx context.Context, groupID string, userMessage string) (*Decision, error) { + group, ok := r.groups[groupID] + if !ok { + return nil, fmt.Errorf("unknown model group: %s", groupID) + } + + var targets []string + if err := json.Unmarshal([]byte(group.Targets), &targets); err != nil || len(targets) == 0 { + return nil, fmt.Errorf("invalid or empty targets for group %s", groupID) + } + + switch group.Strategy { + case "heuristic": + return routeHeuristic(group, targets, userMessage) + case "classifier": + if r.classify == nil { + // Fall back to heuristic if no classifier is available + return routeHeuristic(group, targets, userMessage) + } + return routeClassifier(ctx, r.classify, group, targets, userMessage) + default: + return nil, fmt.Errorf("unknown strategy: %s", group.Strategy) + } +} + +// Reload replaces the group definitions without recreating the router. +func (r *Router) Reload(groups []db.ModelGroup) { + r.groups = make(map[string]db.ModelGroup) + for _, g := range groups { + r.groups[g.ID] = g + } +} +``` + +**Step 2: Create `internal/router/heuristic.go`** + +```go +package router + +import ( + "context" + "encoding/json" + "strings" + + "gophergate/internal/db" +) + +// HeuristicRule defines a pattern-based routing rule. +type HeuristicRule struct { + Pattern string `json:"pattern"` // substring to match in user message + TargetIdx int `json:"target"` // index into targets array (0-based) + CaseSensitive bool `json:"case_sensitive,omitempty"` +} + +func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (*Decision, error) { + // Default to first target (cheapest/fastest) + selected := targets[0] + reason := "default (first target)" + + // If heuristic_rules is set, use them + if group.HeuristicRules != nil && *group.HeuristicRules != "" { + var rules []HeuristicRule + if err := json.Unmarshal([]byte(*group.HeuristicRules), &rules); err == nil { + searchMsg := userMessage + for _, rule := range rules { + pattern := rule.Pattern + msg := searchMsg + if !rule.CaseSensitive { + pattern = strings.ToLower(pattern) + msg = strings.ToLower(msg) + } + if strings.Contains(msg, pattern) { + if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) { + selected = targets[rule.TargetIdx] + reason = "matched heuristic rule: " + rule.Pattern + break + } + } + } + } + } + + // Built-in fallback heuristics (apply even without custom rules) + if reason == "default (first target)" && len(targets) > 1 { + msgLower := strings.ToLower(userMessage) + // Complex task indicators → last target (usually the smarter model) + complexIndicators := []string{ + "step by step", "explain in detail", "reason through", + "think carefully", "analyze", "debug", "write code", + "implement", "refactor", "architecture", + } + for _, indicator := range complexIndicators { + if strings.Contains(msgLower, indicator) { + selected = targets[len(targets)-1] + reason = "complex task indicator: " + indicator + break + } + } + } + + return &Decision{ + SelectedModel: selected, + Strategy: "heuristic", + Reason: reason, + }, nil +} + +// routeHeuristic exists as a package-level func for direct use. +var _ = routeHeuristic // suppress unused warning when classifier is the only caller +``` + +Hmm, actually let me simplify. The `routeHeuristic` function IS used by `Router.Route()`. Let me not use the blank identifier trick. + +**Step 3: Build** + +```bash +cd ~/Documents/projects/web_projects/gophergate && go build ./... +``` + +Fix any compilation errors (missing imports, etc.). + +**Step 4: Commit** + +```bash +git add internal/router/ +git commit -m "feat: add router package with heuristic strategy" +``` + +--- + +## Task 3: Add classifier router + +**Objective:** Implement the classifier strategy that uses a cheap LLM to rate task complexity. + +**Files:** +- Create: `internal/router/classifier.go` + +**Step 1: Create `internal/router/classifier.go`** + +```go +package router + +import ( + "context" + "fmt" + "strconv" + "strings" + + "gophergate/internal/db" +) + +const classifierSystemPrompt = `You are a task complexity classifier. Rate the following user message on a scale of 1 to %d, where: +1 = trivial/simple (basic facts, greetings, simple math) +%d = highly complex (multi-step reasoning, code generation, architecture design) + +Reply with ONLY the number. No explanation.` + +func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, userMessage string) (*Decision, error) { + maxRating := len(targets) + if maxRating < 2 { + maxRating = 2 + } + + prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating) + ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMessage) + if err != nil { + // Classifier failed — fall back to heuristic + return routeHeuristic(group, targets, userMessage) + } + + rating, err := strconv.Atoi(strings.TrimSpace(ratingStr)) + if err != nil || rating < 1 { + rating = 1 + } + if rating > maxRating { + rating = maxRating + } + + idx := rating - 1 // 0-based index into targets + return &Decision{ + SelectedModel: targets[idx], + Strategy: "classifier", + Reason: fmt.Sprintf("complexity rating: %d/%d", rating, maxRating), + }, nil +} + +func getSelectorModel(group db.ModelGroup, targets []string) string { + if group.SelectorModel != nil && *group.SelectorModel != "" { + return *group.SelectorModel + } + // Default: use the first (cheapest) target model as the selector + return targets[0] +} +``` + +**Step 2: Build** + +```bash +cd ~/Documents/projects/web_projects/gophergate && go build ./... +``` + +**Step 3: Commit** + +```bash +git add internal/router/classifier.go +git commit -m "feat: add classifier routing strategy with LLM complexity rating" +``` + +--- + +## Task 4: Wire router into the server + +**Objective:** Add the Router to the Server struct, initialize it, and inject it into `handleChatCompletions`. + +**Files:** +- Modify: `internal/server/server.go` + +**Step 1: Add router field to Server struct** + +In the `Server` struct (around line 23), add after the `registryMu` field: + +```go +router *router.Router +``` + +**Step 2: Add import** + +Add to the imports block: + +```go +"gophergate/internal/router" +``` + +**Step 3: Initialize router in NewServer** + +After `s.setupRoutes()` (line 66), add: + +```go +// Initialize model group router +s.refreshRouter() +``` + +**Step 4: Add refreshRouter method** + +Add a new method on Server: + +```go +func (s *Server) refreshRouter() { + var groups []db.ModelGroup + if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil { + fmt.Printf("Warning: Failed to load model groups: %v\n", err) + groups = nil + } + + // Build classifier function using the OpenAI provider + var classifyFn router.ClassifierFunc + if openaiProvider, ok := s.providers["openai"]; ok { + classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) { + req := &models.UnifiedRequest{ + Model: selectorModel, + Messages: []models.UnifiedMessage{ + {Role: "system", Content: []models.ContentPart{{Type: "text", Text: systemPrompt}}}, + {Role: "user", Content: []models.ContentPart{{Type: "text", Text: userMessage}}}, + }, + MaxTokens: uint32Ptr(5), + Stream: false, + } + resp, err := openaiProvider.ChatCompletion(ctx, req) + if err != nil { + return "", err + } + if len(resp.Choices) == 0 { + return "", fmt.Errorf("no choices in classifier response") + } + return resp.Choices[0].Message.Content, nil + } + } + + if s.router == nil { + s.router = router.New(groups, classifyFn) + } else { + s.router.Reload(groups) + } +} +``` + +**Step 5: Add uint32Ptr helper (if not already in the codebase)** + +At the bottom of server.go, add: + +```go +func uint32Ptr(v uint32) *uint32 { return &v } +``` + +**Step 6: Inject router into handleChatCompletions** + +In `handleChatCompletions`, after the model prefix stripping block (after line 475) and before building the UnifiedRequest (line 478), add: + +```go +// Check if model is a group and route to a concrete model +if s.router != nil && s.router.IsGroup(modelID) { + userMessage := extractUserMessage(req.Messages) + decision, err := s.router.Route(c.Request.Context(), modelID, userMessage) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)}) + return + } + modelID = decision.SelectedModel + log.Printf("[ROUTER] %s → %s (%s: %s)", req.Model, modelID, decision.Strategy, decision.Reason) +} +``` + +**Step 7: Add extractUserMessage helper** + +```go +func extractUserMessage(messages []models.ChatCompletionMessage) string { + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == "user" { + if s, ok := messages[i].Content.(string); ok { + return s + } + // It might be a content array — grab text from first part + if parts, ok := messages[i].Content.([]interface{}); ok && len(parts) > 0 { + if part, ok := parts[0].(map[string]interface{}); ok { + if text, ok := part["text"].(string); ok { + return text + } + } + } + return "" + } + } + return "" +} +``` + +**Step 8: Add router refresh to RefreshProviders** + +At the end of `RefreshProviders()` (before `return nil` at line 171), add: + +```go +s.refreshRouter() +``` + +**Step 9: Build** + +```bash +cd ~/Documents/projects/web_projects/gophergate && go build ./... +``` + +Expect compilation errors — need to check the `ChatCompletionMessage` type. The handler uses `models.ChatCompletionRequest` which has `Messages []ChatCompletionMessage`. Let me verify the type. If it's `[]models.ChatCompletionMessage` with `Content` as a string field, the helper is simpler. Fix as needed. + +**Step 10: Commit** + +```bash +git add internal/server/server.go +git commit -m "feat: wire model group router into chat completions handler" +``` + +--- + +## Task 5: Add admin API endpoints for model groups + +**Objective:** CRUD endpoints at `/api/model-groups` for dashboard management. + +**Files:** +- Create: `internal/server/model_groups_admin.go` + +**Step 1: Create `internal/server/model_groups_admin.go`** + +```go +package server + +import ( + "net/http" + + "gophergate/internal/db" + + "github.com/gin-gonic/gin" +) + +func (s *Server) handleGetModelGroups(c *gin.Context) { + var groups []db.ModelGroup + if err := s.database.Select(&groups, "SELECT * FROM model_groups ORDER BY id"); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if groups == nil { + groups = []db.ModelGroup{} + } + c.JSON(http.StatusOK, groups) +} + +func (s *Server) handleCreateModelGroup(c *gin.Context) { + var group db.ModelGroup + if err := c.ShouldBindJSON(&group); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + _, err := s.database.Exec(` + INSERT INTO model_groups (id, strategy, selector_model, targets, complexity_threshold, heuristic_rules) + VALUES (?, ?, ?, ?, ?, ?)`, + group.ID, group.Strategy, group.SelectorModel, group.Targets, + group.ComplexityThreshold, group.HeuristicRules) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + s.refreshRouter() + c.JSON(http.StatusCreated, group) +} + +func (s *Server) handleUpdateModelGroup(c *gin.Context) { + id := c.Param("id") + var group db.ModelGroup + if err := c.ShouldBindJSON(&group); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + _, err := s.database.Exec(` + UPDATE model_groups SET strategy=?, selector_model=?, targets=?, complexity_threshold=?, heuristic_rules=?, updated_at=CURRENT_TIMESTAMP + WHERE id=?`, + group.Strategy, group.SelectorModel, group.Targets, + group.ComplexityThreshold, group.HeuristicRules, id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + s.refreshRouter() + c.JSON(http.StatusOK, group) +} + +func (s *Server) handleDeleteModelGroup(c *gin.Context) { + id := c.Param("id") + _, err := s.database.Exec("DELETE FROM model_groups WHERE id=?", id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + s.refreshRouter() + c.JSON(http.StatusOK, gin.H{"status": "deleted"}) +} +``` + +**Step 2: Register routes in setupRoutes()** + +In `setupRoutes()`, add under the admin group (after the models endpoints around line 229): + +```go +admin.GET("/model-groups", s.handleGetModelGroups) +admin.POST("/model-groups", s.handleCreateModelGroup) +admin.PUT("/model-groups/:id", s.handleUpdateModelGroup) +admin.DELETE("/model-groups/:id", s.handleDeleteModelGroup) +``` + +**Step 3: Build** + +```bash +cd ~/Documents/projects/web_projects/gophergate && go build ./... +``` + +**Step 4: Commit** + +```bash +git add internal/server/model_groups_admin.go internal/server/server.go +git commit -m "feat: add model groups CRUD admin API endpoints" +``` + +--- + +## Task 6: Add dashboard UI — sidebar entry and page module + +**Objective:** Add a "Model Groups" tab to the dashboard sidebar and a page module for CRUD management. + +**Files:** +- Modify: `static/index.html` +- Create: `static/js/pages/model_groups.js` + +**Step 1: Add sidebar menu item in index.html** + +In the MANAGEMENT section (after line 91, before ``), add: + +```html + +``` + +**Step 2: Add script tag in index.html** + +After the users.js script (line 179), add: + +```html + +``` + +**Step 3: Create `static/js/pages/model_groups.js`** + +```javascript +// Model Groups Management Page + +class ModelGroupsPage { + constructor() { + this.container = document.getElementById('page-content'); + } + + async render() { + this.container.innerHTML = ` + +
+ + `; + await this.loadGroups(); + } + + async loadGroups() { + try { + const groups = await api.get('/api/model-groups'); + const list = document.getElementById('model-groups-list'); + if (!groups || groups.length === 0) { + list.innerHTML = '
No model groups defined. Create one to enable auto-routing.
'; + return; + } + + let targets; + try { targets = JSON.parse(g.targets); } catch { targets = []; } + const heuristicRules = g.heuristic_rules ? JSON.parse(g.heuristic_rules) : null; + + let html = ''; + html += ''; + html += ''; + + groups.forEach(g => { + html += ` + + + + + `; + }); + + html += '
Group IDStrategyTargetsActions
${this.esc(g.id)}${this.esc(g.strategy)}${this.esc(g.targets)} + + +
'; + list.innerHTML = html; + } catch (err) { + document.getElementById('model-groups-list').innerHTML = + `
Failed to load model groups: ${this.esc(err.message)}
`; + } + } + + showCreateForm() { + this.renderForm(null); + } + + async showEditForm(id) { + const groups = await api.get('/api/model-groups'); + const group = groups.find(g => g.id === id); + if (group) this.renderForm(group); + } + + renderForm(group) { + const isEdit = !!group; + const form = document.getElementById('model-group-form'); + form.style.display = 'block'; + form.innerHTML = ` +

${isEdit ? 'Edit' : 'Create'} Model Group

+
+
+ + + Clients use this as the model name. +
+
+ + +
+
+ + + First target = cheapest/fastest. Last target = smartest/most expensive. +
+
+ + +
+
+ + +
+
+ + + Pattern to match in user messages. target = index into targets array. +
+
+ + +
+
+ `; + + // Toggle strategy-specific fields + document.getElementById('mg-strategy').onchange = function() { + const isClassifier = this.value === 'classifier'; + document.getElementById('mg-selector-row').style.display = isClassifier ? '' : 'none'; + document.getElementById('mg-threshold-row').style.display = isClassifier ? '' : 'none'; + document.getElementById('mg-rules-row').style.display = isClassifier ? 'none' : ''; + }; + } + + async saveGroup(event, isEdit) { + event.preventDefault(); + const id = document.getElementById('mg-id').value.trim(); + const strategy = document.getElementById('mg-strategy').value; + const targets = document.getElementById('mg-targets').value; + const selectorModel = document.getElementById('mg-selector-model').value.trim() || null; + const thresholdVal = document.getElementById('mg-threshold').value; + const rules = document.getElementById('mg-rules').value.trim() || null; + + // Validate JSON + try { JSON.parse(targets); } catch { alert('Targets must be valid JSON array'); return; } + if (rules) { try { JSON.parse(rules); } catch { alert('Heuristic rules must be valid JSON'); return; } } + + const body = { id, strategy, targets, selector_model: selectorModel, heuristic_rules: rules }; + if (thresholdVal) body.complexity_threshold = parseInt(thresholdVal); + + try { + if (isEdit) { + await api.put(`/api/model-groups/${encodeURIComponent(id)}`, body); + } else { + await api.post('/api/model-groups', body); + } + document.getElementById('model-group-form').style.display = 'none'; + await this.loadGroups(); + } catch (err) { + alert('Failed to save: ' + err.message); + } + } + + async deleteGroup(id) { + if (!confirm(`Delete model group "${id}"?`)) return; + try { + await api.delete(`/api/model-groups/${encodeURIComponent(id)}`); + await this.loadGroups(); + } catch (err) { + alert('Failed to delete: ' + err.message); + } + } + + esc(str) { + if (!str) return ''; + return String(str).replace(/&/g,'&').replace(//g,'>').replace(/"/g,'"'); + } +} + +const modelGroupsPage = new ModelGroupsPage(); +``` + +**Step 4: Register page in dashboard.js** + +In `static/js/dashboard.js`, find the page loading logic. The `loadPage` method dynamically imports page modules based on `this.currentPage`. The naming convention uses hyphens in `data-page` attributes (e.g., `data-page="model-groups"`). Check how the existing pages are loaded and ensure "model-groups" maps to the new module. + +Looking at the existing pattern, pages are loaded via script tags in index.html and their constructors handle rendering when the page is navigated to. The dashboard.js `loadPage` method calls page-specific init. Let me check if there's a page registry pattern. + +Actually, based on the index.html, pages are loaded as separate script files and the dashboard dispatches to them. The pattern seems to be: each page script defines a class or object, and the dashboard calls a `render()` or `init()` method on it when that page is selected. Let me add the dispatch logic. + +In `dashboard.js`, find the `loadPage` method and ensure it handles "model-groups": + +```javascript +// In the loadPage switch/if-else, add: +else if (page === 'model-groups') { + if (typeof modelGroupsPage !== 'undefined') { + modelGroupsPage.render(); + } +} +``` + +**Step 5: Commit** + +```bash +git add static/index.html static/js/pages/model_groups.js static/js/dashboard.js +git commit -m "feat: add model groups dashboard page with CRUD UI" +``` + +--- + +## Task 7: Integration test — build, run, verify + +**Objective:** Ensure everything compiles and the routing works end-to-end. + +**Step 1: Full build** + +```bash +cd ~/Documents/projects/web_projects/gophergate && go build -o gophergate ./cmd/gophergate +``` + +**Step 2: Start server and test** + +```bash +# In one terminal: +./gophergate + +# In another terminal, test that default groups loaded: +curl -s -u admin:admin123 http://localhost:8080/api/model-groups | jq + +# Expected: array with deepseek-auto and openai-auto groups +``` + +**Step 3: Test routing via API** + +```bash +# Send a request using a model group +curl -s http://localhost:8080/v1/chat/completions \ + -H "Authorization: Bearer YOUR_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "openai-auto", + "messages": [{"role": "user", "content": "What is 2+2?"}] + }' | jq + +# Check server logs for [ROUTER] line showing the decision +``` + +**Step 4: Commit any fixes** + +If any issues found during testing, fix and commit. + +--- + +## Architecture Notes + +### Why this approach + +- **No Provider interface changes** — the router is a pre-processing step in the handler, transparent to providers +- **Groups stored in DB** — manageable from the dashboard, no config file sprawl +- **Classifier is optional** — heuristic mode works with zero added latency or cost +- **Fallback chain** — classifier failure falls back to heuristic; missing router falls back to direct passthrough + +### Edge cases handled + +- No groups defined → router never activates, all models pass through as before +- Unknown group ID → returns error to client +- Empty targets → returns error +- Classifier call fails → falls back to heuristic +- Classifier returns garbage → clamped to valid range +- OpenAI provider disabled → classifier groups fall back to heuristic mode + +### What's NOT in this plan (future work) + +- Streaming classifier support (the ~300ms classifier call happens before streaming begins — acceptable for now) +- responses endpoint routing (`handleResponses` could also use the router but needs a different message extraction) +- Per-client group overrides +- A/B testing / multi-armed bandit routing +- Caching classifier decisions for identical messages diff --git a/internal/models/registry.go b/internal/models/registry.go index 4f64d739..a51ae851 100644 --- a/internal/models/registry.go +++ b/internal/models/registry.go @@ -2,6 +2,24 @@ package models import "strings" +// CanonicalProviders lists the original model creators in priority order. +// When a model name exists in multiple providers (e.g. deepseek-v4-pro in +// deepseek, ollama-cloud, openrouter, etc.), these providers take precedence +// so the proxy uses authoritative metadata (pricing, limits) rather than a +// reseller's values. +var CanonicalProviders = []string{ + "openai", + "google", + "deepseek", + "xai", + "moonshotai", + "moonshotai-cn", + "anthropic", + "mistral", + "cohere", + "minimax", +} + type ModelRegistry struct { Providers map[string]ProviderInfo `json:"-"` } @@ -39,40 +57,154 @@ type ModelModalities struct { Output []string `json:"output"` } -func (r *ModelRegistry) FindModel(modelID string) *ModelMetadata { - // First try exact match in models map - for _, provider := range r.Providers { - if model, ok := provider.Models[modelID]; ok { - return &model - } - } - - // Try searching by ID in metadata - for _, provider := range r.Providers { - for _, model := range provider.Models { - if model.ID == modelID { - return &model +// findInCanonical searches the canonical providers in order for an exact model +// key match. Returns the metadata and true if found. +func (r *ModelRegistry) findInCanonical(modelID string) (*ModelMetadata, bool) { + for _, key := range CanonicalProviders { + if p, ok := r.Providers[key]; ok { + if m, ok := p.Models[modelID]; ok { + return &m, true } } } + return nil, false +} - // Try reverse fuzzy matching (e.g. 'gpt-5.4-mini' matching 'gpt-5.4-mini-2026-04-01') - for _, provider := range r.Providers { - for id, model := range provider.Models { +// findInAll searches all providers (map iteration, random order) for an exact +// model key match. Used as fallback when canonical search fails. +func (r *ModelRegistry) findInAll(modelID string) (*ModelMetadata, bool) { + for _, p := range r.Providers { + if m, ok := p.Models[modelID]; ok { + return &m, true + } + } + return nil, false +} + +// findInCanonicalByID searches canonical providers for a model whose metadata +// ID field matches modelID. +func (r *ModelRegistry) findInCanonicalByID(modelID string) (*ModelMetadata, bool) { + for _, key := range CanonicalProviders { + if p, ok := r.Providers[key]; ok { + for _, m := range p.Models { + if m.ID == modelID { + return &m, true + } + } + } + } + return nil, false +} + +// findInAllByID searches all providers for a model whose metadata ID field +// matches modelID. +func (r *ModelRegistry) findInAllByID(modelID string) (*ModelMetadata, bool) { + for _, p := range r.Providers { + for _, m := range p.Models { + if m.ID == modelID { + return &m, true + } + } + } + return nil, false +} + +// findCanonicalReverseFuzzy searches canonical providers for any model whose +// key starts with modelID. +func (r *ModelRegistry) findCanonicalReverseFuzzy(modelID string) (*ModelMetadata, bool) { + for _, key := range CanonicalProviders { + if p, ok := r.Providers[key]; ok { + for id, m := range p.Models { + if strings.HasPrefix(id, modelID) { + return &m, true + } + } + } + } + return nil, false +} + +// findAllReverseFuzzy searches all providers for any model whose key starts +// with modelID. +func (r *ModelRegistry) findAllReverseFuzzy(modelID string) (*ModelMetadata, bool) { + for _, p := range r.Providers { + for id, m := range p.Models { if strings.HasPrefix(id, modelID) { - return &model + return &m, true } } } + return nil, false +} - // Try fuzzy matching (e.g. 'gpt-4o-2024-05-13' matching 'gpt-4o') - for _, provider := range r.Providers { - for id, model := range provider.Models { - if strings.HasPrefix(modelID, id) { - return &model +// findCanonicalForwardFuzzy searches canonical providers for any model whose +// key is a prefix of modelID. +func (r *ModelRegistry) findCanonicalForwardFuzzy(modelID string) (*ModelMetadata, bool) { + for _, key := range CanonicalProviders { + if p, ok := r.Providers[key]; ok { + for id, m := range p.Models { + if strings.HasPrefix(modelID, id) { + return &m, true + } } } } + return nil, false +} + +// findAllForwardFuzzy searches all providers for any model whose key is a +// prefix of modelID. +func (r *ModelRegistry) findAllForwardFuzzy(modelID string) (*ModelMetadata, bool) { + for _, p := range r.Providers { + for id, m := range p.Models { + if strings.HasPrefix(modelID, id) { + return &m, true + } + } + } + return nil, false +} + +// FindModel looks up model metadata by ID. It searches canonical providers +// first at each strategy level (exact key, metadata ID, reverse fuzzy, +// forward fuzzy) and falls back to all providers only when canonical search +// yields no result. This prevents reseller entries (ollama-cloud, openrouter, +// etc.) from overriding the original provider's authoritative pricing and +// limits. +func (r *ModelRegistry) FindModel(modelID string) *ModelMetadata { + // 1. Exact key match — canonical first, then all + if m, ok := r.findInCanonical(modelID); ok { + return m + } + if m, ok := r.findInAll(modelID); ok { + return m + } + + // 2. Match by metadata ID field — canonical first, then all + if m, ok := r.findInCanonicalByID(modelID); ok { + return m + } + if m, ok := r.findInAllByID(modelID); ok { + return m + } + + // 3. Reverse fuzzy: model key starts with modelID + // e.g. 'gpt-5.4-mini' matching 'gpt-5.4-mini-2026-04-01' + if m, ok := r.findCanonicalReverseFuzzy(modelID); ok { + return m + } + if m, ok := r.findAllReverseFuzzy(modelID); ok { + return m + } + + // 4. Forward fuzzy: modelID starts with model key + // e.g. 'gpt-4o-2024-05-13' matching 'gpt-4o' + if m, ok := r.findCanonicalForwardFuzzy(modelID); ok { + return m + } + if m, ok := r.findAllForwardFuzzy(modelID); ok { + return m + } return nil } diff --git a/internal/models/registry_test.go b/internal/models/registry_test.go index 6d0434e6..594698d1 100644 --- a/internal/models/registry_test.go +++ b/internal/models/registry_test.go @@ -59,6 +59,35 @@ func TestModelRegistry_FindModel_NotFound(t *testing.T) { } } +func TestModelRegistry_FindModel_CanonicalPriority(t *testing.T) { + // Same model name in canonical (deepseek) and reseller (ollama-cloud). + // Canonical must win so the proxy uses authoritative limits. + r := &ModelRegistry{ + Providers: map[string]ProviderInfo{ + "ollama-cloud": { + Models: map[string]ModelMetadata{ + "deepseek-v4-pro": {ID: "deepseek-v4-pro", Name: "DSv4 Pro (Ollama Cloud)", Limit: &ModelLimit{Context: 1048576, Output: 1048576}}, + }, + }, + "deepseek": { + Models: map[string]ModelMetadata{ + "deepseek-v4-pro": {ID: "deepseek-v4-pro", Name: "DeepSeek v4 Pro", Limit: &ModelLimit{Context: 1000000, Output: 384000}}, + }, + }, + }, + } + m := r.FindModel("deepseek-v4-pro") + if m == nil { + t.Fatal("expected to find deepseek-v4-pro") + } + if m.Name != "DeepSeek v4 Pro" { + t.Fatalf("expected DeepSeek v4 Pro (canonical), got %s", m.Name) + } + if m.Limit.Output != 384000 { + t.Fatalf("expected output limit 384000 (canonical), got %d", m.Limit.Output) + } +} + func TestModelRegistry_FindModel_ReverseFuzzy(t *testing.T) { r := &ModelRegistry{ Providers: map[string]ProviderInfo{