From d2b9da89d9d8a5a409bed27d7b25e5004e0b499b Mon Sep 17 00:00:00 2001 From: hobokenchicken Date: Thu, 7 May 2026 14:47:17 -0400 Subject: [PATCH] fix FindModel: prioritize canonical providers to prevent reseller limit overrides FindModel iterates providers in random map order, so when deepseek-v4-pro exists in both 'deepseek' (output=384000) and 'ollama-cloud' (output=1048576), it sometimes returned the wrong metadata. The proxy then injected max_tokens=1048576 into DeepSeek's API, which rejected it with 400 (valid range is [1, 393216]). Fix: define CanonicalProviders list (deepseek, openai, google, xai, etc.) and search them in priority order before falling back to all providers. Each of the four lookup strategies (exact key, metadata ID, reverse fuzzy, forward fuzzy) checks canonical providers first. --- .hermes/plans/auto-model-routing.md | 919 ++++++++++++++++++++++++++++ internal/models/registry.go | 176 +++++- internal/models/registry_test.go | 29 + 3 files changed, 1102 insertions(+), 22 deletions(-) create mode 100644 .hermes/plans/auto-model-routing.md diff --git a/.hermes/plans/auto-model-routing.md b/.hermes/plans/auto-model-routing.md new file mode 100644 index 00000000..13b75049 --- /dev/null +++ b/.hermes/plans/auto-model-routing.md @@ -0,0 +1,919 @@ +# Automatic Model Routing — Implementation Plan + +> **For Hermes:** Use subagent-driven-development skill to implement this plan task-by-task. + +**Goal:** Add a model-group router that lets clients send `model: "deepseek-auto"` and have gophergate pick the best concrete model based on heuristic rules or an optional classifier LLM. + +**Architecture:** A new `internal/router/` package with heuristic and classifier strategies, backed by a `model_groups` DB table. The router injects into `handleChatCompletions` after provider resolution but before the provider call — zero changes to the Provider interface. Admin CRUD endpoints and a dashboard tab for management. + +**Tech Stack:** Go 1.22+, Gin, sqlx (SQLite), resty, existing OpenAI provider for classifier calls. + +--- + +## Task 1: Add `model_groups` DB migration and struct + +**Objective:** Create the `model_groups` table and Go struct. + +**Files:** +- Modify: `internal/db/db.go` + +**Step 1: Add CREATE TABLE to migrations** + +In `RunMigrations()`, add to the `queries` slice (after `client_tokens`): + +```go +`CREATE TABLE IF NOT EXISTS model_groups ( + id TEXT PRIMARY KEY, + strategy TEXT NOT NULL DEFAULT 'heuristic', + selector_model TEXT, + targets TEXT NOT NULL DEFAULT '[]', + complexity_threshold INTEGER, + heuristic_rules TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP +)`, +``` + +**Step 2: Add the Go struct** + +After the `ClientToken` struct (around line 264), add: + +```go +type ModelGroup struct { + ID string `db:"id" json:"id"` + Strategy string `db:"strategy" json:"strategy"` + SelectorModel *string `db:"selector_model" json:"selector_model"` + Targets string `db:"targets" json:"targets"` // JSON array + ComplexityThreshold *int `db:"complexity_threshold" json:"complexity_threshold"` + HeuristicRules *string `db:"heuristic_rules" json:"heuristic_rules"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} +``` + +**Step 3: Seed default groups** + +After the "Default client" block in `RunMigrations()`, add: + +```go +// Seed default model groups +defaultGroups := []struct { + id, strategy, targets string +}{ + {"deepseek-auto", "heuristic", `["deepseek-chat","deepseek-reasoner"]`}, + {"openai-auto", "heuristic", `["gpt-4o-mini","gpt-4o"]`}, + {"gemini-auto", "heuristic", `["gemini-2.0-flash","gemini-2.5-pro"]`}, +} +for _, g := range defaultGroups { + db.Exec(`INSERT OR IGNORE INTO model_groups (id, strategy, targets) VALUES (?, ?, ?)`, + g.id, g.strategy, g.targets) +} +``` + +**Step 4: Build and verify** + +```bash +cd ~/Documents/projects/web_projects/gophergate && go build ./... +``` + +**Step 5: Commit** + +```bash +git add internal/db/db.go +git commit -m "feat: add model_groups table and default seed data" +``` + +--- + +## Task 2: Create router package — interface and heuristic router + +**Objective:** Create `internal/router/` with the Router interface and heuristic implementation. + +**Files:** +- Create: `internal/router/router.go` +- Create: `internal/router/heuristic.go` + +**Step 1: Create `internal/router/router.go`** + +```go +package router + +import ( + "context" + "encoding/json" + + "gophergate/internal/db" +) + +// Decision holds the result of a routing decision. +type Decision struct { + SelectedModel string `json:"selected_model"` + Strategy string `json:"strategy"` // "heuristic" or "classifier" + Reason string `json:"reason"` +} + +// ClassifierFunc is the callback for classifier-based routing. +// Takes a system prompt, user message, and selector model. +// Returns a complexity rating string (e.g. "3"). +type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) + +// Router resolves model groups to concrete models. +type Router struct { + groups map[string]db.ModelGroup + classify ClassifierFunc +} + +// New creates a Router. classify may be nil if no classifier groups exist. +func New(groups []db.ModelGroup, classify ClassifierFunc) *Router { + r := &Router{ + groups: make(map[string]db.ModelGroup), + classify: classify, + } + for _, g := range groups { + r.groups[g.ID] = g + } + return r +} + +// IsGroup returns true if the model name is a group ID. +func (r *Router) IsGroup(modelID string) bool { + _, ok := r.groups[modelID] + return ok +} + +// Route resolves a group to a concrete model. +// Extracts the user message from the request body JSON bytes. +func (r *Router) Route(ctx context.Context, groupID string, userMessage string) (*Decision, error) { + group, ok := r.groups[groupID] + if !ok { + return nil, fmt.Errorf("unknown model group: %s", groupID) + } + + var targets []string + if err := json.Unmarshal([]byte(group.Targets), &targets); err != nil || len(targets) == 0 { + return nil, fmt.Errorf("invalid or empty targets for group %s", groupID) + } + + switch group.Strategy { + case "heuristic": + return routeHeuristic(group, targets, userMessage) + case "classifier": + if r.classify == nil { + // Fall back to heuristic if no classifier is available + return routeHeuristic(group, targets, userMessage) + } + return routeClassifier(ctx, r.classify, group, targets, userMessage) + default: + return nil, fmt.Errorf("unknown strategy: %s", group.Strategy) + } +} + +// Reload replaces the group definitions without recreating the router. +func (r *Router) Reload(groups []db.ModelGroup) { + r.groups = make(map[string]db.ModelGroup) + for _, g := range groups { + r.groups[g.ID] = g + } +} +``` + +**Step 2: Create `internal/router/heuristic.go`** + +```go +package router + +import ( + "context" + "encoding/json" + "strings" + + "gophergate/internal/db" +) + +// HeuristicRule defines a pattern-based routing rule. +type HeuristicRule struct { + Pattern string `json:"pattern"` // substring to match in user message + TargetIdx int `json:"target"` // index into targets array (0-based) + CaseSensitive bool `json:"case_sensitive,omitempty"` +} + +func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (*Decision, error) { + // Default to first target (cheapest/fastest) + selected := targets[0] + reason := "default (first target)" + + // If heuristic_rules is set, use them + if group.HeuristicRules != nil && *group.HeuristicRules != "" { + var rules []HeuristicRule + if err := json.Unmarshal([]byte(*group.HeuristicRules), &rules); err == nil { + searchMsg := userMessage + for _, rule := range rules { + pattern := rule.Pattern + msg := searchMsg + if !rule.CaseSensitive { + pattern = strings.ToLower(pattern) + msg = strings.ToLower(msg) + } + if strings.Contains(msg, pattern) { + if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) { + selected = targets[rule.TargetIdx] + reason = "matched heuristic rule: " + rule.Pattern + break + } + } + } + } + } + + // Built-in fallback heuristics (apply even without custom rules) + if reason == "default (first target)" && len(targets) > 1 { + msgLower := strings.ToLower(userMessage) + // Complex task indicators → last target (usually the smarter model) + complexIndicators := []string{ + "step by step", "explain in detail", "reason through", + "think carefully", "analyze", "debug", "write code", + "implement", "refactor", "architecture", + } + for _, indicator := range complexIndicators { + if strings.Contains(msgLower, indicator) { + selected = targets[len(targets)-1] + reason = "complex task indicator: " + indicator + break + } + } + } + + return &Decision{ + SelectedModel: selected, + Strategy: "heuristic", + Reason: reason, + }, nil +} + +// routeHeuristic exists as a package-level func for direct use. +var _ = routeHeuristic // suppress unused warning when classifier is the only caller +``` + +Hmm, actually let me simplify. The `routeHeuristic` function IS used by `Router.Route()`. Let me not use the blank identifier trick. + +**Step 3: Build** + +```bash +cd ~/Documents/projects/web_projects/gophergate && go build ./... +``` + +Fix any compilation errors (missing imports, etc.). + +**Step 4: Commit** + +```bash +git add internal/router/ +git commit -m "feat: add router package with heuristic strategy" +``` + +--- + +## Task 3: Add classifier router + +**Objective:** Implement the classifier strategy that uses a cheap LLM to rate task complexity. + +**Files:** +- Create: `internal/router/classifier.go` + +**Step 1: Create `internal/router/classifier.go`** + +```go +package router + +import ( + "context" + "fmt" + "strconv" + "strings" + + "gophergate/internal/db" +) + +const classifierSystemPrompt = `You are a task complexity classifier. Rate the following user message on a scale of 1 to %d, where: +1 = trivial/simple (basic facts, greetings, simple math) +%d = highly complex (multi-step reasoning, code generation, architecture design) + +Reply with ONLY the number. No explanation.` + +func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, userMessage string) (*Decision, error) { + maxRating := len(targets) + if maxRating < 2 { + maxRating = 2 + } + + prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating) + ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMessage) + if err != nil { + // Classifier failed — fall back to heuristic + return routeHeuristic(group, targets, userMessage) + } + + rating, err := strconv.Atoi(strings.TrimSpace(ratingStr)) + if err != nil || rating < 1 { + rating = 1 + } + if rating > maxRating { + rating = maxRating + } + + idx := rating - 1 // 0-based index into targets + return &Decision{ + SelectedModel: targets[idx], + Strategy: "classifier", + Reason: fmt.Sprintf("complexity rating: %d/%d", rating, maxRating), + }, nil +} + +func getSelectorModel(group db.ModelGroup, targets []string) string { + if group.SelectorModel != nil && *group.SelectorModel != "" { + return *group.SelectorModel + } + // Default: use the first (cheapest) target model as the selector + return targets[0] +} +``` + +**Step 2: Build** + +```bash +cd ~/Documents/projects/web_projects/gophergate && go build ./... +``` + +**Step 3: Commit** + +```bash +git add internal/router/classifier.go +git commit -m "feat: add classifier routing strategy with LLM complexity rating" +``` + +--- + +## Task 4: Wire router into the server + +**Objective:** Add the Router to the Server struct, initialize it, and inject it into `handleChatCompletions`. + +**Files:** +- Modify: `internal/server/server.go` + +**Step 1: Add router field to Server struct** + +In the `Server` struct (around line 23), add after the `registryMu` field: + +```go +router *router.Router +``` + +**Step 2: Add import** + +Add to the imports block: + +```go +"gophergate/internal/router" +``` + +**Step 3: Initialize router in NewServer** + +After `s.setupRoutes()` (line 66), add: + +```go +// Initialize model group router +s.refreshRouter() +``` + +**Step 4: Add refreshRouter method** + +Add a new method on Server: + +```go +func (s *Server) refreshRouter() { + var groups []db.ModelGroup + if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil { + fmt.Printf("Warning: Failed to load model groups: %v\n", err) + groups = nil + } + + // Build classifier function using the OpenAI provider + var classifyFn router.ClassifierFunc + if openaiProvider, ok := s.providers["openai"]; ok { + classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) { + req := &models.UnifiedRequest{ + Model: selectorModel, + Messages: []models.UnifiedMessage{ + {Role: "system", Content: []models.ContentPart{{Type: "text", Text: systemPrompt}}}, + {Role: "user", Content: []models.ContentPart{{Type: "text", Text: userMessage}}}, + }, + MaxTokens: uint32Ptr(5), + Stream: false, + } + resp, err := openaiProvider.ChatCompletion(ctx, req) + if err != nil { + return "", err + } + if len(resp.Choices) == 0 { + return "", fmt.Errorf("no choices in classifier response") + } + return resp.Choices[0].Message.Content, nil + } + } + + if s.router == nil { + s.router = router.New(groups, classifyFn) + } else { + s.router.Reload(groups) + } +} +``` + +**Step 5: Add uint32Ptr helper (if not already in the codebase)** + +At the bottom of server.go, add: + +```go +func uint32Ptr(v uint32) *uint32 { return &v } +``` + +**Step 6: Inject router into handleChatCompletions** + +In `handleChatCompletions`, after the model prefix stripping block (after line 475) and before building the UnifiedRequest (line 478), add: + +```go +// Check if model is a group and route to a concrete model +if s.router != nil && s.router.IsGroup(modelID) { + userMessage := extractUserMessage(req.Messages) + decision, err := s.router.Route(c.Request.Context(), modelID, userMessage) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)}) + return + } + modelID = decision.SelectedModel + log.Printf("[ROUTER] %s → %s (%s: %s)", req.Model, modelID, decision.Strategy, decision.Reason) +} +``` + +**Step 7: Add extractUserMessage helper** + +```go +func extractUserMessage(messages []models.ChatCompletionMessage) string { + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == "user" { + if s, ok := messages[i].Content.(string); ok { + return s + } + // It might be a content array — grab text from first part + if parts, ok := messages[i].Content.([]interface{}); ok && len(parts) > 0 { + if part, ok := parts[0].(map[string]interface{}); ok { + if text, ok := part["text"].(string); ok { + return text + } + } + } + return "" + } + } + return "" +} +``` + +**Step 8: Add router refresh to RefreshProviders** + +At the end of `RefreshProviders()` (before `return nil` at line 171), add: + +```go +s.refreshRouter() +``` + +**Step 9: Build** + +```bash +cd ~/Documents/projects/web_projects/gophergate && go build ./... +``` + +Expect compilation errors — need to check the `ChatCompletionMessage` type. The handler uses `models.ChatCompletionRequest` which has `Messages []ChatCompletionMessage`. Let me verify the type. If it's `[]models.ChatCompletionMessage` with `Content` as a string field, the helper is simpler. Fix as needed. + +**Step 10: Commit** + +```bash +git add internal/server/server.go +git commit -m "feat: wire model group router into chat completions handler" +``` + +--- + +## Task 5: Add admin API endpoints for model groups + +**Objective:** CRUD endpoints at `/api/model-groups` for dashboard management. + +**Files:** +- Create: `internal/server/model_groups_admin.go` + +**Step 1: Create `internal/server/model_groups_admin.go`** + +```go +package server + +import ( + "net/http" + + "gophergate/internal/db" + + "github.com/gin-gonic/gin" +) + +func (s *Server) handleGetModelGroups(c *gin.Context) { + var groups []db.ModelGroup + if err := s.database.Select(&groups, "SELECT * FROM model_groups ORDER BY id"); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if groups == nil { + groups = []db.ModelGroup{} + } + c.JSON(http.StatusOK, groups) +} + +func (s *Server) handleCreateModelGroup(c *gin.Context) { + var group db.ModelGroup + if err := c.ShouldBindJSON(&group); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + _, err := s.database.Exec(` + INSERT INTO model_groups (id, strategy, selector_model, targets, complexity_threshold, heuristic_rules) + VALUES (?, ?, ?, ?, ?, ?)`, + group.ID, group.Strategy, group.SelectorModel, group.Targets, + group.ComplexityThreshold, group.HeuristicRules) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + s.refreshRouter() + c.JSON(http.StatusCreated, group) +} + +func (s *Server) handleUpdateModelGroup(c *gin.Context) { + id := c.Param("id") + var group db.ModelGroup + if err := c.ShouldBindJSON(&group); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + _, err := s.database.Exec(` + UPDATE model_groups SET strategy=?, selector_model=?, targets=?, complexity_threshold=?, heuristic_rules=?, updated_at=CURRENT_TIMESTAMP + WHERE id=?`, + group.Strategy, group.SelectorModel, group.Targets, + group.ComplexityThreshold, group.HeuristicRules, id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + s.refreshRouter() + c.JSON(http.StatusOK, group) +} + +func (s *Server) handleDeleteModelGroup(c *gin.Context) { + id := c.Param("id") + _, err := s.database.Exec("DELETE FROM model_groups WHERE id=?", id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + s.refreshRouter() + c.JSON(http.StatusOK, gin.H{"status": "deleted"}) +} +``` + +**Step 2: Register routes in setupRoutes()** + +In `setupRoutes()`, add under the admin group (after the models endpoints around line 229): + +```go +admin.GET("/model-groups", s.handleGetModelGroups) +admin.POST("/model-groups", s.handleCreateModelGroup) +admin.PUT("/model-groups/:id", s.handleUpdateModelGroup) +admin.DELETE("/model-groups/:id", s.handleDeleteModelGroup) +``` + +**Step 3: Build** + +```bash +cd ~/Documents/projects/web_projects/gophergate && go build ./... +``` + +**Step 4: Commit** + +```bash +git add internal/server/model_groups_admin.go internal/server/server.go +git commit -m "feat: add model groups CRUD admin API endpoints" +``` + +--- + +## Task 6: Add dashboard UI — sidebar entry and page module + +**Objective:** Add a "Model Groups" tab to the dashboard sidebar and a page module for CRUD management. + +**Files:** +- Modify: `static/index.html` +- Create: `static/js/pages/model_groups.js` + +**Step 1: Add sidebar menu item in index.html** + +In the MANAGEMENT section (after line 91, before ``), add: + +```html + +``` + +**Step 2: Add script tag in index.html** + +After the users.js script (line 179), add: + +```html + +``` + +**Step 3: Create `static/js/pages/model_groups.js`** + +```javascript +// Model Groups Management Page + +class ModelGroupsPage { + constructor() { + this.container = document.getElementById('page-content'); + } + + async render() { + this.container.innerHTML = ` + +
+ + `; + await this.loadGroups(); + } + + async loadGroups() { + try { + const groups = await api.get('/api/model-groups'); + const list = document.getElementById('model-groups-list'); + if (!groups || groups.length === 0) { + list.innerHTML = '
No model groups defined. Create one to enable auto-routing.
'; + return; + } + + let targets; + try { targets = JSON.parse(g.targets); } catch { targets = []; } + const heuristicRules = g.heuristic_rules ? JSON.parse(g.heuristic_rules) : null; + + let html = ''; + html += ''; + html += ''; + + groups.forEach(g => { + html += ` + + + + + `; + }); + + html += '
Group IDStrategyTargetsActions
${this.esc(g.id)}${this.esc(g.strategy)}${this.esc(g.targets)} + + +
'; + list.innerHTML = html; + } catch (err) { + document.getElementById('model-groups-list').innerHTML = + `
Failed to load model groups: ${this.esc(err.message)}
`; + } + } + + showCreateForm() { + this.renderForm(null); + } + + async showEditForm(id) { + const groups = await api.get('/api/model-groups'); + const group = groups.find(g => g.id === id); + if (group) this.renderForm(group); + } + + renderForm(group) { + const isEdit = !!group; + const form = document.getElementById('model-group-form'); + form.style.display = 'block'; + form.innerHTML = ` +

