Compare commits

..

28 Commits

Author SHA1 Message Date
hobokenchicken 73a82e6175 feat: implement advanced condition-based heuristic model routing
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Upgrades the routing engine to support tag, token limit, multimodal, reasoning, and tool calling conditions. Adds unit tests for the new routing features.
2026-06-05 15:05:13 +00:00
newkirk b3354a1bbc Add Xiaomi MiMo provider (mimo-v2.5) support
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-05-29 12:19:24 -04:00
newkirk 1dc5f586b9 fix: improve OpenAI error body capture and log request body on 400
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Use resp.Body() instead of resp.RawBody() for non-streaming error responses
- Fall back to RawBody() for streaming responses
- Log the full request body on API errors for debugging
2026-05-17 19:57:59 -04:00
newkirk 40f055cb57 fix: correct deepseek pricing, gemini streaming tokens, and group-name logging
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Add promo discount system for deepseek-v4-pro (75% off until 2026-05-31)
- Rewrite StreamGemini to handle both SSE and JSON array response formats,
  fixing 0-token logging for gemini-3-flash and gemini-3-flash-preview
- Fall back to model group name for cost lookup when concrete model
  isnt in the registry (fixes $0 cost on deepseek-auto entries)
- Move registry lock before FindModel call to fix data race
2026-05-17 19:49:37 -04:00
hobokenchicken 970e778703 chore: update .gitignore to ignore nohup.out and bak files
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-05-11 03:13:54 +00:00
hobokenchicken 477a811999 fix: remove tool call ID truncation and improve DeepSeek reasoning handling
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
The 40-character truncation of tool call IDs in helper.go caused collisions
when models (like deepseek-v4-flash) generated longer IDs, leading to
"Duplicate value for 'tool_call_id'" errors. Removed the limit to allow
full unique IDs.

DeepSeek: updated reasoning_content injection to use an empty string
instead of a space, better matching provider expectations for history.

Improved API error reporting across all providers by capturing raw body
content when response parsing fails or returns empty strings.
2026-05-11 03:13:33 +00:00
hobokenchicken d2b9da89d9 fix FindModel: prioritize canonical providers to prevent reseller limit overrides
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
FindModel iterates providers in random map order, so when deepseek-v4-pro
exists in both 'deepseek' (output=384000) and 'ollama-cloud' (output=1048576),
it sometimes returned the wrong metadata. The proxy then injected
max_tokens=1048576 into DeepSeek's API, which rejected it with 400
(valid range is [1, 393216]).

Fix: define CanonicalProviders list (deepseek, openai, google, xai, etc.)
and search them in priority order before falling back to all providers.
Each of the four lookup strategies (exact key, metadata ID, reverse fuzzy,
forward fuzzy) checks canonical providers first.
2026-05-07 14:47:17 -04:00
hobokenchicken b7df3108fa docs: update README, TODO, and deployment docs
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
README: Added hierarchical routing, classifier bucket mapping, two-level
dispatch, model groups table, DeepSeek language note, deploy script, and
updated model names to match current models.dev registry.

TODO: Added 15 completed items covering model groups, routing, dispatch,
and provider fixes from May 7 session.

deployment.md: Added deploy.sh instructions.
2026-05-07 14:07:52 -04:00
hobokenchicken 28b8271c1d fix: inject English system prompt for DeepSeek provider
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
DeepSeek models default to Chinese for some prompts. The ensureEnglish()
function prepends 'Always respond in English' as a system message when
no system prompt is already set. Applied to both ChatCompletion and
ChatCompletionStream paths.
2026-05-07 14:03:39 -04:00
hobokenchicken eb585c0001 fix: switch dispatcher classifier to gpt-5.4-nano
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
gpt-5.4-nano correctly discriminates complexity (1 vs 10)
while deepseek-v4-flash rated everything as 1/10.
2026-05-07 14:00:19 -04:00
hobokenchicken 4aea7a3b4c fix: select provider AFTER routing resolves model groups
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Previously, provider selection happened on the raw client-requested model
name (e.g. 'dispatcher') which defaulted to OpenAI. After routing resolved
it to 'deepseek-v4-flash', the provider was never re-selected.

Now prefix-stripping + routing runs first, then selectProvider() picks
the correct provider based on the resolved concrete model.
2026-05-07 13:54:42 -04:00
hobokenchicken 330eaa57d1 fix: update model names to match current models.dev registry
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
heavy-logic: kimi-k2.5 -> kimi-k2.6
standard-pro: gemini-3-flash -> gemini-3-flash-preview
2026-05-07 13:48:33 -04:00
hobokenchicken 0ae30036f0 fix: classifier selector model now routes to correct provider
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Extracted selectProvider() method from handleChatCompletions' inline
logic. The classifier callback now calls selectProvider(selectorModel)
instead of hardcoding openaiProvider.

This fixes the 'circuit breaker is open' error when dispatcher tries
to use deepseek-v4-flash as its selector model.
2026-05-07 13:37:19 -04:00
hobokenchicken 3c0b59622e feat: classifier bucket mapping + dispatcher seed group
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Classifier: When complexity_threshold is set (e.g. 10), uses it as the
rating scale and maps ratings proportionally to target buckets instead
of 1:1. Formula: idx = rating * len(targets) / (threshold + 1).

With threshold=10 and 3 targets: 1-3→target[0], 4-7→target[1], 8-10→target[2].

Seed: Added 'dispatcher' group (classifier, threshold=10, selector=deepseek-v4-flash)
that auto-routes to fast-flow/standard-pro/heavy-logic by complexity score.

Combined with hierarchical routing, this enables two-level dispatch:
  dispatcher scores 1-10 → routes to tier group → tier picks concrete model.
2026-05-07 13:18:35 -04:00
hobokenchicken 7517307c11 feat: add hierarchical routing — groups can target other groups
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
RouteToConcrete() recursively resolves group chains until a concrete
model is reached, with cycle detection and max depth (10) guard.

Example: all-purpose -> fast-flow -> deepseek-v4-flash
The dashboard log shows the full chain: 'deepseek-v4-flash (hierarchical:
fast-flow (default (first target)) -> deepseek-v4-flash (default (first target)))'
2026-05-07 12:28:31 -04:00
hobokenchicken 19517b0847 chore: add deploy.sh for prod restarts
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-05-07 12:02:28 -04:00
hobokenchicken a3a6f765e7 feat: add logic_level and primary_use metadata to model groups
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Schema: Added logic_level (INTEGER) and primary_use (TEXT) columns
to model_groups table with auto-migration for existing databases.

Seed: Three new default groups:
  heavy-logic  (level 9) — Complex Coding, Logic, Agents
  standard-pro (level 5) — General Assistant, Long Docs
  fast-flow    (level 2) — Classification, JSON, Basic Q&A

Admin API: INSERT/UPDATE handlers now accept and persist the new fields.
Dashboard: Table shows Level and Primary Use columns; form includes
both fields with appropriate inputs and placeholders.
2026-05-07 12:01:28 -04:00
hobokenchicken 79dd122b56 feat: expose model groups in /v1/models endpoint
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Add Groups() method to Router so handleListModels can append model
group IDs (e.g. 'deepseek-auto', 'openai-auto') to the model list,
marked with owned_by: 'gophergate'. This lets clients discover and
use groups via the standard OpenAI /v1/models endpoint.
2026-05-07 11:26:05 -04:00
hobokenchicken 3021e4b2b4 fix: log resolved model name instead of group name in Recent Activity
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
When using model groups (e.g. 'deepseek-auto'), the dashboard logged the
group name instead of the concrete resolved model (e.g. 'deepseek-reasoner').

Now:
- logRequest passes the resolved modelID (concrete) + modelGroup (group name)
- RequestLog struct has a new ModelGroup field (omitempty)
- Dashboard displays resolved model (via group) when a group was used

Files changed:
  internal/server/logging.go  - add ModelGroup field
  internal/server/server.go   - pass resolved modelID, capture modelGroup
  static/js/websocket.js      - show group annotation in Recent Activity
  static/js/pages/overview.js - show group annotation in overview table
  static/js/pages/monitoring.js - show group annotation in stream