${isEdit ? 'Edit' : 'Create'} Model Group

+
+
+ + + Clients use this as the model name. +
+
+ + +
+
+ + + First target = cheapest/fastest. Last target = smartest/most expensive. +
+
+ + +
+
+ + +
+
+ + + Pattern to match in user messages. target = index into targets array. +
+
+ + +
+
+ `; + + // Toggle strategy-specific fields + document.getElementById('mg-strategy').onchange = function() { + const isClassifier = this.value === 'classifier'; + document.getElementById('mg-selector-row').style.display = isClassifier ? '' : 'none'; + document.getElementById('mg-threshold-row').style.display = isClassifier ? '' : 'none'; + document.getElementById('mg-rules-row').style.display = isClassifier ? 'none' : ''; + }; + } + + async saveGroup(event, isEdit) { + event.preventDefault(); + const id = document.getElementById('mg-id').value.trim(); + const strategy = document.getElementById('mg-strategy').value; + const targets = document.getElementById('mg-targets').value; + const selectorModel = document.getElementById('mg-selector-model').value.trim() || null; + const thresholdVal = document.getElementById('mg-threshold').value; + const rules = document.getElementById('mg-rules').value.trim() || null; + + // Validate JSON + try { JSON.parse(targets); } catch { alert('Targets must be valid JSON array'); return; } + if (rules) { try { JSON.parse(rules); } catch { alert('Heuristic rules must be valid JSON'); return; } } + + const body = { id, strategy, targets, selector_model: selectorModel, heuristic_rules: rules }; + if (thresholdVal) body.complexity_threshold = parseInt(thresholdVal); + + try { + if (isEdit) { + await api.put(`/api/model-groups/${encodeURIComponent(id)}`, body); + } else { + await api.post('/api/model-groups', body); + } + document.getElementById('model-group-form').style.display = 'none'; + await this.loadGroups(); + } catch (err) { + alert('Failed to save: ' + err.message); + } + } + + async deleteGroup(id) { + if (!confirm(`Delete model group "${id}"?`)) return; + try { + await api.delete(`/api/model-groups/${encodeURIComponent(id)}`); + await this.loadGroups(); + } catch (err) { + alert('Failed to delete: ' + err.message); + } + } + + esc(str) { + if (!str) return ''; + return String(str).replace(/&/g,'&').replace(//g,'>').replace(/"/g,'"'); + } +} + +const modelGroupsPage = new ModelGroupsPage(); +``` + +**Step 4: Register page in dashboard.js** + +In `static/js/dashboard.js`, find the page loading logic. The `loadPage` method dynamically imports page modules based on `this.currentPage`. The naming convention uses hyphens in `data-page` attributes (e.g., `data-page="model-groups"`). Check how the existing pages are loaded and ensure "model-groups" maps to the new module. + +Looking at the existing pattern, pages are loaded via script tags in index.html and their constructors handle rendering when the page is navigated to. The dashboard.js `loadPage` method calls page-specific init. Let me check if there's a page registry pattern. + +Actually, based on the index.html, pages are loaded as separate script files and the dashboard dispatches to them. The pattern seems to be: each page script defines a class or object, and the dashboard calls a `render()` or `init()` method on it when that page is selected. Let me add the dispatch logic. + +In `dashboard.js`, find the `loadPage` method and ensure it handles "model-groups": + +```javascript +// In the loadPage switch/if-else, add: +else if (page === 'model-groups') { + if (typeof modelGroupsPage !== 'undefined') { + modelGroupsPage.render(); + } +} +``` + +**Step 5: Commit** + +```bash +git add static/index.html static/js/pages/model_groups.js static/js/dashboard.js +git commit -m "feat: add model groups dashboard page with CRUD UI" +``` + +--- + +## Task 7: Integration test — build, run, verify + +**Objective:** Ensure everything compiles and the routing works end-to-end. + +**Step 1: Full build** + +```bash +cd ~/Documents/projects/web_projects/gophergate && go build -o gophergate ./cmd/gophergate +``` + +**Step 2: Start server and test** + +```bash +# In one terminal: +./gophergate + +# In another terminal, test that default groups loaded: +curl -s -u admin:admin123 http://localhost:8080/api/model-groups | jq + +# Expected: array with deepseek-auto and openai-auto groups +``` + +**Step 3: Test routing via API** + +```bash +# Send a request using a model group +curl -s http://localhost:8080/v1/chat/completions \ + -H "Authorization: Bearer YOUR_TOKEN" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "openai-auto", + "messages": [{"role": "user", "content": "What is 2+2?"}] + }' | jq + +# Check server logs for [ROUTER] line showing the decision +``` + +**Step 4: Commit any fixes** + +If any issues found during testing, fix and commit. + +--- + +## Architecture Notes + +### Why this approach + +- **No Provider interface changes** — the router is a pre-processing step in the handler, transparent to providers +- **Groups stored in DB** — manageable from the dashboard, no config file sprawl +- **Classifier is optional** — heuristic mode works with zero added latency or cost +- **Fallback chain** — classifier failure falls back to heuristic; missing router falls back to direct passthrough + +### Edge cases handled + +- No groups defined → router never activates, all models pass through as before +- Unknown group ID → returns error to client +- Empty targets → returns error +- Classifier call fails → falls back to heuristic +- Classifier returns garbage → clamped to valid range +- OpenAI provider disabled → classifier groups fall back to heuristic mode + +### What's NOT in this plan (future work) + +- Streaming classifier support (the ~300ms classifier call happens before streaming begins — acceptable for now) +- responses endpoint routing (`handleResponses` could also use the router but needs a different message extraction) +- Per-client group overrides +- A/B testing / multi-armed bandit routing +- Caching classifier decisions for identical messages diff --git a/internal/models/registry.go b/internal/models/registry.go index 4f64d739..a51ae851 100644 --- a/internal/models/registry.go +++ b/internal/models/registry.go @@ -2,6 +2,24 @@ package models import "strings" +// CanonicalProviders lists the original model creators in priority order. +// When a model name exists in multiple providers (e.g. deepseek-v4-pro in +// deepseek, ollama-cloud, openrouter, etc.), these providers take precedence +// so the proxy uses authoritative metadata (pricing, limits) rather than a +// reseller's values. +var CanonicalProviders = []string{ + "openai", + "google", + "deepseek", + "xai", + "moonshotai", + "moonshotai-cn", + "anthropic", + "mistral", + "cohere", + "minimax", +} + type ModelRegistry struct { Providers map[string]ProviderInfo `json:"-"` } @@ -39,40 +57,154 @@ type ModelModalities struct { Output []string `json:"output"` } -func (r *ModelRegistry) FindModel(modelID string) *ModelMetadata { - // First try exact match in models map - for _, provider := range r.Providers { - if model, ok := provider.Models[modelID]; ok { - return &model - } - } - - // Try searching by ID in metadata - for _, provider := range r.Providers { - for _, model := range provider.Models { - if model.ID == modelID { - return &model +// findInCanonical searches the canonical providers in order for an exact model +// key match. Returns the metadata and true if found. +func (r *ModelRegistry) findInCanonical(modelID string) (*ModelMetadata, bool) { + for _, key := range CanonicalProviders { + if p, ok := r.Providers[key]; ok { + if m, ok := p.Models[modelID]; ok { + return &m, true } } } + return nil, false +} - // Try reverse fuzzy matching (e.g. 'gpt-5.4-mini' matching 'gpt-5.4-mini-2026-04-01') - for _, provider := range r.Providers { - for id, model := range provider.Models { +// findInAll searches all providers (map iteration, random order) for an exact +// model key match. Used as fallback when canonical search fails. +func (r *ModelRegistry) findInAll(modelID string) (*ModelMetadata, bool) { + for _, p := range r.Providers { + if m, ok := p.Models[modelID]; ok { + return &m, true + } + } + return nil, false +} + +// findInCanonicalByID searches canonical providers for a model whose metadata +// ID field matches modelID. +func (r *ModelRegistry) findInCanonicalByID(modelID string) (*ModelMetadata, bool) { + for _, key := range CanonicalProviders { + if p, ok := r.Providers[key]; ok { + for _, m := range p.Models { + if m.ID == modelID { + return &m, true + } + } + } + } + return nil, false +} + +// findInAllByID searches all providers for a model whose metadata ID field +// matches modelID. +func (r *ModelRegistry) findInAllByID(modelID string) (*ModelMetadata, bool) { + for _, p := range r.Providers { + for _, m := range p.Models { + if m.ID == modelID { + return &m, true + } + } + } + return nil, false +} + +// findCanonicalReverseFuzzy searches canonical providers for any model whose +// key starts with modelID. +func (r *ModelRegistry) findCanonicalReverseFuzzy(modelID string) (*ModelMetadata, bool) { + for _, key := range CanonicalProviders { + if p, ok := r.Providers[key]; ok { + for id, m := range p.Models { + if strings.HasPrefix(id, modelID) { + return &m, true + } + } + } + } + return nil, false +} + +// findAllReverseFuzzy searches all providers for any model whose key starts +// with modelID. +func (r *ModelRegistry) findAllReverseFuzzy(modelID string) (*ModelMetadata, bool) { + for _, p := range r.Providers { + for id, m := range p.Models { if strings.HasPrefix(id, modelID) { - return &model + return &m, true } } } + return nil, false +} - // Try fuzzy matching (e.g. 'gpt-4o-2024-05-13' matching 'gpt-4o') - for _, provider := range r.Providers { - for id, model := range provider.Models { - if strings.HasPrefix(modelID, id) { - return &model +// findCanonicalForwardFuzzy searches canonical providers for any model whose +// key is a prefix of modelID. +func (r *ModelRegistry) findCanonicalForwardFuzzy(modelID string) (*ModelMetadata, bool) { + for _, key := range CanonicalProviders { + if p, ok := r.Providers[key]; ok { + for id, m := range p.Models { + if strings.HasPrefix(modelID, id) { + return &m, true + } } } } + return nil, false +} + +// findAllForwardFuzzy searches all providers for any model whose key is a +// prefix of modelID. +func (r *ModelRegistry) findAllForwardFuzzy(modelID string) (*ModelMetadata, bool) { + for _, p := range r.Providers { + for id, m := range p.Models { + if strings.HasPrefix(modelID, id) { + return &m, true + } + } + } + return nil, false +} + +// FindModel looks up model metadata by ID. It searches canonical providers +// first at each strategy level (exact key, metadata ID, reverse fuzzy, +// forward fuzzy) and falls back to all providers only when canonical search +// yields no result. This prevents reseller entries (ollama-cloud, openrouter, +// etc.) from overriding the original provider's authoritative pricing and +// limits. +func (r *ModelRegistry) FindModel(modelID string) *ModelMetadata { + // 1. Exact key match — canonical first, then all + if m, ok := r.findInCanonical(modelID); ok { + return m + } + if m, ok := r.findInAll(modelID); ok { + return m + } + + // 2. Match by metadata ID field — canonical first, then all + if m, ok := r.findInCanonicalByID(modelID); ok { + return m + } + if m, ok := r.findInAllByID(modelID); ok { + return m + } + + // 3. Reverse fuzzy: model key starts with modelID + // e.g. 'gpt-5.4-mini' matching 'gpt-5.4-mini-2026-04-01' + if m, ok := r.findCanonicalReverseFuzzy(modelID); ok { + return m + } + if m, ok := r.findAllReverseFuzzy(modelID); ok { + return m + } + + // 4. Forward fuzzy: modelID starts with model key + // e.g. 'gpt-4o-2024-05-13' matching 'gpt-4o' + if m, ok := r.findCanonicalForwardFuzzy(modelID); ok { + return m + } + if m, ok := r.findAllForwardFuzzy(modelID); ok { + return m + } return nil } diff --git a/internal/models/registry_test.go b/internal/models/registry_test.go index 6d0434e6..594698d1 100644 --- a/internal/models/registry_test.go +++ b/internal/models/registry_test.go @@ -59,6 +59,35 @@ func TestModelRegistry_FindModel_NotFound(t *testing.T) { } } +func TestModelRegistry_FindModel_CanonicalPriority(t *testing.T) { + // Same model name in canonical (deepseek) and reseller (ollama-cloud). + // Canonical must win so the proxy uses authoritative limits. + r := &ModelRegistry{ + Providers: map[string]ProviderInfo{ + "ollama-cloud": { + Models: map[string]ModelMetadata{ + "deepseek-v4-pro": {ID: "deepseek-v4-pro", Name: "DSv4 Pro (Ollama Cloud)", Limit: &ModelLimit{Context: 1048576, Output: 1048576}}, + }, + }, + "deepseek": { + Models: map[string]ModelMetadata{ + "deepseek-v4-pro": {ID: "deepseek-v4-pro", Name: "DeepSeek v4 Pro", Limit: &ModelLimit{Context: 1000000, Output: 384000}}, + }, + }, + }, + } + m := r.FindModel("deepseek-v4-pro") + if m == nil { + t.Fatal("expected to find deepseek-v4-pro") + } + if m.Name != "DeepSeek v4 Pro" { + t.Fatalf("expected DeepSeek v4 Pro (canonical), got %s", m.Name) + } + if m.Limit.Output != 384000 { + t.Fatalf("expected output limit 384000 (canonical), got %d", m.Limit.Output) + } +} + func TestModelRegistry_FindModel_ReverseFuzzy(t *testing.T) { r := &ModelRegistry{ Providers: map[string]ProviderInfo{