2026-05-07 11:16:36 -04:00
hobokenchicken 14de7e9ebf fix: wrap model-groups API responses in SuccessResponse for api.js client
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-05-05 11:41:23 -04:00
hobokenchicken 4fef201e95 fix: remove /api prefix from model-groups API calls (api.js already prepends it)
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-05-05 11:33:05 -04:00
hobokenchicken bac03de051 docs: add automatic model routing to README
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-05-05 11:28:59 -04:00
hobokenchicken 37949e560b feat: add model groups dashboard page with CRUD UI
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-05-05 10:55:25 -04:00
hobokenchicken f04cb6b8f2 feat: add model groups CRUD admin API endpoints 2026-05-05 10:50:33 -04:00
hobokenchicken 10262c0e5a feat: wire model group router into chat completions handler 2026-05-05 10:47:32 -04:00
hobokenchicken d345f8c41d feat: add classifier routing strategy with LLM complexity rating 2026-05-05 10:40:26 -04:00
hobokenchicken d1f7a57f58 feat: add router package with heuristic strategy 2026-05-05 10:37:36 -04:00
hobokenchicken dc9af4d79c feat: add model_groups table and default seed data 2026-05-05 10:33:35 -04:00
39 changed files with 3351 additions and 247 deletions
+3
View File
@@ -18,6 +18,9 @@ DEEPSEEK_API_KEY=sk-...
MOONSHOT_API_KEY=sk-...
GROK_API_KEY=xai-...
# Xiaomi MiMo
XIAOMI_API_KEY=sk-...
# ==============================================================================
# Server Configuration
# ==============================================================================
+2
View File
@@ -14,3 +14,5 @@
.pi-lens/cache/
server.pid
/target
nohup.out
*.bak
+919
View File
@@ -0,0 +1,919 @@
# Automatic Model Routing — Implementation Plan
> **For Hermes:** Use subagent-driven-development skill to implement this plan task-by-task.
**Goal:** Add a model-group router that lets clients send `model: "deepseek-auto"` and have gophergate pick the best concrete model based on heuristic rules or an optional classifier LLM.
**Architecture:** A new `internal/router/` package with heuristic and classifier strategies, backed by a `model_groups` DB table. The router injects into `handleChatCompletions` after provider resolution but before the provider call — zero changes to the Provider interface. Admin CRUD endpoints and a dashboard tab for management.
**Tech Stack:** Go 1.22+, Gin, sqlx (SQLite), resty, existing OpenAI provider for classifier calls.
---
## Task 1: Add `model_groups` DB migration and struct
**Objective:** Create the `model_groups` table and Go struct.
**Files:**
- Modify: `internal/db/db.go`
**Step 1: Add CREATE TABLE to migrations**
In `RunMigrations()`, add to the `queries` slice (after `client_tokens`):
```go
`CREATE TABLE IF NOT EXISTS model_groups (
id TEXT PRIMARY KEY,
strategy TEXT NOT NULL DEFAULT 'heuristic',
selector_model TEXT,
targets TEXT NOT NULL DEFAULT '[]',
complexity_threshold INTEGER,
heuristic_rules TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
```
**Step 2: Add the Go struct**
After the `ClientToken` struct (around line 264), add:
```go
type ModelGroup struct {
ID string `db:"id" json:"id"`
Strategy string `db:"strategy" json:"strategy"`
SelectorModel *string `db:"selector_model" json:"selector_model"`
Targets string `db:"targets" json:"targets"` // JSON array
ComplexityThreshold *int `db:"complexity_threshold" json:"complexity_threshold"`
HeuristicRules *string `db:"heuristic_rules" json:"heuristic_rules"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
```
**Step 3: Seed default groups**
After the "Default client" block in `RunMigrations()`, add:
```go
// Seed default model groups
defaultGroups := []struct {
id, strategy, targets string
}{
{"deepseek-auto", "heuristic", `["deepseek-chat","deepseek-reasoner"]`},
{"openai-auto", "heuristic", `["gpt-4o-mini","gpt-4o"]`},
{"gemini-auto", "heuristic", `["gemini-2.0-flash","gemini-2.5-pro"]`},
}
for _, g := range defaultGroups {
db.Exec(`INSERT OR IGNORE INTO model_groups (id, strategy, targets) VALUES (?, ?, ?)`,
g.id, g.strategy, g.targets)
}
```
**Step 4: Build and verify**
```bash
cd ~/Documents/projects/web_projects/gophergate && go build ./...
```
**Step 5: Commit**
```bash
git add internal/db/db.go
git commit -m "feat: add model_groups table and default seed data"
```
---
## Task 2: Create router package — interface and heuristic router
**Objective:** Create `internal/router/` with the Router interface and heuristic implementation.
**Files:**
- Create: `internal/router/router.go`
- Create: `internal/router/heuristic.go`
**Step 1: Create `internal/router/router.go`**
```go
package router
import (
"context"
"encoding/json"
"gophergate/internal/db"
)
// Decision holds the result of a routing decision.
type Decision struct {
SelectedModel string `json:"selected_model"`
Strategy string `json:"strategy"` // "heuristic" or "classifier"
Reason string `json:"reason"`
}
// ClassifierFunc is the callback for classifier-based routing.
// Takes a system prompt, user message, and selector model.
// Returns a complexity rating string (e.g. "3").
type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error)
// Router resolves model groups to concrete models.
type Router struct {
groups map[string]db.ModelGroup
classify ClassifierFunc
}
// New creates a Router. classify may be nil if no classifier groups exist.
func New(groups []db.ModelGroup, classify ClassifierFunc) *Router {
r := &Router{
groups: make(map[string]db.ModelGroup),
classify: classify,
}
for _, g := range groups {
r.groups[g.ID] = g
}
return r
}
// IsGroup returns true if the model name is a group ID.
func (r *Router) IsGroup(modelID string) bool {
_, ok := r.groups[modelID]
return ok
}
// Route resolves a group to a concrete model.
// Extracts the user message from the request body JSON bytes.
func (r *Router) Route(ctx context.Context, groupID string, userMessage string) (*Decision, error) {
group, ok := r.groups[groupID]
if !ok {
return nil, fmt.Errorf("unknown model group: %s", groupID)
}
var targets []string
if err := json.Unmarshal([]byte(group.Targets), &targets); err != nil || len(targets) == 0 {
return nil, fmt.Errorf("invalid or empty targets for group %s", groupID)
}
switch group.Strategy {
case "heuristic":
return routeHeuristic(group, targets, userMessage)
case "classifier":
if r.classify == nil {
// Fall back to heuristic if no classifier is available
return routeHeuristic(group, targets, userMessage)
}
return routeClassifier(ctx, r.classify, group, targets, userMessage)
default:
return nil, fmt.Errorf("unknown strategy: %s", group.Strategy)
}
}
// Reload replaces the group definitions without recreating the router.
func (r *Router) Reload(groups []db.ModelGroup) {
r.groups = make(map[string]db.ModelGroup)
for _, g := range groups {
r.groups[g.ID] = g
}
}
```
**Step 2: Create `internal/router/heuristic.go`**
```go
package router
import (
"context"
"encoding/json"
"strings"
"gophergate/internal/db"
)
// HeuristicRule defines a pattern-based routing rule.
type HeuristicRule struct {
Pattern string `json:"pattern"` // substring to match in user message
TargetIdx int `json:"target"` // index into targets array (0-based)
CaseSensitive bool `json:"case_sensitive,omitempty"`
}
func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (*Decision, error) {
// Default to first target (cheapest/fastest)
selected := targets[0]
reason := "default (first target)"
// If heuristic_rules is set, use them
if group.HeuristicRules != nil && *group.HeuristicRules != "" {
var rules []HeuristicRule
if err := json.Unmarshal([]byte(*group.HeuristicRules), &rules); err == nil {
searchMsg := userMessage
for _, rule := range rules {
pattern := rule.Pattern
msg := searchMsg
if !rule.CaseSensitive {
pattern = strings.ToLower(pattern)
msg = strings.ToLower(msg)
}
if strings.Contains(msg, pattern) {
if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) {
selected = targets[rule.TargetIdx]
reason = "matched heuristic rule: " + rule.Pattern
break
}
}
}
}
}
// Built-in fallback heuristics (apply even without custom rules)
if reason == "default (first target)" && len(targets) > 1 {
msgLower := strings.ToLower(userMessage)
// Complex task indicators → last target (usually the smarter model)
complexIndicators := []string{
"step by step", "explain in detail", "reason through",
"think carefully", "analyze", "debug", "write code",
"implement", "refactor", "architecture",
}
for _, indicator := range complexIndicators {
if strings.Contains(msgLower, indicator) {
selected = targets[len(targets)-1]
reason = "complex task indicator: " + indicator
break
}
}
}
return &Decision{
SelectedModel: selected,
Strategy: "heuristic",
Reason: reason,
}, nil
}
// routeHeuristic exists as a package-level func for direct use.
var _ = routeHeuristic // suppress unused warning when classifier is the only caller
```
Hmm, actually let me simplify. The `routeHeuristic` function IS used by `Router.Route()`. Let me not use the blank identifier trick.
**Step 3: Build**
```bash
cd ~/Documents/projects/web_projects/gophergate && go build ./...
```
Fix any compilation errors (missing imports, etc.).
**Step 4: Commit**
```bash
git add internal/router/
git commit -m "feat: add router package with heuristic strategy"
```
---
## Task 3: Add classifier router
**Objective:** Implement the classifier strategy that uses a cheap LLM to rate task complexity.
**Files:**
- Create: `internal/router/classifier.go`
**Step 1: Create `internal/router/classifier.go`**
```go
package router
import (
"context"
"fmt"
"strconv"
"strings"
"gophergate/internal/db"
)
const classifierSystemPrompt = `You are a task complexity classifier. Rate the following user message on a scale of 1 to %d, where:
1 = trivial/simple (basic facts, greetings, simple math)
%d = highly complex (multi-step reasoning, code generation, architecture design)
Reply with ONLY the number. No explanation.`
func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, userMessage string) (*Decision, error) {
maxRating := len(targets)
if maxRating < 2 {
maxRating = 2
}
prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating)
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMessage)
if err != nil {
// Classifier failed — fall back to heuristic
return routeHeuristic(group, targets, userMessage)
}
rating, err := strconv.Atoi(strings.TrimSpace(ratingStr))
if err != nil || rating < 1 {
rating = 1
}
if rating > maxRating {
rating = maxRating
}
idx := rating - 1 // 0-based index into targets
return &Decision{
SelectedModel: targets[idx],
Strategy: "classifier",
Reason: fmt.Sprintf("complexity rating: %d/%d", rating, maxRating),
}, nil
}
func getSelectorModel(group db.ModelGroup, targets []string) string {
if group.SelectorModel != nil && *group.SelectorModel != "" {
return *group.SelectorModel
}
// Default: use the first (cheapest) target model as the selector
return targets[0]
}
```
**Step 2: Build**
```bash
cd ~/Documents/projects/web_projects/gophergate && go build ./...
```
**Step 3: Commit**
```bash
git add internal/router/classifier.go
git commit -m "feat: add classifier routing strategy with LLM complexity rating"
```
---
## Task 4: Wire router into the server
**Objective:** Add the Router to the Server struct, initialize it, and inject it into `handleChatCompletions`.
**Files:**
- Modify: `internal/server/server.go`
**Step 1: Add router field to Server struct**
In the `Server` struct (around line 23), add after the `registryMu` field:
```go
router *router.Router
```
**Step 2: Add import**
Add to the imports block:
```go
"gophergate/internal/router"
```
**Step 3: Initialize router in NewServer**
After `s.setupRoutes()` (line 66), add:
```go
// Initialize model group router
s.refreshRouter()
```
**Step 4: Add refreshRouter method**
Add a new method on Server:
```go
func (s *Server) refreshRouter() {
var groups []db.ModelGroup
if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil {
fmt.Printf("Warning: Failed to load model groups: %v\n", err)
groups = nil
}
// Build classifier function using the OpenAI provider
var classifyFn router.ClassifierFunc
if openaiProvider, ok := s.providers["openai"]; ok {
classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) {
req := &models.UnifiedRequest{
Model: selectorModel,
Messages: []models.UnifiedMessage{
{Role: "system", Content: []models.ContentPart{{Type: "text", Text: systemPrompt}}},
{Role: "user", Content: []models.ContentPart{{Type: "text", Text: userMessage}}},
},
MaxTokens: uint32Ptr(5),
Stream: false,
}
resp, err := openaiProvider.ChatCompletion(ctx, req)
if err != nil {
return "", err
}
if len(resp.Choices) == 0 {
return "", fmt.Errorf("no choices in classifier response")
}
return resp.Choices[0].Message.Content, nil
}
}
if s.router == nil {
s.router = router.New(groups, classifyFn)
} else {
s.router.Reload(groups)
}
}
```
**Step 5: Add uint32Ptr helper (if not already in the codebase)**
At the bottom of server.go, add:
```go
func uint32Ptr(v uint32) *uint32 { return &v }
```
**Step 6: Inject router into handleChatCompletions**
In `handleChatCompletions`, after the model prefix stripping block (after line 475) and before building the UnifiedRequest (line 478), add:
```go
// Check if model is a group and route to a concrete model
if s.router != nil && s.router.IsGroup(modelID) {
userMessage := extractUserMessage(req.Messages)
decision, err := s.router.Route(c.Request.Context(), modelID, userMessage)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
return
}
modelID = decision.SelectedModel
log.Printf("[ROUTER] %s → %s (%s: %s)", req.Model, modelID, decision.Strategy, decision.Reason)
}
```
**Step 7: Add extractUserMessage helper**
```go
func extractUserMessage(messages []models.ChatCompletionMessage) string {
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
if s, ok := messages[i].Content.(string); ok {
return s
}
// It might be a content array — grab text from first part
if parts, ok := messages[i].Content.([]interface{}); ok && len(parts) > 0 {
if part, ok := parts[0].(map[string]interface{}); ok {
if text, ok := part["text"].(string); ok {
return text
}
}
}
return ""
}
}
return ""
}
```
**Step 8: Add router refresh to RefreshProviders**
At the end of `RefreshProviders()` (before `return nil` at line 171), add:
```go
s.refreshRouter()
```
**Step 9: Build**
```bash
cd ~/Documents/projects/web_projects/gophergate && go build ./...
```
Expect compilation errors — need to check the `ChatCompletionMessage` type. The handler uses `models.ChatCompletionRequest` which has `Messages []ChatCompletionMessage`. Let me verify the type. If it's `[]models.ChatCompletionMessage` with `Content` as a string field, the helper is simpler. Fix as needed.
**Step 10: Commit**
```bash
git add internal/server/server.go
git commit -m "feat: wire model group router into chat completions handler"
```
---
## Task 5: Add admin API endpoints for model groups
**Objective:** CRUD endpoints at `/api/model-groups` for dashboard management.
**Files:**
- Create: `internal/server/model_groups_admin.go`
**Step 1: Create `internal/server/model_groups_admin.go`**
```go
package server
import (
"net/http"
"gophergate/internal/db"
"github.com/gin-gonic/gin"
)
func (s *Server) handleGetModelGroups(c *gin.Context) {
var groups []db.ModelGroup
if err := s.database.Select(&groups, "SELECT * FROM model_groups ORDER BY id"); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if groups == nil {
groups = []db.ModelGroup{}
}
c.JSON(http.StatusOK, groups)
}
func (s *Server) handleCreateModelGroup(c *gin.Context) {
var group db.ModelGroup
if err := c.ShouldBindJSON(&group); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
_, err := s.database.Exec(`
INSERT INTO model_groups (id, strategy, selector_model, targets, complexity_threshold, heuristic_rules)
VALUES (?, ?, ?, ?, ?, ?)`,
group.ID, group.Strategy, group.SelectorModel, group.Targets,
group.ComplexityThreshold, group.HeuristicRules)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.refreshRouter()
c.JSON(http.StatusCreated, group)
}
func (s *Server) handleUpdateModelGroup(c *gin.Context) {
id := c.Param("id")
var group db.ModelGroup
if err := c.ShouldBindJSON(&group); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
_, err := s.database.Exec(`
UPDATE model_groups SET strategy=?, selector_model=?, targets=?, complexity_threshold=?, heuristic_rules=?, updated_at=CURRENT_TIMESTAMP
WHERE id=?`,
group.Strategy, group.SelectorModel, group.Targets,
group.ComplexityThreshold, group.HeuristicRules, id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.refreshRouter()
c.JSON(http.StatusOK, group)
}
func (s *Server) handleDeleteModelGroup(c *gin.Context) {
id := c.Param("id")
_, err := s.database.Exec("DELETE FROM model_groups WHERE id=?", id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.refreshRouter()
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
}
```
**Step 2: Register routes in setupRoutes()**
In `setupRoutes()`, add under the admin group (after the models endpoints around line 229):
```go
admin.GET("/model-groups", s.handleGetModelGroups)
admin.POST("/model-groups", s.handleCreateModelGroup)
admin.PUT("/model-groups/:id", s.handleUpdateModelGroup)
admin.DELETE("/model-groups/:id", s.handleDeleteModelGroup)
```
**Step 3: Build**
```bash
cd ~/Documents/projects/web_projects/gophergate && go build ./...
```
**Step 4: Commit**
```bash
git add internal/server/model_groups_admin.go internal/server/server.go
git commit -m "feat: add model groups CRUD admin API endpoints"
```
---
## Task 6: Add dashboard UI — sidebar entry and page module
**Objective:** Add a "Model Groups" tab to the dashboard sidebar and a page module for CRUD management.
**Files:**
- Modify: `static/index.html`
- Create: `static/js/pages/model_groups.js`
**Step 1: Add sidebar menu item in index.html**
In the MANAGEMENT section (after line 91, before `</ul>`), add:
```html
<li class="menu-item" data-page="model-groups">
<i class="fas fa-code-branch"></i>
<span>Model Groups</span>
</li>
```
**Step 2: Add script tag in index.html**
After the users.js script (line 179), add:
```html
<script src="/js/pages/model_groups.js?v=8"></script>
```
**Step 3: Create `static/js/pages/model_groups.js`**
```javascript
// Model Groups Management Page
class ModelGroupsPage {
constructor() {
this.container = document.getElementById('page-content');
}
async render() {
this.container.innerHTML = `
<div class="page-header">
<h3>Model Groups</h3>
<p class="text-muted">Define auto-routing groups that pick the best model for each request.</p>
<button class="btn btn-primary" onclick="modelGroupsPage.showCreateForm()">
<i class="fas fa-plus"></i> Add Group
</button>
</div>
<div id="model-groups-list" class="table-container"></div>
<div id="model-group-form" class="form-container" style="display:none;"></div>
`;
await this.loadGroups();
}
async loadGroups() {
try {
const groups = await api.get('/api/model-groups');
const list = document.getElementById('model-groups-list');
if (!groups || groups.length === 0) {
list.innerHTML = '<div class="empty-state">No model groups defined. Create one to enable auto-routing.</div>';
return;
}
let targets;
try { targets = JSON.parse(g.targets); } catch { targets = []; }
const heuristicRules = g.heuristic_rules ? JSON.parse(g.heuristic_rules) : null;
let html = '<table class="data-table"><thead><tr>';
html += '<th>Group ID</th><th>Strategy</th><th>Targets</th><th>Actions</th>';
html += '</tr></thead><tbody>';
groups.forEach(g => {
html += `<tr>
<td><code>${this.esc(g.id)}</code></td>
<td><span class="badge">${this.esc(g.strategy)}</span></td>
<td><code>${this.esc(g.targets)}</code></td>
<td>
<button class="btn btn-sm" onclick="modelGroupsPage.showEditForm('${this.esc(g.id)}')">Edit</button>
<button class="btn btn-sm btn-danger" onclick="modelGroupsPage.deleteGroup('${this.esc(g.id)}')">Delete</button>
</td>
</tr>`;
});
html += '</tbody></table>';
list.innerHTML = html;
} catch (err) {
document.getElementById('model-groups-list').innerHTML =
`<div class="error-message">Failed to load model groups: ${this.esc(err.message)}</div>`;
}
}
showCreateForm() {
this.renderForm(null);
}
async showEditForm(id) {
const groups = await api.get('/api/model-groups');
const group = groups.find(g => g.id === id);
if (group) this.renderForm(group);
}
renderForm(group) {
const isEdit = !!group;
const form = document.getElementById('model-group-form');
form.style.display = 'block';
form.innerHTML = `
<h4>${isEdit ? 'Edit' : 'Create'} Model Group</h4>
<form onsubmit="modelGroupsPage.saveGroup(event, ${isEdit})">
<div class="form-control">
<label>Group ID</label>
<input type="text" id="mg-id" value="${this.esc(group?.id || '')}" ${isEdit ? 'readonly' : 'required'}
placeholder="e.g. deepseek-auto">
<small>Clients use this as the model name.</small>
</div>
<div class="form-control">
<label>Strategy</label>
<select id="mg-strategy">
<option value="heuristic" ${group?.strategy === 'heuristic' ? 'selected' : ''}>Heuristic (rules-based)</option>
<option value="classifier" ${group?.strategy === 'classifier' ? 'selected' : ''}>Classifier (LLM judge)</option>
</select>
</div>
<div class="form-control">
<label>Targets (JSON array)</label>
<input type="text" id="mg-targets" value='${this.esc(group?.targets || '["cheap-model","smart-model"]')}' required>
<small>First target = cheapest/fastest. Last target = smartest/most expensive.</small>
</div>
<div class="form-control" id="mg-selector-row" ${group?.strategy === 'classifier' ? '' : 'style="display:none"'}>
<label>Selector Model</label>
<input type="text" id="mg-selector-model" value="${this.esc(group?.selector_model || 'gpt-4o-mini')}"
placeholder="Model used to judge task complexity">
</div>
<div class="form-control" id="mg-threshold-row" ${group?.strategy === 'classifier' ? '' : 'style="display:none"'}>
<label>Complexity Threshold</label>
<input type="number" id="mg-threshold" value="${group?.complexity_threshold || ''}" min="1"
placeholder="Tasks rated >= this go to the smart model">
</div>
<div class="form-control" id="mg-rules-row" ${group?.strategy === 'heuristic' ? '' : 'style="display:none"'}>
<label>Heuristic Rules (JSON array)</label>
<textarea id="mg-rules" rows="4" placeholder='[{"pattern":"step by step","target":1}]'>${group?.heuristic_rules || ''}</textarea>
<small>Pattern to match in user messages. target = index into targets array.</small>
</div>
<div class="form-actions">
<button type="submit" class="btn btn-primary">Save</button>
<button type="button" class="btn" onclick="document.getElementById('model-group-form').style.display='none'">Cancel</button>
</div>
</form>
`;
// Toggle strategy-specific fields
document.getElementById('mg-strategy').onchange = function() {
const isClassifier = this.value === 'classifier';
document.getElementById('mg-selector-row').style.display = isClassifier ? '' : 'none';
document.getElementById('mg-threshold-row').style.display = isClassifier ? '' : 'none';
document.getElementById('mg-rules-row').style.display = isClassifier ? 'none' : '';
};
}
async saveGroup(event, isEdit) {
event.preventDefault();
const id = document.getElementById('mg-id').value.trim();
const strategy = document.getElementById('mg-strategy').value;
const targets = document.getElementById('mg-targets').value;
const selectorModel = document.getElementById('mg-selector-model').value.trim() || null;
const thresholdVal = document.getElementById('mg-threshold').value;
const rules = document.getElementById('mg-rules').value.trim() || null;
// Validate JSON
try { JSON.parse(targets); } catch { alert('Targets must be valid JSON array'); return; }
if (rules) { try { JSON.parse(rules); } catch { alert('Heuristic rules must be valid JSON'); return; } }
const body = { id, strategy, targets, selector_model: selectorModel, heuristic_rules: rules };
if (thresholdVal) body.complexity_threshold = parseInt(thresholdVal);
try {
if (isEdit) {
await api.put(`/api/model-groups/${encodeURIComponent(id)}`, body);
} else {
await api.post('/api/model-groups', body);
}
document.getElementById('model-group-form').style.display = 'none';
await this.loadGroups();
} catch (err) {
alert('Failed to save: ' + err.message);
}
}
async deleteGroup(id) {
if (!confirm(`Delete model group "${id}"?`)) return;
try {
await api.delete(`/api/model-groups/${encodeURIComponent(id)}`);
await this.loadGroups();
} catch (err) {
alert('Failed to delete: ' + err.message);
}
}
esc(str) {
if (!str) return '';
return String(str).replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;').replace(/"/g,'&quot;');
}
}
const modelGroupsPage = new ModelGroupsPage();
```
**Step 4: Register page in dashboard.js**
In `static/js/dashboard.js`, find the page loading logic. The `loadPage` method dynamically imports page modules based on `this.currentPage`. The naming convention uses hyphens in `data-page` attributes (e.g., `data-page="model-groups"`). Check how the existing pages are loaded and ensure "model-groups" maps to the new module.
Looking at the existing pattern, pages are loaded via script tags in index.html and their constructors handle rendering when the page is navigated to. The dashboard.js `loadPage` method calls page-specific init. Let me check if there's a page registry pattern.
Actually, based on the index.html, pages are loaded as separate script files and the dashboard dispatches to them. The pattern seems to be: each page script defines a class or object, and the dashboard calls a `render()` or `init()` method on it when that page is selected. Let me add the dispatch logic.
In `dashboard.js`, find the `loadPage` method and ensure it handles "model-groups":
```javascript
// In the loadPage switch/if-else, add:
else if (page === 'model-groups') {
if (typeof modelGroupsPage !== 'undefined') {
modelGroupsPage.render();
}
}
```
**Step 5: Commit**
```bash
git add static/index.html static/js/pages/model_groups.js static/js/dashboard.js
git commit -m "feat: add model groups dashboard page with CRUD UI"
```
---
## Task 7: Integration test — build, run, verify
**Objective:** Ensure everything compiles and the routing works end-to-end.
**Step 1: Full build**
```bash
cd ~/Documents/projects/web_projects/gophergate && go build -o gophergate ./cmd/gophergate
```
**Step 2: Start server and test**
```bash
# In one terminal:
./gophergate
# In another terminal, test that default groups loaded:
curl -s -u admin:admin123 http://localhost:8080/api/model-groups | jq
# Expected: array with deepseek-auto and openai-auto groups
```
**Step 3: Test routing via API**
```bash
# Send a request using a model group
curl -s http://localhost:8080/v1/chat/completions \
-H "Authorization: Bearer YOUR_TOKEN" \
-H "Content-Type: application/json" \
-d '{
"model": "openai-auto",
"messages": [{"role": "user", "content": "What is 2+2?"}]
}' | jq
# Check server logs for [ROUTER] line showing the decision
```
**Step 4: Commit any fixes**
If any issues found during testing, fix and commit.
---
## Architecture Notes
### Why this approach
- **No Provider interface changes** — the router is a pre-processing step in the handler, transparent to providers
- **Groups stored in DB** — manageable from the dashboard, no config file sprawl
- **Classifier is optional** — heuristic mode works with zero added latency or cost
- **Fallback chain** — classifier failure falls back to heuristic; missing router falls back to direct passthrough
### Edge cases handled
- No groups defined → router never activates, all models pass through as before
- Unknown group ID → returns error to client
- Empty targets → returns error
- Classifier call fails → falls back to heuristic
- Classifier returns garbage → clamped to valid range
- OpenAI provider disabled → classifier groups fall back to heuristic mode
### What's NOT in this plan (future work)
- Streaming classifier support (the ~300ms classifier call happens before streaming begins — acceptable for now)
- responses endpoint routing (`handleResponses` could also use the router but needs a different message extraction)
- Per-client group overrides
- A/B testing / multi-armed bandit routing
- Caching classifier decisions for identical messages
+3
View File
@@ -46,6 +46,9 @@ Implements HMAC-SHA256 signed tokens for dashboard authentication. Tokens secure
### 5. WebSocket Hub (`internal/server/websocket.go`)
A centralized hub for managing WebSocket connections, allowing real-time broadcast of system events, system metrics, and request logs to the dashboard.
### 6. Model Group Router (`internal/router/`)
Resolves model groups (e.g., `deepseek-auto`, `dustins_stack`) into concrete models. It supports a Classifier strategy (uses a cheap LLM to rate complexity) and an upgraded Heuristic strategy (evaluates custom condition rules like tags, token counts, multimodal inputs, reasoning, and tool calling flags or legacy keyword patterns).
## Concurrency Model
Go's goroutines and channels are used extensively:
+87 -18
View File
@@ -7,11 +7,11 @@ A unified, high-performance LLM proxy gateway built in Go. It provides OpenAI-co
- **Unified API:** OpenAI-compatible `/v1/chat/completions`, `/v1/images/generations`, `/v1/responses`, and `/v1/models` endpoints.
- The `/v1/responses` endpoint (OpenAI Responses API) is currently supported for OpenAI models only. Non-OpenAI providers (Gemini, DeepSeek, Moonshot, Grok, Ollama) return a "not supported" response.
- **Multi-Provider Support:**
- **OpenAI:** GPT-4o, GPT-4o Mini, o1, o3 reasoning models, DALL-E 2/3 image generation.
- **Google Gemini:** Gemini 2.0 Flash, Pro, and vision models (with native CoT support), Imagen 3 image generation.
- **DeepSeek:** DeepSeek Chat and Reasoner (R1) models.
- **Moonshot:** Kimi K2.5 and other Kimi models.
- **xAI Grok:** Grok-4 models.
- **OpenAI:** GPT-4o, GPT-4o Mini, GPT-5, GPT-5.4, o1/o3/o4 reasoning models, DALL-E 2/3 image generation.
- **Google Gemini:** Gemini 2.5 Flash/Pro, Gemini 3 Flash/Pro previews, Imagen 3 image generation.
- **DeepSeek:** DeepSeek Chat, Reasoner, V4 Flash, V4 Pro.
- **Moonshot:** Kimi K2.5, K2.6 reasoning models.
- **xAI Grok:** Grok-3, Grok-4, Grok-4.3 reasoning models.
- **Ollama:** Local LLMs running on your network.
- **Observability & Tracking:**
- **Asynchronous Logging:** Non-blocking request logging to SQLite using background workers.
@@ -20,13 +20,24 @@ A unified, high-performance LLM proxy gateway built in Go. It provides OpenAI-co
- **Streaming Support:** Full SSE (Server-Sent Events) support for all providers.
- **Multimodal (Vision):** Image processing (Base64 and remote URLs) across compatible providers.
- **Image Generation:** DALL-E 2/3 (OpenAI) and Imagen 3 (Gemini) via OpenAI-compatible `/v1/images/generations` endpoint.
- **Automatic Model Routing:**
- **Hierarchical Routing:** Groups can target other groups, cascading through multiple levels until a concrete model is reached. Cycle detection and depth limiting (max 10) prevent infinite loops.
- **Heuristic strategy:** Free, zero-latency routing supporting both keyword matching (regex/substrings) and condition-based checks (evaluating tags, token limits, multimodal inputs, reasoning, and tool calling requirements).
- **Classifier strategy:** Uses a cheap LLM to rate task complexity on a configurable scale (1-10), then selects the appropriate model. Bucket mapping distributes ratings proportionally across targets.
- **Two-Level Dispatch:** A `dispatcher` group (classifier, threshold=10) auto-routes to tier groups by complexity score, which then apply their own internal strategies.
- **Metadata:** Groups support `logic_level` (1-10 complexity scale) and `primary_use` (description) fields for organizational clarity.
- Pre-seeded with provider groups, tier groups (heavy-logic / standard-pro / fast-flow), and a dispatcher. Model groups are exposed in `/v1/models` so clients can discover them.
- **Multi-User Access Control:**
- **Admin Role:** Full access to all dashboard features, user management, and system configuration.
- **Viewer Role:** Read-only access to usage analytics, costs, and monitoring.
- **Client API Keys:** Create and manage multiple client tokens for external integrations.
- **Reliability:**
- **Circuit Breaking:** Automatically protects when providers are down (coming soon).
- **Rate Limiting:** Per-client and global rate limits (coming soon).
- **Circuit Breaking:** Protects providers when they are down, auto-recovers after timeout.
- **Provider-Aware Classification:** Classifier selector models are routed to the correct provider automatically.
## DeepSeek Language Note
DeepSeek models default to Chinese for some prompts. GopherGate automatically injects an English system prompt ("Always respond in English.") when no system message is present. If the client provides its own system prompt, it is left untouched.
## Security
@@ -71,7 +82,9 @@ GopherGate is designed with security in mind:
# LLM_PROXY__ENCRYPTION_KEY=... (32-byte hex or base64 string)
# OPENAI_API_KEY=sk-...
# GEMINI_API_KEY=AIza...
# DEEPSEEK_API_KEY=sk-...
# MOONSHOT_API_KEY=...
# GROK_API_KEY=xai-...
# For Ollama (optional): Set base URL and enable
# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://localhost:11434/v1
# LLM_PROXY__PROVIDERS__OLLAMA__ENABLED=true
@@ -83,7 +96,16 @@ GopherGate is designed with security in mind:
./gophergate
```
The server starts on `http://0.0.0.0:8080` by default.
The server starts on `http://0.0.0.0:8080` by default. Configure `LLM_PROXY__SERVER__PORT` in `.env` to change it.
### Quick Deploy Script
A `deploy.sh` script is included for production restarts:
```bash
./deploy.sh
# git pull -> go build -> stop old process -> start new process
```
### Deployment (Docker)
@@ -106,6 +128,8 @@ Access the dashboard at `http://localhost:8080`.
- **Usage:** Summary stats, time-series analytics, and provider breakdown.
- **Clients:** API key management and per-client usage tracking.
- **Providers:** Provider configuration and status monitoring.
- **Model Groups:** Define auto-routing groups with heuristic or classifier strategies. Supports logic level and primary use metadata.
- **Models:** Model enable/disable and cost configuration.
- **Users:** Admin-only user management for dashboard access.
- **Monitoring:** Live request stream via WebSocket.
@@ -125,14 +149,6 @@ You can reset the admin password to default by running:
The proxy is a drop-in replacement for OpenAI. Configure your client:
Moonshot models are available through the same OpenAI-compatible endpoint. For
example, use `kimi-k2.5` as the model name after setting `MOONSHOT_API_KEY` in
your environment.
Ollama models (like `llama3`, `gemma2`, `mistral`) are also available through the same
endpoint after enabling Ollama in configuration and setting the base URL to your
Ollama server (default: `http://localhost:11434/v1`).
### Python
```python
@@ -170,7 +186,60 @@ response = client.responses.create(
print(response.output_text)
```
**Note:** The `/v1/responses` endpoint is currently supported for OpenAI models only. Requests routed to Gemini, DeepSeek, Moonshot, Grok, or Ollama models return a "not supported" error.
**Note:** The `/v1/responses` endpoint is currently supported for OpenAI models only.
### Automatic Model Routing
Use a model group name to let gophergate pick the best model automatically:
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8080/v1",
api_key="YOUR_CLIENT_API_KEY"
)
# Simple query -- routes to the cheap/fast model
response = client.chat.completions.create(
model="fast-flow",
messages=[{"role": "user", "content": "What is 2+2?"}]
)
# Complex query -- routes to the reasoning model automatically
response = client.chat.completions.create(
model="heavy-logic",
messages=[{"role": "user", "content": "Write a Python red-black tree implementation."}]
)
```
### Two-Level Dispatch
The `dispatcher` group uses a classifier to score prompts 1-10, then routes to the appropriate tier group:
```python
# Automatically routed based on complexity:
# 1-3 -> fast-flow (classification, basic Q&A)
# 4-7 -> standard-pro (general assistant, long docs)
# 8-10 -> heavy-logic (complex coding, logic, agents)
response = client.chat.completions.create(
model="dispatcher",
messages=[{"role": "user", "content": "Debug this race condition in my Go code."}]
)
# This goes: dispatcher -> heavy-logic -> deepseek-v4-pro
```
Pre-seeded groups:
| Group | Level | Strategy | Targets | Primary Use |
|-------|-------|----------|---------|-------------|
| `fast-flow` | 2 | heuristic | deepseek-v4-flash, gpt-5.4-nano | Classification, JSON, Basic Q&A |
| `standard-pro` | 5 | heuristic | gpt-5.4-mini, gemini-3-flash-preview | General Assistant, Long Docs |
| `heavy-logic` | 9 | heuristic | grok-4.3, kimi-k2.6, deepseek-v4-pro | Complex Coding, Logic, Agents |
| `dispatcher` | - | classifier | fast-flow, standard-pro, heavy-logic | Auto-dispatches by complexity |
| `deepseek-auto` | - | heuristic | deepseek-chat, deepseek-reasoner | Legacy provider group |
| `openai-auto` | - | heuristic | gpt-4o-mini, gpt-4o | Legacy provider group |
| `gemini-auto` | - | heuristic | gemini-2.0-flash, gemini-2.5-pro | Legacy provider group |
### Image Generation (DALL-E / Imagen)
@@ -191,7 +260,7 @@ resp = client.images.generate(
)
print(resp.data[0].url)
# Imagen 3 (Gemini) uses same endpoint
# Imagen 3 (Gemini) -- uses same endpoint
resp = client.images.generate(
model="imagen-3.0-generate-001",
prompt="A gopher coding in Go",
+14 -1
View File
@@ -15,11 +15,24 @@
- [x] Dashboard Analytics & Usage Summary (Fixed SQL robustness)
- [x] WebSocket for real-time dashboard updates (Hub with client counting)
- [x] Asynchronous Request Logging to SQLite
- [x] Update documentation (README, deployment, architecture)
- [x] Cost Tracking accuracy (Registry integration with `models.dev`)
- [x] Model Listing endpoint (`/v1/models`) with provider filtering
- [x] System Metrics endpoint (`/api/system/metrics` using `gopsutil`)
- [x] Fixed dashboard 404s and 500s
- [x] Model groups with heuristic and classifier routing strategies
- [x] Hierarchical routing — groups can target other groups with cycle detection
- [x] Classifier bucket mapping via complexity_threshold (1-10 scale -> N targets)
- [x] Two-level dispatch — classifier router delegates to tier groups
- [x] Model groups exposed in /v1/models endpoint (owned_by: gophergate)
- [x] logic_level and primary_use metadata on model groups
- [x] Model group CRUD dashboard page
- [x] dispatcher, heavy-logic, standard-pro, fast-flow seed groups
- [x] Provider selection moved after routing resolution (fixes group routing)
- [x] Classifier selector model routed to correct provider (selectProvider)
- [x] DeepSeek English system prompt injection (ensureEnglish)
- [x] Deploy script (deploy.sh)
- [x] Recent Activity pane shows resolved model + group annotation
- [x] Model names aligned with models.dev registry
## Planned Resolutions (High Priority)
+23
View File
@@ -0,0 +1,23 @@
#!/bin/bash
# Define the service name/path for easy updates
BINARY_NAME="gophergate"
SOURCE_PATH="./cmd/gophergate/main.go"
echo "Stopping existing $BINARY_NAME processes..."
# Using pkill; || true ensures the script continues even if no process was found
pkill -9 "$BINARY_NAME" || echo "No running process found."
echo "Pulling latest changes from git..."
git pull
echo "Building the application..."
if go build -o "$BINARY_NAME" "$SOURCE_PATH"; then
echo "Build successful. Starting $BINARY_NAME in the background..."
# Launch with nohup and redirect output to a log file
nohup "./$BINARY_NAME" > gophergate.log 2>&1 &
echo "Service started. PID: $!"
else
echo "Build failed! Keeping the previous state."
exit 1
fi
+17
View File
@@ -26,6 +26,22 @@ go build -o gophergate ./cmd/gophergate
./gophergate
```
### Quick Deploy Script
A `deploy.sh` script is provided for production restarts:
```bash
./deploy.sh
```
This script will:
1. Stop any running gophergate process
2. Pull latest changes from git
3. Build the application
4. Start it in the background (logs to `gophergate.log`)
If the build fails, the previous binary is left untouched and the script exits.
## Docker Deployment
The project includes a multi-stage `Dockerfile` for minimal image size.
@@ -50,3 +66,4 @@ docker run -d \
- **SSL/TLS:** It is recommended to run the proxy behind a reverse proxy like Nginx or Caddy for SSL termination.
- **Backups:** Regularly backup the `data/llm_proxy.db` file.
- **Monitoring:** Monitor the `/health` endpoint for system status.
- **Logs:** When started with `deploy.sh` or `nohup`, logs are written to `gophergate.log`.
+15
View File
@@ -37,6 +37,7 @@ type ProviderConfig struct {
Moonshot MoonshotConfig `mapstructure:"moonshot"`
Grok GrokConfig `mapstructure:"grok"`
Ollama OllamaConfig `mapstructure:"ollama"`
Xiaomi XiaomiConfig `mapstructure:"xiaomi"`
}
type OpenAIConfig struct {
@@ -81,6 +82,13 @@ type OllamaConfig struct {
Models []string `mapstructure:"models"`
}
type XiaomiConfig struct {
APIKeyEnv string `mapstructure:"api_key_env"`
BaseURL string `mapstructure:"base_url"`
DefaultModel string `mapstructure:"default_model"`
Enabled bool `mapstructure:"enabled"`
}
func Load() (*Config, error) {
v := viper.New()
@@ -120,6 +128,11 @@ func Load() (*Config, error) {
v.SetDefault("providers.ollama.enabled", false)
v.SetDefault("providers.ollama.models", []string{})
v.SetDefault("providers.xiaomi.api_key_env", "XIAOMI_API_KEY")
v.SetDefault("providers.xiaomi.base_url", "https://api.xiaomimimo.com/v1")
v.SetDefault("providers.xiaomi.default_model", "mimo-v2.5")
v.SetDefault("providers.xiaomi.enabled", true)
// Environment variables
v.SetEnvPrefix("LLM_PROXY")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "__"))
@@ -210,6 +223,8 @@ func (c *Config) GetAPIKey(provider string) (string, error) {
case "ollama":
// Ollama doesn't require an API key
return "", nil
case "xiaomi":
envVar = c.Providers.Xiaomi.APIKeyEnv
default:
return "", fmt.Errorf("unknown provider: %s", provider)
}
+67
View File
@@ -32,6 +32,14 @@ func Init(path string) (*DB, error) {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
// Enable Write-Ahead Logging (WAL) and set a busy timeout to handle concurrent access
if _, err := db.Exec("PRAGMA journal_mode=WAL;"); err != nil {
log.Printf("failed to enable WAL mode: %v", err)
}
if _, err := db.Exec("PRAGMA busy_timeout=5000;"); err != nil {
log.Printf("failed to set busy timeout: %v", err)
}
instance := &DB{db}
// Run migrations
@@ -122,6 +130,18 @@ func (db *DB) RunMigrations() error {
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_used_at DATETIME,
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE CASCADE
)`,
`CREATE TABLE IF NOT EXISTS model_groups (
id TEXT PRIMARY KEY,
strategy TEXT NOT NULL DEFAULT 'heuristic',
selector_model TEXT,
targets TEXT NOT NULL DEFAULT '[]',
complexity_threshold INTEGER,
heuristic_rules TEXT,
logic_level INTEGER,
primary_use TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
}
@@ -152,6 +172,10 @@ func (db *DB) RunMigrations() error {
}
}
// Add columns to existing model_groups tables (safe — SQLite ignores duplicates on error)
db.Exec("ALTER TABLE model_groups ADD COLUMN logic_level INTEGER")
db.Exec("ALTER TABLE model_groups ADD COLUMN primary_use TEXT")
// Default admin user
var count int
if err := db.Get(&count, "SELECT COUNT(*) FROM users"); err != nil {
@@ -177,6 +201,25 @@ func (db *DB) RunMigrations() error {
return fmt.Errorf("failed to insert default client: %w", err)
}
// Seed default model groups
defaultGroups := []struct {
id, strategy, targets, selectorModel string
complexityThreshold, logicLevel *int
primaryUse *string
}{
{"deepseek-auto", "heuristic", `["deepseek-chat","deepseek-reasoner"]`, "", nil, nil, nil},
{"openai-auto", "heuristic", `["gpt-4o-mini","gpt-4o"]`, "", nil, nil, nil},
{"gemini-auto", "heuristic", `["gemini-2.0-flash","gemini-2.5-pro"]`, "", nil, nil, nil},
{"heavy-logic", "heuristic", `["grok-4.3","kimi-k2.6","deepseek-v4-pro"]`, "", nil, intPtr(9), strPtr("Complex Coding, Logic, Agents.")},
{"standard-pro", "heuristic", `["gpt-5.4-mini","gemini-3-flash-preview"]`, "", nil, intPtr(5), strPtr("General Assistant, Long Docs.")},
{"fast-flow", "heuristic", `["deepseek-v4-flash","gpt-5.4-nano"]`, "", nil, intPtr(2), strPtr("Classification, JSON, Basic Q&A.")},
{"dispatcher", "classifier", `["fast-flow","standard-pro","heavy-logic"]`, "gpt-5.4-nano", intPtr(10), nil, strPtr("Auto-dispatches to tier groups by complexity.")},
}
for _, g := range defaultGroups {
db.Exec(`INSERT OR IGNORE INTO model_groups (id, strategy, targets, selector_model, complexity_threshold, logic_level, primary_use) VALUES (?, ?, ?, ?, ?, ?, ?)`,
g.id, g.strategy, g.targets, nilStr(g.selectorModel), g.complexityThreshold, g.logicLevel, g.primaryUse)
}
return nil
}
@@ -262,3 +305,27 @@ type ClientToken struct {
CreatedAt time.Time `db:"created_at"`
LastUsedAt *time.Time `db:"last_used_at"`
}
type ModelGroup struct {
ID string `db:"id" json:"id"`
Strategy string `db:"strategy" json:"strategy"`
SelectorModel *string `db:"selector_model" json:"selector_model"`
Targets string `db:"targets" json:"targets"` // JSON array
ComplexityThreshold *int `db:"complexity_threshold" json:"complexity_threshold"`
HeuristicRules *string `db:"heuristic_rules" json:"heuristic_rules"`
LogicLevel *int `db:"logic_level" json:"logic_level"`
PrimaryUse *string `db:"primary_use" json:"primary_use"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
func intPtr(v int) *int { return &v }
func strPtr(v string) *string { return &v }
// nilStr returns a *string for non-empty strings, nil for empty.
func nilStr(v string) *string {
if v == "" {
return nil
}
return &v
}
+44 -5
View File
@@ -14,9 +14,21 @@ import (
func AuthMiddleware(database *db.DB, requireAuth bool) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
// Fallback to checking "Authentication" header in case the client library used the wrong name
authHeader = c.GetHeader("Authentication")
}
if authHeader == "" {
if requireAuth {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"})
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": "Missing Authorization or Authentication header.",
"type": "invalid_request_error",
"param": nil,
"code": "401",
},
})
return
}
c.Next()
@@ -25,23 +37,50 @@ func AuthMiddleware(database *db.DB, requireAuth bool) gin.HandlerFunc {
token := strings.TrimPrefix(authHeader, "Bearer ")
if token == authHeader { // No "Bearer " prefix
if requireAuth {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": "Invalid authorization header format. Bearer token required.",
"type": "invalid_request_error",
"param": nil,
"code": "401",
},
})
return
}
c.Next()
return
}
// Try to resolve client from database
// Try to resolve client from database with a read-only SELECT
var clientID string
err := database.Get(&clientID, "UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = 1 RETURNING client_id", token)
err := database.Get(&clientID, "SELECT client_id FROM client_tokens WHERE token = ? AND is_active = 1", token)
if err == nil {
c.Set("auth", models.AuthInfo{
Token: token,
ClientID: clientID,
})
// Update last_used_at asynchronously so that database locks or write delays
// do not block or fail the client's request authentication.
go func(t string) {
if _, updateErr := database.Exec("UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ?", t); updateErr != nil {
log.Printf("Warning: failed to update client token last_used_at: %v", updateErr)
}
}(token)
c.Next()
} else {
log.Printf("Token not found or inactive in DB: %s", token)
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid or inactive token"})
log.Printf("Token not found, inactive or error in DB: %s (err: %v)", token, err)
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": "Invalid or inactive client token.",
"type": "invalid_request_error",
"param": nil,
"code": "401",
},
})
}
}
}
+155 -22
View File
@@ -2,6 +2,25 @@ package models
import "strings"
// CanonicalProviders lists the original model creators in priority order.
// When a model name exists in multiple providers (e.g. deepseek-v4-pro in
// deepseek, ollama-cloud, openrouter, etc.), these providers take precedence
// so the proxy uses authoritative metadata (pricing, limits) rather than a
// reseller's values.
var CanonicalProviders = []string{
"openai",
"google",
"deepseek",
"xai",
"moonshotai",
"moonshotai-cn",
"anthropic",
"mistral",
"cohere",
"minimax",
"xiaomi",
}
type ModelRegistry struct {
Providers map[string]ProviderInfo `json:"-"`
}
@@ -39,40 +58,154 @@ type ModelModalities struct {
Output []string `json:"output"`
}
func (r *ModelRegistry) FindModel(modelID string) *ModelMetadata {
// First try exact match in models map
for _, provider := range r.Providers {
if model, ok := provider.Models[modelID]; ok {
return &model
}
}
// Try searching by ID in metadata
for _, provider := range r.Providers {
for _, model := range provider.Models {
if model.ID == modelID {
return &model
// findInCanonical searches the canonical providers in order for an exact model
// key match. Returns the metadata and true if found.
func (r *ModelRegistry) findInCanonical(modelID string) (*ModelMetadata, bool) {
for _, key := range CanonicalProviders {
if p, ok := r.Providers[key]; ok {
if m, ok := p.Models[modelID]; ok {
return &m, true
}
}
}
return nil, false
}
// Try reverse fuzzy matching (e.g. 'gpt-5.4-mini' matching 'gpt-5.4-mini-2026-04-01')
for _, provider := range r.Providers {
for id, model := range provider.Models {
// findInAll searches all providers (map iteration, random order) for an exact
// model key match. Used as fallback when canonical search fails.
func (r *ModelRegistry) findInAll(modelID string) (*ModelMetadata, bool) {
for _, p := range r.Providers {
if m, ok := p.Models[modelID]; ok {
return &m, true
}
}
return nil, false
}
// findInCanonicalByID searches canonical providers for a model whose metadata
// ID field matches modelID.
func (r *ModelRegistry) findInCanonicalByID(modelID string) (*ModelMetadata, bool) {
for _, key := range CanonicalProviders {
if p, ok := r.Providers[key]; ok {
for _, m := range p.Models {
if m.ID == modelID {
return &m, true
}
}
}
}
return nil, false
}
// findInAllByID searches all providers for a model whose metadata ID field
// matches modelID.
func (r *ModelRegistry) findInAllByID(modelID string) (*ModelMetadata, bool) {
for _, p := range r.Providers {
for _, m := range p.Models {
if m.ID == modelID {
return &m, true
}
}
}
return nil, false
}
// findCanonicalReverseFuzzy searches canonical providers for any model whose
// key starts with modelID.
func (r *ModelRegistry) findCanonicalReverseFuzzy(modelID string) (*ModelMetadata, bool) {
for _, key := range CanonicalProviders {
if p, ok := r.Providers[key]; ok {
for id, m := range p.Models {
if strings.HasPrefix(id, modelID) {
return &m, true
}
}
}
}
return nil, false
}
// findAllReverseFuzzy searches all providers for any model whose key starts
// with modelID.
func (r *ModelRegistry) findAllReverseFuzzy(modelID string) (*ModelMetadata, bool) {
for _, p := range r.Providers {
for id, m := range p.Models {
if strings.HasPrefix(id, modelID) {
return &model
return &m, true
}
}
}
return nil, false
}
// Try fuzzy matching (e.g. 'gpt-4o-2024-05-13' matching 'gpt-4o')
for _, provider := range r.Providers {
for id, model := range provider.Models {
if strings.HasPrefix(modelID, id) {
return &model
// findCanonicalForwardFuzzy searches canonical providers for any model whose
// key is a prefix of modelID.
func (r *ModelRegistry) findCanonicalForwardFuzzy(modelID string) (*ModelMetadata, bool) {
for _, key := range CanonicalProviders {
if p, ok := r.Providers[key]; ok {
for id, m := range p.Models {
if strings.HasPrefix(modelID, id) {
return &m, true
}
}
}
}
return nil, false
}
// findAllForwardFuzzy searches all providers for any model whose key is a
// prefix of modelID.
func (r *ModelRegistry) findAllForwardFuzzy(modelID string) (*ModelMetadata, bool) {
for _, p := range r.Providers {
for id, m := range p.Models {
if strings.HasPrefix(modelID, id) {
return &m, true
}
}
}
return nil, false
}
// FindModel looks up model metadata by ID. It searches canonical providers
// first at each strategy level (exact key, metadata ID, reverse fuzzy,
// forward fuzzy) and falls back to all providers only when canonical search
// yields no result. This prevents reseller entries (ollama-cloud, openrouter,
// etc.) from overriding the original provider's authoritative pricing and
// limits.
func (r *ModelRegistry) FindModel(modelID string) *ModelMetadata {
// 1. Exact key match — canonical first, then all
if m, ok := r.findInCanonical(modelID); ok {
return m
}
if m, ok := r.findInAll(modelID); ok {
return m
}
// 2. Match by metadata ID field — canonical first, then all
if m, ok := r.findInCanonicalByID(modelID); ok {
return m
}
if m, ok := r.findInAllByID(modelID); ok {
return m
}
// 3. Reverse fuzzy: model key starts with modelID
// e.g. 'gpt-5.4-mini' matching 'gpt-5.4-mini-2026-04-01'
if m, ok := r.findCanonicalReverseFuzzy(modelID); ok {
return m
}
if m, ok := r.findAllReverseFuzzy(modelID); ok {
return m
}
// 4. Forward fuzzy: modelID starts with model key
// e.g. 'gpt-4o-2024-05-13' matching 'gpt-4o'
if m, ok := r.findCanonicalForwardFuzzy(modelID); ok {
return m
}
if m, ok := r.findAllForwardFuzzy(modelID); ok {
return m
}
return nil
}
+29
View File
@@ -59,6 +59,35 @@ func TestModelRegistry_FindModel_NotFound(t *testing.T) {
}
}
func TestModelRegistry_FindModel_CanonicalPriority(t *testing.T) {
// Same model name in canonical (deepseek) and reseller (ollama-cloud).
// Canonical must win so the proxy uses authoritative limits.
r := &ModelRegistry{
Providers: map[string]ProviderInfo{
"ollama-cloud": {
Models: map[string]ModelMetadata{
"deepseek-v4-pro": {ID: "deepseek-v4-pro", Name: "DSv4 Pro (Ollama Cloud)", Limit: &ModelLimit{Context: 1048576, Output: 1048576}},
},
},
"deepseek": {
Models: map[string]ModelMetadata{
"deepseek-v4-pro": {ID: "deepseek-v4-pro", Name: "DeepSeek v4 Pro", Limit: &ModelLimit{Context: 1000000, Output: 384000}},
},
},
},
}
m := r.FindModel("deepseek-v4-pro")
if m == nil {
t.Fatal("expected to find deepseek-v4-pro")
}
if m.Name != "DeepSeek v4 Pro" {
t.Fatalf("expected DeepSeek v4 Pro (canonical), got %s", m.Name)
}
if m.Limit.Output != 384000 {
t.Fatalf("expected output limit 384000 (canonical), got %d", m.Limit.Output)
}
}
func TestModelRegistry_FindModel_ReverseFuzzy(t *testing.T) {
r := &ModelRegistry{
Providers: map[string]ProviderInfo{
+51 -16
View File
@@ -62,6 +62,9 @@ func (u *deepSeekUsage) ToUnified() *models.Usage {
}
func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
// Ensure English responses — DeepSeek defaults to Chinese for some prompts
ensureEnglish(req)
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
@@ -69,19 +72,26 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi
body := BuildOpenAIBody(req, messagesJSON, false)
// Sanitize for deepseek-reasoner
if req.Model == "deepseek-reasoner" {
delete(body, "temperature")
delete(body, "top_p")
delete(body, "presence_penalty")
delete(body, "frequency_penalty")
// Sanitize for models that support reasoning/thinking mode
isReasoner := strings.Contains(req.Model, "reasoner") || strings.Contains(req.Model, "v4") || strings.Contains(req.Model, "r1")
if isReasoner {
// deepseek-reasoner (R1) does not support these parameters
if req.Model == "deepseek-reasoner" || strings.HasPrefix(req.Model, "deepseek-r1") {
delete(body, "temperature")
delete(body, "top_p")
delete(body, "presence_penalty")
delete(body, "frequency_penalty")
}
if msgs, ok := body["messages"].([]interface{}); ok {
for _, m := range msgs {
if msg, ok := m.(map[string]interface{}); ok {
if msg["role"] == "assistant" {
// DeepSeek requires reasoning_content to be passed back in history
// if the model is in thinking mode.
if msg["reasoning_content"] == nil {
msg["reasoning_content"] = " "
msg["reasoning_content"] = ""
}
if msg["content"] == nil || msg["content"] == "" {
msg["content"] = ""
@@ -103,7 +113,15 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
var msg string
if resp.RawBody() != nil {
bodyBytes, _ := io.ReadAll(resp.RawBody())
msg = string(bodyBytes)
}
if msg == "" {
msg = resp.String()
}
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), msg)
}
var respJSON map[string]interface{}
@@ -129,6 +147,8 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi
}
func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
ensureEnglish(req)
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
@@ -136,19 +156,26 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models
body := BuildOpenAIBody(req, messagesJSON, true)
// Sanitize for deepseek-reasoner
if req.Model == "deepseek-reasoner" {
delete(body, "temperature")
delete(body, "top_p")
delete(body, "presence_penalty")
delete(body, "frequency_penalty")
// Sanitize for models that support reasoning/thinking mode
isReasoner := strings.Contains(req.Model, "reasoner") || strings.Contains(req.Model, "v4") || strings.Contains(req.Model, "r1")
if isReasoner {
// deepseek-reasoner (R1) does not support these parameters
if req.Model == "deepseek-reasoner" || strings.HasPrefix(req.Model, "deepseek-r1") {
delete(body, "temperature")
delete(body, "top_p")
delete(body, "presence_penalty")
delete(body, "frequency_penalty")
}
if msgs, ok := body["messages"].([]interface{}); ok {
for _, m := range msgs {
if msg, ok := m.(map[string]interface{}); ok {
if msg["role"] == "assistant" {
// DeepSeek requires reasoning_content to be passed back in history
// if the model is in thinking mode.
if msg["reasoning_content"] == nil {
msg["reasoning_content"] = " "
msg["reasoning_content"] = ""
}
if msg["content"] == nil || msg["content"] == "" {
msg["content"] = ""
@@ -171,7 +198,15 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
var msg string
if resp.RawBody() != nil {
bodyBytes, _ := io.ReadAll(resp.RawBody())
msg = string(bodyBytes)
}
if msg == "" {
msg = resp.String()
}
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), msg)
}
ch := make(chan *models.ChatCompletionStreamResponse)
+41 -20
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"time"
@@ -125,7 +126,13 @@ func (p *GeminiProvider) ImageGeneration(ctx context.Context, req *models.ImageG
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Gemini Imagen API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Gemini Imagen API error (%d): %s", resp.StatusCode(), msg)
}
// Parse Imagen response
@@ -331,6 +338,7 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
}
// Map Tools
hasMappedTools := false
if len(req.Tools) > 0 {
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
for _, t := range req.Tools {
@@ -338,13 +346,16 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
}
}
body.Tools = []GeminiTool{geminiTool}
if len(geminiTool.FunctionDeclarations) > 0 {
body.Tools = []GeminiTool{geminiTool}
hasMappedTools = true
}
}
baseURL := p.config.BaseURL
lowerModel := strings.ToLower(req.Model)
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
// Use v1beta for preview and newer models
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools {
// Use v1beta for preview, newer models, or when using tools
if !strings.Contains(baseURL, "v1beta") {
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
}
@@ -363,11 +374,17 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
}
if !resp.IsSuccess() {
fmt.Printf("[Gemini] API Error %d: %s\n", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
fmt.Printf("[Gemini] API Error %d: %s\n", resp.StatusCode(), msg)
// Also log the request body for debugging (careful with API keys if logged elsewhere)
reqJSON, _ := json.Marshal(body)
fmt.Printf("[Gemini] Request Body: %s\n", string(reqJSON))
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), msg)
}
// Parse Gemini response and convert to OpenAI format
@@ -565,6 +582,7 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
GenerationConfig: genConfig,
}
hasMappedTools := false
if len(req.Tools) > 0 {
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
for _, t := range req.Tools {
@@ -572,13 +590,16 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
}
}
body.Tools = []GeminiTool{geminiTool}
if len(geminiTool.FunctionDeclarations) > 0 {
body.Tools = []GeminiTool{geminiTool}
hasMappedTools = true
}
}
baseURL := p.config.BaseURL
lowerModel := strings.ToLower(req.Model)
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
// Use v1beta for preview and newer models
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools {
// Use v1beta for preview, newer models, or when using tools
if !strings.Contains(baseURL, "v1beta") {
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
}
@@ -599,19 +620,19 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), msg)
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
err := StreamGemini(resp.RawBody(), ch, req.Model)
if err != nil {
fmt.Printf("Gemini Stream error: %v\n", err)
}
}()
ch, err := StreamGemini(resp.RawBody(), req.Model)
if err != nil {
return nil, fmt.Errorf("gemini stream init error: %w", err)
}
return ch, nil
}
+15 -2
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"time"
"github.com/go-resty/resty/v2"
@@ -48,7 +49,13 @@ func (p *GrokProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRe
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), msg)
}
var respJSON map[string]interface{}
@@ -79,7 +86,13 @@ func (p *GrokProvider) ChatCompletionStream(ctx context.Context, req *models.Uni
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), msg)
}
ch := make(chan *models.ChatCompletionStreamResponse)
+251 -88
View File
@@ -10,11 +10,32 @@ import (
"gophergate/internal/models"
)
func sanitizeFunctionName(name string) string {
var sb strings.Builder
for _, ch := range name {
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' {
sb.WriteRune(ch)
} else {
sb.WriteRune('_')
}
}
res := sb.String()
if res == "" {
return "function"
}
return res
}
// MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images.
func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) {
var result []interface{}
for _, m := range messages {
if m.Role == "tool" {
role := strings.ToLower(m.Role)
if role == "model" {
role = "assistant"
}
if role == "tool" || role == "function" {
text := ""
if len(m.Content) > 0 {
text = m.Content[0].Text
@@ -23,15 +44,14 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
"role": "tool",
"content": text,
}
id := "unknown"
if m.ToolCallID != nil {
id := *m.ToolCallID
if len(id) > 40 {
id = id[:40]
}
msg["tool_call_id"] = id
id = *m.ToolCallID
}
msg["tool_call_id"] = id
if m.Name != nil {
msg["name"] = *m.Name
msg["name"] = sanitizeFunctionName(*m.Name)
}
result = append(result, msg)
continue
@@ -59,7 +79,9 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
}
var finalContent interface{}
if len(parts) == 1 {
if len(parts) == 0 {
finalContent = nil
} else if len(parts) == 1 {
if p, ok := parts[0].(map[string]interface{}); ok && p["type"] == "text" {
finalContent = p["text"]
} else {
@@ -70,7 +92,7 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
}
msg := map[string]interface{}{
"role": m.Role,
"role": role,
"content": finalContent,
}
@@ -82,20 +104,18 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
sanitizedCalls := make([]models.ToolCall, len(m.ToolCalls))
copy(sanitizedCalls, m.ToolCalls)
for i := range sanitizedCalls {
if len(sanitizedCalls[i].ID) > 40 {
sanitizedCalls[i].ID = sanitizedCalls[i].ID[:40]
if sanitizedCalls[i].Type == "" {
sanitizedCalls[i].Type = "function"
}
sanitizedCalls[i].Function.Name = sanitizeFunctionName(sanitizedCalls[i].Function.Name)
}
msg["tool_calls"] = sanitizedCalls
if len(parts) == 0 {
msg["content"] = ""
}
msg["content"] = "" // OpenAI requirement: content must be string if tool_calls present
}
if m.Name != nil {
msg["name"] = *m.Name
}
result = append(result, msg)
}
return result, nil
@@ -121,11 +141,25 @@ func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{},
body["max_tokens"] = *request.MaxTokens
}
if len(request.Tools) > 0 {
body["tools"] = request.Tools
sanitizedTools := make([]models.Tool, len(request.Tools))
copy(sanitizedTools, request.Tools)
for i := range sanitizedTools {
if sanitizedTools[i].Type == "function" {
sanitizedTools[i].Function.Name = sanitizeFunctionName(sanitizedTools[i].Function.Name)
}
}
body["tools"] = sanitizedTools
}
if request.ToolChoice != nil {
var toolChoice interface{}
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
if tcMap, ok := toolChoice.(map[string]interface{}); ok {
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
if name, ok := funcMap["name"].(string); ok {
funcMap["name"] = sanitizeFunctionName(name)
}
}
}
body["tool_choice"] = toolChoice
}
}
@@ -361,87 +395,216 @@ func StreamOpenAI(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
return scanner.Err()
}
func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error {
defer ctx.Close()
// geminiStreamChunk is the shared data structure for parsing Gemini streaming responses.
type geminiStreamChunk struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text,omitempty"`
Thought string `json:"thought,omitempty"`
} `json:"parts"`
} `json:"content"`
FinishReason string `json:"finishReason"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount uint32 `json:"promptTokenCount"`
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
TotalTokenCount uint32 `json:"totalTokenCount"`
CachedContentTokenCount uint32 `json:"cachedContentTokenCount"`
} `json:"usageMetadata"`
}
dec := json.NewDecoder(ctx)
t, err := dec.Token()
if err != nil {
return err
// emitGeminiChunk builds a ChatCompletionStreamResponse from a parsed geminiStreamChunk
// and sends it to the channel. Returns true if anything was emitted.
func emitGeminiChunk(ch chan<- *models.ChatCompletionStreamResponse, chunk *geminiStreamChunk, model string) bool {
if len(chunk.Candidates) == 0 && chunk.UsageMetadata.TotalTokenCount == 0 {
return false
}
if delim, ok := t.(json.Delim); ok && delim == '[' {
for dec.More() {
var geminiChunk struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text,omitempty"`
Thought string `json:"thought,omitempty"`
} `json:"parts"`
} `json:"content"`
FinishReason string `json:"finishReason"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount uint32 `json:"promptTokenCount"`
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
TotalTokenCount uint32 `json:"totalTokenCount"`
CachedContentTokenCount uint32 `json:"cachedContentTokenCount"`
} `json:"usageMetadata"`
content := ""
var reasoning *string
var finishReason *string
if len(chunk.Candidates) > 0 {
for _, p := range chunk.Candidates[0].Content.Parts {
if p.Text != "" {
content += p.Text
}
if err := dec.Decode(&geminiChunk); err != nil {
return err
if p.Thought != "" {
if reasoning == nil {
reasoning = new(string)
}
*reasoning += p.Thought
}
}
fr := strings.ToLower(chunk.Candidates[0].FinishReason)
finishReason = &fr
}
if len(geminiChunk.Candidates) > 0 || geminiChunk.UsageMetadata.TotalTokenCount > 0 {
content := ""
var reasoning *string
if len(geminiChunk.Candidates) > 0 {
for _, p := range geminiChunk.Candidates[0].Content.Parts {
if p.Text != "" {
content += p.Text
}
if p.Thought != "" {
if reasoning == nil {
reasoning = new(string)
}
*reasoning += p.Thought
}
}
}
ch <- &models.ChatCompletionStreamResponse{
ID: "gemini-stream",
Object: "chat.completion.chunk",
Created: 0,
Model: model,
Choices: []models.ChatStreamChoice{
{
Index: 0,
Delta: models.ChatStreamDelta{
Content: &content,
ReasoningContent: reasoning,
},
FinishReason: finishReason,
},
},
Usage: &models.Usage{
PromptTokens: chunk.UsageMetadata.PromptTokenCount,
CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount,
TotalTokens: chunk.UsageMetadata.TotalTokenCount,
CacheReadTokens: uint32Ptr(chunk.UsageMetadata.CachedContentTokenCount),
},
}
return true
}
var finishReason *string
if len(geminiChunk.Candidates) > 0 {
fr := strings.ToLower(geminiChunk.Candidates[0].FinishReason)
finishReason = &fr
}
// StreamGemini handles Gemini streaming responses in two formats:
// 1. SSE format (newer models): each line is "data: {...}"
// 2. JSON array format (older models): response body is [ {...}, {...} ]
//
// Usage metadata is only present in the final chunk, which we accumulate
// and emit so the server can log it on stream end.
func StreamGemini(ctx io.ReadCloser, model string) (<-chan *models.ChatCompletionStreamResponse, error) {
ch := make(chan *models.ChatCompletionStreamResponse)
ch <- &models.ChatCompletionStreamResponse{
ID: "gemini-stream",
Object: "chat.completion.chunk",
Created: 0,
Model: model,
Choices: []models.ChatStreamChoice{
{
Index: 0,
Delta: models.ChatStreamDelta{
Content: &content,
ReasoningContent: reasoning,
},
FinishReason: finishReason,
},
},
Usage: &models.Usage{
PromptTokens: geminiChunk.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiChunk.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiChunk.UsageMetadata.TotalTokenCount,
CacheReadTokens: uint32Ptr(geminiChunk.UsageMetadata.CachedContentTokenCount),
},
}
go func() {
defer func() {
_ = ctx.Close()
}()
defer close(ch)
// Peek at the first byte to detect format
peek := make([]byte, 6)
n, _ := io.ReadAtLeast(ctx, peek, 1)
if n == 0 {
return
}
first := string(peek[:n])
if first[0] == '[' {
// JSON array format
rest, _ := io.ReadAll(ctx)
streamGeminiJSONArray(append([]byte(first), rest...), ch, model)
return
} else if strings.HasPrefix(first, "data:") || strings.HasPrefix(first, "data: ") {
// SSE format — pre-pend the peeked bytes then run SSE scanner
combined := io.MultiReader(
strings.NewReader(string(peek[:n])),
ctx,
)
streamGeminiSSE(combined, ch, model)
} else {
// Unknown format — might still be SSE starting after a peek char
// Pre-pend peeked bytes and try SSE
combined := io.MultiReader(
strings.NewReader(string(peek[:n])),
ctx,
)
streamGeminiSSE(combined, ch, model)
}
}()
return ch, nil
}
// readAll reads remaining bytes from a reader (keeps the function signature simple
// for the JSON array fallback path).
func readAll(r io.Reader) []byte {
b, _ := io.ReadAll(r)
return b
}
func streamGeminiJSONArray(data []byte, ch chan<- *models.ChatCompletionStreamResponse, model string) {
var chunks []geminiStreamChunk
if err := json.Unmarshal(data, &chunks); err != nil {
fmt.Printf("[Gemini-Stream] JSON array parse error: %v\n", err)
return
}
// Track the last chunk with usage for the final emission
var lastUsage *geminiStreamChunk
for i := range chunks {
if chunks[i].UsageMetadata.TotalTokenCount > 0 {
lastUsage = &chunks[i]
}
}
if lastUsage != nil {
// Emit a synthetic final chunk with usage data
if len(lastUsage.Candidates) == 0 && lastUsage.UsageMetadata.TotalTokenCount > 0 {
emitGeminiChunk(ch, lastUsage, model)
}
}
// Also emit each content-bearing chunk
for i := range chunks {
emitGeminiChunk(ch, &chunks[i], model)
}
}
func streamGeminiSSE(r io.Reader, ch chan<- *models.ChatCompletionStreamResponse, model string) {
scanner := bufio.NewScanner(r)
// Track the last seen usage for emission at end of stream
var lastUsage geminiStreamChunk
for scanner.Scan() {
line := scanner.Text()
if line == "" || !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
// Emit final usage if we have one
if lastUsage.UsageMetadata.TotalTokenCount > 0 {
emitGeminiChunk(ch, &lastUsage, model)
}
break
}
var chunk geminiStreamChunk
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
// Capture usage from any chunk (Gemini puts it in the final response)
if chunk.UsageMetadata.TotalTokenCount > 0 {
lastUsage = chunk
}
// Emit content chunks as they arrive
if len(chunk.Candidates) > 0 {
emitGeminiChunk(ch, &chunk, model)
}
}
return nil
// If stream ended without [DONE] marker but we collected usage, emit it
if lastUsage.UsageMetadata.TotalTokenCount > 0 {
emitGeminiChunk(ch, &lastUsage, model)
}
if err := scanner.Err(); err != nil {
fmt.Printf("[Gemini-Stream] SSE scan error: %v\n", err)
}
}
// ensureEnglish injects a system message instructing the model to respond in
// English when no system prompt is already present. Some providers (e.g. DeepSeek)
// default to Chinese for certain prompts.
func ensureEnglish(req *models.UnifiedRequest) {
if len(req.Messages) > 0 && req.Messages[0].Role == "system" {
return // already has a system prompt, don't interfere
}
enMsg := models.UnifiedMessage{
Role: "system",
Content: []models.UnifiedContentPart{
{Type: "text", Text: "You are a helpful assistant. Always respond in English."},
},
}
req.Messages = append([]models.UnifiedMessage{enMsg}, req.Messages...)
}
+127
View File
@@ -0,0 +1,127 @@
package providers
import (
"encoding/json"
"testing"
"gophergate/internal/models"
)
func TestSanitizeFunctionName(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"google-search", "google-search"},
{"google.search", "google_search"},
{"google search", "google_search"},
{"web_search(query)", "web_search_query_"},
{"", "function"},
{"123_abc-XYZ", "123_abc-XYZ"},
{"invalid.name.with.dots", "invalid_name_with_dots"},
}
for _, tc := range tests {
actual := sanitizeFunctionName(tc.input)
if actual != tc.expected {
t.Errorf("sanitizeFunctionName(%q) = %q; expected %q", tc.input, actual, tc.expected)
}
}
}
func TestMessagesToOpenAIJSON_SanitizeToolCalls(t *testing.T) {
messages := []models.UnifiedMessage{
{
Role: "assistant",
Content: []models.UnifiedContentPart{
{Type: "text", Text: "I will use search."},
},
ToolCalls: []models.ToolCall{
{
ID: "call_1",
Type: "function",
Function: models.FunctionCall{
Name: "google.search",
Arguments: `{"query": "hello"}`,
},
},
},
},
{
Role: "tool",
Content: []models.UnifiedContentPart{
{Type: "text", Text: `{"result": "success"}`},
},
ToolCallID: stringPtr("call_1"),
Name: stringPtr("google.search"),
},
}
res, err := MessagesToOpenAIJSON(messages)
if err != nil {
t.Fatalf("MessagesToOpenAIJSON failed: %v", err)
}
if len(res) != 2 {
t.Fatalf("expected 2 messages, got %d", len(res))
}
// Verify assistant message
msg1 := res[0].(map[string]interface{})
if msg1["role"] != "assistant" {
t.Errorf("expected role assistant, got %v", msg1["role"])
}
calls := msg1["tool_calls"].([]models.ToolCall)
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
if calls[0].Function.Name != "google_search" {
t.Errorf("expected function name google_search, got %q", calls[0].Function.Name)
}
// Verify tool response message
msg2 := res[1].(map[string]interface{})
if msg2["role"] != "tool" {
t.Errorf("expected role tool, got %v", msg2["role"])
}
if msg2["name"] != "google_search" {
t.Errorf("expected tool name google_search, got %v", msg2["name"])
}
}
func TestBuildOpenAIBody_SanitizeToolsAndChoice(t *testing.T) {
req := &models.UnifiedRequest{
Model: "gpt-4o",
Tools: []models.Tool{
{
Type: "function",
Function: models.FunctionDef{
Name: "google.search",
},
},
},
ToolChoice: json.RawMessage(`{"type": "function", "function": {"name": "google.search"}}`),
}
body := BuildOpenAIBody(req, nil, false)
// Verify tools
tools := body["tools"].([]models.Tool)
if len(tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(tools))
}
if tools[0].Function.Name != "google_search" {
t.Errorf("expected tool function name google_search, got %q", tools[0].Function.Name)
}
// Verify tool_choice
toolChoice := body["tool_choice"].(map[string]interface{})
funcObj := toolChoice["function"].(map[string]interface{})
if funcObj["name"] != "google_search" {
t.Errorf("expected tool_choice function name google_search, got %q", funcObj["name"])
}
}
func stringPtr(s string) *string {
return &s
}
+15 -2
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"time"
@@ -59,7 +60,13 @@ func (p *MoonshotProvider) ChatCompletion(ctx context.Context, req *models.Unifi
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), msg)
}
var respJSON map[string]interface{}
@@ -100,7 +107,13 @@ func (p *MoonshotProvider) ChatCompletionStream(ctx context.Context, req *models
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), msg)
}
ch := make(chan *models.ChatCompletionStreamResponse)
+14 -2
View File
@@ -56,7 +56,13 @@ func (p *OllamaProvider) ChatCompletion(ctx context.Context, req *models.Unified
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), msg)
}
var respJSON map[string]interface{}
@@ -86,7 +92,13 @@ func (p *OllamaProvider) ChatCompletionStream(ctx context.Context, req *models.U
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), msg)
}
ch := make(chan *models.ChatCompletionStreamResponse)
+57 -3
View File
@@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"strings"
"time"
@@ -38,6 +40,17 @@ func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.Unified
body := BuildOpenAIBody(req, messagesJSON, false)
// Debug message sequence
for i, m := range messagesJSON {
mMap, _ := m.(map[string]interface{})
role, _ := mMap["role"].(string)
hasToolCalls := false
if tc, ok := mMap["tool_calls"]; ok && tc != nil {
hasToolCalls = true
}
log.Printf("[DEBUG] OpenAI Msg[%d]: role=%s, hasToolCalls=%v", i, role, hasToolCalls)
}
// Transition: Newer models require max_completion_tokens
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
if maxTokens, ok := body["max_tokens"]; ok {
@@ -57,7 +70,17 @@ func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.Unified
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if b := resp.Body(); len(b) > 0 {
msg = string(b)
}
}
// Log the request body for debugging
reqJSON, _ := json.Marshal(body)
log.Printf("OpenAI API Error (%d): %s", resp.StatusCode(), msg)
log.Printf("OpenAI request body: %s", string(reqJSON))
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), msg)
}
var respJSON map[string]interface{}
@@ -104,7 +127,13 @@ func (p *OpenAIProvider) ImageGeneration(ctx context.Context, req *models.ImageG
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("OpenAI image API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("OpenAI image API error (%d): %s", resp.StatusCode(), msg)
}
var result models.ImageGenerationResponse
@@ -123,6 +152,17 @@ func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.U
body := BuildOpenAIBody(req, messagesJSON, true)
// Debug message sequence
for i, m := range messagesJSON {
mMap, _ := m.(map[string]interface{})
role, _ := mMap["role"].(string)
hasToolCalls := false
if tc, ok := mMap["tool_calls"]; ok && tc != nil {
hasToolCalls = true
}
log.Printf("[DEBUG] OpenAI Stream Msg[%d]: role=%s, hasToolCalls=%v", i, role, hasToolCalls)
}
// Transition: Newer models require max_completion_tokens
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
if maxTokens, ok := body["max_tokens"]; ok {
@@ -143,7 +183,21 @@ func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.U
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if b := resp.Body(); len(b) > 0 {
msg = string(b)
}
if msg == "" {
if b, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(b)
}
}
}
reqJSON, _ := json.Marshal(body)
log.Printf("OpenAI API Error (%d): %s", resp.StatusCode(), msg)
log.Printf("OpenAI request body: %s", string(reqJSON))
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), msg)
}
ch := make(chan *models.ChatCompletionStreamResponse)
+15 -2
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"gophergate/internal/models"
)
@@ -26,7 +27,13 @@ func (p *OpenAIProvider) Responses(ctx context.Context, req *models.ResponsesReq
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("OpenAI Responses API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("OpenAI Responses API error (%d): %s", resp.StatusCode(), msg)
}
var respJSON map[string]interface{}
@@ -53,7 +60,13 @@ func (p *OpenAIProvider) ResponsesStream(ctx context.Context, req *models.Respon
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("OpenAI Responses API error (%d): %s", resp.StatusCode(), resp.String())
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("OpenAI Responses API error (%d): %s", resp.StatusCode(), msg)
}
ch := make(chan *models.ResponsesStreamChunk)
+133
View File
@@ -0,0 +1,133 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"time"
"github.com/go-resty/resty/v2"
"gophergate/internal/config"
"gophergate/internal/models"
)
type XiaomiProvider struct {
client *resty.Client
config config.XiaomiConfig
apiKey string
}
func NewXiaomiProvider(cfg config.XiaomiConfig, apiKey string) *XiaomiProvider {
return &XiaomiProvider{
client: resty.New().SetTimeout(10 * time.Minute),
config: cfg,
apiKey: strings.TrimSpace(apiKey),
}
}
func (p *XiaomiProvider) Name() string {
return "xiaomi"
}
func (p *XiaomiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
body := BuildOpenAIBody(req, messagesJSON, false)
baseURL := strings.TrimRight(p.config.BaseURL, "/")
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetHeader("Content-Type", "application/json").
SetHeader("Accept", "application/json").
SetBody(body).
Post(fmt.Sprintf("%s/chat/completions", baseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
msg := resp.String()
if msg == "" {
if b := resp.Body(); len(b) > 0 {
msg = string(b)
}
}
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Xiaomi API error (%d): %s", resp.StatusCode(), msg)
}
var respJSON map[string]interface{}
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return ParseOpenAIResponse(respJSON, req.Model)
}
func (p *XiaomiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
body := BuildOpenAIBody(req, messagesJSON, true)
baseURL := strings.TrimRight(p.config.BaseURL, "/")
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetHeader("Content-Type", "application/json").
SetHeader("Accept", "text/event-stream").
SetBody(body).
SetDoNotParseResponse(true).
Post(fmt.Sprintf("%s/chat/completions", baseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
msg := resp.String()
if msg == "" {
if body, err := io.ReadAll(resp.RawBody()); err == nil {
msg = string(body)
}
}
return nil, fmt.Errorf("Xiaomi API error (%d): %s", resp.StatusCode(), msg)
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
if err := StreamOpenAI(resp.RawBody(), ch); err != nil {
fmt.Printf("Xiaomi Stream error: %v\n", err)
}
}()
return ch, nil
}
func (p *XiaomiProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
return nil, fmt.Errorf("xiaomi does not support image generation")
}
func (p *XiaomiProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
return nil, fmt.Errorf("responses API not supported by xiaomi")
}
func (p *XiaomiProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
return nil, fmt.Errorf("responses API not supported by xiaomi")
}
+76
View File
@@ -0,0 +1,76 @@
package router
import (
"context"
"fmt"
"strconv"
"strings"
"gophergate/internal/db"
)
const classifierSystemPrompt = `You are a task complexity classifier. Rate the following user message on a scale of 1 to %d, where:
1 = trivial/simple (basic facts, greetings, simple math)
%d = highly complex (multi-step reasoning, code generation, architecture design)
Reply with ONLY the number. No explanation.`
func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
// Determine the rating scale
maxRating := len(targets)
if maxRating < 2 {
maxRating = 2
}
// When complexity_threshold is set, use it as a wider scale (e.g., 1-10)
// and map ratings proportionally to target buckets.
bucketMode := group.ComplexityThreshold != nil && *group.ComplexityThreshold > 0
if bucketMode {
maxRating = *group.ComplexityThreshold
}
prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating)
userMsg := ""
if routeCtx != nil {
userMsg = routeCtx.UserMessage
}
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMsg)
if err != nil {
// Classifier failed — fall back to heuristic
return routeHeuristic(group, targets, routeCtx)
}
rating, err := strconv.Atoi(strings.TrimSpace(ratingStr))
if err != nil || rating < 1 {
rating = 1
}
if rating > maxRating {
rating = maxRating
}
var idx int
if bucketMode {
// Proportional mapping: wider scale → N target buckets
// e.g., threshold=10, 3 targets: 1-3→0, 4-7→1, 8-10→2
idx = rating * len(targets) / (maxRating + 1)
if idx >= len(targets) {
idx = len(targets) - 1
}
} else {
idx = rating - 1 // 1:1 mapping
}
return &Decision{
SelectedModel: targets[idx],
Strategy: "classifier",
Reason: fmt.Sprintf("complexity rating: %d/%d", rating, maxRating),
}, nil
}
func getSelectorModel(group db.ModelGroup, targets []string) string {
if group.SelectorModel != nil && *group.SelectorModel != "" {
return *group.SelectorModel
}
// Default: use the first (cheapest) target model as the selector
return targets[0]
}
+219
View File
@@ -0,0 +1,219 @@
package router
import (
"encoding/json"
"regexp"
"strings"
"gophergate/internal/db"
)
// HeuristicRule defines a pattern-based routing rule (legacy format).
type HeuristicRule struct {
Pattern string `json:"pattern"`
TargetIdx int `json:"target"`
CaseSensitive bool `json:"case_sensitive,omitempty"`
}
// ConditionRule defines a condition-based routing rule (new format).
type ConditionRule struct {
RuleID string `json:"rule_id"`
Description string `json:"description,omitempty"`
Conditions Conditions `json:"conditions"`
PrimaryModel string `json:"primary_model"`
FallbackModel string `json:"fallback_model,omitempty"`
}
// Conditions defines the matching parameters for a rule.
type Conditions struct {
AnyOfTags []string `json:"any_of_tags,omitempty"`
MaxInputTokensLt *int `json:"max_input_tokens_lt,omitempty"`
RequiresReasoning *bool `json:"requires_reasoning,omitempty"`
RequiresToolCalling *bool `json:"requires_tool_calling,omitempty"`
HasMultimodalInput *bool `json:"has_multimodal_input,omitempty"`
IsDefaultFallback *bool `json:"is_default_fallback,omitempty"`
}
func routeHeuristic(group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
if routeCtx == nil {
routeCtx = &RouteContext{}
}
selected := targets[0]
reason := "default (first target)"
// If heuristic_rules is set, determine format and parse
if group.HeuristicRules != nil && *group.HeuristicRules != "" {
rulesJSON := *group.HeuristicRules
if isConditionBasedRules(rulesJSON) {
var condRules []ConditionRule
if err := json.Unmarshal([]byte(rulesJSON), &condRules); err == nil {
for _, rule := range condRules {
if matchConditions(rule.Conditions, routeCtx) {
// Resolve primary/fallback to concrete models in target list
targetModel := ""
if rule.PrimaryModel != "" {
targetModel = getModelInTargets(rule.PrimaryModel, targets)
}
if targetModel == "" && rule.FallbackModel != "" {
targetModel = getModelInTargets(rule.FallbackModel, targets)
}
if targetModel != "" {
selected = targetModel
reason = "matched condition rule: " + rule.RuleID
if rule.Description != "" {
reason += " (" + rule.Description + ")"
}
break
}
}
}
}
} else {
// Fallback to legacy pattern-based rules
var legacyRules []HeuristicRule
if err := json.Unmarshal([]byte(rulesJSON), &legacyRules); err == nil {
searchMsg := routeCtx.UserMessage
for _, rule := range legacyRules {
pattern := rule.Pattern
if pattern == "" {
continue // Avoid infinite matches with empty patterns
}
msg := searchMsg
if !rule.CaseSensitive {
pattern = strings.ToLower(pattern)
msg = strings.ToLower(msg)
}
// Support both regex matching (if pattern is valid regex) and literal contains
matched := false
if strings.Contains(rule.Pattern, "(") || strings.Contains(rule.Pattern, "\\b") {
var re *regexp.Regexp
var err error
if !rule.CaseSensitive {
re, err = regexp.Compile("(?i)" + rule.Pattern)
} else {
re, err = regexp.Compile(rule.Pattern)
}
if err == nil {
matched = re.MatchString(routeCtx.UserMessage)
}
}
if !matched && strings.Contains(msg, pattern) {
matched = true
}
if matched {
if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) {
selected = targets[rule.TargetIdx]
reason = "matched heuristic rule: " + rule.Pattern
break
}
}
}
}
}
}
// Built-in fallback heuristics (if no custom rule matched)
if reason == "default (first target)" && len(targets) > 1 {
msgLower := strings.ToLower(routeCtx.UserMessage)
complexIndicators := []string{
"step by step", "explain in detail", "reason through",
"think carefully", "analyze", "debug", "write code",
"implement", "refactor", "architecture",
}
for _, indicator := range complexIndicators {
if strings.Contains(msgLower, indicator) {
selected = targets[len(targets)-1]
reason = "complex task indicator: " + indicator
break
}
}
}
return &Decision{
SelectedModel: selected,
Strategy: "heuristic",
Reason: reason,
}, nil
}
// isConditionBasedRules returns true if the JSON represents condition-based rules.
func isConditionBasedRules(rulesJSON string) bool {
var rules []ConditionRule
if err := json.Unmarshal([]byte(rulesJSON), &rules); err == nil && len(rules) > 0 {
// If the rule has either conditions or primary_model/rule_id, treat it as condition-based
return rules[0].PrimaryModel != "" || rules[0].RuleID != ""
}
return false
}
// matchConditions evaluates whether the given conditions match the RouteContext.
func matchConditions(cond Conditions, routeCtx *RouteContext) bool {
if cond.IsDefaultFallback != nil && *cond.IsDefaultFallback {
return true
}
// Check tags: must match any_of_tags if specified
if len(cond.AnyOfTags) > 0 {
tagMatched := false
for _, ruleTag := range cond.AnyOfTags {
for _, ctxTag := range routeCtx.Tags {
if strings.EqualFold(ruleTag, ctxTag) {
tagMatched = true
break
}
}
if tagMatched {
break
}
}
if !tagMatched {
return false
}
}
// Check max input tokens
if cond.MaxInputTokensLt != nil {
if routeCtx.InputTokens >= *cond.MaxInputTokensLt {
return false
}
}
// Check reasoning flag
if cond.RequiresReasoning != nil {
if routeCtx.RequiresReasoning != *cond.RequiresReasoning {
return false
}
}
// Check tool calling flag
if cond.RequiresToolCalling != nil {
if routeCtx.RequiresToolCalling != *cond.RequiresToolCalling {
return false
}
}
// Check multimodal flag
if cond.HasMultimodalInput != nil {
if routeCtx.HasMultimodalInput != *cond.HasMultimodalInput {
return false
}
}
return true
}
// getModelInTargets returns the model name if it exists in targets, or empty string.
func getModelInTargets(modelName string, targets []string) string {
for _, t := range targets {
if strings.EqualFold(t, modelName) {
return t
}
}
return ""
}
+142
View File
@@ -0,0 +1,142 @@
package router
import (
"testing"
"gophergate/internal/db"
)
func TestRouteHeuristic_ConditionRules(t *testing.T) {
targets := []string{
"deepseek-v4-flash", // index 0
"gemini-3-flash", // index 1
"grok-build-0.1", // index 2
"kimi-k2.6", // index 3
"mimo-v2.5-pro", // index 4
"grok-4.3", // index 5
"deepseek-v4-pro", // index 6
}
rulesJSON := `[
{
"rule_id": "fast_flow_extraction",
"conditions": {
"any_of_tags": ["fast-flow", "classification"],
"max_input_tokens_lt": 8000,
"requires_reasoning": false
},
"primary_model": "deepseek-v4-flash",
"fallback_model": "grok-build-0.1"
},
{
"rule_id": "multimodal_long_context",
"conditions": {
"any_of_tags": ["standard-pro", "long-doc"],
"has_multimodal_input": true
},
"primary_model": "gemini-3-flash",
"fallback_model": "mimo-v2.5-pro"
},
{
"rule_id": "regional_fallback_general",
"conditions": {
"is_default_fallback": true
},
"primary_model": "kimi-k2.6"
}
]`
group := db.ModelGroup{
ID: "dustins_stack",
Strategy: "heuristic",
HeuristicRules: &rulesJSON,
}
// 1. Test Match Fast Flow (condition success)
ctx1 := &RouteContext{
UserMessage: "classify this JSON",
InputTokens: 500,
HasMultimodalInput: false,
RequiresReasoning: false,
Tags: []string{"fast-flow", "classification"},
}
dec1, err := routeHeuristic(group, targets, ctx1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec1.SelectedModel != "deepseek-v4-flash" {
t.Fatalf("expected deepseek-v4-flash, got %s", dec1.SelectedModel)
}
// 2. Test Multimodal Long Context (condition success)
ctx2 := &RouteContext{
UserMessage: "explain this video",
InputTokens: 15000,
HasMultimodalInput: true,
RequiresReasoning: false,
Tags: []string{"standard-pro", "video-analysis"},
}
dec2, err := routeHeuristic(group, targets, ctx2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec2.SelectedModel != "gemini-3-flash" {
t.Fatalf("expected gemini-3-flash, got %s", dec2.SelectedModel)
}
// 3. Test Fallback general rule
ctx3 := &RouteContext{
UserMessage: "hello there",
InputTokens: 100,
HasMultimodalInput: false,
RequiresReasoning: false,
Tags: []string{"general"},
}
dec3, err := routeHeuristic(group, targets, ctx3)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec3.SelectedModel != "kimi-k2.6" {
t.Fatalf("expected kimi-k2.6, got %s", dec3.SelectedModel)
}
}
func TestRouteHeuristic_LegacyRules(t *testing.T) {
targets := []string{"gpt-4o-mini", "deepseek-v4-pro", "kimi-k2.6"}
// Legacy pattern-based rule with regex
rulesJSON := `[
{"pattern": "\\b(agent|agents|tool use)\\b", "target": 1},
{"pattern": "summarize", "target": 2}
]`
group := db.ModelGroup{
ID: "heavy-logic",
Strategy: "heuristic",
HeuristicRules: &rulesJSON,
}
// 1. Test regex match
ctx1 := &RouteContext{
UserMessage: "We need an agent to do tool use",
}
dec1, err := routeHeuristic(group, targets, ctx1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec1.SelectedModel != "deepseek-v4-pro" {
t.Fatalf("expected deepseek-v4-pro, got %s", dec1.SelectedModel)
}
// 2. Test literal match
ctx2 := &RouteContext{
UserMessage: "Please summarize this text",
}
dec2, err := routeHeuristic(group, targets, ctx2)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if dec2.SelectedModel != "kimi-k2.6" {
t.Fatalf("expected kimi-k2.6, got %s", dec2.SelectedModel)
}
}
+139
View File
@@ -0,0 +1,139 @@
package router
import (
"context"
"encoding/json"
"fmt"
"strings"
"gophergate/internal/db"
)
// Decision holds the result of a routing decision.
type Decision struct {
SelectedModel string `json:"selected_model"`
Strategy string `json:"strategy"` // "heuristic" or "classifier"
Reason string `json:"reason"`
}
// RouteContext holds metadata of the request to evaluate condition rules.
type RouteContext struct {
UserMessage string `json:"user_message"`
InputTokens int `json:"input_tokens"`
HasMultimodalInput bool `json:"has_multimodal_input"`
RequiresToolCalling bool `json:"requires_tool_calling"`
RequiresReasoning bool `json:"requires_reasoning"`
Tags []string `json:"tags"`
}
// ClassifierFunc is the callback for classifier-based routing.
type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error)
// Router resolves model groups to concrete models.
type Router struct {
groups map[string]db.ModelGroup
classify ClassifierFunc
}
// New creates a Router. classify may be nil if no classifier groups exist.
func New(groups []db.ModelGroup, classify ClassifierFunc) *Router {
r := &Router{
groups: make(map[string]db.ModelGroup),
classify: classify,
}
for _, g := range groups {
r.groups[g.ID] = g
}
return r
}
// Groups returns all registered model group IDs.
func (r *Router) Groups() []string {
ids := make([]string, 0, len(r.groups))
for id := range r.groups {
ids = append(ids, id)
}
return ids
}
// IsGroup returns true if the model name is a group ID.
func (r *Router) IsGroup(modelID string) bool {
_, ok := r.groups[modelID]
return ok
}
// Route resolves a group to a concrete model.
func (r *Router) Route(ctx context.Context, groupID string, routeCtx *RouteContext) (*Decision, error) {
group, ok := r.groups[groupID]
if !ok {
return nil, fmt.Errorf("unknown model group: %s", groupID)
}
var targets []string
if err := json.Unmarshal([]byte(group.Targets), &targets); err != nil || len(targets) == 0 {
return nil, fmt.Errorf("invalid or empty targets for group %s", groupID)
}
switch group.Strategy {
case "heuristic":
return routeHeuristic(group, targets, routeCtx)
case "classifier":
if r.classify == nil {
return routeHeuristic(group, targets, routeCtx)
}
return routeClassifier(ctx, r.classify, group, targets, routeCtx)
default:
return nil, fmt.Errorf("unknown strategy: %s", group.Strategy)
}
}
// RouteToConcrete resolves a model name to a concrete model, following group
// chains recursively until a non-group target is reached. Returns the original
// name unchanged if it is not a group.
func (r *Router) RouteToConcrete(ctx context.Context, modelID string, routeCtx *RouteContext) (*Decision, error) {
const maxDepth = 10
visited := make(map[string]bool)
current := modelID
var chain []*Decision
for depth := 0; depth < maxDepth; depth++ {
if !r.IsGroup(current) {
// Build a composite reason showing the chain traversed
reason := "direct"
if len(chain) > 0 {
parts := make([]string, len(chain))
for i, d := range chain {
parts[i] = d.SelectedModel + " (" + d.Reason + ")"
}
reason = strings.Join(parts, " -> ")
}
return &Decision{
SelectedModel: current,
Strategy: "hierarchical",
Reason: reason,
}, nil
}
if visited[current] {
return nil, fmt.Errorf("routing cycle detected: group %s already visited", current)
}
visited[current] = true
decision, err := r.Route(ctx, current, routeCtx)
if err != nil {
return nil, err
}
chain = append(chain, decision)
current = decision.SelectedModel
}
return nil, fmt.Errorf("routing depth exceeded: reached max depth of %d", maxDepth)
}
// Reload replaces the group definitions without recreating the router.
func (r *Router) Reload(groups []db.ModelGroup) {
r.groups = make(map[string]db.ModelGroup)
for _, g := range groups {
r.groups[g.ID] = g
}
}
+1
View File
@@ -12,6 +12,7 @@ type RequestLog struct {
ClientID string `json:"client_id"`
Provider string `json:"provider"`
Model string `json:"model"`
ModelGroup string `json:"model_group,omitempty"`
PromptTokens uint32 `json:"prompt_tokens"`
CompletionTokens uint32 `json:"completion_tokens"`
ReasoningTokens uint32 `json:"reasoning_tokens"`
+76
View File
@@ -0,0 +1,76 @@
package server
import (
"net/http"
"gophergate/internal/db"
"github.com/gin-gonic/gin"
)
func (s *Server) handleGetModelGroups(c *gin.Context) {
var groups []db.ModelGroup
if err := s.database.Select(&groups, "SELECT * FROM model_groups ORDER BY id"); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if groups == nil {
groups = []db.ModelGroup{}
}
c.JSON(http.StatusOK, SuccessResponse(groups))
}
func (s *Server) handleCreateModelGroup(c *gin.Context) {
var group db.ModelGroup
if err := c.ShouldBindJSON(&group); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
_, err := s.database.Exec(`
INSERT INTO model_groups (id, strategy, selector_model, targets, complexity_threshold, heuristic_rules, logic_level, primary_use)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
group.ID, group.Strategy, group.SelectorModel, group.Targets,
group.ComplexityThreshold, group.HeuristicRules, group.LogicLevel, group.PrimaryUse)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.refreshRouter()
c.JSON(http.StatusCreated, SuccessResponse(group))
}
func (s *Server) handleUpdateModelGroup(c *gin.Context) {
id := c.Param("id")
var group db.ModelGroup
if err := c.ShouldBindJSON(&group); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
_, err := s.database.Exec(`
UPDATE model_groups SET strategy=?, selector_model=?, targets=?, complexity_threshold=?, heuristic_rules=?, logic_level=?, primary_use=?, updated_at=CURRENT_TIMESTAMP
WHERE id=?`,
group.Strategy, group.SelectorModel, group.Targets,
group.ComplexityThreshold, group.HeuristicRules, group.LogicLevel, group.PrimaryUse, id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.refreshRouter()
c.JSON(http.StatusOK, SuccessResponse(group))
}
func (s *Server) handleDeleteModelGroup(c *gin.Context) {
id := c.Param("id")
_, err := s.database.Exec("DELETE FROM model_groups WHERE id=?", id)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.refreshRouter()
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
}
+1
View File
@@ -19,6 +19,7 @@ func (s *Server) handleGetModels(c *gin.Context) {
"deepseek": "deepseek",
"xai": "grok",
"ollama": "ollama",
"xiaomi": "xiaomi",
}
// Merge registry models with DB overrides
+10 -1
View File
@@ -25,7 +25,7 @@ func (s *Server) handleGetProviders(c *gin.Context) {
dbMap[cfg.ID] = cfg
}
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama"}
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama", "xiaomi"}
var result []gin.H
for _, id := range providerIDs {
@@ -54,6 +54,10 @@ func (s *Server) handleGetProviders(c *gin.Context) {
name = "xAI Grok"
enabled = s.cfg.Providers.Grok.Enabled
baseURL = s.cfg.Providers.Grok.BaseURL
case "xiaomi":
name = "Xiaomi MiMo"
enabled = s.cfg.Providers.Xiaomi.Enabled
baseURL = s.cfg.Providers.Xiaomi.BaseURL
case "ollama":
name = "Ollama"
enabled = s.cfg.Providers.Ollama.Enabled
@@ -109,6 +113,9 @@ func (s *Server) handleGetProviders(c *gin.Context) {
if id == "grok" {
registryID = "xai"
}
if id == "xiaomi" {
registryID = "xiaomi"
}
if pInfo, ok := s.registry.Providers[registryID]; ok {
for mID := range pInfo.Models {
@@ -226,6 +233,8 @@ func (s *Server) handleTestProvider(c *gin.Context) {
testReq.Model = "kimi-k2.5"
} else if name == "grok" {
testReq.Model = "grok-4-1-fast-non-reasoning"
} else if name == "xiaomi" {
testReq.Model = "mimo-v2.5"
}
_, err := provider.ChatCompletion(c.Request.Context(), testReq)
+362 -59
View File
@@ -2,6 +2,7 @@ package server
import (
"encoding/json"
"context"
"fmt"
"io"
"log"
@@ -15,6 +16,7 @@ import (
"gophergate/internal/middleware"
"gophergate/internal/models"
"gophergate/internal/providers"
"gophergate/internal/router"
"gophergate/internal/utils"
"github.com/gin-gonic/gin"
@@ -30,6 +32,7 @@ type Server struct {
logger *RequestLogger
registry *models.ModelRegistry
registryMu sync.RWMutex
modelRouter *router.Router
}
func NewServer(cfg *config.Config, database *db.DB) *Server {
@@ -64,6 +67,9 @@ func NewServer(cfg *config.Config, database *db.DB) *Server {
}
s.setupRoutes()
// Initialize model group router
s.refreshRouter()
return s
}
@@ -79,7 +85,7 @@ func (s *Server) RefreshProviders() error {
dbMap[cfg.ID] = cfg
}
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama"}
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama", "xiaomi"}
for _, id := range providerIDs {
// Default values from config
enabled := false
@@ -107,6 +113,10 @@ func (s *Server) RefreshProviders() error {
enabled = s.cfg.Providers.Grok.Enabled
baseURL = s.cfg.Providers.Grok.BaseURL
apiKey, _ = s.cfg.GetAPIKey("grok")
case "xiaomi":
enabled = s.cfg.Providers.Xiaomi.Enabled
baseURL = s.cfg.Providers.Xiaomi.BaseURL
apiKey, _ = s.cfg.GetAPIKey("xiaomi")
}
// Overrides from DB
@@ -161,6 +171,10 @@ func (s *Server) RefreshProviders() error {
cfg := s.cfg.Providers.Ollama
cfg.BaseURL = baseURL
p = providers.NewOllamaProvider(cfg)
case "xiaomi":
cfg := s.cfg.Providers.Xiaomi
cfg.BaseURL = baseURL
p = providers.NewXiaomiProvider(cfg, apiKey)
}
if p != nil {
@@ -168,9 +182,53 @@ func (s *Server) RefreshProviders() error {
}
}
s.refreshRouter()
return nil
}
func (s *Server) refreshRouter() {
var groups []db.ModelGroup
if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil {
fmt.Printf("Warning: Failed to load model groups: %v\n", err)
groups = nil
}
var classifyFn router.ClassifierFunc
classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) {
provider, _, err := s.selectProvider(selectorModel)
if err != nil {
return "", err
}
req := &models.UnifiedRequest{
Model: selectorModel,
Messages: []models.UnifiedMessage{
{Role: "system", Content: []models.UnifiedContentPart{{Type: "text", Text: systemPrompt}}},
{Role: "user", Content: []models.UnifiedContentPart{{Type: "text", Text: userMessage}}},
},
MaxTokens: uint32Ptr(5),
Stream: false,
}
resp, err := provider.ChatCompletion(ctx, req)
if err != nil {
return "", err
}
if len(resp.Choices) == 0 {
return "", fmt.Errorf("no choices in classifier response")
}
content, ok := resp.Choices[0].Message.Content.(string)
if !ok {
return "", fmt.Errorf("classifier response content is not a string")
}
return content, nil
}
if s.modelRouter == nil {
s.modelRouter = router.New(groups, classifyFn)
} else {
s.modelRouter.Reload(groups)
}
}
func (s *Server) setupRoutes() {
// Static files
s.router.StaticFile("/", "./static/index.html")
@@ -228,6 +286,11 @@ func (s *Server) setupRoutes() {
admin.GET("/models", s.handleGetModels)
admin.PUT("/models/:id", s.handleUpdateModel)
admin.GET("/model-groups", s.handleGetModelGroups)
admin.POST("/model-groups", s.handleCreateModelGroup)
admin.PUT("/model-groups/:id", s.handleUpdateModelGroup)
admin.DELETE("/model-groups/:id", s.handleDeleteModelGroup)
admin.GET("/users", s.handleGetUsers)
admin.POST("/users", s.handleCreateUser)
admin.PUT("/users/:id", s.handleUpdateUser)
@@ -254,9 +317,33 @@ func (s *Server) handleResponses(c *gin.Context) {
return
}
// Select provider based on model name
// Strip common prefixes and resolve model groups to concrete models
// (same pattern as handleChatCompletions).
modelGroup := ""
modelID := req.Model
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/", "xiaomi/"}
for _, p := range prefixes {
if strings.HasPrefix(modelID, p) {
modelID = strings.TrimPrefix(modelID, p)
break
}
}
if s.modelRouter != nil {
routeCtx := s.buildRouteContextFromResponses(req)
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
return
}
if decision.SelectedModel != modelID {
modelGroup = modelID
}
modelID = decision.SelectedModel
}
// Select provider based on resolved model name
providerName := "openai" // default for Responses API
modelLower := strings.ToLower(req.Model)
modelLower := strings.ToLower(modelID)
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
providerName = "gemini"
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
@@ -284,17 +371,7 @@ func (s *Server) handleResponses(c *gin.Context) {
return
}
// Strip common prefixes from model name
modelID := req.Model
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"}
for _, p := range prefixes {
if strings.HasPrefix(modelID, p) {
modelID = strings.TrimPrefix(modelID, p)
break
}
}
// Use the stripped model name for the actual API call
// Use resolved model for the actual API call
req.Model = modelID
clientID := "default"
@@ -309,7 +386,7 @@ func (s *Server) handleResponses(c *gin.Context) {
if stream {
ch, err := provider.ResponsesStream(c.Request.Context(), &req)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, false)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -324,9 +401,9 @@ func (s *Server) handleResponses(c *gin.Context) {
if !ok {
fmt.Fprintf(w, "data: [DONE]\n\n")
if lastUsage != nil {
s.logRequest(startTime, clientID, providerName, req.Model, lastUsage.ToUsage(), nil, false)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage.ToUsage(), nil, false)
} else {
s.logRequest(startTime, clientID, providerName, req.Model, nil, nil, false)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false)
}
return false
}
@@ -346,15 +423,15 @@ func (s *Server) handleResponses(c *gin.Context) {
resp, err := provider.Responses(c.Request.Context(), &req)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, false)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if resp.Usage != nil {
s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage.ToUsage(), nil, false)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage.ToUsage(), nil, false)
} else {
s.logRequest(startTime, clientID, providerName, req.Model, nil, nil, false)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false)
}
c.JSON(http.StatusOK, resp)
}
@@ -378,6 +455,7 @@ func (s *Server) handleListModels(c *gin.Context) {
"xai": true, // Models from models.dev use 'xai' ID for Grok
"llmgateway": true, // Catch-all for newer models
"ollama": true,
"xiaomi": true, // Xiaomi MiMo models
}
s.registryMu.RLock()
@@ -414,6 +492,20 @@ func (s *Server) handleListModels(c *gin.Context) {
}
}
// Add model groups so clients can discover them
if s.modelRouter != nil {
for _, gid := range s.modelRouter.Groups() {
if _, exists := modelMap[gid]; !exists {
modelMap[gid] = OpenAIModel{
ID: gid,
Object: "model",
Created: 1700000000,
OwnedBy: "gophergate",
}
}
}
}
var data []OpenAIModel
for _, m := range modelMap {
data = append(data, m)
@@ -425,21 +517,12 @@ func (s *Server) handleListModels(c *gin.Context) {
})
}
func (s *Server) handleChatCompletions(c *gin.Context) {
startTime := time.Now()
var req models.ChatCompletionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Select provider based on model name
func (s *Server) selectProvider(modelID string) (providers.Provider, string, error) {
providerName := "openai" // default
modelLower := strings.ToLower(req.Model)
modelLower := strings.ToLower(modelID)
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
providerName = "gemini"
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
// Only use deepseek provider if it's not explicitly tagged for ollama
providerName = "deepseek"
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
providerName = "moonshot"
@@ -456,17 +539,28 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
strings.Contains(modelLower, "codellama") ||
strings.Contains(modelLower, "command-r") {
providerName = "ollama"
} else if strings.HasPrefix(modelLower, "xiaomi/") || strings.Contains(modelLower, "mimo") || strings.Contains(modelLower, "xiaomi") {
providerName = "xiaomi"
}
provider, ok := s.providers[providerName]
p, ok := s.providers[providerName]
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
return nil, "", fmt.Errorf("Provider %s not enabled or supported", providerName)
}
return p, providerName, nil
}
func (s *Server) handleChatCompletions(c *gin.Context) {
startTime := time.Now()
var req models.ChatCompletionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Strip common prefixes
// Strip common prefixes and prepare model ID
modelID := req.Model
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"}
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/", "xiaomi/"}
for _, p := range prefixes {
if strings.HasPrefix(modelID, p) {
modelID = strings.TrimPrefix(modelID, p)
@@ -474,6 +568,32 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
}
}
// Resolve model groups to concrete models (hierarchical — groups can target groups)
modelGroup := ""
for i, m := range req.Messages {
log.Printf("[DEBUG] Incoming Msg[%d]: role=%s, hasToolCalls=%v, hasContent=%v", i, m.Role, len(m.ToolCalls) > 0, m.Content != nil)
}
if s.modelRouter != nil {
routeCtx := s.buildRouteContextFromChat(req)
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
return
}
if decision.SelectedModel != modelID {
modelGroup = modelID
}
modelID = decision.SelectedModel
log.Printf("[ROUTER] %s (%s: %s)", modelID, decision.Strategy, decision.Reason)
}
// Select provider based on the resolved model name
provider, providerName, err := s.selectProvider(modelID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Convert ChatCompletionRequest to UnifiedRequest
unifiedReq := &models.UnifiedRequest{
Model: modelID,
@@ -490,27 +610,28 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
ToolChoice: req.ToolChoice,
}
// Inject max_tokens from model registry when client doesn't specify one.
// Prevents providers from applying a low default output cap.
// DEBUG: Trace max_tokens through the proxy
clientMaxTokens := "nil"
if unifiedReq.MaxTokens != nil {
clientMaxTokens = fmt.Sprintf("%d", *unifiedReq.MaxTokens)
}
log.Printf("[DEBUG] %s: client max_tokens=%s", modelID, clientMaxTokens)
if unifiedReq.MaxTokens == nil {
// Inject or cap max_tokens from model registry.
s.registryMu.RLock()
meta := s.registry.FindModel(modelID)
s.registryMu.RUnlock()
if meta != nil && meta.Limit != nil && meta.Limit.Output > 0 {
unifiedReq.MaxTokens = &meta.Limit.Output
log.Printf("[DEBUG] %s: injected registry max_tokens=%d", modelID, meta.Limit.Output)
if unifiedReq.MaxTokens == nil {
unifiedReq.MaxTokens = &meta.Limit.Output
log.Printf("[DEBUG] %s: injected registry max_tokens=%d", modelID, meta.Limit.Output)
} else if *unifiedReq.MaxTokens > meta.Limit.Output {
log.Printf("[DEBUG] %s: capping client max_tokens (%d) to registry limit (%d)", modelID, *unifiedReq.MaxTokens, meta.Limit.Output)
unifiedReq.MaxTokens = &meta.Limit.Output
} else {
log.Printf("[DEBUG] %s: using client max_tokens (%d)", modelID, *unifiedReq.MaxTokens)
}
} else {
log.Printf("[DEBUG] %s: no registry limit found, leaving max_tokens nil (provider default)", modelID)
if unifiedReq.MaxTokens == nil {
log.Printf("[DEBUG] %s: no registry limit found, leaving max_tokens nil", modelID)
} else {
log.Printf("[DEBUG] %s: using client max_tokens (%d), no registry limit to cap", modelID, *unifiedReq.MaxTokens)
}
}
} else {
log.Printf("[DEBUG] %s: using client's max_tokens=%d", modelID, *unifiedReq.MaxTokens)
}
// Handle Stop sequences
if req.Stop != nil {
@@ -592,7 +713,7 @@ if unifiedReq.MaxTokens == nil {
if unifiedReq.Stream {
ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, unifiedReq.HasImages)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -606,7 +727,7 @@ if unifiedReq.MaxTokens == nil {
chunk, ok := <-ch
if !ok {
fmt.Fprintf(w, "data: [DONE]\n\n")
s.logRequest(startTime, clientID, providerName, req.Model, lastUsage, nil, unifiedReq.HasImages)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage, nil, unifiedReq.HasImages)
return false
}
if chunk.Usage != nil {
@@ -624,15 +745,29 @@ if unifiedReq.MaxTokens == nil {
resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, unifiedReq.HasImages)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage, nil, unifiedReq.HasImages)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage, nil, unifiedReq.HasImages)
c.JSON(http.StatusOK, resp)
}
func extractUserMessage(messages []models.ChatMessage) string {
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
switch c := messages[i].Content.(type) {
case string:
return c
default:
return ""
}
}
}
return ""
}
func (s *Server) handleImageGenerations(c *gin.Context) {
startTime := time.Now()
var req models.ImageGenerationRequest
@@ -684,7 +819,7 @@ func (s *Server) handleImageGenerations(c *gin.Context) {
resp, err := provider.ImageGeneration(c.Request.Context(), &req)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, false)
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -698,7 +833,7 @@ func (s *Server) handleImageGenerations(c *gin.Context) {
// Calculate per-image cost (not per-token like chat)
cost := imageGenCost(providerName, req.Model, req.Size, uint32(len(resp.Data)))
s.logRequest(startTime, clientID, providerName, req.Model, &models.Usage{
s.logRequest(startTime, clientID, providerName, req.Model, "", &models.Usage{
PromptTokens: promptTokens,
CompletionTokens: uint32(len(resp.Data)),
TotalTokens: promptTokens + uint32(len(resp.Data)),
@@ -740,12 +875,13 @@ func imageGenCost(provider, model string, size *string, n uint32) float64 {
return perImage * float64(n)
}
func (s *Server) logRequest(start time.Time, clientID, provider, model string, usage *models.Usage, err error, hasImages bool) {
func (s *Server) logRequest(start time.Time, clientID, provider, model, modelGroup string, usage *models.Usage, err error, hasImages bool) {
entry := RequestLog{
Timestamp: start,
ClientID: clientID,
Provider: provider,
Model: model,
ModelGroup: modelGroup,
Status: "success",
DurationMS: time.Since(start).Milliseconds(),
HasImages: hasImages,
@@ -770,9 +906,14 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model string, u
entry.CacheWriteTokens = *usage.CacheWriteTokens
}
// Calculate cost using registry
// Calculate cost using registry; if the resolved model is unknown,
// fall back to the model group so group requests still get priced.
s.registryMu.RLock()
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
pricingModel := model
if s.registry != nil && s.registry.FindModel(pricingModel) == nil && modelGroup != "" {
pricingModel = modelGroup
}
entry.Cost = utils.CalculateCost(s.registry, pricingModel, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
s.registryMu.RUnlock()
}
@@ -799,3 +940,165 @@ func (s *Server) Run() error {
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
return s.router.Run(addr)
}
func uint32Ptr(v uint32) *uint32 { return &v }
func (s *Server) buildRouteContextFromChat(req models.ChatCompletionRequest) *router.RouteContext {
userMessage := extractUserMessage(req.Messages)
requiresToolCalling := len(req.Tools) > 0
hasMultimodal := false
inputTokens := 0
for _, msg := range req.Messages {
if strContent, ok := msg.Content.(string); ok {
inputTokens += len(strContent) / 4
} else if parts, ok := msg.Content.([]interface{}); ok {
for _, part := range parts {
if partMap, ok := part.(map[string]interface{}); ok {
partType, _ := partMap["type"].(string)
if partType == "text" {
text, _ := partMap["text"].(string)
inputTokens += len(text) / 4
} else if partType == "image_url" {
hasMultimodal = true
inputTokens += 1000 // Approximate cost of an image in tokens
}
}
}
}
}
msgLower := strings.ToLower(userMessage)
requiresReasoning := strings.Contains(msgLower, "reason") ||
strings.Contains(msgLower, "think step by step") ||
strings.Contains(msgLower, "mathematics") ||
strings.Contains(msgLower, "architecture") ||
strings.Contains(msgLower, "explain in detail")
routeCtx := &router.RouteContext{
UserMessage: userMessage,
InputTokens: inputTokens,
HasMultimodalInput: hasMultimodal,
RequiresToolCalling: requiresToolCalling,
RequiresReasoning: requiresReasoning,
}
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
return routeCtx
}
func (s *Server) buildRouteContextFromResponses(req models.ResponsesRequest) *router.RouteContext {
var userMessage string
hasMultimodal := false
inputTokens := len(req.Instructions) / 4
requiresToolCalling := len(req.Tools) > 0 && string(req.Tools) != "null" && string(req.Tools) != ""
var strInput string
if err := json.Unmarshal(req.Input, &strInput); err == nil {
userMessage = strInput
inputTokens += len(userMessage) / 4
} else {
var msgs []models.ResponseInputMessage
if err := json.Unmarshal(req.Input, &msgs); err == nil {
for _, m := range msgs {
var contentStr string
if err := json.Unmarshal(m.Content, &contentStr); err == nil {
if m.Role == "user" {
userMessage = contentStr
}
inputTokens += len(contentStr) / 4
} else {
var parts []models.ContentPart
if err := json.Unmarshal(m.Content, &parts); err == nil {
for _, p := range parts {
if p.Type == "text" {
if m.Role == "user" {
userMessage = p.Text
}
inputTokens += len(p.Text) / 4
} else if p.Type == "image_url" {
hasMultimodal = true
inputTokens += 1000
}
}
}
}
}
}
}
msgLower := strings.ToLower(userMessage)
requiresReasoning := strings.Contains(msgLower, "reason") ||
strings.Contains(msgLower, "think step by step") ||
strings.Contains(msgLower, "mathematics") ||
strings.Contains(msgLower, "architecture") ||
strings.Contains(msgLower, "explain in detail")
routeCtx := &router.RouteContext{
UserMessage: userMessage,
InputTokens: inputTokens,
HasMultimodalInput: hasMultimodal,
RequiresToolCalling: requiresToolCalling,
RequiresReasoning: requiresReasoning,
}
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
return routeCtx
}
func (s *Server) getRouteCtxTags(routeCtx *router.RouteContext) []string {
var tags []string
msgLower := strings.ToLower(routeCtx.UserMessage)
// fast-flow keywords
fastFlowKeywords := []string{
"classify", "classification", "label", "tag", "route", "routing", "intent",
"json", "yaml", "csv", "schema", "parse", "extract", "transform", "format", "regex",
"short answer", "brief", "concise", "tl;dr", "one line", "simple",
"fix this", "small bug", "quick fix", "typo", "syntax error",
}
for _, kw := range fastFlowKeywords {
if strings.Contains(msgLower, kw) {
tags = append(tags, "fast-flow", "classification", "json-extraction", "basic-qa")
break
}
}
// standard-pro keywords
standardProKeywords := []string{
"explain", "summarize", "rewrite", "draft", "edit", "polish", "outline",
"long doc", "document", "email", "memo", "proposal", "report", "handout", "notes",
"compare", "choose", "recommend", "tradeoff", "pros and cons", "analysis",
"code review", "debug", "bug", "feature", "api", "endpoint", "implement",
"plan", "planning", "workflow", "integration",
}
for _, kw := range standardProKeywords {
if strings.Contains(msgLower, kw) {
tags = append(tags, "standard-pro", "long-doc")
break
}
}
if routeCtx.HasMultimodalInput {
tags = append(tags, "video-analysis", "multimodal-qa")
}
// heavy-logic keywords
heavyLogicKeywords := []string{
"agent", "agents", "tool use", "function calling", "multi-agent", "orchestrate",
"system design", "scaling", "performance", "architecture review", "distributed",
"hard bug", "race condition", "deadlock", "memory leak", "crash", "production outage",
"long context", "large codebase", "many files", "complex refactor", "migration",
"research", "deep dive", "literature", "paper", "scholarly", "thorough analysis",
"deep reasoning", "think step by step", "reason through", "careful analysis",
}
for _, kw := range heavyLogicKeywords {
if strings.Contains(msgLower, kw) {
tags = append(tags, "heavy-logic", "deep-reasoning", "architecture", "hard-debugging")
break
}
}
if routeCtx.RequiresToolCalling {
tags = append(tags, "tool-heavy", "multi-step-agent", "swe-bench")
}
return tags
}
+25
View File
@@ -45,6 +45,24 @@ func FetchRegistry() (*models.ModelRegistry, error) {
return nil, fmt.Errorf("failed to fetch registry after 3 attempts: %w", lastErr)
}
// promoDiscount describes a temporary pricing discount applied on top of
// the standard (list) price from the model registry.
type promoDiscount struct {
Factor float64 // multiplier applied after standard calculation (0.25 = 75% off)
ExpiresAt time.Time // discount ends at this time (UTC)
}
// promoDiscounts maps model IDs to active promotional discounts.
// Sources:
// - DeepSeek v4 Pro: 75% off list pricing until 2026-05-31
// https://api-docs.deepseek.com/quick_start/pricing
var promoDiscounts = map[string]promoDiscount{
"deepseek-v4-pro": {
Factor: 0.25,
ExpiresAt: time.Date(2026, 5, 31, 23, 59, 59, 0, time.UTC),
},
}
func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens, completionTokens, reasoningTokens, cacheRead, cacheWrite uint32) float64 {
meta := registry.FindModel(modelID)
if meta == nil || meta.Cost == nil {
@@ -72,5 +90,12 @@ func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens,
cost += float64(cacheWrite) * (*meta.Cost.CacheWrite) / 1000000.0
}
// Apply promotional discounts (e.g. DeepSeek 75% off until 2026-05-31).
if discount, ok := promoDiscounts[modelID]; ok {
if time.Now().UTC().Before(discount.ExpiresAt) {
cost *= discount.Factor
}
}
return cost
}
+6 -1
View File
@@ -89,6 +89,10 @@
<i class="fas fa-brain"></i>
<span>Models</span>
</li>
<li class="menu-item" data-page="model-groups">
<i class="fas fa-code-branch"></i>
<span>Model Groups</span>
</li>
</ul>
</div>
@@ -164,7 +168,7 @@
<script src="/js/auth.js?v=7"></script>
<script src="/js/charts.js?v=7"></script>
<script src="/js/websocket.js?v=7"></script>
<script src="/js/dashboard.js?v=7"></script>
<script src="/js/dashboard.js?v=8"></script>
<!-- Page Modules -->
<script src="/js/pages/overview.js?v=7"></script>
@@ -177,5 +181,6 @@
<script src="/js/pages/settings.js?v=7"></script>
<script src="/js/pages/logs.js?v=7"></script>
<script src="/js/pages/users.js?v=7"></script>
<script src="/js/pages/model_groups.js?v=9"></script>
</body>
</html>
+6
View File
@@ -119,6 +119,7 @@ class Dashboard {
'settings': 'Settings',
'logs': 'Logs',
'models': 'Models',
'model-groups': 'Model Groups',
'users': 'User Management'
};
if (titleElement) titleElement.textContent = titles[page] || 'Dashboard';
@@ -130,6 +131,11 @@ class Dashboard {
if (content) {
content.innerHTML = await this.getPageTemplate(page);
await this.initializePageScript(page);
// Model Groups page uses its own render method
if (page === 'model-groups' && typeof modelGroupsPage !== 'undefined') {
await modelGroupsPage.render();
}
}
} catch (error) {
console.error(`Error loading page ${page}:`, error);
+186
View File
@@ -0,0 +1,186 @@
// Model Groups Management Page
class ModelGroupsPage {
constructor() {
this.container = document.getElementById('page-content');
}
async render() {
this.container.innerHTML = `
<div class="page-header">
<h3>Model Groups</h3>
<p class="text-muted">Define auto-routing groups that pick the best model for each request.</p>
<button class="btn btn-primary" onclick="modelGroupsPage.showCreateForm()">
<i class="fas fa-plus"></i> Add Group
</button>
</div>
<div id="model-groups-list" class="table-container"></div>
<div id="model-group-form" class="form-container" style="display:none;"></div>
`;
await this.loadGroups();
}
async loadGroups() {
try {
const groups = await api.get('/model-groups');
const list = document.getElementById('model-groups-list');
if (!groups || groups.length === 0) {
list.innerHTML = '<div class="empty-state">No model groups defined. Create one to enable auto-routing.</div>';
return;
}
let html = '<table class="data-table"><thead><tr>';
html += '<th>Group ID</th><th>Level</th><th>Primary Use</th><th>Strategy</th><th>Targets</th><th>Actions</th>';
html += '</tr></thead><tbody>';
groups.forEach(g => {
html += '<tr>';
html += '<td><code>' + this.esc(g.id) + '</code></td>';
html += '<td>' + (g.logic_level != null ? g.logic_level : '&mdash;') + '</td>';
html += '<td>' + this.esc(g.primary_use || '&mdash;') + '</td>';
html += '<td><span class="badge">' + this.esc(g.strategy) + '</span></td>';
html += '<td><code>' + this.esc(g.targets) + '</code></td>';
html += '<td>';
html += '<button class="btn btn-sm" onclick="modelGroupsPage.showEditForm(\'' + this.esc(g.id) + '\')">Edit</button> ';
html += '<button class="btn btn-sm btn-danger" onclick="modelGroupsPage.deleteGroup(\'' + this.esc(g.id) + '\')">Delete</button>';
html += '</td></tr>';
});
html += '</tbody></table>';
list.innerHTML = html;
} catch (err) {
document.getElementById('model-groups-list').innerHTML =
'<div class="error-message">Failed to load model groups: ' + this.esc(err.message) + '</div>';
}
}
showCreateForm() {
this.renderForm(null);
}
async showEditForm(id) {
try {
const groups = await api.get('/model-groups');
const group = groups.find(g => g.id === id);
if (group) this.renderForm(group);
} catch (err) {
alert('Failed to load group: ' + err.message);
}
}
renderForm(group) {
const isEdit = !!group;
const form = document.getElementById('model-group-form');
form.style.display = 'block';
form.innerHTML = `
<h4>${isEdit ? 'Edit' : 'Create'} Model Group</h4>
<form onsubmit="modelGroupsPage.saveGroup(event, ${isEdit})">
<div class="form-control">
<label>Group ID</label>
<input type="text" id="mg-id" value="${this.esc(group ? group.id : '')}" ${isEdit ? 'readonly' : 'required'}
placeholder="e.g. deepseek-auto">
<small>Clients use this as the model name.</small>
</div>
<div class="form-control">
<label>Strategy</label>
<select id="mg-strategy">
<option value="heuristic" ${group && group.strategy === 'heuristic' ? 'selected' : ''}>Heuristic (rules-based)</option>
<option value="classifier" ${group && group.strategy === 'classifier' ? 'selected' : ''}>Classifier (LLM judge)</option>
</select>
</div>
<div class="form-control">
<label>Targets (JSON array)</label>
<input type="text" id="mg-targets" value='${this.esc(group ? group.targets : '["cheap-model","smart-model"]')}' required>
<small>First target = cheapest/fastest. Last target = smartest/most expensive.</small>
</div>
<div class="form-control" id="mg-selector-row" style="${group && group.strategy === 'classifier' ? '' : 'display:none'}">
<label>Selector Model</label>
<input type="text" id="mg-selector-model" value="${this.esc(group && group.selector_model ? group.selector_model : 'gpt-4o-mini')}"
placeholder="Model used to judge task complexity">
</div>
<div class="form-control" id="mg-threshold-row" style="${group && group.strategy === 'classifier' ? '' : 'display:none'}">
<label>Complexity Threshold</label>
<input type="number" id="mg-threshold" value="${group && group.complexity_threshold ? group.complexity_threshold : ''}" min="1"
placeholder="Tasks rated >= this go to the smart model">
</div>
<div class="form-control" id="mg-rules-row" style="${group && group.strategy === 'heuristic' ? '' : 'display:none'}">
<label>Heuristic Rules (JSON array)</label>
<textarea id="mg-rules" rows="4" placeholder='[{"pattern":"step by step","target":1}]'>${group && group.heuristic_rules ? group.heuristic_rules : ''}</textarea>
<small>Pattern to match in user messages. target = index into targets array.</small>
</div>
<div class="form-control">
<label>Logic Level (1-10)</label>
<input type="number" id="mg-level" value="${group && group.logic_level != null ? group.logic_level : ''}" min="1" max="10"
placeholder="e.g. 8 for heavy logic, 2 for fast/basic">
<small>Rough complexity scale. 1-3: fast/light, 4-7: standard, 8-10: heavy.</small>
</div>
<div class="form-control">
<label>Primary Use</label>
<input type="text" id="mg-primary-use" value="${this.esc(group && group.primary_use ? group.primary_use : '')}"
placeholder="e.g. Complex Coding, Logic, Agents.">
<small>Brief description of what this group is best used for.</small>
</div>
<div class="form-actions">
<button type="submit" class="btn btn-primary">Save</button>
<button type="button" class="btn" onclick="document.getElementById('model-group-form').style.display='none'">Cancel</button>
</div>
</form>
`;
document.getElementById('mg-strategy').onchange = function() {
var isClassifier = this.value === 'classifier';
document.getElementById('mg-selector-row').style.display = isClassifier ? '' : 'none';
document.getElementById('mg-threshold-row').style.display = isClassifier ? '' : 'none';
document.getElementById('mg-rules-row').style.display = isClassifier ? 'none' : '';
};
}
async saveGroup(event, isEdit) {
event.preventDefault();
var id = document.getElementById('mg-id').value.trim();
var strategy = document.getElementById('mg-strategy').value;
var targets = document.getElementById('mg-targets').value;
var selectorModel = document.getElementById('mg-selector-model').value.trim() || null;
var thresholdVal = document.getElementById('mg-threshold').value;
var rules = document.getElementById('mg-rules').value.trim() || null;
var logicLevelVal = document.getElementById('mg-level').value;
var primaryUse = document.getElementById('mg-primary-use').value.trim() || null;
try { JSON.parse(targets); } catch (e) { alert('Targets must be valid JSON array'); return; }
if (rules) { try { JSON.parse(rules); } catch (e) { alert('Heuristic rules must be valid JSON'); return; } }
var body = { id: id, strategy: strategy, targets: targets, selector_model: selectorModel, heuristic_rules: rules };
if (thresholdVal) body.complexity_threshold = parseInt(thresholdVal);
if (logicLevelVal) body.logic_level = parseInt(logicLevelVal);
if (primaryUse) body.primary_use = primaryUse;
try {
if (isEdit) {
await api.put('/model-groups/' + encodeURIComponent(id), body);
} else {
await api.post('/model-groups', body);
}
document.getElementById('model-group-form').style.display = 'none';
await this.loadGroups();
} catch (err) {
alert('Failed to save: ' + err.message);
}
}
async deleteGroup(id) {
if (!confirm('Delete model group "' + id + '"? This cannot be undone.')) return;
try {
await api.delete('/model-groups/' + encodeURIComponent(id));
await this.loadGroups();
} catch (err) {
alert('Failed to delete: ' + err.message);
}
}
esc(str) {
if (!str) return '';
return String(str).replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;').replace(/"/g,'&quot;');
}
}
var modelGroupsPage = new ModelGroupsPage();
+1 -1
View File
@@ -392,7 +392,7 @@ class MonitoringPage {
</div>
<div class="stream-entry-content">
<strong>${request.client_id || 'Unknown'}</strong> →
${request.provider || 'Unknown'} (${request.model || 'Unknown'})
${request.provider || 'Unknown'} (${request.model || 'Unknown'}${request.model_group ? ` via ${request.model_group}` : ''})
<div class="stream-entry-details">
${request.total_tokens || request.tokens || 0} tokens • ${request.duration_ms || request.duration || 0}ms
</div>
+2 -2
View File
@@ -252,7 +252,7 @@ class OverviewPage {
<td>${time}</td>
<td><span class="badge-client">${request.client_id}</span></td>
<td>${request.provider}</td>
<td><code class="code-sm">${request.model}</code></td>
<td><code class="code-sm">${request.model}${request.model_group ? ` (via ${request.model_group})` : ''}</code></td>
<td>${request.tokens.toLocaleString()}</td>
<td>
<span class="status-badge ${statusClass}">
@@ -313,7 +313,7 @@ class OverviewPage {
<td>${time}</td>
<td><span class="badge-client">${request.client_id}</span></td>
<td>${request.provider}</td>
<td><code class="code-sm">${request.model}</code></td>
<td><code class="code-sm">${request.model}${request.model_group ? ` (via ${request.model_group})` : ''}</code></td>
<td>${(request.total_tokens || request.tokens || 0).toLocaleString()}</td>
<td>
<span class="status-badge ${statusClass}">
+2 -2
View File
@@ -309,7 +309,7 @@ class WebSocketManager {
<td>${time}</td>
<td>${request.client_id || 'Unknown'}</td>
<td>${request.provider || 'Unknown'}</td>
<td>${request.model || 'Unknown'}</td>
<td>${request.model || 'Unknown'}${request.model_group ? ` (via ${request.model_group})` : ''}</td>
<td>${(request.total_tokens || request.tokens || 0)}</td>
<td>
<span class="status-badge ${statusClass}">
@@ -358,7 +358,7 @@ class WebSocketManager {
</div>
<div class="stream-entry-content">
<strong>${request.client_id || 'Unknown'}</strong> →
${request.provider || 'Unknown'} (${request.model || 'Unknown'})
${request.provider || 'Unknown'} (${request.model || 'Unknown'}${request.model_group ? ` via ${request.model_group}` : ''})
<div class="stream-entry-details">
${(request.total_tokens || request.tokens || 0)} tokens • ${(request.duration_ms || request.duration || 0)}ms
</div>