Compare commits
28 Commits
c009d401fb
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 73a82e6175 | |||
| b3354a1bbc | |||
| 1dc5f586b9 | |||
| 40f055cb57 | |||
| 970e778703 | |||
| 477a811999 | |||
| d2b9da89d9 | |||
| b7df3108fa | |||
| 28b8271c1d | |||
| eb585c0001 | |||
| 4aea7a3b4c | |||
| 330eaa57d1 | |||
| 0ae30036f0 | |||
| 3c0b59622e | |||
| 7517307c11 | |||
| 19517b0847 | |||
| a3a6f765e7 | |||
| 79dd122b56 | |||
| 3021e4b2b4 | |||
| 14de7e9ebf | |||
| 4fef201e95 | |||
| bac03de051 | |||
| 37949e560b | |||
| f04cb6b8f2 | |||
| 10262c0e5a | |||
| d345f8c41d | |||
| d1f7a57f58 | |||
| dc9af4d79c |
@@ -18,6 +18,9 @@ DEEPSEEK_API_KEY=sk-...
|
||||
MOONSHOT_API_KEY=sk-...
|
||||
GROK_API_KEY=xai-...
|
||||
|
||||
# Xiaomi MiMo
|
||||
XIAOMI_API_KEY=sk-...
|
||||
|
||||
# ==============================================================================
|
||||
# Server Configuration
|
||||
# ==============================================================================
|
||||
|
||||
@@ -14,3 +14,5 @@
|
||||
.pi-lens/cache/
|
||||
server.pid
|
||||
/target
|
||||
nohup.out
|
||||
*.bak
|
||||
|
||||
@@ -0,0 +1,919 @@
|
||||
# Automatic Model Routing — Implementation Plan
|
||||
|
||||
> **For Hermes:** Use subagent-driven-development skill to implement this plan task-by-task.
|
||||
|
||||
**Goal:** Add a model-group router that lets clients send `model: "deepseek-auto"` and have gophergate pick the best concrete model based on heuristic rules or an optional classifier LLM.
|
||||
|
||||
**Architecture:** A new `internal/router/` package with heuristic and classifier strategies, backed by a `model_groups` DB table. The router injects into `handleChatCompletions` after provider resolution but before the provider call — zero changes to the Provider interface. Admin CRUD endpoints and a dashboard tab for management.
|
||||
|
||||
**Tech Stack:** Go 1.22+, Gin, sqlx (SQLite), resty, existing OpenAI provider for classifier calls.
|
||||
|
||||
---
|
||||
|
||||
## Task 1: Add `model_groups` DB migration and struct
|
||||
|
||||
**Objective:** Create the `model_groups` table and Go struct.
|
||||
|
||||
**Files:**
|
||||
- Modify: `internal/db/db.go`
|
||||
|
||||
**Step 1: Add CREATE TABLE to migrations**
|
||||
|
||||
In `RunMigrations()`, add to the `queries` slice (after `client_tokens`):
|
||||
|
||||
```go
|
||||
`CREATE TABLE IF NOT EXISTS model_groups (
|
||||
id TEXT PRIMARY KEY,
|
||||
strategy TEXT NOT NULL DEFAULT 'heuristic',
|
||||
selector_model TEXT,
|
||||
targets TEXT NOT NULL DEFAULT '[]',
|
||||
complexity_threshold INTEGER,
|
||||
heuristic_rules TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)`,
|
||||
```
|
||||
|
||||
**Step 2: Add the Go struct**
|
||||
|
||||
After the `ClientToken` struct (around line 264), add:
|
||||
|
||||
```go
|
||||
type ModelGroup struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
Strategy string `db:"strategy" json:"strategy"`
|
||||
SelectorModel *string `db:"selector_model" json:"selector_model"`
|
||||
Targets string `db:"targets" json:"targets"` // JSON array
|
||||
ComplexityThreshold *int `db:"complexity_threshold" json:"complexity_threshold"`
|
||||
HeuristicRules *string `db:"heuristic_rules" json:"heuristic_rules"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
```
|
||||
|
||||
**Step 3: Seed default groups**
|
||||
|
||||
After the "Default client" block in `RunMigrations()`, add:
|
||||
|
||||
```go
|
||||
// Seed default model groups
|
||||
defaultGroups := []struct {
|
||||
id, strategy, targets string
|
||||
}{
|
||||
{"deepseek-auto", "heuristic", `["deepseek-chat","deepseek-reasoner"]`},
|
||||
{"openai-auto", "heuristic", `["gpt-4o-mini","gpt-4o"]`},
|
||||
{"gemini-auto", "heuristic", `["gemini-2.0-flash","gemini-2.5-pro"]`},
|
||||
}
|
||||
for _, g := range defaultGroups {
|
||||
db.Exec(`INSERT OR IGNORE INTO model_groups (id, strategy, targets) VALUES (?, ?, ?)`,
|
||||
g.id, g.strategy, g.targets)
|
||||
}
|
||||
```
|
||||
|
||||
**Step 4: Build and verify**
|
||||
|
||||
```bash
|
||||
cd ~/Documents/projects/web_projects/gophergate && go build ./...
|
||||
```
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add internal/db/db.go
|
||||
git commit -m "feat: add model_groups table and default seed data"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 2: Create router package — interface and heuristic router
|
||||
|
||||
**Objective:** Create `internal/router/` with the Router interface and heuristic implementation.
|
||||
|
||||
**Files:**
|
||||
- Create: `internal/router/router.go`
|
||||
- Create: `internal/router/heuristic.go`
|
||||
|
||||
**Step 1: Create `internal/router/router.go`**
|
||||
|
||||
```go
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
// Decision holds the result of a routing decision.
|
||||
type Decision struct {
|
||||
SelectedModel string `json:"selected_model"`
|
||||
Strategy string `json:"strategy"` // "heuristic" or "classifier"
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// ClassifierFunc is the callback for classifier-based routing.
|
||||
// Takes a system prompt, user message, and selector model.
|
||||
// Returns a complexity rating string (e.g. "3").
|
||||
type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error)
|
||||
|
||||
// Router resolves model groups to concrete models.
|
||||
type Router struct {
|
||||
groups map[string]db.ModelGroup
|
||||
classify ClassifierFunc
|
||||
}
|
||||
|
||||
// New creates a Router. classify may be nil if no classifier groups exist.
|
||||
func New(groups []db.ModelGroup, classify ClassifierFunc) *Router {
|
||||
r := &Router{
|
||||
groups: make(map[string]db.ModelGroup),
|
||||
classify: classify,
|
||||
}
|
||||
for _, g := range groups {
|
||||
r.groups[g.ID] = g
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// IsGroup returns true if the model name is a group ID.
|
||||
func (r *Router) IsGroup(modelID string) bool {
|
||||
_, ok := r.groups[modelID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Route resolves a group to a concrete model.
|
||||
// Extracts the user message from the request body JSON bytes.
|
||||
func (r *Router) Route(ctx context.Context, groupID string, userMessage string) (*Decision, error) {
|
||||
group, ok := r.groups[groupID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown model group: %s", groupID)
|
||||
}
|
||||
|
||||
var targets []string
|
||||
if err := json.Unmarshal([]byte(group.Targets), &targets); err != nil || len(targets) == 0 {
|
||||
return nil, fmt.Errorf("invalid or empty targets for group %s", groupID)
|
||||
}
|
||||
|
||||
switch group.Strategy {
|
||||
case "heuristic":
|
||||
return routeHeuristic(group, targets, userMessage)
|
||||
case "classifier":
|
||||
if r.classify == nil {
|
||||
// Fall back to heuristic if no classifier is available
|
||||
return routeHeuristic(group, targets, userMessage)
|
||||
}
|
||||
return routeClassifier(ctx, r.classify, group, targets, userMessage)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown strategy: %s", group.Strategy)
|
||||
}
|
||||
}
|
||||
|
||||
// Reload replaces the group definitions without recreating the router.
|
||||
func (r *Router) Reload(groups []db.ModelGroup) {
|
||||
r.groups = make(map[string]db.ModelGroup)
|
||||
for _, g := range groups {
|
||||
r.groups[g.ID] = g
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: Create `internal/router/heuristic.go`**
|
||||
|
||||
```go
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
// HeuristicRule defines a pattern-based routing rule.
|
||||
type HeuristicRule struct {
|
||||
Pattern string `json:"pattern"` // substring to match in user message
|
||||
TargetIdx int `json:"target"` // index into targets array (0-based)
|
||||
CaseSensitive bool `json:"case_sensitive,omitempty"`
|
||||
}
|
||||
|
||||
func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (*Decision, error) {
|
||||
// Default to first target (cheapest/fastest)
|
||||
selected := targets[0]
|
||||
reason := "default (first target)"
|
||||
|
||||
// If heuristic_rules is set, use them
|
||||
if group.HeuristicRules != nil && *group.HeuristicRules != "" {
|
||||
var rules []HeuristicRule
|
||||
if err := json.Unmarshal([]byte(*group.HeuristicRules), &rules); err == nil {
|
||||
searchMsg := userMessage
|
||||
for _, rule := range rules {
|
||||
pattern := rule.Pattern
|
||||
msg := searchMsg
|
||||
if !rule.CaseSensitive {
|
||||
pattern = strings.ToLower(pattern)
|
||||
msg = strings.ToLower(msg)
|
||||
}
|
||||
if strings.Contains(msg, pattern) {
|
||||
if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) {
|
||||
selected = targets[rule.TargetIdx]
|
||||
reason = "matched heuristic rule: " + rule.Pattern
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Built-in fallback heuristics (apply even without custom rules)
|
||||
if reason == "default (first target)" && len(targets) > 1 {
|
||||
msgLower := strings.ToLower(userMessage)
|
||||
// Complex task indicators → last target (usually the smarter model)
|
||||
complexIndicators := []string{
|
||||
"step by step", "explain in detail", "reason through",
|
||||
"think carefully", "analyze", "debug", "write code",
|
||||
"implement", "refactor", "architecture",
|
||||
}
|
||||
for _, indicator := range complexIndicators {
|
||||
if strings.Contains(msgLower, indicator) {
|
||||
selected = targets[len(targets)-1]
|
||||
reason = "complex task indicator: " + indicator
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Decision{
|
||||
SelectedModel: selected,
|
||||
Strategy: "heuristic",
|
||||
Reason: reason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// routeHeuristic exists as a package-level func for direct use.
|
||||
var _ = routeHeuristic // suppress unused warning when classifier is the only caller
|
||||
```
|
||||
|
||||
Hmm, actually let me simplify. The `routeHeuristic` function IS used by `Router.Route()`. Let me not use the blank identifier trick.
|
||||
|
||||
**Step 3: Build**
|
||||
|
||||
```bash
|
||||
cd ~/Documents/projects/web_projects/gophergate && go build ./...
|
||||
```
|
||||
|
||||
Fix any compilation errors (missing imports, etc.).
|
||||
|
||||
**Step 4: Commit**
|
||||
|
||||
```bash
|
||||
git add internal/router/
|
||||
git commit -m "feat: add router package with heuristic strategy"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 3: Add classifier router
|
||||
|
||||
**Objective:** Implement the classifier strategy that uses a cheap LLM to rate task complexity.
|
||||
|
||||
**Files:**
|
||||
- Create: `internal/router/classifier.go`
|
||||
|
||||
**Step 1: Create `internal/router/classifier.go`**
|
||||
|
||||
```go
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
const classifierSystemPrompt = `You are a task complexity classifier. Rate the following user message on a scale of 1 to %d, where:
|
||||
1 = trivial/simple (basic facts, greetings, simple math)
|
||||
%d = highly complex (multi-step reasoning, code generation, architecture design)
|
||||
|
||||
Reply with ONLY the number. No explanation.`
|
||||
|
||||
func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, userMessage string) (*Decision, error) {
|
||||
maxRating := len(targets)
|
||||
if maxRating < 2 {
|
||||
maxRating = 2
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating)
|
||||
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMessage)
|
||||
if err != nil {
|
||||
// Classifier failed — fall back to heuristic
|
||||
return routeHeuristic(group, targets, userMessage)
|
||||
}
|
||||
|
||||
rating, err := strconv.Atoi(strings.TrimSpace(ratingStr))
|
||||
if err != nil || rating < 1 {
|
||||
rating = 1
|
||||
}
|
||||
if rating > maxRating {
|
||||
rating = maxRating
|
||||
}
|
||||
|
||||
idx := rating - 1 // 0-based index into targets
|
||||
return &Decision{
|
||||
SelectedModel: targets[idx],
|
||||
Strategy: "classifier",
|
||||
Reason: fmt.Sprintf("complexity rating: %d/%d", rating, maxRating),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getSelectorModel(group db.ModelGroup, targets []string) string {
|
||||
if group.SelectorModel != nil && *group.SelectorModel != "" {
|
||||
return *group.SelectorModel
|
||||
}
|
||||
// Default: use the first (cheapest) target model as the selector
|
||||
return targets[0]
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: Build**
|
||||
|
||||
```bash
|
||||
cd ~/Documents/projects/web_projects/gophergate && go build ./...
|
||||
```
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add internal/router/classifier.go
|
||||
git commit -m "feat: add classifier routing strategy with LLM complexity rating"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 4: Wire router into the server
|
||||
|
||||
**Objective:** Add the Router to the Server struct, initialize it, and inject it into `handleChatCompletions`.
|
||||
|
||||
**Files:**
|
||||
- Modify: `internal/server/server.go`
|
||||
|
||||
**Step 1: Add router field to Server struct**
|
||||
|
||||
In the `Server` struct (around line 23), add after the `registryMu` field:
|
||||
|
||||
```go
|
||||
router *router.Router
|
||||
```
|
||||
|
||||
**Step 2: Add import**
|
||||
|
||||
Add to the imports block:
|
||||
|
||||
```go
|
||||
"gophergate/internal/router"
|
||||
```
|
||||
|
||||
**Step 3: Initialize router in NewServer**
|
||||
|
||||
After `s.setupRoutes()` (line 66), add:
|
||||
|
||||
```go
|
||||
// Initialize model group router
|
||||
s.refreshRouter()
|
||||
```
|
||||
|
||||
**Step 4: Add refreshRouter method**
|
||||
|
||||
Add a new method on Server:
|
||||
|
||||
```go
|
||||
func (s *Server) refreshRouter() {
|
||||
var groups []db.ModelGroup
|
||||
if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil {
|
||||
fmt.Printf("Warning: Failed to load model groups: %v\n", err)
|
||||
groups = nil
|
||||
}
|
||||
|
||||
// Build classifier function using the OpenAI provider
|
||||
var classifyFn router.ClassifierFunc
|
||||
if openaiProvider, ok := s.providers["openai"]; ok {
|
||||
classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) {
|
||||
req := &models.UnifiedRequest{
|
||||
Model: selectorModel,
|
||||
Messages: []models.UnifiedMessage{
|
||||
{Role: "system", Content: []models.ContentPart{{Type: "text", Text: systemPrompt}}},
|
||||
{Role: "user", Content: []models.ContentPart{{Type: "text", Text: userMessage}}},
|
||||
},
|
||||
MaxTokens: uint32Ptr(5),
|
||||
Stream: false,
|
||||
}
|
||||
resp, err := openaiProvider.ChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", fmt.Errorf("no choices in classifier response")
|
||||
}
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
}
|
||||
|
||||
if s.router == nil {
|
||||
s.router = router.New(groups, classifyFn)
|
||||
} else {
|
||||
s.router.Reload(groups)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Step 5: Add uint32Ptr helper (if not already in the codebase)**
|
||||
|
||||
At the bottom of server.go, add:
|
||||
|
||||
```go
|
||||
func uint32Ptr(v uint32) *uint32 { return &v }
|
||||
```
|
||||
|
||||
**Step 6: Inject router into handleChatCompletions**
|
||||
|
||||
In `handleChatCompletions`, after the model prefix stripping block (after line 475) and before building the UnifiedRequest (line 478), add:
|
||||
|
||||
```go
|
||||
// Check if model is a group and route to a concrete model
|
||||
if s.router != nil && s.router.IsGroup(modelID) {
|
||||
userMessage := extractUserMessage(req.Messages)
|
||||
decision, err := s.router.Route(c.Request.Context(), modelID, userMessage)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
|
||||
return
|
||||
}
|
||||
modelID = decision.SelectedModel
|
||||
log.Printf("[ROUTER] %s → %s (%s: %s)", req.Model, modelID, decision.Strategy, decision.Reason)
|
||||
}
|
||||
```
|
||||
|
||||
**Step 7: Add extractUserMessage helper**
|
||||
|
||||
```go
|
||||
func extractUserMessage(messages []models.ChatCompletionMessage) string {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == "user" {
|
||||
if s, ok := messages[i].Content.(string); ok {
|
||||
return s
|
||||
}
|
||||
// It might be a content array — grab text from first part
|
||||
if parts, ok := messages[i].Content.([]interface{}); ok && len(parts) > 0 {
|
||||
if part, ok := parts[0].(map[string]interface{}); ok {
|
||||
if text, ok := part["text"].(string); ok {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
```
|
||||
|
||||
**Step 8: Add router refresh to RefreshProviders**
|
||||
|
||||
At the end of `RefreshProviders()` (before `return nil` at line 171), add:
|
||||
|
||||
```go
|
||||
s.refreshRouter()
|
||||
```
|
||||
|
||||
**Step 9: Build**
|
||||
|
||||
```bash
|
||||
cd ~/Documents/projects/web_projects/gophergate && go build ./...
|
||||
```
|
||||
|
||||
Expect compilation errors — need to check the `ChatCompletionMessage` type. The handler uses `models.ChatCompletionRequest` which has `Messages []ChatCompletionMessage`. Let me verify the type. If it's `[]models.ChatCompletionMessage` with `Content` as a string field, the helper is simpler. Fix as needed.
|
||||
|
||||
**Step 10: Commit**
|
||||
|
||||
```bash
|
||||
git add internal/server/server.go
|
||||
git commit -m "feat: wire model group router into chat completions handler"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 5: Add admin API endpoints for model groups
|
||||
|
||||
**Objective:** CRUD endpoints at `/api/model-groups` for dashboard management.
|
||||
|
||||
**Files:**
|
||||
- Create: `internal/server/model_groups_admin.go`
|
||||
|
||||
**Step 1: Create `internal/server/model_groups_admin.go`**
|
||||
|
||||
```go
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"gophergate/internal/db"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (s *Server) handleGetModelGroups(c *gin.Context) {
|
||||
var groups []db.ModelGroup
|
||||
if err := s.database.Select(&groups, "SELECT * FROM model_groups ORDER BY id"); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if groups == nil {
|
||||
groups = []db.ModelGroup{}
|
||||
}
|
||||
c.JSON(http.StatusOK, groups)
|
||||
}
|
||||
|
||||
func (s *Server) handleCreateModelGroup(c *gin.Context) {
|
||||
var group db.ModelGroup
|
||||
if err := c.ShouldBindJSON(&group); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.database.Exec(`
|
||||
INSERT INTO model_groups (id, strategy, selector_model, targets, complexity_threshold, heuristic_rules)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
group.ID, group.Strategy, group.SelectorModel, group.Targets,
|
||||
group.ComplexityThreshold, group.HeuristicRules)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
c.JSON(http.StatusCreated, group)
|
||||
}
|
||||
|
||||
func (s *Server) handleUpdateModelGroup(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var group db.ModelGroup
|
||||
if err := c.ShouldBindJSON(&group); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.database.Exec(`
|
||||
UPDATE model_groups SET strategy=?, selector_model=?, targets=?, complexity_threshold=?, heuristic_rules=?, updated_at=CURRENT_TIMESTAMP
|
||||
WHERE id=?`,
|
||||
group.Strategy, group.SelectorModel, group.Targets,
|
||||
group.ComplexityThreshold, group.HeuristicRules, id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
c.JSON(http.StatusOK, group)
|
||||
}
|
||||
|
||||
func (s *Server) handleDeleteModelGroup(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
_, err := s.database.Exec("DELETE FROM model_groups WHERE id=?", id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: Register routes in setupRoutes()**
|
||||
|
||||
In `setupRoutes()`, add under the admin group (after the models endpoints around line 229):
|
||||
|
||||
```go
|
||||
admin.GET("/model-groups", s.handleGetModelGroups)
|
||||
admin.POST("/model-groups", s.handleCreateModelGroup)
|
||||
admin.PUT("/model-groups/:id", s.handleUpdateModelGroup)
|
||||
admin.DELETE("/model-groups/:id", s.handleDeleteModelGroup)
|
||||
```
|
||||
|
||||
**Step 3: Build**
|
||||
|
||||
```bash
|
||||
cd ~/Documents/projects/web_projects/gophergate && go build ./...
|
||||
```
|
||||
|
||||
**Step 4: Commit**
|
||||
|
||||
```bash
|
||||
git add internal/server/model_groups_admin.go internal/server/server.go
|
||||
git commit -m "feat: add model groups CRUD admin API endpoints"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 6: Add dashboard UI — sidebar entry and page module
|
||||
|
||||
**Objective:** Add a "Model Groups" tab to the dashboard sidebar and a page module for CRUD management.
|
||||
|
||||
**Files:**
|
||||
- Modify: `static/index.html`
|
||||
- Create: `static/js/pages/model_groups.js`
|
||||
|
||||
**Step 1: Add sidebar menu item in index.html**
|
||||
|
||||
In the MANAGEMENT section (after line 91, before `</ul>`), add:
|
||||
|
||||
```html
|
||||
<li class="menu-item" data-page="model-groups">
|
||||
<i class="fas fa-code-branch"></i>
|
||||
<span>Model Groups</span>
|
||||
</li>
|
||||
```
|
||||
|
||||
**Step 2: Add script tag in index.html**
|
||||
|
||||
After the users.js script (line 179), add:
|
||||
|
||||
```html
|
||||
<script src="/js/pages/model_groups.js?v=8"></script>
|
||||
```
|
||||
|
||||
**Step 3: Create `static/js/pages/model_groups.js`**
|
||||
|
||||
```javascript
|
||||
// Model Groups Management Page
|
||||
|
||||
class ModelGroupsPage {
|
||||
constructor() {
|
||||
this.container = document.getElementById('page-content');
|
||||
}
|
||||
|
||||
async render() {
|
||||
this.container.innerHTML = `
|
||||
<div class="page-header">
|
||||
<h3>Model Groups</h3>
|
||||
<p class="text-muted">Define auto-routing groups that pick the best model for each request.</p>
|
||||
<button class="btn btn-primary" onclick="modelGroupsPage.showCreateForm()">
|
||||
<i class="fas fa-plus"></i> Add Group
|
||||
</button>
|
||||
</div>
|
||||
<div id="model-groups-list" class="table-container"></div>
|
||||
<div id="model-group-form" class="form-container" style="display:none;"></div>
|
||||
`;
|
||||
await this.loadGroups();
|
||||
}
|
||||
|
||||
async loadGroups() {
|
||||
try {
|
||||
const groups = await api.get('/api/model-groups');
|
||||
const list = document.getElementById('model-groups-list');
|
||||
if (!groups || groups.length === 0) {
|
||||
list.innerHTML = '<div class="empty-state">No model groups defined. Create one to enable auto-routing.</div>';
|
||||
return;
|
||||
}
|
||||
|
||||
let targets;
|
||||
try { targets = JSON.parse(g.targets); } catch { targets = []; }
|
||||
const heuristicRules = g.heuristic_rules ? JSON.parse(g.heuristic_rules) : null;
|
||||
|
||||
let html = '<table class="data-table"><thead><tr>';
|
||||
html += '<th>Group ID</th><th>Strategy</th><th>Targets</th><th>Actions</th>';
|
||||
html += '</tr></thead><tbody>';
|
||||
|
||||
groups.forEach(g => {
|
||||
html += `<tr>
|
||||
<td><code>${this.esc(g.id)}</code></td>
|
||||
<td><span class="badge">${this.esc(g.strategy)}</span></td>
|
||||
<td><code>${this.esc(g.targets)}</code></td>
|
||||
<td>
|
||||
<button class="btn btn-sm" onclick="modelGroupsPage.showEditForm('${this.esc(g.id)}')">Edit</button>
|
||||
<button class="btn btn-sm btn-danger" onclick="modelGroupsPage.deleteGroup('${this.esc(g.id)}')">Delete</button>
|
||||
</td>
|
||||
</tr>`;
|
||||
});
|
||||
|
||||
html += '</tbody></table>';
|
||||
list.innerHTML = html;
|
||||
} catch (err) {
|
||||
document.getElementById('model-groups-list').innerHTML =
|
||||
`<div class="error-message">Failed to load model groups: ${this.esc(err.message)}</div>`;
|
||||
}
|
||||
}
|
||||
|
||||
showCreateForm() {
|
||||
this.renderForm(null);
|
||||
}
|
||||
|
||||
async showEditForm(id) {
|
||||
const groups = await api.get('/api/model-groups');
|
||||
const group = groups.find(g => g.id === id);
|
||||
if (group) this.renderForm(group);
|
||||
}
|
||||
|
||||
renderForm(group) {
|
||||
const isEdit = !!group;
|
||||
const form = document.getElementById('model-group-form');
|
||||
form.style.display = 'block';
|
||||
form.innerHTML = `
|
||||
<h4>${isEdit ? 'Edit' : 'Create'} Model Group</h4>
|
||||
<form onsubmit="modelGroupsPage.saveGroup(event, ${isEdit})">
|
||||
<div class="form-control">
|
||||
<label>Group ID</label>
|
||||
<input type="text" id="mg-id" value="${this.esc(group?.id || '')}" ${isEdit ? 'readonly' : 'required'}
|
||||
placeholder="e.g. deepseek-auto">
|
||||
<small>Clients use this as the model name.</small>
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label>Strategy</label>
|
||||
<select id="mg-strategy">
|
||||
<option value="heuristic" ${group?.strategy === 'heuristic' ? 'selected' : ''}>Heuristic (rules-based)</option>
|
||||
<option value="classifier" ${group?.strategy === 'classifier' ? 'selected' : ''}>Classifier (LLM judge)</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label>Targets (JSON array)</label>
|
||||
<input type="text" id="mg-targets" value='${this.esc(group?.targets || '["cheap-model","smart-model"]')}' required>
|
||||
<small>First target = cheapest/fastest. Last target = smartest/most expensive.</small>
|
||||
</div>
|
||||
<div class="form-control" id="mg-selector-row" ${group?.strategy === 'classifier' ? '' : 'style="display:none"'}>
|
||||
<label>Selector Model</label>
|
||||
<input type="text" id="mg-selector-model" value="${this.esc(group?.selector_model || 'gpt-4o-mini')}"
|
||||
placeholder="Model used to judge task complexity">
|
||||
</div>
|
||||
<div class="form-control" id="mg-threshold-row" ${group?.strategy === 'classifier' ? '' : 'style="display:none"'}>
|
||||
<label>Complexity Threshold</label>
|
||||
<input type="number" id="mg-threshold" value="${group?.complexity_threshold || ''}" min="1"
|
||||
placeholder="Tasks rated >= this go to the smart model">
|
||||
</div>
|
||||
<div class="form-control" id="mg-rules-row" ${group?.strategy === 'heuristic' ? '' : 'style="display:none"'}>
|
||||
<label>Heuristic Rules (JSON array)</label>
|
||||
<textarea id="mg-rules" rows="4" placeholder='[{"pattern":"step by step","target":1}]'>${group?.heuristic_rules || ''}</textarea>
|
||||
<small>Pattern to match in user messages. target = index into targets array.</small>
|
||||
</div>
|
||||
<div class="form-actions">
|
||||
<button type="submit" class="btn btn-primary">Save</button>
|
||||
<button type="button" class="btn" onclick="document.getElementById('model-group-form').style.display='none'">Cancel</button>
|
||||
</div>
|
||||
</form>
|
||||
`;
|
||||
|
||||
// Toggle strategy-specific fields
|
||||
document.getElementById('mg-strategy').onchange = function() {
|
||||
const isClassifier = this.value === 'classifier';
|
||||
document.getElementById('mg-selector-row').style.display = isClassifier ? '' : 'none';
|
||||
document.getElementById('mg-threshold-row').style.display = isClassifier ? '' : 'none';
|
||||
document.getElementById('mg-rules-row').style.display = isClassifier ? 'none' : '';
|
||||
};
|
||||
}
|
||||
|
||||
async saveGroup(event, isEdit) {
|
||||
event.preventDefault();
|
||||
const id = document.getElementById('mg-id').value.trim();
|
||||
const strategy = document.getElementById('mg-strategy').value;
|
||||
const targets = document.getElementById('mg-targets').value;
|
||||
const selectorModel = document.getElementById('mg-selector-model').value.trim() || null;
|
||||
const thresholdVal = document.getElementById('mg-threshold').value;
|
||||
const rules = document.getElementById('mg-rules').value.trim() || null;
|
||||
|
||||
// Validate JSON
|
||||
try { JSON.parse(targets); } catch { alert('Targets must be valid JSON array'); return; }
|
||||
if (rules) { try { JSON.parse(rules); } catch { alert('Heuristic rules must be valid JSON'); return; } }
|
||||
|
||||
const body = { id, strategy, targets, selector_model: selectorModel, heuristic_rules: rules };
|
||||
if (thresholdVal) body.complexity_threshold = parseInt(thresholdVal);
|
||||
|
||||
try {
|
||||
if (isEdit) {
|
||||
await api.put(`/api/model-groups/${encodeURIComponent(id)}`, body);
|
||||
} else {
|
||||
await api.post('/api/model-groups', body);
|
||||
}
|
||||
document.getElementById('model-group-form').style.display = 'none';
|
||||
await this.loadGroups();
|
||||
} catch (err) {
|
||||
alert('Failed to save: ' + err.message);
|
||||
}
|
||||
}
|
||||
|
||||
async deleteGroup(id) {
|
||||
if (!confirm(`Delete model group "${id}"?`)) return;
|
||||
try {
|
||||
await api.delete(`/api/model-groups/${encodeURIComponent(id)}`);
|
||||
await this.loadGroups();
|
||||
} catch (err) {
|
||||
alert('Failed to delete: ' + err.message);
|
||||
}
|
||||
}
|
||||
|
||||
esc(str) {
|
||||
if (!str) return '';
|
||||
return String(str).replace(/&/g,'&').replace(/</g,'<').replace(/>/g,'>').replace(/"/g,'"');
|
||||
}
|
||||
}
|
||||
|
||||
const modelGroupsPage = new ModelGroupsPage();
|
||||
```
|
||||
|
||||
**Step 4: Register page in dashboard.js**
|
||||
|
||||
In `static/js/dashboard.js`, find the page loading logic. The `loadPage` method dynamically imports page modules based on `this.currentPage`. The naming convention uses hyphens in `data-page` attributes (e.g., `data-page="model-groups"`). Check how the existing pages are loaded and ensure "model-groups" maps to the new module.
|
||||
|
||||
Looking at the existing pattern, pages are loaded via script tags in index.html and their constructors handle rendering when the page is navigated to. The dashboard.js `loadPage` method calls page-specific init. Let me check if there's a page registry pattern.
|
||||
|
||||
Actually, based on the index.html, pages are loaded as separate script files and the dashboard dispatches to them. The pattern seems to be: each page script defines a class or object, and the dashboard calls a `render()` or `init()` method on it when that page is selected. Let me add the dispatch logic.
|
||||
|
||||
In `dashboard.js`, find the `loadPage` method and ensure it handles "model-groups":
|
||||
|
||||
```javascript
|
||||
// In the loadPage switch/if-else, add:
|
||||
else if (page === 'model-groups') {
|
||||
if (typeof modelGroupsPage !== 'undefined') {
|
||||
modelGroupsPage.render();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add static/index.html static/js/pages/model_groups.js static/js/dashboard.js
|
||||
git commit -m "feat: add model groups dashboard page with CRUD UI"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 7: Integration test — build, run, verify
|
||||
|
||||
**Objective:** Ensure everything compiles and the routing works end-to-end.
|
||||
|
||||
**Step 1: Full build**
|
||||
|
||||
```bash
|
||||
cd ~/Documents/projects/web_projects/gophergate && go build -o gophergate ./cmd/gophergate
|
||||
```
|
||||
|
||||
**Step 2: Start server and test**
|
||||
|
||||
```bash
|
||||
# In one terminal:
|
||||
./gophergate
|
||||
|
||||
# In another terminal, test that default groups loaded:
|
||||
curl -s -u admin:admin123 http://localhost:8080/api/model-groups | jq
|
||||
|
||||
# Expected: array with deepseek-auto and openai-auto groups
|
||||
```
|
||||
|
||||
**Step 3: Test routing via API**
|
||||
|
||||
```bash
|
||||
# Send a request using a model group
|
||||
curl -s http://localhost:8080/v1/chat/completions \
|
||||
-H "Authorization: Bearer YOUR_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "openai-auto",
|
||||
"messages": [{"role": "user", "content": "What is 2+2?"}]
|
||||
}' | jq
|
||||
|
||||
# Check server logs for [ROUTER] line showing the decision
|
||||
```
|
||||
|
||||
**Step 4: Commit any fixes**
|
||||
|
||||
If any issues found during testing, fix and commit.
|
||||
|
||||
---
|
||||
|
||||
## Architecture Notes
|
||||
|
||||
### Why this approach
|
||||
|
||||
- **No Provider interface changes** — the router is a pre-processing step in the handler, transparent to providers
|
||||
- **Groups stored in DB** — manageable from the dashboard, no config file sprawl
|
||||
- **Classifier is optional** — heuristic mode works with zero added latency or cost
|
||||
- **Fallback chain** — classifier failure falls back to heuristic; missing router falls back to direct passthrough
|
||||
|
||||
### Edge cases handled
|
||||
|
||||
- No groups defined → router never activates, all models pass through as before
|
||||
- Unknown group ID → returns error to client
|
||||
- Empty targets → returns error
|
||||
- Classifier call fails → falls back to heuristic
|
||||
- Classifier returns garbage → clamped to valid range
|
||||
- OpenAI provider disabled → classifier groups fall back to heuristic mode
|
||||
|
||||
### What's NOT in this plan (future work)
|
||||
|
||||
- Streaming classifier support (the ~300ms classifier call happens before streaming begins — acceptable for now)
|
||||
- responses endpoint routing (`handleResponses` could also use the router but needs a different message extraction)
|
||||
- Per-client group overrides
|
||||
- A/B testing / multi-armed bandit routing
|
||||
- Caching classifier decisions for identical messages
|
||||
@@ -46,6 +46,9 @@ Implements HMAC-SHA256 signed tokens for dashboard authentication. Tokens secure
|
||||
### 5. WebSocket Hub (`internal/server/websocket.go`)
|
||||
A centralized hub for managing WebSocket connections, allowing real-time broadcast of system events, system metrics, and request logs to the dashboard.
|
||||
|
||||
### 6. Model Group Router (`internal/router/`)
|
||||
Resolves model groups (e.g., `deepseek-auto`, `dustins_stack`) into concrete models. It supports a Classifier strategy (uses a cheap LLM to rate complexity) and an upgraded Heuristic strategy (evaluates custom condition rules like tags, token counts, multimodal inputs, reasoning, and tool calling flags or legacy keyword patterns).
|
||||
|
||||
## Concurrency Model
|
||||
|
||||
Go's goroutines and channels are used extensively:
|
||||
|
||||
@@ -7,11 +7,11 @@ A unified, high-performance LLM proxy gateway built in Go. It provides OpenAI-co
|
||||
- **Unified API:** OpenAI-compatible `/v1/chat/completions`, `/v1/images/generations`, `/v1/responses`, and `/v1/models` endpoints.
|
||||
- The `/v1/responses` endpoint (OpenAI Responses API) is currently supported for OpenAI models only. Non-OpenAI providers (Gemini, DeepSeek, Moonshot, Grok, Ollama) return a "not supported" response.
|
||||
- **Multi-Provider Support:**
|
||||
- **OpenAI:** GPT-4o, GPT-4o Mini, o1, o3 reasoning models, DALL-E 2/3 image generation.
|
||||
- **Google Gemini:** Gemini 2.0 Flash, Pro, and vision models (with native CoT support), Imagen 3 image generation.
|
||||
- **DeepSeek:** DeepSeek Chat and Reasoner (R1) models.
|
||||
- **Moonshot:** Kimi K2.5 and other Kimi models.
|
||||
- **xAI Grok:** Grok-4 models.
|
||||
- **OpenAI:** GPT-4o, GPT-4o Mini, GPT-5, GPT-5.4, o1/o3/o4 reasoning models, DALL-E 2/3 image generation.
|
||||
- **Google Gemini:** Gemini 2.5 Flash/Pro, Gemini 3 Flash/Pro previews, Imagen 3 image generation.
|
||||
- **DeepSeek:** DeepSeek Chat, Reasoner, V4 Flash, V4 Pro.
|
||||
- **Moonshot:** Kimi K2.5, K2.6 reasoning models.
|
||||
- **xAI Grok:** Grok-3, Grok-4, Grok-4.3 reasoning models.
|
||||
- **Ollama:** Local LLMs running on your network.
|
||||
- **Observability & Tracking:**
|
||||
- **Asynchronous Logging:** Non-blocking request logging to SQLite using background workers.
|
||||
@@ -20,13 +20,24 @@ A unified, high-performance LLM proxy gateway built in Go. It provides OpenAI-co
|
||||
- **Streaming Support:** Full SSE (Server-Sent Events) support for all providers.
|
||||
- **Multimodal (Vision):** Image processing (Base64 and remote URLs) across compatible providers.
|
||||
- **Image Generation:** DALL-E 2/3 (OpenAI) and Imagen 3 (Gemini) via OpenAI-compatible `/v1/images/generations` endpoint.
|
||||
- **Automatic Model Routing:**
|
||||
- **Hierarchical Routing:** Groups can target other groups, cascading through multiple levels until a concrete model is reached. Cycle detection and depth limiting (max 10) prevent infinite loops.
|
||||
- **Heuristic strategy:** Free, zero-latency routing supporting both keyword matching (regex/substrings) and condition-based checks (evaluating tags, token limits, multimodal inputs, reasoning, and tool calling requirements).
|
||||
- **Classifier strategy:** Uses a cheap LLM to rate task complexity on a configurable scale (1-10), then selects the appropriate model. Bucket mapping distributes ratings proportionally across targets.
|
||||
- **Two-Level Dispatch:** A `dispatcher` group (classifier, threshold=10) auto-routes to tier groups by complexity score, which then apply their own internal strategies.
|
||||
- **Metadata:** Groups support `logic_level` (1-10 complexity scale) and `primary_use` (description) fields for organizational clarity.
|
||||
- Pre-seeded with provider groups, tier groups (heavy-logic / standard-pro / fast-flow), and a dispatcher. Model groups are exposed in `/v1/models` so clients can discover them.
|
||||
- **Multi-User Access Control:**
|
||||
- **Admin Role:** Full access to all dashboard features, user management, and system configuration.
|
||||
- **Viewer Role:** Read-only access to usage analytics, costs, and monitoring.
|
||||
- **Client API Keys:** Create and manage multiple client tokens for external integrations.
|
||||
- **Reliability:**
|
||||
- **Circuit Breaking:** Automatically protects when providers are down (coming soon).
|
||||
- **Rate Limiting:** Per-client and global rate limits (coming soon).
|
||||
- **Circuit Breaking:** Protects providers when they are down, auto-recovers after timeout.
|
||||
- **Provider-Aware Classification:** Classifier selector models are routed to the correct provider automatically.
|
||||
|
||||
## DeepSeek Language Note
|
||||
|
||||
DeepSeek models default to Chinese for some prompts. GopherGate automatically injects an English system prompt ("Always respond in English.") when no system message is present. If the client provides its own system prompt, it is left untouched.
|
||||
|
||||
## Security
|
||||
|
||||
@@ -71,7 +82,9 @@ GopherGate is designed with security in mind:
|
||||
# LLM_PROXY__ENCRYPTION_KEY=... (32-byte hex or base64 string)
|
||||
# OPENAI_API_KEY=sk-...
|
||||
# GEMINI_API_KEY=AIza...
|
||||
# DEEPSEEK_API_KEY=sk-...
|
||||
# MOONSHOT_API_KEY=...
|
||||
# GROK_API_KEY=xai-...
|
||||
# For Ollama (optional): Set base URL and enable
|
||||
# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://localhost:11434/v1
|
||||
# LLM_PROXY__PROVIDERS__OLLAMA__ENABLED=true
|
||||
@@ -83,7 +96,16 @@ GopherGate is designed with security in mind:
|
||||
./gophergate
|
||||
```
|
||||
|
||||
The server starts on `http://0.0.0.0:8080` by default.
|
||||
The server starts on `http://0.0.0.0:8080` by default. Configure `LLM_PROXY__SERVER__PORT` in `.env` to change it.
|
||||
|
||||
### Quick Deploy Script
|
||||
|
||||
A `deploy.sh` script is included for production restarts:
|
||||
|
||||
```bash
|
||||
./deploy.sh
|
||||
# git pull -> go build -> stop old process -> start new process
|
||||
```
|
||||
|
||||
### Deployment (Docker)
|
||||
|
||||
@@ -106,6 +128,8 @@ Access the dashboard at `http://localhost:8080`.
|
||||
- **Usage:** Summary stats, time-series analytics, and provider breakdown.
|
||||
- **Clients:** API key management and per-client usage tracking.
|
||||
- **Providers:** Provider configuration and status monitoring.
|
||||
- **Model Groups:** Define auto-routing groups with heuristic or classifier strategies. Supports logic level and primary use metadata.
|
||||
- **Models:** Model enable/disable and cost configuration.
|
||||
- **Users:** Admin-only user management for dashboard access.
|
||||
- **Monitoring:** Live request stream via WebSocket.
|
||||
|
||||
@@ -125,14 +149,6 @@ You can reset the admin password to default by running:
|
||||
|
||||
The proxy is a drop-in replacement for OpenAI. Configure your client:
|
||||
|
||||
Moonshot models are available through the same OpenAI-compatible endpoint. For
|
||||
example, use `kimi-k2.5` as the model name after setting `MOONSHOT_API_KEY` in
|
||||
your environment.
|
||||
|
||||
Ollama models (like `llama3`, `gemma2`, `mistral`) are also available through the same
|
||||
endpoint after enabling Ollama in configuration and setting the base URL to your
|
||||
Ollama server (default: `http://localhost:11434/v1`).
|
||||
|
||||
### Python
|
||||
|
||||
```python
|
||||
@@ -170,7 +186,60 @@ response = client.responses.create(
|
||||
print(response.output_text)
|
||||
```
|
||||
|
||||
**Note:** The `/v1/responses` endpoint is currently supported for OpenAI models only. Requests routed to Gemini, DeepSeek, Moonshot, Grok, or Ollama models return a "not supported" error.
|
||||
**Note:** The `/v1/responses` endpoint is currently supported for OpenAI models only.
|
||||
|
||||
### Automatic Model Routing
|
||||
|
||||
Use a model group name to let gophergate pick the best model automatically:
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1",
|
||||
api_key="YOUR_CLIENT_API_KEY"
|
||||
)
|
||||
|
||||
# Simple query -- routes to the cheap/fast model
|
||||
response = client.chat.completions.create(
|
||||
model="fast-flow",
|
||||
messages=[{"role": "user", "content": "What is 2+2?"}]
|
||||
)
|
||||
|
||||
# Complex query -- routes to the reasoning model automatically
|
||||
response = client.chat.completions.create(
|
||||
model="heavy-logic",
|
||||
messages=[{"role": "user", "content": "Write a Python red-black tree implementation."}]
|
||||
)
|
||||
```
|
||||
|
||||
### Two-Level Dispatch
|
||||
|
||||
The `dispatcher` group uses a classifier to score prompts 1-10, then routes to the appropriate tier group:
|
||||
|
||||
```python
|
||||
# Automatically routed based on complexity:
|
||||
# 1-3 -> fast-flow (classification, basic Q&A)
|
||||
# 4-7 -> standard-pro (general assistant, long docs)
|
||||
# 8-10 -> heavy-logic (complex coding, logic, agents)
|
||||
response = client.chat.completions.create(
|
||||
model="dispatcher",
|
||||
messages=[{"role": "user", "content": "Debug this race condition in my Go code."}]
|
||||
)
|
||||
# This goes: dispatcher -> heavy-logic -> deepseek-v4-pro
|
||||
```
|
||||
|
||||
Pre-seeded groups:
|
||||
|
||||
| Group | Level | Strategy | Targets | Primary Use |
|
||||
|-------|-------|----------|---------|-------------|
|
||||
| `fast-flow` | 2 | heuristic | deepseek-v4-flash, gpt-5.4-nano | Classification, JSON, Basic Q&A |
|
||||
| `standard-pro` | 5 | heuristic | gpt-5.4-mini, gemini-3-flash-preview | General Assistant, Long Docs |
|
||||
| `heavy-logic` | 9 | heuristic | grok-4.3, kimi-k2.6, deepseek-v4-pro | Complex Coding, Logic, Agents |
|
||||
| `dispatcher` | - | classifier | fast-flow, standard-pro, heavy-logic | Auto-dispatches by complexity |
|
||||
| `deepseek-auto` | - | heuristic | deepseek-chat, deepseek-reasoner | Legacy provider group |
|
||||
| `openai-auto` | - | heuristic | gpt-4o-mini, gpt-4o | Legacy provider group |
|
||||
| `gemini-auto` | - | heuristic | gemini-2.0-flash, gemini-2.5-pro | Legacy provider group |
|
||||
|
||||
### Image Generation (DALL-E / Imagen)
|
||||
|
||||
@@ -191,7 +260,7 @@ resp = client.images.generate(
|
||||
)
|
||||
print(resp.data[0].url)
|
||||
|
||||
# Imagen 3 (Gemini) — uses same endpoint
|
||||
# Imagen 3 (Gemini) -- uses same endpoint
|
||||
resp = client.images.generate(
|
||||
model="imagen-3.0-generate-001",
|
||||
prompt="A gopher coding in Go",
|
||||
|
||||
@@ -15,11 +15,24 @@
|
||||
- [x] Dashboard Analytics & Usage Summary (Fixed SQL robustness)
|
||||
- [x] WebSocket for real-time dashboard updates (Hub with client counting)
|
||||
- [x] Asynchronous Request Logging to SQLite
|
||||
- [x] Update documentation (README, deployment, architecture)
|
||||
- [x] Cost Tracking accuracy (Registry integration with `models.dev`)
|
||||
- [x] Model Listing endpoint (`/v1/models`) with provider filtering
|
||||
- [x] System Metrics endpoint (`/api/system/metrics` using `gopsutil`)
|
||||
- [x] Fixed dashboard 404s and 500s
|
||||
- [x] Model groups with heuristic and classifier routing strategies
|
||||
- [x] Hierarchical routing — groups can target other groups with cycle detection
|
||||
- [x] Classifier bucket mapping via complexity_threshold (1-10 scale -> N targets)
|
||||
- [x] Two-level dispatch — classifier router delegates to tier groups
|
||||
- [x] Model groups exposed in /v1/models endpoint (owned_by: gophergate)
|
||||
- [x] logic_level and primary_use metadata on model groups
|
||||
- [x] Model group CRUD dashboard page
|
||||
- [x] dispatcher, heavy-logic, standard-pro, fast-flow seed groups
|
||||
- [x] Provider selection moved after routing resolution (fixes group routing)
|
||||
- [x] Classifier selector model routed to correct provider (selectProvider)
|
||||
- [x] DeepSeek English system prompt injection (ensureEnglish)
|
||||
- [x] Deploy script (deploy.sh)
|
||||
- [x] Recent Activity pane shows resolved model + group annotation
|
||||
- [x] Model names aligned with models.dev registry
|
||||
|
||||
## Planned Resolutions (High Priority)
|
||||
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Define the service name/path for easy updates
|
||||
BINARY_NAME="gophergate"
|
||||
SOURCE_PATH="./cmd/gophergate/main.go"
|
||||
|
||||
echo "Stopping existing $BINARY_NAME processes..."
|
||||
# Using pkill; || true ensures the script continues even if no process was found
|
||||
pkill -9 "$BINARY_NAME" || echo "No running process found."
|
||||
|
||||
echo "Pulling latest changes from git..."
|
||||
git pull
|
||||
|
||||
echo "Building the application..."
|
||||
if go build -o "$BINARY_NAME" "$SOURCE_PATH"; then
|
||||
echo "Build successful. Starting $BINARY_NAME in the background..."
|
||||
# Launch with nohup and redirect output to a log file
|
||||
nohup "./$BINARY_NAME" > gophergate.log 2>&1 &
|
||||
echo "Service started. PID: $!"
|
||||
else
|
||||
echo "Build failed! Keeping the previous state."
|
||||
exit 1
|
||||
fi
|
||||
@@ -26,6 +26,22 @@ go build -o gophergate ./cmd/gophergate
|
||||
./gophergate
|
||||
```
|
||||
|
||||
### Quick Deploy Script
|
||||
|
||||
A `deploy.sh` script is provided for production restarts:
|
||||
|
||||
```bash
|
||||
./deploy.sh
|
||||
```
|
||||
|
||||
This script will:
|
||||
1. Stop any running gophergate process
|
||||
2. Pull latest changes from git
|
||||
3. Build the application
|
||||
4. Start it in the background (logs to `gophergate.log`)
|
||||
|
||||
If the build fails, the previous binary is left untouched and the script exits.
|
||||
|
||||
## Docker Deployment
|
||||
|
||||
The project includes a multi-stage `Dockerfile` for minimal image size.
|
||||
@@ -50,3 +66,4 @@ docker run -d \
|
||||
- **SSL/TLS:** It is recommended to run the proxy behind a reverse proxy like Nginx or Caddy for SSL termination.
|
||||
- **Backups:** Regularly backup the `data/llm_proxy.db` file.
|
||||
- **Monitoring:** Monitor the `/health` endpoint for system status.
|
||||
- **Logs:** When started with `deploy.sh` or `nohup`, logs are written to `gophergate.log`.
|
||||
|
||||
@@ -37,6 +37,7 @@ type ProviderConfig struct {
|
||||
Moonshot MoonshotConfig `mapstructure:"moonshot"`
|
||||
Grok GrokConfig `mapstructure:"grok"`
|
||||
Ollama OllamaConfig `mapstructure:"ollama"`
|
||||
Xiaomi XiaomiConfig `mapstructure:"xiaomi"`
|
||||
}
|
||||
|
||||
type OpenAIConfig struct {
|
||||
@@ -81,6 +82,13 @@ type OllamaConfig struct {
|
||||
Models []string `mapstructure:"models"`
|
||||
}
|
||||
|
||||
type XiaomiConfig struct {
|
||||
APIKeyEnv string `mapstructure:"api_key_env"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
DefaultModel string `mapstructure:"default_model"`
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
v := viper.New()
|
||||
|
||||
@@ -120,6 +128,11 @@ func Load() (*Config, error) {
|
||||
v.SetDefault("providers.ollama.enabled", false)
|
||||
v.SetDefault("providers.ollama.models", []string{})
|
||||
|
||||
v.SetDefault("providers.xiaomi.api_key_env", "XIAOMI_API_KEY")
|
||||
v.SetDefault("providers.xiaomi.base_url", "https://api.xiaomimimo.com/v1")
|
||||
v.SetDefault("providers.xiaomi.default_model", "mimo-v2.5")
|
||||
v.SetDefault("providers.xiaomi.enabled", true)
|
||||
|
||||
// Environment variables
|
||||
v.SetEnvPrefix("LLM_PROXY")
|
||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "__"))
|
||||
@@ -210,6 +223,8 @@ func (c *Config) GetAPIKey(provider string) (string, error) {
|
||||
case "ollama":
|
||||
// Ollama doesn't require an API key
|
||||
return "", nil
|
||||
case "xiaomi":
|
||||
envVar = c.Providers.Xiaomi.APIKeyEnv
|
||||
default:
|
||||
return "", fmt.Errorf("unknown provider: %s", provider)
|
||||
}
|
||||
|
||||
@@ -32,6 +32,14 @@ func Init(path string) (*DB, error) {
|
||||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
// Enable Write-Ahead Logging (WAL) and set a busy timeout to handle concurrent access
|
||||
if _, err := db.Exec("PRAGMA journal_mode=WAL;"); err != nil {
|
||||
log.Printf("failed to enable WAL mode: %v", err)
|
||||
}
|
||||
if _, err := db.Exec("PRAGMA busy_timeout=5000;"); err != nil {
|
||||
log.Printf("failed to set busy timeout: %v", err)
|
||||
}
|
||||
|
||||
instance := &DB{db}
|
||||
|
||||
// Run migrations
|
||||
@@ -122,6 +130,18 @@ func (db *DB) RunMigrations() error {
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
last_used_at DATETIME,
|
||||
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE CASCADE
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS model_groups (
|
||||
id TEXT PRIMARY KEY,
|
||||
strategy TEXT NOT NULL DEFAULT 'heuristic',
|
||||
selector_model TEXT,
|
||||
targets TEXT NOT NULL DEFAULT '[]',
|
||||
complexity_threshold INTEGER,
|
||||
heuristic_rules TEXT,
|
||||
logic_level INTEGER,
|
||||
primary_use TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)`,
|
||||
}
|
||||
|
||||
@@ -152,6 +172,10 @@ func (db *DB) RunMigrations() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Add columns to existing model_groups tables (safe — SQLite ignores duplicates on error)
|
||||
db.Exec("ALTER TABLE model_groups ADD COLUMN logic_level INTEGER")
|
||||
db.Exec("ALTER TABLE model_groups ADD COLUMN primary_use TEXT")
|
||||
|
||||
// Default admin user
|
||||
var count int
|
||||
if err := db.Get(&count, "SELECT COUNT(*) FROM users"); err != nil {
|
||||
@@ -177,6 +201,25 @@ func (db *DB) RunMigrations() error {
|
||||
return fmt.Errorf("failed to insert default client: %w", err)
|
||||
}
|
||||
|
||||
// Seed default model groups
|
||||
defaultGroups := []struct {
|
||||
id, strategy, targets, selectorModel string
|
||||
complexityThreshold, logicLevel *int
|
||||
primaryUse *string
|
||||
}{
|
||||
{"deepseek-auto", "heuristic", `["deepseek-chat","deepseek-reasoner"]`, "", nil, nil, nil},
|
||||
{"openai-auto", "heuristic", `["gpt-4o-mini","gpt-4o"]`, "", nil, nil, nil},
|
||||
{"gemini-auto", "heuristic", `["gemini-2.0-flash","gemini-2.5-pro"]`, "", nil, nil, nil},
|
||||
{"heavy-logic", "heuristic", `["grok-4.3","kimi-k2.6","deepseek-v4-pro"]`, "", nil, intPtr(9), strPtr("Complex Coding, Logic, Agents.")},
|
||||
{"standard-pro", "heuristic", `["gpt-5.4-mini","gemini-3-flash-preview"]`, "", nil, intPtr(5), strPtr("General Assistant, Long Docs.")},
|
||||
{"fast-flow", "heuristic", `["deepseek-v4-flash","gpt-5.4-nano"]`, "", nil, intPtr(2), strPtr("Classification, JSON, Basic Q&A.")},
|
||||
{"dispatcher", "classifier", `["fast-flow","standard-pro","heavy-logic"]`, "gpt-5.4-nano", intPtr(10), nil, strPtr("Auto-dispatches to tier groups by complexity.")},
|
||||
}
|
||||
for _, g := range defaultGroups {
|
||||
db.Exec(`INSERT OR IGNORE INTO model_groups (id, strategy, targets, selector_model, complexity_threshold, logic_level, primary_use) VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
g.id, g.strategy, g.targets, nilStr(g.selectorModel), g.complexityThreshold, g.logicLevel, g.primaryUse)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -262,3 +305,27 @@ type ClientToken struct {
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
LastUsedAt *time.Time `db:"last_used_at"`
|
||||
}
|
||||
|
||||
type ModelGroup struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
Strategy string `db:"strategy" json:"strategy"`
|
||||
SelectorModel *string `db:"selector_model" json:"selector_model"`
|
||||
Targets string `db:"targets" json:"targets"` // JSON array
|
||||
ComplexityThreshold *int `db:"complexity_threshold" json:"complexity_threshold"`
|
||||
HeuristicRules *string `db:"heuristic_rules" json:"heuristic_rules"`
|
||||
LogicLevel *int `db:"logic_level" json:"logic_level"`
|
||||
PrimaryUse *string `db:"primary_use" json:"primary_use"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
func intPtr(v int) *int { return &v }
|
||||
func strPtr(v string) *string { return &v }
|
||||
|
||||
// nilStr returns a *string for non-empty strings, nil for empty.
|
||||
func nilStr(v string) *string {
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return &v
|
||||
}
|
||||
|
||||
@@ -14,9 +14,21 @@ import (
|
||||
func AuthMiddleware(database *db.DB, requireAuth bool) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
// Fallback to checking "Authentication" header in case the client library used the wrong name
|
||||
authHeader = c.GetHeader("Authentication")
|
||||
}
|
||||
|
||||
if authHeader == "" {
|
||||
if requireAuth {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"})
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Missing Authorization or Authentication header.",
|
||||
"type": "invalid_request_error",
|
||||
"param": nil,
|
||||
"code": "401",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
@@ -25,23 +37,50 @@ func AuthMiddleware(database *db.DB, requireAuth bool) gin.HandlerFunc {
|
||||
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if token == authHeader { // No "Bearer " prefix
|
||||
if requireAuth {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Invalid authorization header format. Bearer token required.",
|
||||
"type": "invalid_request_error",
|
||||
"param": nil,
|
||||
"code": "401",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Try to resolve client from database
|
||||
// Try to resolve client from database with a read-only SELECT
|
||||
var clientID string
|
||||
err := database.Get(&clientID, "UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = 1 RETURNING client_id", token)
|
||||
err := database.Get(&clientID, "SELECT client_id FROM client_tokens WHERE token = ? AND is_active = 1", token)
|
||||
|
||||
if err == nil {
|
||||
c.Set("auth", models.AuthInfo{
|
||||
Token: token,
|
||||
ClientID: clientID,
|
||||
})
|
||||
|
||||
// Update last_used_at asynchronously so that database locks or write delays
|
||||
// do not block or fail the client's request authentication.
|
||||
go func(t string) {
|
||||
if _, updateErr := database.Exec("UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ?", t); updateErr != nil {
|
||||
log.Printf("Warning: failed to update client token last_used_at: %v", updateErr)
|
||||
}
|
||||
}(token)
|
||||
|
||||
c.Next()
|
||||
} else {
|
||||
log.Printf("Token not found or inactive in DB: %s", token)
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid or inactive token"})
|
||||
log.Printf("Token not found, inactive or error in DB: %s (err: %v)", token, err)
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Invalid or inactive client token.",
|
||||
"type": "invalid_request_error",
|
||||
"param": nil,
|
||||
"code": "401",
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+151
-18
@@ -2,6 +2,25 @@ package models
|
||||
|
||||
import "strings"
|
||||
|
||||
// CanonicalProviders lists the original model creators in priority order.
|
||||
// When a model name exists in multiple providers (e.g. deepseek-v4-pro in
|
||||
// deepseek, ollama-cloud, openrouter, etc.), these providers take precedence
|
||||
// so the proxy uses authoritative metadata (pricing, limits) rather than a
|
||||
// reseller's values.
|
||||
var CanonicalProviders = []string{
|
||||
"openai",
|
||||
"google",
|
||||
"deepseek",
|
||||
"xai",
|
||||
"moonshotai",
|
||||
"moonshotai-cn",
|
||||
"anthropic",
|
||||
"mistral",
|
||||
"cohere",
|
||||
"minimax",
|
||||
"xiaomi",
|
||||
}
|
||||
|
||||
type ModelRegistry struct {
|
||||
Providers map[string]ProviderInfo `json:"-"`
|
||||
}
|
||||
@@ -39,40 +58,154 @@ type ModelModalities struct {
|
||||
Output []string `json:"output"`
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) FindModel(modelID string) *ModelMetadata {
|
||||
// First try exact match in models map
|
||||
for _, provider := range r.Providers {
|
||||
if model, ok := provider.Models[modelID]; ok {
|
||||
return &model
|
||||
// findInCanonical searches the canonical providers in order for an exact model
|
||||
// key match. Returns the metadata and true if found.
|
||||
func (r *ModelRegistry) findInCanonical(modelID string) (*ModelMetadata, bool) {
|
||||
for _, key := range CanonicalProviders {
|
||||
if p, ok := r.Providers[key]; ok {
|
||||
if m, ok := p.Models[modelID]; ok {
|
||||
return &m, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Try searching by ID in metadata
|
||||
for _, provider := range r.Providers {
|
||||
for _, model := range provider.Models {
|
||||
if model.ID == modelID {
|
||||
return &model
|
||||
// findInAll searches all providers (map iteration, random order) for an exact
|
||||
// model key match. Used as fallback when canonical search fails.
|
||||
func (r *ModelRegistry) findInAll(modelID string) (*ModelMetadata, bool) {
|
||||
for _, p := range r.Providers {
|
||||
if m, ok := p.Models[modelID]; ok {
|
||||
return &m, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Try reverse fuzzy matching (e.g. 'gpt-5.4-mini' matching 'gpt-5.4-mini-2026-04-01')
|
||||
for _, provider := range r.Providers {
|
||||
for id, model := range provider.Models {
|
||||
// findInCanonicalByID searches canonical providers for a model whose metadata
|
||||
// ID field matches modelID.
|
||||
func (r *ModelRegistry) findInCanonicalByID(modelID string) (*ModelMetadata, bool) {
|
||||
for _, key := range CanonicalProviders {
|
||||
if p, ok := r.Providers[key]; ok {
|
||||
for _, m := range p.Models {
|
||||
if m.ID == modelID {
|
||||
return &m, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findInAllByID searches all providers for a model whose metadata ID field
|
||||
// matches modelID.
|
||||
func (r *ModelRegistry) findInAllByID(modelID string) (*ModelMetadata, bool) {
|
||||
for _, p := range r.Providers {
|
||||
for _, m := range p.Models {
|
||||
if m.ID == modelID {
|
||||
return &m, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findCanonicalReverseFuzzy searches canonical providers for any model whose
|
||||
// key starts with modelID.
|
||||
func (r *ModelRegistry) findCanonicalReverseFuzzy(modelID string) (*ModelMetadata, bool) {
|
||||
for _, key := range CanonicalProviders {
|
||||
if p, ok := r.Providers[key]; ok {
|
||||
for id, m := range p.Models {
|
||||
if strings.HasPrefix(id, modelID) {
|
||||
return &model
|
||||
return &m, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Try fuzzy matching (e.g. 'gpt-4o-2024-05-13' matching 'gpt-4o')
|
||||
for _, provider := range r.Providers {
|
||||
for id, model := range provider.Models {
|
||||
// findAllReverseFuzzy searches all providers for any model whose key starts
|
||||
// with modelID.
|
||||
func (r *ModelRegistry) findAllReverseFuzzy(modelID string) (*ModelMetadata, bool) {
|
||||
for _, p := range r.Providers {
|
||||
for id, m := range p.Models {
|
||||
if strings.HasPrefix(id, modelID) {
|
||||
return &m, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findCanonicalForwardFuzzy searches canonical providers for any model whose
|
||||
// key is a prefix of modelID.
|
||||
func (r *ModelRegistry) findCanonicalForwardFuzzy(modelID string) (*ModelMetadata, bool) {
|
||||
for _, key := range CanonicalProviders {
|
||||
if p, ok := r.Providers[key]; ok {
|
||||
for id, m := range p.Models {
|
||||
if strings.HasPrefix(modelID, id) {
|
||||
return &model
|
||||
return &m, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findAllForwardFuzzy searches all providers for any model whose key is a
|
||||
// prefix of modelID.
|
||||
func (r *ModelRegistry) findAllForwardFuzzy(modelID string) (*ModelMetadata, bool) {
|
||||
for _, p := range r.Providers {
|
||||
for id, m := range p.Models {
|
||||
if strings.HasPrefix(modelID, id) {
|
||||
return &m, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// FindModel looks up model metadata by ID. It searches canonical providers
|
||||
// first at each strategy level (exact key, metadata ID, reverse fuzzy,
|
||||
// forward fuzzy) and falls back to all providers only when canonical search
|
||||
// yields no result. This prevents reseller entries (ollama-cloud, openrouter,
|
||||
// etc.) from overriding the original provider's authoritative pricing and
|
||||
// limits.
|
||||
func (r *ModelRegistry) FindModel(modelID string) *ModelMetadata {
|
||||
// 1. Exact key match — canonical first, then all
|
||||
if m, ok := r.findInCanonical(modelID); ok {
|
||||
return m
|
||||
}
|
||||
if m, ok := r.findInAll(modelID); ok {
|
||||
return m
|
||||
}
|
||||
|
||||
// 2. Match by metadata ID field — canonical first, then all
|
||||
if m, ok := r.findInCanonicalByID(modelID); ok {
|
||||
return m
|
||||
}
|
||||
if m, ok := r.findInAllByID(modelID); ok {
|
||||
return m
|
||||
}
|
||||
|
||||
// 3. Reverse fuzzy: model key starts with modelID
|
||||
// e.g. 'gpt-5.4-mini' matching 'gpt-5.4-mini-2026-04-01'
|
||||
if m, ok := r.findCanonicalReverseFuzzy(modelID); ok {
|
||||
return m
|
||||
}
|
||||
if m, ok := r.findAllReverseFuzzy(modelID); ok {
|
||||
return m
|
||||
}
|
||||
|
||||
// 4. Forward fuzzy: modelID starts with model key
|
||||
// e.g. 'gpt-4o-2024-05-13' matching 'gpt-4o'
|
||||
if m, ok := r.findCanonicalForwardFuzzy(modelID); ok {
|
||||
return m
|
||||
}
|
||||
if m, ok := r.findAllForwardFuzzy(modelID); ok {
|
||||
return m
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -59,6 +59,35 @@ func TestModelRegistry_FindModel_NotFound(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRegistry_FindModel_CanonicalPriority(t *testing.T) {
|
||||
// Same model name in canonical (deepseek) and reseller (ollama-cloud).
|
||||
// Canonical must win so the proxy uses authoritative limits.
|
||||
r := &ModelRegistry{
|
||||
Providers: map[string]ProviderInfo{
|
||||
"ollama-cloud": {
|
||||
Models: map[string]ModelMetadata{
|
||||
"deepseek-v4-pro": {ID: "deepseek-v4-pro", Name: "DSv4 Pro (Ollama Cloud)", Limit: &ModelLimit{Context: 1048576, Output: 1048576}},
|
||||
},
|
||||
},
|
||||
"deepseek": {
|
||||
Models: map[string]ModelMetadata{
|
||||
"deepseek-v4-pro": {ID: "deepseek-v4-pro", Name: "DeepSeek v4 Pro", Limit: &ModelLimit{Context: 1000000, Output: 384000}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
m := r.FindModel("deepseek-v4-pro")
|
||||
if m == nil {
|
||||
t.Fatal("expected to find deepseek-v4-pro")
|
||||
}
|
||||
if m.Name != "DeepSeek v4 Pro" {
|
||||
t.Fatalf("expected DeepSeek v4 Pro (canonical), got %s", m.Name)
|
||||
}
|
||||
if m.Limit.Output != 384000 {
|
||||
t.Fatalf("expected output limit 384000 (canonical), got %d", m.Limit.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRegistry_FindModel_ReverseFuzzy(t *testing.T) {
|
||||
r := &ModelRegistry{
|
||||
Providers: map[string]ProviderInfo{
|
||||
|
||||
@@ -62,6 +62,9 @@ func (u *deepSeekUsage) ToUnified() *models.Usage {
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
// Ensure English responses — DeepSeek defaults to Chinese for some prompts
|
||||
ensureEnglish(req)
|
||||
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
@@ -69,17 +72,24 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||
|
||||
// Sanitize for deepseek-reasoner
|
||||
if req.Model == "deepseek-reasoner" {
|
||||
// Sanitize for models that support reasoning/thinking mode
|
||||
isReasoner := strings.Contains(req.Model, "reasoner") || strings.Contains(req.Model, "v4") || strings.Contains(req.Model, "r1")
|
||||
|
||||
if isReasoner {
|
||||
// deepseek-reasoner (R1) does not support these parameters
|
||||
if req.Model == "deepseek-reasoner" || strings.HasPrefix(req.Model, "deepseek-r1") {
|
||||
delete(body, "temperature")
|
||||
delete(body, "top_p")
|
||||
delete(body, "presence_penalty")
|
||||
delete(body, "frequency_penalty")
|
||||
}
|
||||
|
||||
if msgs, ok := body["messages"].([]interface{}); ok {
|
||||
for _, m := range msgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
if msg["role"] == "assistant" {
|
||||
// DeepSeek requires reasoning_content to be passed back in history
|
||||
// if the model is in thinking mode.
|
||||
if msg["reasoning_content"] == nil {
|
||||
msg["reasoning_content"] = ""
|
||||
}
|
||||
@@ -103,7 +113,15 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
var msg string
|
||||
if resp.RawBody() != nil {
|
||||
bodyBytes, _ := io.ReadAll(resp.RawBody())
|
||||
msg = string(bodyBytes)
|
||||
}
|
||||
if msg == "" {
|
||||
msg = resp.String()
|
||||
}
|
||||
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
@@ -129,6 +147,8 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
ensureEnglish(req)
|
||||
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
@@ -136,17 +156,24 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||
|
||||
// Sanitize for deepseek-reasoner
|
||||
if req.Model == "deepseek-reasoner" {
|
||||
// Sanitize for models that support reasoning/thinking mode
|
||||
isReasoner := strings.Contains(req.Model, "reasoner") || strings.Contains(req.Model, "v4") || strings.Contains(req.Model, "r1")
|
||||
|
||||
if isReasoner {
|
||||
// deepseek-reasoner (R1) does not support these parameters
|
||||
if req.Model == "deepseek-reasoner" || strings.HasPrefix(req.Model, "deepseek-r1") {
|
||||
delete(body, "temperature")
|
||||
delete(body, "top_p")
|
||||
delete(body, "presence_penalty")
|
||||
delete(body, "frequency_penalty")
|
||||
}
|
||||
|
||||
if msgs, ok := body["messages"].([]interface{}); ok {
|
||||
for _, m := range msgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
if msg["role"] == "assistant" {
|
||||
// DeepSeek requires reasoning_content to be passed back in history
|
||||
// if the model is in thinking mode.
|
||||
if msg["reasoning_content"] == nil {
|
||||
msg["reasoning_content"] = ""
|
||||
}
|
||||
@@ -171,7 +198,15 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
var msg string
|
||||
if resp.RawBody() != nil {
|
||||
bodyBytes, _ := io.ReadAll(resp.RawBody())
|
||||
msg = string(bodyBytes)
|
||||
}
|
||||
if msg == "" {
|
||||
msg = resp.String()
|
||||
}
|
||||
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -125,7 +126,13 @@ func (p *GeminiProvider) ImageGeneration(ctx context.Context, req *models.ImageG
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Gemini Imagen API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Gemini Imagen API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
// Parse Imagen response
|
||||
@@ -331,6 +338,7 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
}
|
||||
|
||||
// Map Tools
|
||||
hasMappedTools := false
|
||||
if len(req.Tools) > 0 {
|
||||
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
|
||||
for _, t := range req.Tools {
|
||||
@@ -338,13 +346,16 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
|
||||
}
|
||||
}
|
||||
if len(geminiTool.FunctionDeclarations) > 0 {
|
||||
body.Tools = []GeminiTool{geminiTool}
|
||||
hasMappedTools = true
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := p.config.BaseURL
|
||||
lowerModel := strings.ToLower(req.Model)
|
||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
|
||||
// Use v1beta for preview and newer models
|
||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools {
|
||||
// Use v1beta for preview, newer models, or when using tools
|
||||
if !strings.Contains(baseURL, "v1beta") {
|
||||
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
|
||||
}
|
||||
@@ -363,11 +374,17 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
fmt.Printf("[Gemini] API Error %d: %s\n", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
fmt.Printf("[Gemini] API Error %d: %s\n", resp.StatusCode(), msg)
|
||||
// Also log the request body for debugging (careful with API keys if logged elsewhere)
|
||||
reqJSON, _ := json.Marshal(body)
|
||||
fmt.Printf("[Gemini] Request Body: %s\n", string(reqJSON))
|
||||
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
// Parse Gemini response and convert to OpenAI format
|
||||
@@ -565,6 +582,7 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
||||
GenerationConfig: genConfig,
|
||||
}
|
||||
|
||||
hasMappedTools := false
|
||||
if len(req.Tools) > 0 {
|
||||
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
|
||||
for _, t := range req.Tools {
|
||||
@@ -572,13 +590,16 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
||||
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
|
||||
}
|
||||
}
|
||||
if len(geminiTool.FunctionDeclarations) > 0 {
|
||||
body.Tools = []GeminiTool{geminiTool}
|
||||
hasMappedTools = true
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := p.config.BaseURL
|
||||
lowerModel := strings.ToLower(req.Model)
|
||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
|
||||
// Use v1beta for preview and newer models
|
||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools {
|
||||
// Use v1beta for preview, newer models, or when using tools
|
||||
if !strings.Contains(baseURL, "v1beta") {
|
||||
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
|
||||
}
|
||||
@@ -599,19 +620,19 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamGemini(resp.RawBody(), ch, req.Model)
|
||||
ch, err := StreamGemini(resp.RawBody(), req.Model)
|
||||
if err != nil {
|
||||
fmt.Printf("Gemini Stream error: %v\n", err)
|
||||
return nil, fmt.Errorf("gemini stream init error: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
@@ -48,7 +49,13 @@ func (p *GrokProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRe
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
@@ -79,7 +86,13 @@ func (p *GrokProvider) ChatCompletionStream(ctx context.Context, req *models.Uni
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
+205
-42
@@ -10,11 +10,32 @@ import (
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
func sanitizeFunctionName(name string) string {
|
||||
var sb strings.Builder
|
||||
for _, ch := range name {
|
||||
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' {
|
||||
sb.WriteRune(ch)
|
||||
} else {
|
||||
sb.WriteRune('_')
|
||||
}
|
||||
}
|
||||
res := sb.String()
|
||||
if res == "" {
|
||||
return "function"
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images.
|
||||
func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) {
|
||||
var result []interface{}
|
||||
for _, m := range messages {
|
||||
if m.Role == "tool" {
|
||||
role := strings.ToLower(m.Role)
|
||||
if role == "model" {
|
||||
role = "assistant"
|
||||
}
|
||||
|
||||
if role == "tool" || role == "function" {
|
||||
text := ""
|
||||
if len(m.Content) > 0 {
|
||||
text = m.Content[0].Text
|
||||
@@ -23,15 +44,14 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
|
||||
"role": "tool",
|
||||
"content": text,
|
||||
}
|
||||
id := "unknown"
|
||||
if m.ToolCallID != nil {
|
||||
id := *m.ToolCallID
|
||||
if len(id) > 40 {
|
||||
id = id[:40]
|
||||
id = *m.ToolCallID
|
||||
}
|
||||
msg["tool_call_id"] = id
|
||||
}
|
||||
|
||||
if m.Name != nil {
|
||||
msg["name"] = *m.Name
|
||||
msg["name"] = sanitizeFunctionName(*m.Name)
|
||||
}
|
||||
result = append(result, msg)
|
||||
continue
|
||||
@@ -59,7 +79,9 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
|
||||
}
|
||||
|
||||
var finalContent interface{}
|
||||
if len(parts) == 1 {
|
||||
if len(parts) == 0 {
|
||||
finalContent = nil
|
||||
} else if len(parts) == 1 {
|
||||
if p, ok := parts[0].(map[string]interface{}); ok && p["type"] == "text" {
|
||||
finalContent = p["text"]
|
||||
} else {
|
||||
@@ -70,7 +92,7 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
|
||||
}
|
||||
|
||||
msg := map[string]interface{}{
|
||||
"role": m.Role,
|
||||
"role": role,
|
||||
"content": finalContent,
|
||||
}
|
||||
|
||||
@@ -82,20 +104,18 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
|
||||
sanitizedCalls := make([]models.ToolCall, len(m.ToolCalls))
|
||||
copy(sanitizedCalls, m.ToolCalls)
|
||||
for i := range sanitizedCalls {
|
||||
if len(sanitizedCalls[i].ID) > 40 {
|
||||
sanitizedCalls[i].ID = sanitizedCalls[i].ID[:40]
|
||||
if sanitizedCalls[i].Type == "" {
|
||||
sanitizedCalls[i].Type = "function"
|
||||
}
|
||||
sanitizedCalls[i].Function.Name = sanitizeFunctionName(sanitizedCalls[i].Function.Name)
|
||||
}
|
||||
msg["tool_calls"] = sanitizedCalls
|
||||
if len(parts) == 0 {
|
||||
msg["content"] = ""
|
||||
}
|
||||
msg["content"] = "" // OpenAI requirement: content must be string if tool_calls present
|
||||
}
|
||||
|
||||
if m.Name != nil {
|
||||
msg["name"] = *m.Name
|
||||
}
|
||||
|
||||
result = append(result, msg)
|
||||
}
|
||||
return result, nil
|
||||
@@ -121,11 +141,25 @@ func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{},
|
||||
body["max_tokens"] = *request.MaxTokens
|
||||
}
|
||||
if len(request.Tools) > 0 {
|
||||
body["tools"] = request.Tools
|
||||
sanitizedTools := make([]models.Tool, len(request.Tools))
|
||||
copy(sanitizedTools, request.Tools)
|
||||
for i := range sanitizedTools {
|
||||
if sanitizedTools[i].Type == "function" {
|
||||
sanitizedTools[i].Function.Name = sanitizeFunctionName(sanitizedTools[i].Function.Name)
|
||||
}
|
||||
}
|
||||
body["tools"] = sanitizedTools
|
||||
}
|
||||
if request.ToolChoice != nil {
|
||||
var toolChoice interface{}
|
||||
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
|
||||
if tcMap, ok := toolChoice.(map[string]interface{}); ok {
|
||||
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
|
||||
if name, ok := funcMap["name"].(string); ok {
|
||||
funcMap["name"] = sanitizeFunctionName(name)
|
||||
}
|
||||
}
|
||||
}
|
||||
body["tool_choice"] = toolChoice
|
||||
}
|
||||
}
|
||||
@@ -361,18 +395,8 @@ func StreamOpenAI(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error {
|
||||
defer ctx.Close()
|
||||
|
||||
dec := json.NewDecoder(ctx)
|
||||
|
||||
t, err := dec.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if delim, ok := t.(json.Delim); ok && delim == '[' {
|
||||
for dec.More() {
|
||||
var geminiChunk struct {
|
||||
// geminiStreamChunk is the shared data structure for parsing Gemini streaming responses.
|
||||
type geminiStreamChunk struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
@@ -390,15 +414,18 @@ func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
if err := dec.Decode(&geminiChunk); err != nil {
|
||||
return err
|
||||
// emitGeminiChunk builds a ChatCompletionStreamResponse from a parsed geminiStreamChunk
|
||||
// and sends it to the channel. Returns true if anything was emitted.
|
||||
func emitGeminiChunk(ch chan<- *models.ChatCompletionStreamResponse, chunk *geminiStreamChunk, model string) bool {
|
||||
if len(chunk.Candidates) == 0 && chunk.UsageMetadata.TotalTokenCount == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(geminiChunk.Candidates) > 0 || geminiChunk.UsageMetadata.TotalTokenCount > 0 {
|
||||
content := ""
|
||||
var reasoning *string
|
||||
if len(geminiChunk.Candidates) > 0 {
|
||||
for _, p := range geminiChunk.Candidates[0].Content.Parts {
|
||||
var finishReason *string
|
||||
if len(chunk.Candidates) > 0 {
|
||||
for _, p := range chunk.Candidates[0].Content.Parts {
|
||||
if p.Text != "" {
|
||||
content += p.Text
|
||||
}
|
||||
@@ -409,11 +436,7 @@ func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
|
||||
*reasoning += p.Thought
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var finishReason *string
|
||||
if len(geminiChunk.Candidates) > 0 {
|
||||
fr := strings.ToLower(geminiChunk.Candidates[0].FinishReason)
|
||||
fr := strings.ToLower(chunk.Candidates[0].FinishReason)
|
||||
finishReason = &fr
|
||||
}
|
||||
|
||||
@@ -433,15 +456,155 @@ func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
|
||||
},
|
||||
},
|
||||
Usage: &models.Usage{
|
||||
PromptTokens: geminiChunk.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiChunk.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiChunk.UsageMetadata.TotalTokenCount,
|
||||
CacheReadTokens: uint32Ptr(geminiChunk.UsageMetadata.CachedContentTokenCount),
|
||||
PromptTokens: chunk.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: chunk.UsageMetadata.TotalTokenCount,
|
||||
CacheReadTokens: uint32Ptr(chunk.UsageMetadata.CachedContentTokenCount),
|
||||
},
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// StreamGemini handles Gemini streaming responses in two formats:
|
||||
// 1. SSE format (newer models): each line is "data: {...}"
|
||||
// 2. JSON array format (older models): response body is [ {...}, {...} ]
|
||||
//
|
||||
// Usage metadata is only present in the final chunk, which we accumulate
|
||||
// and emit so the server can log it on stream end.
|
||||
func StreamGemini(ctx io.ReadCloser, model string) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
_ = ctx.Close()
|
||||
}()
|
||||
defer close(ch)
|
||||
|
||||
// Peek at the first byte to detect format
|
||||
peek := make([]byte, 6)
|
||||
n, _ := io.ReadAtLeast(ctx, peek, 1)
|
||||
if n == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
first := string(peek[:n])
|
||||
|
||||
if first[0] == '[' {
|
||||
// JSON array format
|
||||
rest, _ := io.ReadAll(ctx)
|
||||
streamGeminiJSONArray(append([]byte(first), rest...), ch, model)
|
||||
return
|
||||
} else if strings.HasPrefix(first, "data:") || strings.HasPrefix(first, "data: ") {
|
||||
// SSE format — pre-pend the peeked bytes then run SSE scanner
|
||||
combined := io.MultiReader(
|
||||
strings.NewReader(string(peek[:n])),
|
||||
ctx,
|
||||
)
|
||||
streamGeminiSSE(combined, ch, model)
|
||||
} else {
|
||||
// Unknown format — might still be SSE starting after a peek char
|
||||
// Pre-pend peeked bytes and try SSE
|
||||
combined := io.MultiReader(
|
||||
strings.NewReader(string(peek[:n])),
|
||||
ctx,
|
||||
)
|
||||
streamGeminiSSE(combined, ch, model)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// readAll reads remaining bytes from a reader (keeps the function signature simple
|
||||
// for the JSON array fallback path).
|
||||
func readAll(r io.Reader) []byte {
|
||||
b, _ := io.ReadAll(r)
|
||||
return b
|
||||
}
|
||||
|
||||
func streamGeminiJSONArray(data []byte, ch chan<- *models.ChatCompletionStreamResponse, model string) {
|
||||
var chunks []geminiStreamChunk
|
||||
if err := json.Unmarshal(data, &chunks); err != nil {
|
||||
fmt.Printf("[Gemini-Stream] JSON array parse error: %v\n", err)
|
||||
return
|
||||
}
|
||||
// Track the last chunk with usage for the final emission
|
||||
var lastUsage *geminiStreamChunk
|
||||
for i := range chunks {
|
||||
if chunks[i].UsageMetadata.TotalTokenCount > 0 {
|
||||
lastUsage = &chunks[i]
|
||||
}
|
||||
}
|
||||
if lastUsage != nil {
|
||||
// Emit a synthetic final chunk with usage data
|
||||
if len(lastUsage.Candidates) == 0 && lastUsage.UsageMetadata.TotalTokenCount > 0 {
|
||||
emitGeminiChunk(ch, lastUsage, model)
|
||||
}
|
||||
}
|
||||
// Also emit each content-bearing chunk
|
||||
for i := range chunks {
|
||||
emitGeminiChunk(ch, &chunks[i], model)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
func streamGeminiSSE(r io.Reader, ch chan<- *models.ChatCompletionStreamResponse, model string) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
// Track the last seen usage for emission at end of stream
|
||||
var lastUsage geminiStreamChunk
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
// Emit final usage if we have one
|
||||
if lastUsage.UsageMetadata.TotalTokenCount > 0 {
|
||||
emitGeminiChunk(ch, &lastUsage, model)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
var chunk geminiStreamChunk
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Capture usage from any chunk (Gemini puts it in the final response)
|
||||
if chunk.UsageMetadata.TotalTokenCount > 0 {
|
||||
lastUsage = chunk
|
||||
}
|
||||
|
||||
// Emit content chunks as they arrive
|
||||
if len(chunk.Candidates) > 0 {
|
||||
emitGeminiChunk(ch, &chunk, model)
|
||||
}
|
||||
}
|
||||
|
||||
// If stream ended without [DONE] marker but we collected usage, emit it
|
||||
if lastUsage.UsageMetadata.TotalTokenCount > 0 {
|
||||
emitGeminiChunk(ch, &lastUsage, model)
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
fmt.Printf("[Gemini-Stream] SSE scan error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ensureEnglish injects a system message instructing the model to respond in
|
||||
// English when no system prompt is already present. Some providers (e.g. DeepSeek)
|
||||
// default to Chinese for certain prompts.
|
||||
func ensureEnglish(req *models.UnifiedRequest) {
|
||||
if len(req.Messages) > 0 && req.Messages[0].Role == "system" {
|
||||
return // already has a system prompt, don't interfere
|
||||
}
|
||||
enMsg := models.UnifiedMessage{
|
||||
Role: "system",
|
||||
Content: []models.UnifiedContentPart{
|
||||
{Type: "text", Text: "You are a helpful assistant. Always respond in English."},
|
||||
},
|
||||
}
|
||||
req.Messages = append([]models.UnifiedMessage{enMsg}, req.Messages...)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
func TestSanitizeFunctionName(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"google-search", "google-search"},
|
||||
{"google.search", "google_search"},
|
||||
{"google search", "google_search"},
|
||||
{"web_search(query)", "web_search_query_"},
|
||||
{"", "function"},
|
||||
{"123_abc-XYZ", "123_abc-XYZ"},
|
||||
{"invalid.name.with.dots", "invalid_name_with_dots"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
actual := sanitizeFunctionName(tc.input)
|
||||
if actual != tc.expected {
|
||||
t.Errorf("sanitizeFunctionName(%q) = %q; expected %q", tc.input, actual, tc.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesToOpenAIJSON_SanitizeToolCalls(t *testing.T) {
|
||||
messages := []models.UnifiedMessage{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []models.UnifiedContentPart{
|
||||
{Type: "text", Text: "I will use search."},
|
||||
},
|
||||
ToolCalls: []models.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: models.FunctionCall{
|
||||
Name: "google.search",
|
||||
Arguments: `{"query": "hello"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: []models.UnifiedContentPart{
|
||||
{Type: "text", Text: `{"result": "success"}`},
|
||||
},
|
||||
ToolCallID: stringPtr("call_1"),
|
||||
Name: stringPtr("google.search"),
|
||||
},
|
||||
}
|
||||
|
||||
res, err := MessagesToOpenAIJSON(messages)
|
||||
if err != nil {
|
||||
t.Fatalf("MessagesToOpenAIJSON failed: %v", err)
|
||||
}
|
||||
|
||||
if len(res) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(res))
|
||||
}
|
||||
|
||||
// Verify assistant message
|
||||
msg1 := res[0].(map[string]interface{})
|
||||
if msg1["role"] != "assistant" {
|
||||
t.Errorf("expected role assistant, got %v", msg1["role"])
|
||||
}
|
||||
calls := msg1["tool_calls"].([]models.ToolCall)
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "google_search" {
|
||||
t.Errorf("expected function name google_search, got %q", calls[0].Function.Name)
|
||||
}
|
||||
|
||||
// Verify tool response message
|
||||
msg2 := res[1].(map[string]interface{})
|
||||
if msg2["role"] != "tool" {
|
||||
t.Errorf("expected role tool, got %v", msg2["role"])
|
||||
}
|
||||
if msg2["name"] != "google_search" {
|
||||
t.Errorf("expected tool name google_search, got %v", msg2["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIBody_SanitizeToolsAndChoice(t *testing.T) {
|
||||
req := &models.UnifiedRequest{
|
||||
Model: "gpt-4o",
|
||||
Tools: []models.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: models.FunctionDef{
|
||||
Name: "google.search",
|
||||
},
|
||||
},
|
||||
},
|
||||
ToolChoice: json.RawMessage(`{"type": "function", "function": {"name": "google.search"}}`),
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, nil, false)
|
||||
|
||||
// Verify tools
|
||||
tools := body["tools"].([]models.Tool)
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(tools))
|
||||
}
|
||||
if tools[0].Function.Name != "google_search" {
|
||||
t.Errorf("expected tool function name google_search, got %q", tools[0].Function.Name)
|
||||
}
|
||||
|
||||
// Verify tool_choice
|
||||
toolChoice := body["tool_choice"].(map[string]interface{})
|
||||
funcObj := toolChoice["function"].(map[string]interface{})
|
||||
if funcObj["name"] != "google_search" {
|
||||
t.Errorf("expected tool_choice function name google_search, got %q", funcObj["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -59,7 +60,13 @@ func (p *MoonshotProvider) ChatCompletion(ctx context.Context, req *models.Unifi
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
@@ -100,7 +107,13 @@ func (p *MoonshotProvider) ChatCompletionStream(ctx context.Context, req *models
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
@@ -56,7 +56,13 @@ func (p *OllamaProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
@@ -86,7 +92,13 @@ func (p *OllamaProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -38,6 +40,17 @@ func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||
|
||||
// Debug message sequence
|
||||
for i, m := range messagesJSON {
|
||||
mMap, _ := m.(map[string]interface{})
|
||||
role, _ := mMap["role"].(string)
|
||||
hasToolCalls := false
|
||||
if tc, ok := mMap["tool_calls"]; ok && tc != nil {
|
||||
hasToolCalls = true
|
||||
}
|
||||
log.Printf("[DEBUG] OpenAI Msg[%d]: role=%s, hasToolCalls=%v", i, role, hasToolCalls)
|
||||
}
|
||||
|
||||
// Transition: Newer models require max_completion_tokens
|
||||
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
|
||||
if maxTokens, ok := body["max_tokens"]; ok {
|
||||
@@ -57,7 +70,17 @@ func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if b := resp.Body(); len(b) > 0 {
|
||||
msg = string(b)
|
||||
}
|
||||
}
|
||||
// Log the request body for debugging
|
||||
reqJSON, _ := json.Marshal(body)
|
||||
log.Printf("OpenAI API Error (%d): %s", resp.StatusCode(), msg)
|
||||
log.Printf("OpenAI request body: %s", string(reqJSON))
|
||||
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
@@ -104,7 +127,13 @@ func (p *OpenAIProvider) ImageGeneration(ctx context.Context, req *models.ImageG
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("OpenAI image API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("OpenAI image API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
var result models.ImageGenerationResponse
|
||||
@@ -123,6 +152,17 @@ func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||
|
||||
// Debug message sequence
|
||||
for i, m := range messagesJSON {
|
||||
mMap, _ := m.(map[string]interface{})
|
||||
role, _ := mMap["role"].(string)
|
||||
hasToolCalls := false
|
||||
if tc, ok := mMap["tool_calls"]; ok && tc != nil {
|
||||
hasToolCalls = true
|
||||
}
|
||||
log.Printf("[DEBUG] OpenAI Stream Msg[%d]: role=%s, hasToolCalls=%v", i, role, hasToolCalls)
|
||||
}
|
||||
|
||||
// Transition: Newer models require max_completion_tokens
|
||||
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
|
||||
if maxTokens, ok := body["max_tokens"]; ok {
|
||||
@@ -143,7 +183,21 @@ func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if b := resp.Body(); len(b) > 0 {
|
||||
msg = string(b)
|
||||
}
|
||||
if msg == "" {
|
||||
if b, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(b)
|
||||
}
|
||||
}
|
||||
}
|
||||
reqJSON, _ := json.Marshal(body)
|
||||
log.Printf("OpenAI API Error (%d): %s", resp.StatusCode(), msg)
|
||||
log.Printf("OpenAI request body: %s", string(reqJSON))
|
||||
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
@@ -26,7 +27,13 @@ func (p *OpenAIProvider) Responses(ctx context.Context, req *models.ResponsesReq
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("OpenAI Responses API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("OpenAI Responses API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
@@ -53,7 +60,13 @@ func (p *OpenAIProvider) ResponsesStream(ctx context.Context, req *models.Respon
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("OpenAI Responses API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("OpenAI Responses API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
ch := make(chan *models.ResponsesStreamChunk)
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
type XiaomiProvider struct {
|
||||
client *resty.Client
|
||||
config config.XiaomiConfig
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewXiaomiProvider(cfg config.XiaomiConfig, apiKey string) *XiaomiProvider {
|
||||
return &XiaomiProvider{
|
||||
client: resty.New().SetTimeout(10 * time.Minute),
|
||||
config: cfg,
|
||||
apiKey: strings.TrimSpace(apiKey),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *XiaomiProvider) Name() string {
|
||||
return "xiaomi"
|
||||
}
|
||||
|
||||
func (p *XiaomiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||
|
||||
baseURL := strings.TrimRight(p.config.BaseURL, "/")
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("Accept", "application/json").
|
||||
SetBody(body).
|
||||
Post(fmt.Sprintf("%s/chat/completions", baseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if b := resp.Body(); len(b) > 0 {
|
||||
msg = string(b)
|
||||
}
|
||||
}
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Xiaomi API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return ParseOpenAIResponse(respJSON, req.Model)
|
||||
}
|
||||
|
||||
func (p *XiaomiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||
|
||||
baseURL := strings.TrimRight(p.config.BaseURL, "/")
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("Accept", "text/event-stream").
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(fmt.Sprintf("%s/chat/completions", baseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Xiaomi API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
if err := StreamOpenAI(resp.RawBody(), ch); err != nil {
|
||||
fmt.Printf("Xiaomi Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (p *XiaomiProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
|
||||
return nil, fmt.Errorf("xiaomi does not support image generation")
|
||||
}
|
||||
|
||||
func (p *XiaomiProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by xiaomi")
|
||||
}
|
||||
|
||||
func (p *XiaomiProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by xiaomi")
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
const classifierSystemPrompt = `You are a task complexity classifier. Rate the following user message on a scale of 1 to %d, where:
|
||||
1 = trivial/simple (basic facts, greetings, simple math)
|
||||
%d = highly complex (multi-step reasoning, code generation, architecture design)
|
||||
|
||||
Reply with ONLY the number. No explanation.`
|
||||
|
||||
func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
|
||||
// Determine the rating scale
|
||||
maxRating := len(targets)
|
||||
if maxRating < 2 {
|
||||
maxRating = 2
|
||||
}
|
||||
|
||||
// When complexity_threshold is set, use it as a wider scale (e.g., 1-10)
|
||||
// and map ratings proportionally to target buckets.
|
||||
bucketMode := group.ComplexityThreshold != nil && *group.ComplexityThreshold > 0
|
||||
if bucketMode {
|
||||
maxRating = *group.ComplexityThreshold
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating)
|
||||
userMsg := ""
|
||||
if routeCtx != nil {
|
||||
userMsg = routeCtx.UserMessage
|
||||
}
|
||||
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMsg)
|
||||
if err != nil {
|
||||
// Classifier failed — fall back to heuristic
|
||||
return routeHeuristic(group, targets, routeCtx)
|
||||
}
|
||||
|
||||
rating, err := strconv.Atoi(strings.TrimSpace(ratingStr))
|
||||
if err != nil || rating < 1 {
|
||||
rating = 1
|
||||
}
|
||||
if rating > maxRating {
|
||||
rating = maxRating
|
||||
}
|
||||
|
||||
var idx int
|
||||
if bucketMode {
|
||||
// Proportional mapping: wider scale → N target buckets
|
||||
// e.g., threshold=10, 3 targets: 1-3→0, 4-7→1, 8-10→2
|
||||
idx = rating * len(targets) / (maxRating + 1)
|
||||
if idx >= len(targets) {
|
||||
idx = len(targets) - 1
|
||||
}
|
||||
} else {
|
||||
idx = rating - 1 // 1:1 mapping
|
||||
}
|
||||
|
||||
return &Decision{
|
||||
SelectedModel: targets[idx],
|
||||
Strategy: "classifier",
|
||||
Reason: fmt.Sprintf("complexity rating: %d/%d", rating, maxRating),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getSelectorModel(group db.ModelGroup, targets []string) string {
|
||||
if group.SelectorModel != nil && *group.SelectorModel != "" {
|
||||
return *group.SelectorModel
|
||||
}
|
||||
// Default: use the first (cheapest) target model as the selector
|
||||
return targets[0]
|
||||
}
|
||||
@@ -0,0 +1,219 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
// HeuristicRule defines a pattern-based routing rule (legacy format).
|
||||
type HeuristicRule struct {
|
||||
Pattern string `json:"pattern"`
|
||||
TargetIdx int `json:"target"`
|
||||
CaseSensitive bool `json:"case_sensitive,omitempty"`
|
||||
}
|
||||
|
||||
// ConditionRule defines a condition-based routing rule (new format).
|
||||
type ConditionRule struct {
|
||||
RuleID string `json:"rule_id"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Conditions Conditions `json:"conditions"`
|
||||
PrimaryModel string `json:"primary_model"`
|
||||
FallbackModel string `json:"fallback_model,omitempty"`
|
||||
}
|
||||
|
||||
// Conditions defines the matching parameters for a rule.
|
||||
type Conditions struct {
|
||||
AnyOfTags []string `json:"any_of_tags,omitempty"`
|
||||
MaxInputTokensLt *int `json:"max_input_tokens_lt,omitempty"`
|
||||
RequiresReasoning *bool `json:"requires_reasoning,omitempty"`
|
||||
RequiresToolCalling *bool `json:"requires_tool_calling,omitempty"`
|
||||
HasMultimodalInput *bool `json:"has_multimodal_input,omitempty"`
|
||||
IsDefaultFallback *bool `json:"is_default_fallback,omitempty"`
|
||||
}
|
||||
|
||||
func routeHeuristic(group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
|
||||
if routeCtx == nil {
|
||||
routeCtx = &RouteContext{}
|
||||
}
|
||||
|
||||
selected := targets[0]
|
||||
reason := "default (first target)"
|
||||
|
||||
// If heuristic_rules is set, determine format and parse
|
||||
if group.HeuristicRules != nil && *group.HeuristicRules != "" {
|
||||
rulesJSON := *group.HeuristicRules
|
||||
|
||||
if isConditionBasedRules(rulesJSON) {
|
||||
var condRules []ConditionRule
|
||||
if err := json.Unmarshal([]byte(rulesJSON), &condRules); err == nil {
|
||||
for _, rule := range condRules {
|
||||
if matchConditions(rule.Conditions, routeCtx) {
|
||||
// Resolve primary/fallback to concrete models in target list
|
||||
targetModel := ""
|
||||
if rule.PrimaryModel != "" {
|
||||
targetModel = getModelInTargets(rule.PrimaryModel, targets)
|
||||
}
|
||||
if targetModel == "" && rule.FallbackModel != "" {
|
||||
targetModel = getModelInTargets(rule.FallbackModel, targets)
|
||||
}
|
||||
|
||||
if targetModel != "" {
|
||||
selected = targetModel
|
||||
reason = "matched condition rule: " + rule.RuleID
|
||||
if rule.Description != "" {
|
||||
reason += " (" + rule.Description + ")"
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback to legacy pattern-based rules
|
||||
var legacyRules []HeuristicRule
|
||||
if err := json.Unmarshal([]byte(rulesJSON), &legacyRules); err == nil {
|
||||
searchMsg := routeCtx.UserMessage
|
||||
for _, rule := range legacyRules {
|
||||
pattern := rule.Pattern
|
||||
if pattern == "" {
|
||||
continue // Avoid infinite matches with empty patterns
|
||||
}
|
||||
msg := searchMsg
|
||||
if !rule.CaseSensitive {
|
||||
pattern = strings.ToLower(pattern)
|
||||
msg = strings.ToLower(msg)
|
||||
}
|
||||
|
||||
// Support both regex matching (if pattern is valid regex) and literal contains
|
||||
matched := false
|
||||
if strings.Contains(rule.Pattern, "(") || strings.Contains(rule.Pattern, "\\b") {
|
||||
var re *regexp.Regexp
|
||||
var err error
|
||||
if !rule.CaseSensitive {
|
||||
re, err = regexp.Compile("(?i)" + rule.Pattern)
|
||||
} else {
|
||||
re, err = regexp.Compile(rule.Pattern)
|
||||
}
|
||||
if err == nil {
|
||||
matched = re.MatchString(routeCtx.UserMessage)
|
||||
}
|
||||
}
|
||||
|
||||
if !matched && strings.Contains(msg, pattern) {
|
||||
matched = true
|
||||
}
|
||||
|
||||
if matched {
|
||||
if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) {
|
||||
selected = targets[rule.TargetIdx]
|
||||
reason = "matched heuristic rule: " + rule.Pattern
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Built-in fallback heuristics (if no custom rule matched)
|
||||
if reason == "default (first target)" && len(targets) > 1 {
|
||||
msgLower := strings.ToLower(routeCtx.UserMessage)
|
||||
complexIndicators := []string{
|
||||
"step by step", "explain in detail", "reason through",
|
||||
"think carefully", "analyze", "debug", "write code",
|
||||
"implement", "refactor", "architecture",
|
||||
}
|
||||
for _, indicator := range complexIndicators {
|
||||
if strings.Contains(msgLower, indicator) {
|
||||
selected = targets[len(targets)-1]
|
||||
reason = "complex task indicator: " + indicator
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Decision{
|
||||
SelectedModel: selected,
|
||||
Strategy: "heuristic",
|
||||
Reason: reason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// isConditionBasedRules returns true if the JSON represents condition-based rules.
|
||||
func isConditionBasedRules(rulesJSON string) bool {
|
||||
var rules []ConditionRule
|
||||
if err := json.Unmarshal([]byte(rulesJSON), &rules); err == nil && len(rules) > 0 {
|
||||
// If the rule has either conditions or primary_model/rule_id, treat it as condition-based
|
||||
return rules[0].PrimaryModel != "" || rules[0].RuleID != ""
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchConditions evaluates whether the given conditions match the RouteContext.
|
||||
func matchConditions(cond Conditions, routeCtx *RouteContext) bool {
|
||||
if cond.IsDefaultFallback != nil && *cond.IsDefaultFallback {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check tags: must match any_of_tags if specified
|
||||
if len(cond.AnyOfTags) > 0 {
|
||||
tagMatched := false
|
||||
for _, ruleTag := range cond.AnyOfTags {
|
||||
for _, ctxTag := range routeCtx.Tags {
|
||||
if strings.EqualFold(ruleTag, ctxTag) {
|
||||
tagMatched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if tagMatched {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !tagMatched {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check max input tokens
|
||||
if cond.MaxInputTokensLt != nil {
|
||||
if routeCtx.InputTokens >= *cond.MaxInputTokensLt {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check reasoning flag
|
||||
if cond.RequiresReasoning != nil {
|
||||
if routeCtx.RequiresReasoning != *cond.RequiresReasoning {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check tool calling flag
|
||||
if cond.RequiresToolCalling != nil {
|
||||
if routeCtx.RequiresToolCalling != *cond.RequiresToolCalling {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check multimodal flag
|
||||
if cond.HasMultimodalInput != nil {
|
||||
if routeCtx.HasMultimodalInput != *cond.HasMultimodalInput {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// getModelInTargets returns the model name if it exists in targets, or empty string.
|
||||
func getModelInTargets(modelName string, targets []string) string {
|
||||
for _, t := range targets {
|
||||
if strings.EqualFold(t, modelName) {
|
||||
return t
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
func TestRouteHeuristic_ConditionRules(t *testing.T) {
|
||||
targets := []string{
|
||||
"deepseek-v4-flash", // index 0
|
||||
"gemini-3-flash", // index 1
|
||||
"grok-build-0.1", // index 2
|
||||
"kimi-k2.6", // index 3
|
||||
"mimo-v2.5-pro", // index 4
|
||||
"grok-4.3", // index 5
|
||||
"deepseek-v4-pro", // index 6
|
||||
}
|
||||
|
||||
rulesJSON := `[
|
||||
{
|
||||
"rule_id": "fast_flow_extraction",
|
||||
"conditions": {
|
||||
"any_of_tags": ["fast-flow", "classification"],
|
||||
"max_input_tokens_lt": 8000,
|
||||
"requires_reasoning": false
|
||||
},
|
||||
"primary_model": "deepseek-v4-flash",
|
||||
"fallback_model": "grok-build-0.1"
|
||||
},
|
||||
{
|
||||
"rule_id": "multimodal_long_context",
|
||||
"conditions": {
|
||||
"any_of_tags": ["standard-pro", "long-doc"],
|
||||
"has_multimodal_input": true
|
||||
},
|
||||
"primary_model": "gemini-3-flash",
|
||||
"fallback_model": "mimo-v2.5-pro"
|
||||
},
|
||||
{
|
||||
"rule_id": "regional_fallback_general",
|
||||
"conditions": {
|
||||
"is_default_fallback": true
|
||||
},
|
||||
"primary_model": "kimi-k2.6"
|
||||
}
|
||||
]`
|
||||
|
||||
group := db.ModelGroup{
|
||||
ID: "dustins_stack",
|
||||
Strategy: "heuristic",
|
||||
HeuristicRules: &rulesJSON,
|
||||
}
|
||||
|
||||
// 1. Test Match Fast Flow (condition success)
|
||||
ctx1 := &RouteContext{
|
||||
UserMessage: "classify this JSON",
|
||||
InputTokens: 500,
|
||||
HasMultimodalInput: false,
|
||||
RequiresReasoning: false,
|
||||
Tags: []string{"fast-flow", "classification"},
|
||||
}
|
||||
dec1, err := routeHeuristic(group, targets, ctx1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec1.SelectedModel != "deepseek-v4-flash" {
|
||||
t.Fatalf("expected deepseek-v4-flash, got %s", dec1.SelectedModel)
|
||||
}
|
||||
|
||||
// 2. Test Multimodal Long Context (condition success)
|
||||
ctx2 := &RouteContext{
|
||||
UserMessage: "explain this video",
|
||||
InputTokens: 15000,
|
||||
HasMultimodalInput: true,
|
||||
RequiresReasoning: false,
|
||||
Tags: []string{"standard-pro", "video-analysis"},
|
||||
}
|
||||
dec2, err := routeHeuristic(group, targets, ctx2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec2.SelectedModel != "gemini-3-flash" {
|
||||
t.Fatalf("expected gemini-3-flash, got %s", dec2.SelectedModel)
|
||||
}
|
||||
|
||||
// 3. Test Fallback general rule
|
||||
ctx3 := &RouteContext{
|
||||
UserMessage: "hello there",
|
||||
InputTokens: 100,
|
||||
HasMultimodalInput: false,
|
||||
RequiresReasoning: false,
|
||||
Tags: []string{"general"},
|
||||
}
|
||||
dec3, err := routeHeuristic(group, targets, ctx3)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec3.SelectedModel != "kimi-k2.6" {
|
||||
t.Fatalf("expected kimi-k2.6, got %s", dec3.SelectedModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteHeuristic_LegacyRules(t *testing.T) {
|
||||
targets := []string{"gpt-4o-mini", "deepseek-v4-pro", "kimi-k2.6"}
|
||||
|
||||
// Legacy pattern-based rule with regex
|
||||
rulesJSON := `[
|
||||
{"pattern": "\\b(agent|agents|tool use)\\b", "target": 1},
|
||||
{"pattern": "summarize", "target": 2}
|
||||
]`
|
||||
|
||||
group := db.ModelGroup{
|
||||
ID: "heavy-logic",
|
||||
Strategy: "heuristic",
|
||||
HeuristicRules: &rulesJSON,
|
||||
}
|
||||
|
||||
// 1. Test regex match
|
||||
ctx1 := &RouteContext{
|
||||
UserMessage: "We need an agent to do tool use",
|
||||
}
|
||||
dec1, err := routeHeuristic(group, targets, ctx1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec1.SelectedModel != "deepseek-v4-pro" {
|
||||
t.Fatalf("expected deepseek-v4-pro, got %s", dec1.SelectedModel)
|
||||
}
|
||||
|
||||
// 2. Test literal match
|
||||
ctx2 := &RouteContext{
|
||||
UserMessage: "Please summarize this text",
|
||||
}
|
||||
dec2, err := routeHeuristic(group, targets, ctx2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec2.SelectedModel != "kimi-k2.6" {
|
||||
t.Fatalf("expected kimi-k2.6, got %s", dec2.SelectedModel)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
// Decision holds the result of a routing decision.
|
||||
type Decision struct {
|
||||
SelectedModel string `json:"selected_model"`
|
||||
Strategy string `json:"strategy"` // "heuristic" or "classifier"
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// RouteContext holds metadata of the request to evaluate condition rules.
|
||||
type RouteContext struct {
|
||||
UserMessage string `json:"user_message"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
HasMultimodalInput bool `json:"has_multimodal_input"`
|
||||
RequiresToolCalling bool `json:"requires_tool_calling"`
|
||||
RequiresReasoning bool `json:"requires_reasoning"`
|
||||
Tags []string `json:"tags"`
|
||||
}
|
||||
|
||||
// ClassifierFunc is the callback for classifier-based routing.
|
||||
type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error)
|
||||
|
||||
// Router resolves model groups to concrete models.
|
||||
type Router struct {
|
||||
groups map[string]db.ModelGroup
|
||||
classify ClassifierFunc
|
||||
}
|
||||
|
||||
// New creates a Router. classify may be nil if no classifier groups exist.
|
||||
func New(groups []db.ModelGroup, classify ClassifierFunc) *Router {
|
||||
r := &Router{
|
||||
groups: make(map[string]db.ModelGroup),
|
||||
classify: classify,
|
||||
}
|
||||
for _, g := range groups {
|
||||
r.groups[g.ID] = g
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Groups returns all registered model group IDs.
|
||||
func (r *Router) Groups() []string {
|
||||
ids := make([]string, 0, len(r.groups))
|
||||
for id := range r.groups {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// IsGroup returns true if the model name is a group ID.
|
||||
func (r *Router) IsGroup(modelID string) bool {
|
||||
_, ok := r.groups[modelID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Route resolves a group to a concrete model.
|
||||
func (r *Router) Route(ctx context.Context, groupID string, routeCtx *RouteContext) (*Decision, error) {
|
||||
group, ok := r.groups[groupID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown model group: %s", groupID)
|
||||
}
|
||||
|
||||
var targets []string
|
||||
if err := json.Unmarshal([]byte(group.Targets), &targets); err != nil || len(targets) == 0 {
|
||||
return nil, fmt.Errorf("invalid or empty targets for group %s", groupID)
|
||||
}
|
||||
|
||||
switch group.Strategy {
|
||||
case "heuristic":
|
||||
return routeHeuristic(group, targets, routeCtx)
|
||||
case "classifier":
|
||||
if r.classify == nil {
|
||||
return routeHeuristic(group, targets, routeCtx)
|
||||
}
|
||||
return routeClassifier(ctx, r.classify, group, targets, routeCtx)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown strategy: %s", group.Strategy)
|
||||
}
|
||||
}
|
||||
|
||||
// RouteToConcrete resolves a model name to a concrete model, following group
|
||||
// chains recursively until a non-group target is reached. Returns the original
|
||||
// name unchanged if it is not a group.
|
||||
func (r *Router) RouteToConcrete(ctx context.Context, modelID string, routeCtx *RouteContext) (*Decision, error) {
|
||||
const maxDepth = 10
|
||||
visited := make(map[string]bool)
|
||||
current := modelID
|
||||
var chain []*Decision
|
||||
|
||||
for depth := 0; depth < maxDepth; depth++ {
|
||||
if !r.IsGroup(current) {
|
||||
// Build a composite reason showing the chain traversed
|
||||
reason := "direct"
|
||||
if len(chain) > 0 {
|
||||
parts := make([]string, len(chain))
|
||||
for i, d := range chain {
|
||||
parts[i] = d.SelectedModel + " (" + d.Reason + ")"
|
||||
}
|
||||
reason = strings.Join(parts, " -> ")
|
||||
}
|
||||
return &Decision{
|
||||
SelectedModel: current,
|
||||
Strategy: "hierarchical",
|
||||
Reason: reason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if visited[current] {
|
||||
return nil, fmt.Errorf("routing cycle detected: group %s already visited", current)
|
||||
}
|
||||
visited[current] = true
|
||||
|
||||
decision, err := r.Route(ctx, current, routeCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chain = append(chain, decision)
|
||||
current = decision.SelectedModel
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("routing depth exceeded: reached max depth of %d", maxDepth)
|
||||
}
|
||||
|
||||
// Reload replaces the group definitions without recreating the router.
|
||||
func (r *Router) Reload(groups []db.ModelGroup) {
|
||||
r.groups = make(map[string]db.ModelGroup)
|
||||
for _, g := range groups {
|
||||
r.groups[g.ID] = g
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ type RequestLog struct {
|
||||
ClientID string `json:"client_id"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
ModelGroup string `json:"model_group,omitempty"`
|
||||
PromptTokens uint32 `json:"prompt_tokens"`
|
||||
CompletionTokens uint32 `json:"completion_tokens"`
|
||||
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"gophergate/internal/db"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (s *Server) handleGetModelGroups(c *gin.Context) {
|
||||
var groups []db.ModelGroup
|
||||
if err := s.database.Select(&groups, "SELECT * FROM model_groups ORDER BY id"); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if groups == nil {
|
||||
groups = []db.ModelGroup{}
|
||||
}
|
||||
c.JSON(http.StatusOK, SuccessResponse(groups))
|
||||
}
|
||||
|
||||
func (s *Server) handleCreateModelGroup(c *gin.Context) {
|
||||
var group db.ModelGroup
|
||||
if err := c.ShouldBindJSON(&group); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.database.Exec(`
|
||||
INSERT INTO model_groups (id, strategy, selector_model, targets, complexity_threshold, heuristic_rules, logic_level, primary_use)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
group.ID, group.Strategy, group.SelectorModel, group.Targets,
|
||||
group.ComplexityThreshold, group.HeuristicRules, group.LogicLevel, group.PrimaryUse)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
c.JSON(http.StatusCreated, SuccessResponse(group))
|
||||
}
|
||||
|
||||
func (s *Server) handleUpdateModelGroup(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var group db.ModelGroup
|
||||
if err := c.ShouldBindJSON(&group); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.database.Exec(`
|
||||
UPDATE model_groups SET strategy=?, selector_model=?, targets=?, complexity_threshold=?, heuristic_rules=?, logic_level=?, primary_use=?, updated_at=CURRENT_TIMESTAMP
|
||||
WHERE id=?`,
|
||||
group.Strategy, group.SelectorModel, group.Targets,
|
||||
group.ComplexityThreshold, group.HeuristicRules, group.LogicLevel, group.PrimaryUse, id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
c.JSON(http.StatusOK, SuccessResponse(group))
|
||||
}
|
||||
|
||||
func (s *Server) handleDeleteModelGroup(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
_, err := s.database.Exec("DELETE FROM model_groups WHERE id=?", id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
||||
}
|
||||
@@ -19,6 +19,7 @@ func (s *Server) handleGetModels(c *gin.Context) {
|
||||
"deepseek": "deepseek",
|
||||
"xai": "grok",
|
||||
"ollama": "ollama",
|
||||
"xiaomi": "xiaomi",
|
||||
}
|
||||
|
||||
// Merge registry models with DB overrides
|
||||
|
||||
@@ -25,7 +25,7 @@ func (s *Server) handleGetProviders(c *gin.Context) {
|
||||
dbMap[cfg.ID] = cfg
|
||||
}
|
||||
|
||||
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama"}
|
||||
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama", "xiaomi"}
|
||||
var result []gin.H
|
||||
|
||||
for _, id := range providerIDs {
|
||||
@@ -54,6 +54,10 @@ func (s *Server) handleGetProviders(c *gin.Context) {
|
||||
name = "xAI Grok"
|
||||
enabled = s.cfg.Providers.Grok.Enabled
|
||||
baseURL = s.cfg.Providers.Grok.BaseURL
|
||||
case "xiaomi":
|
||||
name = "Xiaomi MiMo"
|
||||
enabled = s.cfg.Providers.Xiaomi.Enabled
|
||||
baseURL = s.cfg.Providers.Xiaomi.BaseURL
|
||||
case "ollama":
|
||||
name = "Ollama"
|
||||
enabled = s.cfg.Providers.Ollama.Enabled
|
||||
@@ -109,6 +113,9 @@ func (s *Server) handleGetProviders(c *gin.Context) {
|
||||
if id == "grok" {
|
||||
registryID = "xai"
|
||||
}
|
||||
if id == "xiaomi" {
|
||||
registryID = "xiaomi"
|
||||
}
|
||||
|
||||
if pInfo, ok := s.registry.Providers[registryID]; ok {
|
||||
for mID := range pInfo.Models {
|
||||
@@ -226,6 +233,8 @@ func (s *Server) handleTestProvider(c *gin.Context) {
|
||||
testReq.Model = "kimi-k2.5"
|
||||
} else if name == "grok" {
|
||||
testReq.Model = "grok-4-1-fast-non-reasoning"
|
||||
} else if name == "xiaomi" {
|
||||
testReq.Model = "mimo-v2.5"
|
||||
}
|
||||
|
||||
_, err := provider.ChatCompletion(c.Request.Context(), testReq)
|
||||
|
||||
+358
-55
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"gophergate/internal/middleware"
|
||||
"gophergate/internal/models"
|
||||
"gophergate/internal/providers"
|
||||
"gophergate/internal/router"
|
||||
"gophergate/internal/utils"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -30,6 +32,7 @@ type Server struct {
|
||||
logger *RequestLogger
|
||||
registry *models.ModelRegistry
|
||||
registryMu sync.RWMutex
|
||||
modelRouter *router.Router
|
||||
}
|
||||
|
||||
func NewServer(cfg *config.Config, database *db.DB) *Server {
|
||||
@@ -64,6 +67,9 @@ func NewServer(cfg *config.Config, database *db.DB) *Server {
|
||||
}
|
||||
|
||||
s.setupRoutes()
|
||||
|
||||
// Initialize model group router
|
||||
s.refreshRouter()
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -79,7 +85,7 @@ func (s *Server) RefreshProviders() error {
|
||||
dbMap[cfg.ID] = cfg
|
||||
}
|
||||
|
||||
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama"}
|
||||
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama", "xiaomi"}
|
||||
for _, id := range providerIDs {
|
||||
// Default values from config
|
||||
enabled := false
|
||||
@@ -107,6 +113,10 @@ func (s *Server) RefreshProviders() error {
|
||||
enabled = s.cfg.Providers.Grok.Enabled
|
||||
baseURL = s.cfg.Providers.Grok.BaseURL
|
||||
apiKey, _ = s.cfg.GetAPIKey("grok")
|
||||
case "xiaomi":
|
||||
enabled = s.cfg.Providers.Xiaomi.Enabled
|
||||
baseURL = s.cfg.Providers.Xiaomi.BaseURL
|
||||
apiKey, _ = s.cfg.GetAPIKey("xiaomi")
|
||||
}
|
||||
|
||||
// Overrides from DB
|
||||
@@ -161,6 +171,10 @@ func (s *Server) RefreshProviders() error {
|
||||
cfg := s.cfg.Providers.Ollama
|
||||
cfg.BaseURL = baseURL
|
||||
p = providers.NewOllamaProvider(cfg)
|
||||
case "xiaomi":
|
||||
cfg := s.cfg.Providers.Xiaomi
|
||||
cfg.BaseURL = baseURL
|
||||
p = providers.NewXiaomiProvider(cfg, apiKey)
|
||||
}
|
||||
|
||||
if p != nil {
|
||||
@@ -168,9 +182,53 @@ func (s *Server) RefreshProviders() error {
|
||||
}
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) refreshRouter() {
|
||||
var groups []db.ModelGroup
|
||||
if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil {
|
||||
fmt.Printf("Warning: Failed to load model groups: %v\n", err)
|
||||
groups = nil
|
||||
}
|
||||
|
||||
var classifyFn router.ClassifierFunc
|
||||
classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) {
|
||||
provider, _, err := s.selectProvider(selectorModel)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req := &models.UnifiedRequest{
|
||||
Model: selectorModel,
|
||||
Messages: []models.UnifiedMessage{
|
||||
{Role: "system", Content: []models.UnifiedContentPart{{Type: "text", Text: systemPrompt}}},
|
||||
{Role: "user", Content: []models.UnifiedContentPart{{Type: "text", Text: userMessage}}},
|
||||
},
|
||||
MaxTokens: uint32Ptr(5),
|
||||
Stream: false,
|
||||
}
|
||||
resp, err := provider.ChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", fmt.Errorf("no choices in classifier response")
|
||||
}
|
||||
content, ok := resp.Choices[0].Message.Content.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("classifier response content is not a string")
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
if s.modelRouter == nil {
|
||||
s.modelRouter = router.New(groups, classifyFn)
|
||||
} else {
|
||||
s.modelRouter.Reload(groups)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) setupRoutes() {
|
||||
// Static files
|
||||
s.router.StaticFile("/", "./static/index.html")
|
||||
@@ -228,6 +286,11 @@ func (s *Server) setupRoutes() {
|
||||
admin.GET("/models", s.handleGetModels)
|
||||
admin.PUT("/models/:id", s.handleUpdateModel)
|
||||
|
||||
admin.GET("/model-groups", s.handleGetModelGroups)
|
||||
admin.POST("/model-groups", s.handleCreateModelGroup)
|
||||
admin.PUT("/model-groups/:id", s.handleUpdateModelGroup)
|
||||
admin.DELETE("/model-groups/:id", s.handleDeleteModelGroup)
|
||||
|
||||
admin.GET("/users", s.handleGetUsers)
|
||||
admin.POST("/users", s.handleCreateUser)
|
||||
admin.PUT("/users/:id", s.handleUpdateUser)
|
||||
@@ -254,9 +317,33 @@ func (s *Server) handleResponses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Select provider based on model name
|
||||
// Strip common prefixes and resolve model groups to concrete models
|
||||
// (same pattern as handleChatCompletions).
|
||||
modelGroup := ""
|
||||
modelID := req.Model
|
||||
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/", "xiaomi/"}
|
||||
for _, p := range prefixes {
|
||||
if strings.HasPrefix(modelID, p) {
|
||||
modelID = strings.TrimPrefix(modelID, p)
|
||||
break
|
||||
}
|
||||
}
|
||||
if s.modelRouter != nil {
|
||||
routeCtx := s.buildRouteContextFromResponses(req)
|
||||
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
|
||||
return
|
||||
}
|
||||
if decision.SelectedModel != modelID {
|
||||
modelGroup = modelID
|
||||
}
|
||||
modelID = decision.SelectedModel
|
||||
}
|
||||
|
||||
// Select provider based on resolved model name
|
||||
providerName := "openai" // default for Responses API
|
||||
modelLower := strings.ToLower(req.Model)
|
||||
modelLower := strings.ToLower(modelID)
|
||||
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
|
||||
providerName = "gemini"
|
||||
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
|
||||
@@ -284,17 +371,7 @@ func (s *Server) handleResponses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Strip common prefixes from model name
|
||||
modelID := req.Model
|
||||
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"}
|
||||
for _, p := range prefixes {
|
||||
if strings.HasPrefix(modelID, p) {
|
||||
modelID = strings.TrimPrefix(modelID, p)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Use the stripped model name for the actual API call
|
||||
// Use resolved model for the actual API call
|
||||
req.Model = modelID
|
||||
|
||||
clientID := "default"
|
||||
@@ -309,7 +386,7 @@ func (s *Server) handleResponses(c *gin.Context) {
|
||||
if stream {
|
||||
ch, err := provider.ResponsesStream(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, false)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -324,9 +401,9 @@ func (s *Server) handleResponses(c *gin.Context) {
|
||||
if !ok {
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
if lastUsage != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, lastUsage.ToUsage(), nil, false)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage.ToUsage(), nil, false)
|
||||
} else {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, nil, false)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false)
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -346,15 +423,15 @@ func (s *Server) handleResponses(c *gin.Context) {
|
||||
|
||||
resp, err := provider.Responses(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, false)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Usage != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage.ToUsage(), nil, false)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage.ToUsage(), nil, false)
|
||||
} else {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, nil, false)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false)
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
@@ -378,6 +455,7 @@ func (s *Server) handleListModels(c *gin.Context) {
|
||||
"xai": true, // Models from models.dev use 'xai' ID for Grok
|
||||
"llmgateway": true, // Catch-all for newer models
|
||||
"ollama": true,
|
||||
"xiaomi": true, // Xiaomi MiMo models
|
||||
}
|
||||
|
||||
s.registryMu.RLock()
|
||||
@@ -414,6 +492,20 @@ func (s *Server) handleListModels(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Add model groups so clients can discover them
|
||||
if s.modelRouter != nil {
|
||||
for _, gid := range s.modelRouter.Groups() {
|
||||
if _, exists := modelMap[gid]; !exists {
|
||||
modelMap[gid] = OpenAIModel{
|
||||
ID: gid,
|
||||
Object: "model",
|
||||
Created: 1700000000,
|
||||
OwnedBy: "gophergate",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var data []OpenAIModel
|
||||
for _, m := range modelMap {
|
||||
data = append(data, m)
|
||||
@@ -425,21 +517,12 @@ func (s *Server) handleListModels(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||
startTime := time.Now()
|
||||
var req models.ChatCompletionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Select provider based on model name
|
||||
func (s *Server) selectProvider(modelID string) (providers.Provider, string, error) {
|
||||
providerName := "openai" // default
|
||||
modelLower := strings.ToLower(req.Model)
|
||||
modelLower := strings.ToLower(modelID)
|
||||
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
|
||||
providerName = "gemini"
|
||||
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
|
||||
// Only use deepseek provider if it's not explicitly tagged for ollama
|
||||
providerName = "deepseek"
|
||||
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
|
||||
providerName = "moonshot"
|
||||
@@ -456,17 +539,28 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||
strings.Contains(modelLower, "codellama") ||
|
||||
strings.Contains(modelLower, "command-r") {
|
||||
providerName = "ollama"
|
||||
} else if strings.HasPrefix(modelLower, "xiaomi/") || strings.Contains(modelLower, "mimo") || strings.Contains(modelLower, "xiaomi") {
|
||||
providerName = "xiaomi"
|
||||
}
|
||||
|
||||
provider, ok := s.providers[providerName]
|
||||
p, ok := s.providers[providerName]
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
||||
return nil, "", fmt.Errorf("Provider %s not enabled or supported", providerName)
|
||||
}
|
||||
return p, providerName, nil
|
||||
}
|
||||
|
||||
func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||
startTime := time.Now()
|
||||
var req models.ChatCompletionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Strip common prefixes
|
||||
// Strip common prefixes and prepare model ID
|
||||
modelID := req.Model
|
||||
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"}
|
||||
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/", "xiaomi/"}
|
||||
for _, p := range prefixes {
|
||||
if strings.HasPrefix(modelID, p) {
|
||||
modelID = strings.TrimPrefix(modelID, p)
|
||||
@@ -474,6 +568,32 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve model groups to concrete models (hierarchical — groups can target groups)
|
||||
modelGroup := ""
|
||||
for i, m := range req.Messages {
|
||||
log.Printf("[DEBUG] Incoming Msg[%d]: role=%s, hasToolCalls=%v, hasContent=%v", i, m.Role, len(m.ToolCalls) > 0, m.Content != nil)
|
||||
}
|
||||
if s.modelRouter != nil {
|
||||
routeCtx := s.buildRouteContextFromChat(req)
|
||||
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
|
||||
return
|
||||
}
|
||||
if decision.SelectedModel != modelID {
|
||||
modelGroup = modelID
|
||||
}
|
||||
modelID = decision.SelectedModel
|
||||
log.Printf("[ROUTER] %s (%s: %s)", modelID, decision.Strategy, decision.Reason)
|
||||
}
|
||||
|
||||
// Select provider based on the resolved model name
|
||||
provider, providerName, err := s.selectProvider(modelID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Convert ChatCompletionRequest to UnifiedRequest
|
||||
unifiedReq := &models.UnifiedRequest{
|
||||
Model: modelID,
|
||||
@@ -490,26 +610,27 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||
ToolChoice: req.ToolChoice,
|
||||
}
|
||||
|
||||
// Inject max_tokens from model registry when client doesn't specify one.
|
||||
// Prevents providers from applying a low default output cap.
|
||||
// DEBUG: Trace max_tokens through the proxy
|
||||
clientMaxTokens := "nil"
|
||||
if unifiedReq.MaxTokens != nil {
|
||||
clientMaxTokens = fmt.Sprintf("%d", *unifiedReq.MaxTokens)
|
||||
}
|
||||
log.Printf("[DEBUG] %s: client max_tokens=%s", modelID, clientMaxTokens)
|
||||
if unifiedReq.MaxTokens == nil {
|
||||
// Inject or cap max_tokens from model registry.
|
||||
s.registryMu.RLock()
|
||||
meta := s.registry.FindModel(modelID)
|
||||
s.registryMu.RUnlock()
|
||||
|
||||
if meta != nil && meta.Limit != nil && meta.Limit.Output > 0 {
|
||||
if unifiedReq.MaxTokens == nil {
|
||||
unifiedReq.MaxTokens = &meta.Limit.Output
|
||||
log.Printf("[DEBUG] %s: injected registry max_tokens=%d", modelID, meta.Limit.Output)
|
||||
} else if *unifiedReq.MaxTokens > meta.Limit.Output {
|
||||
log.Printf("[DEBUG] %s: capping client max_tokens (%d) to registry limit (%d)", modelID, *unifiedReq.MaxTokens, meta.Limit.Output)
|
||||
unifiedReq.MaxTokens = &meta.Limit.Output
|
||||
} else {
|
||||
log.Printf("[DEBUG] %s: no registry limit found, leaving max_tokens nil (provider default)", modelID)
|
||||
log.Printf("[DEBUG] %s: using client max_tokens (%d)", modelID, *unifiedReq.MaxTokens)
|
||||
}
|
||||
} else {
|
||||
log.Printf("[DEBUG] %s: using client's max_tokens=%d", modelID, *unifiedReq.MaxTokens)
|
||||
if unifiedReq.MaxTokens == nil {
|
||||
log.Printf("[DEBUG] %s: no registry limit found, leaving max_tokens nil", modelID)
|
||||
} else {
|
||||
log.Printf("[DEBUG] %s: using client max_tokens (%d), no registry limit to cap", modelID, *unifiedReq.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Stop sequences
|
||||
@@ -592,7 +713,7 @@ if unifiedReq.MaxTokens == nil {
|
||||
if unifiedReq.Stream {
|
||||
ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, unifiedReq.HasImages)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -606,7 +727,7 @@ if unifiedReq.MaxTokens == nil {
|
||||
chunk, ok := <-ch
|
||||
if !ok {
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, lastUsage, nil, unifiedReq.HasImages)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage, nil, unifiedReq.HasImages)
|
||||
return false
|
||||
}
|
||||
if chunk.Usage != nil {
|
||||
@@ -624,15 +745,29 @@ if unifiedReq.MaxTokens == nil {
|
||||
|
||||
resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, unifiedReq.HasImages)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage, nil, unifiedReq.HasImages)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage, nil, unifiedReq.HasImages)
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func extractUserMessage(messages []models.ChatMessage) string {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == "user" {
|
||||
switch c := messages[i].Content.(type) {
|
||||
case string:
|
||||
return c
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *Server) handleImageGenerations(c *gin.Context) {
|
||||
startTime := time.Now()
|
||||
var req models.ImageGenerationRequest
|
||||
@@ -684,7 +819,7 @@ func (s *Server) handleImageGenerations(c *gin.Context) {
|
||||
|
||||
resp, err := provider.ImageGeneration(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, false)
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -698,7 +833,7 @@ func (s *Server) handleImageGenerations(c *gin.Context) {
|
||||
// Calculate per-image cost (not per-token like chat)
|
||||
cost := imageGenCost(providerName, req.Model, req.Size, uint32(len(resp.Data)))
|
||||
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, &models.Usage{
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, "", &models.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: uint32(len(resp.Data)),
|
||||
TotalTokens: promptTokens + uint32(len(resp.Data)),
|
||||
@@ -740,12 +875,13 @@ func imageGenCost(provider, model string, size *string, n uint32) float64 {
|
||||
return perImage * float64(n)
|
||||
}
|
||||
|
||||
func (s *Server) logRequest(start time.Time, clientID, provider, model string, usage *models.Usage, err error, hasImages bool) {
|
||||
func (s *Server) logRequest(start time.Time, clientID, provider, model, modelGroup string, usage *models.Usage, err error, hasImages bool) {
|
||||
entry := RequestLog{
|
||||
Timestamp: start,
|
||||
ClientID: clientID,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
ModelGroup: modelGroup,
|
||||
Status: "success",
|
||||
DurationMS: time.Since(start).Milliseconds(),
|
||||
HasImages: hasImages,
|
||||
@@ -770,9 +906,14 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model string, u
|
||||
entry.CacheWriteTokens = *usage.CacheWriteTokens
|
||||
}
|
||||
|
||||
// Calculate cost using registry
|
||||
// Calculate cost using registry; if the resolved model is unknown,
|
||||
// fall back to the model group so group requests still get priced.
|
||||
s.registryMu.RLock()
|
||||
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
|
||||
pricingModel := model
|
||||
if s.registry != nil && s.registry.FindModel(pricingModel) == nil && modelGroup != "" {
|
||||
pricingModel = modelGroup
|
||||
}
|
||||
entry.Cost = utils.CalculateCost(s.registry, pricingModel, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
|
||||
s.registryMu.RUnlock()
|
||||
}
|
||||
|
||||
@@ -799,3 +940,165 @@ func (s *Server) Run() error {
|
||||
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
|
||||
return s.router.Run(addr)
|
||||
}
|
||||
|
||||
func uint32Ptr(v uint32) *uint32 { return &v }
|
||||
|
||||
func (s *Server) buildRouteContextFromChat(req models.ChatCompletionRequest) *router.RouteContext {
|
||||
userMessage := extractUserMessage(req.Messages)
|
||||
requiresToolCalling := len(req.Tools) > 0
|
||||
hasMultimodal := false
|
||||
inputTokens := 0
|
||||
|
||||
for _, msg := range req.Messages {
|
||||
if strContent, ok := msg.Content.(string); ok {
|
||||
inputTokens += len(strContent) / 4
|
||||
} else if parts, ok := msg.Content.([]interface{}); ok {
|
||||
for _, part := range parts {
|
||||
if partMap, ok := part.(map[string]interface{}); ok {
|
||||
partType, _ := partMap["type"].(string)
|
||||
if partType == "text" {
|
||||
text, _ := partMap["text"].(string)
|
||||
inputTokens += len(text) / 4
|
||||
} else if partType == "image_url" {
|
||||
hasMultimodal = true
|
||||
inputTokens += 1000 // Approximate cost of an image in tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msgLower := strings.ToLower(userMessage)
|
||||
requiresReasoning := strings.Contains(msgLower, "reason") ||
|
||||
strings.Contains(msgLower, "think step by step") ||
|
||||
strings.Contains(msgLower, "mathematics") ||
|
||||
strings.Contains(msgLower, "architecture") ||
|
||||
strings.Contains(msgLower, "explain in detail")
|
||||
|
||||
routeCtx := &router.RouteContext{
|
||||
UserMessage: userMessage,
|
||||
InputTokens: inputTokens,
|
||||
HasMultimodalInput: hasMultimodal,
|
||||
RequiresToolCalling: requiresToolCalling,
|
||||
RequiresReasoning: requiresReasoning,
|
||||
}
|
||||
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
|
||||
return routeCtx
|
||||
}
|
||||
|
||||
func (s *Server) buildRouteContextFromResponses(req models.ResponsesRequest) *router.RouteContext {
|
||||
var userMessage string
|
||||
hasMultimodal := false
|
||||
inputTokens := len(req.Instructions) / 4
|
||||
requiresToolCalling := len(req.Tools) > 0 && string(req.Tools) != "null" && string(req.Tools) != ""
|
||||
|
||||
var strInput string
|
||||
if err := json.Unmarshal(req.Input, &strInput); err == nil {
|
||||
userMessage = strInput
|
||||
inputTokens += len(userMessage) / 4
|
||||
} else {
|
||||
var msgs []models.ResponseInputMessage
|
||||
if err := json.Unmarshal(req.Input, &msgs); err == nil {
|
||||
for _, m := range msgs {
|
||||
var contentStr string
|
||||
if err := json.Unmarshal(m.Content, &contentStr); err == nil {
|
||||
if m.Role == "user" {
|
||||
userMessage = contentStr
|
||||
}
|
||||
inputTokens += len(contentStr) / 4
|
||||
} else {
|
||||
var parts []models.ContentPart
|
||||
if err := json.Unmarshal(m.Content, &parts); err == nil {
|
||||
for _, p := range parts {
|
||||
if p.Type == "text" {
|
||||
if m.Role == "user" {
|
||||
userMessage = p.Text
|
||||
}
|
||||
inputTokens += len(p.Text) / 4
|
||||
} else if p.Type == "image_url" {
|
||||
hasMultimodal = true
|
||||
inputTokens += 1000
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msgLower := strings.ToLower(userMessage)
|
||||
requiresReasoning := strings.Contains(msgLower, "reason") ||
|
||||
strings.Contains(msgLower, "think step by step") ||
|
||||
strings.Contains(msgLower, "mathematics") ||
|
||||
strings.Contains(msgLower, "architecture") ||
|
||||
strings.Contains(msgLower, "explain in detail")
|
||||
|
||||
routeCtx := &router.RouteContext{
|
||||
UserMessage: userMessage,
|
||||
InputTokens: inputTokens,
|
||||
HasMultimodalInput: hasMultimodal,
|
||||
RequiresToolCalling: requiresToolCalling,
|
||||
RequiresReasoning: requiresReasoning,
|
||||
}
|
||||
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
|
||||
return routeCtx
|
||||
}
|
||||
|
||||
func (s *Server) getRouteCtxTags(routeCtx *router.RouteContext) []string {
|
||||
var tags []string
|
||||
msgLower := strings.ToLower(routeCtx.UserMessage)
|
||||
|
||||
// fast-flow keywords
|
||||
fastFlowKeywords := []string{
|
||||
"classify", "classification", "label", "tag", "route", "routing", "intent",
|
||||
"json", "yaml", "csv", "schema", "parse", "extract", "transform", "format", "regex",
|
||||
"short answer", "brief", "concise", "tl;dr", "one line", "simple",
|
||||
"fix this", "small bug", "quick fix", "typo", "syntax error",
|
||||
}
|
||||
for _, kw := range fastFlowKeywords {
|
||||
if strings.Contains(msgLower, kw) {
|
||||
tags = append(tags, "fast-flow", "classification", "json-extraction", "basic-qa")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// standard-pro keywords
|
||||
standardProKeywords := []string{
|
||||
"explain", "summarize", "rewrite", "draft", "edit", "polish", "outline",
|
||||
"long doc", "document", "email", "memo", "proposal", "report", "handout", "notes",
|
||||
"compare", "choose", "recommend", "tradeoff", "pros and cons", "analysis",
|
||||
"code review", "debug", "bug", "feature", "api", "endpoint", "implement",
|
||||
"plan", "planning", "workflow", "integration",
|
||||
}
|
||||
for _, kw := range standardProKeywords {
|
||||
if strings.Contains(msgLower, kw) {
|
||||
tags = append(tags, "standard-pro", "long-doc")
|
||||
break
|
||||
}
|
||||
}
|
||||
if routeCtx.HasMultimodalInput {
|
||||
tags = append(tags, "video-analysis", "multimodal-qa")
|
||||
}
|
||||
|
||||
// heavy-logic keywords
|
||||
heavyLogicKeywords := []string{
|
||||
"agent", "agents", "tool use", "function calling", "multi-agent", "orchestrate",
|
||||
"system design", "scaling", "performance", "architecture review", "distributed",
|
||||
"hard bug", "race condition", "deadlock", "memory leak", "crash", "production outage",
|
||||
"long context", "large codebase", "many files", "complex refactor", "migration",
|
||||
"research", "deep dive", "literature", "paper", "scholarly", "thorough analysis",
|
||||
"deep reasoning", "think step by step", "reason through", "careful analysis",
|
||||
}
|
||||
for _, kw := range heavyLogicKeywords {
|
||||
if strings.Contains(msgLower, kw) {
|
||||
tags = append(tags, "heavy-logic", "deep-reasoning", "architecture", "hard-debugging")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if routeCtx.RequiresToolCalling {
|
||||
tags = append(tags, "tool-heavy", "multi-step-agent", "swe-bench")
|
||||
}
|
||||
|
||||
return tags
|
||||
}
|
||||
|
||||
@@ -45,6 +45,24 @@ func FetchRegistry() (*models.ModelRegistry, error) {
|
||||
return nil, fmt.Errorf("failed to fetch registry after 3 attempts: %w", lastErr)
|
||||
}
|
||||
|
||||
// promoDiscount describes a temporary pricing discount applied on top of
|
||||
// the standard (list) price from the model registry.
|
||||
type promoDiscount struct {
|
||||
Factor float64 // multiplier applied after standard calculation (0.25 = 75% off)
|
||||
ExpiresAt time.Time // discount ends at this time (UTC)
|
||||
}
|
||||
|
||||
// promoDiscounts maps model IDs to active promotional discounts.
|
||||
// Sources:
|
||||
// - DeepSeek v4 Pro: 75% off list pricing until 2026-05-31
|
||||
// https://api-docs.deepseek.com/quick_start/pricing
|
||||
var promoDiscounts = map[string]promoDiscount{
|
||||
"deepseek-v4-pro": {
|
||||
Factor: 0.25,
|
||||
ExpiresAt: time.Date(2026, 5, 31, 23, 59, 59, 0, time.UTC),
|
||||
},
|
||||
}
|
||||
|
||||
func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens, completionTokens, reasoningTokens, cacheRead, cacheWrite uint32) float64 {
|
||||
meta := registry.FindModel(modelID)
|
||||
if meta == nil || meta.Cost == nil {
|
||||
@@ -72,5 +90,12 @@ func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens,
|
||||
cost += float64(cacheWrite) * (*meta.Cost.CacheWrite) / 1000000.0
|
||||
}
|
||||
|
||||
// Apply promotional discounts (e.g. DeepSeek 75% off until 2026-05-31).
|
||||
if discount, ok := promoDiscounts[modelID]; ok {
|
||||
if time.Now().UTC().Before(discount.ExpiresAt) {
|
||||
cost *= discount.Factor
|
||||
}
|
||||
}
|
||||
|
||||
return cost
|
||||
}
|
||||
|
||||
+6
-1
@@ -89,6 +89,10 @@
|
||||
<i class="fas fa-brain"></i>
|
||||
<span>Models</span>
|
||||
</li>
|
||||
<li class="menu-item" data-page="model-groups">
|
||||
<i class="fas fa-code-branch"></i>
|
||||
<span>Model Groups</span>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
@@ -164,7 +168,7 @@
|
||||
<script src="/js/auth.js?v=7"></script>
|
||||
<script src="/js/charts.js?v=7"></script>
|
||||
<script src="/js/websocket.js?v=7"></script>
|
||||
<script src="/js/dashboard.js?v=7"></script>
|
||||
<script src="/js/dashboard.js?v=8"></script>
|
||||
|
||||
<!-- Page Modules -->
|
||||
<script src="/js/pages/overview.js?v=7"></script>
|
||||
@@ -177,5 +181,6 @@
|
||||
<script src="/js/pages/settings.js?v=7"></script>
|
||||
<script src="/js/pages/logs.js?v=7"></script>
|
||||
<script src="/js/pages/users.js?v=7"></script>
|
||||
<script src="/js/pages/model_groups.js?v=9"></script>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@@ -119,6 +119,7 @@ class Dashboard {
|
||||
'settings': 'Settings',
|
||||
'logs': 'Logs',
|
||||
'models': 'Models',
|
||||
'model-groups': 'Model Groups',
|
||||
'users': 'User Management'
|
||||
};
|
||||
if (titleElement) titleElement.textContent = titles[page] || 'Dashboard';
|
||||
@@ -130,6 +131,11 @@ class Dashboard {
|
||||
if (content) {
|
||||
content.innerHTML = await this.getPageTemplate(page);
|
||||
await this.initializePageScript(page);
|
||||
|
||||
// Model Groups page uses its own render method
|
||||
if (page === 'model-groups' && typeof modelGroupsPage !== 'undefined') {
|
||||
await modelGroupsPage.render();
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Error loading page ${page}:`, error);
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
// Model Groups Management Page
|
||||
|
||||
class ModelGroupsPage {
|
||||
constructor() {
|
||||
this.container = document.getElementById('page-content');
|
||||
}
|
||||
|
||||
async render() {
|
||||
this.container.innerHTML = `
|
||||
<div class="page-header">
|
||||
<h3>Model Groups</h3>
|
||||
<p class="text-muted">Define auto-routing groups that pick the best model for each request.</p>
|
||||
<button class="btn btn-primary" onclick="modelGroupsPage.showCreateForm()">
|
||||
<i class="fas fa-plus"></i> Add Group
|
||||
</button>
|
||||
</div>
|
||||
<div id="model-groups-list" class="table-container"></div>
|
||||
<div id="model-group-form" class="form-container" style="display:none;"></div>
|
||||
`;
|
||||
await this.loadGroups();
|
||||
}
|
||||
|
||||
async loadGroups() {
|
||||
try {
|
||||
const groups = await api.get('/model-groups');
|
||||
const list = document.getElementById('model-groups-list');
|
||||
if (!groups || groups.length === 0) {
|
||||
list.innerHTML = '<div class="empty-state">No model groups defined. Create one to enable auto-routing.</div>';
|
||||
return;
|
||||
}
|
||||
|
||||
let html = '<table class="data-table"><thead><tr>';
|
||||
html += '<th>Group ID</th><th>Level</th><th>Primary Use</th><th>Strategy</th><th>Targets</th><th>Actions</th>';
|
||||
html += '</tr></thead><tbody>';
|
||||
|
||||
groups.forEach(g => {
|
||||
html += '<tr>';
|
||||
html += '<td><code>' + this.esc(g.id) + '</code></td>';
|
||||
html += '<td>' + (g.logic_level != null ? g.logic_level : '—') + '</td>';
|
||||
html += '<td>' + this.esc(g.primary_use || '—') + '</td>';
|
||||
html += '<td><span class="badge">' + this.esc(g.strategy) + '</span></td>';
|
||||
html += '<td><code>' + this.esc(g.targets) + '</code></td>';
|
||||
html += '<td>';
|
||||
html += '<button class="btn btn-sm" onclick="modelGroupsPage.showEditForm(\'' + this.esc(g.id) + '\')">Edit</button> ';
|
||||
html += '<button class="btn btn-sm btn-danger" onclick="modelGroupsPage.deleteGroup(\'' + this.esc(g.id) + '\')">Delete</button>';
|
||||
html += '</td></tr>';
|
||||
});
|
||||
|
||||
html += '</tbody></table>';
|
||||
list.innerHTML = html;
|
||||
} catch (err) {
|
||||
document.getElementById('model-groups-list').innerHTML =
|
||||
'<div class="error-message">Failed to load model groups: ' + this.esc(err.message) + '</div>';
|
||||
}
|
||||
}
|
||||
|
||||
showCreateForm() {
|
||||
this.renderForm(null);
|
||||
}
|
||||
|
||||
async showEditForm(id) {
|
||||
try {
|
||||
const groups = await api.get('/model-groups');
|
||||
const group = groups.find(g => g.id === id);
|
||||
if (group) this.renderForm(group);
|
||||
} catch (err) {
|
||||
alert('Failed to load group: ' + err.message);
|
||||
}
|
||||
}
|
||||
|
||||
renderForm(group) {
|
||||
const isEdit = !!group;
|
||||
const form = document.getElementById('model-group-form');
|
||||
form.style.display = 'block';
|
||||
form.innerHTML = `
|
||||
<h4>${isEdit ? 'Edit' : 'Create'} Model Group</h4>
|
||||
<form onsubmit="modelGroupsPage.saveGroup(event, ${isEdit})">
|
||||
<div class="form-control">
|
||||
<label>Group ID</label>
|
||||
<input type="text" id="mg-id" value="${this.esc(group ? group.id : '')}" ${isEdit ? 'readonly' : 'required'}
|
||||
placeholder="e.g. deepseek-auto">
|
||||
<small>Clients use this as the model name.</small>
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label>Strategy</label>
|
||||
<select id="mg-strategy">
|
||||
<option value="heuristic" ${group && group.strategy === 'heuristic' ? 'selected' : ''}>Heuristic (rules-based)</option>
|
||||
<option value="classifier" ${group && group.strategy === 'classifier' ? 'selected' : ''}>Classifier (LLM judge)</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label>Targets (JSON array)</label>
|
||||
<input type="text" id="mg-targets" value='${this.esc(group ? group.targets : '["cheap-model","smart-model"]')}' required>
|
||||
<small>First target = cheapest/fastest. Last target = smartest/most expensive.</small>
|
||||
</div>
|
||||
<div class="form-control" id="mg-selector-row" style="${group && group.strategy === 'classifier' ? '' : 'display:none'}">
|
||||
<label>Selector Model</label>
|
||||
<input type="text" id="mg-selector-model" value="${this.esc(group && group.selector_model ? group.selector_model : 'gpt-4o-mini')}"
|
||||
placeholder="Model used to judge task complexity">
|
||||
</div>
|
||||
<div class="form-control" id="mg-threshold-row" style="${group && group.strategy === 'classifier' ? '' : 'display:none'}">
|
||||
<label>Complexity Threshold</label>
|
||||
<input type="number" id="mg-threshold" value="${group && group.complexity_threshold ? group.complexity_threshold : ''}" min="1"
|
||||
placeholder="Tasks rated >= this go to the smart model">
|
||||
</div>
|
||||
<div class="form-control" id="mg-rules-row" style="${group && group.strategy === 'heuristic' ? '' : 'display:none'}">
|
||||
<label>Heuristic Rules (JSON array)</label>
|
||||
<textarea id="mg-rules" rows="4" placeholder='[{"pattern":"step by step","target":1}]'>${group && group.heuristic_rules ? group.heuristic_rules : ''}</textarea>
|
||||
<small>Pattern to match in user messages. target = index into targets array.</small>
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label>Logic Level (1-10)</label>
|
||||
<input type="number" id="mg-level" value="${group && group.logic_level != null ? group.logic_level : ''}" min="1" max="10"
|
||||
placeholder="e.g. 8 for heavy logic, 2 for fast/basic">
|
||||
<small>Rough complexity scale. 1-3: fast/light, 4-7: standard, 8-10: heavy.</small>
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label>Primary Use</label>
|
||||
<input type="text" id="mg-primary-use" value="${this.esc(group && group.primary_use ? group.primary_use : '')}"
|
||||
placeholder="e.g. Complex Coding, Logic, Agents.">
|
||||
<small>Brief description of what this group is best used for.</small>
|
||||
</div>
|
||||
<div class="form-actions">
|
||||
<button type="submit" class="btn btn-primary">Save</button>
|
||||
<button type="button" class="btn" onclick="document.getElementById('model-group-form').style.display='none'">Cancel</button>
|
||||
</div>
|
||||
</form>
|
||||
`;
|
||||
|
||||
document.getElementById('mg-strategy').onchange = function() {
|
||||
var isClassifier = this.value === 'classifier';
|
||||
document.getElementById('mg-selector-row').style.display = isClassifier ? '' : 'none';
|
||||
document.getElementById('mg-threshold-row').style.display = isClassifier ? '' : 'none';
|
||||
document.getElementById('mg-rules-row').style.display = isClassifier ? 'none' : '';
|
||||
};
|
||||
}
|
||||
|
||||
async saveGroup(event, isEdit) {
|
||||
event.preventDefault();
|
||||
var id = document.getElementById('mg-id').value.trim();
|
||||
var strategy = document.getElementById('mg-strategy').value;
|
||||
var targets = document.getElementById('mg-targets').value;
|
||||
var selectorModel = document.getElementById('mg-selector-model').value.trim() || null;
|
||||
var thresholdVal = document.getElementById('mg-threshold').value;
|
||||
var rules = document.getElementById('mg-rules').value.trim() || null;
|
||||
var logicLevelVal = document.getElementById('mg-level').value;
|
||||
var primaryUse = document.getElementById('mg-primary-use').value.trim() || null;
|
||||
|
||||
try { JSON.parse(targets); } catch (e) { alert('Targets must be valid JSON array'); return; }
|
||||
if (rules) { try { JSON.parse(rules); } catch (e) { alert('Heuristic rules must be valid JSON'); return; } }
|
||||
|
||||
var body = { id: id, strategy: strategy, targets: targets, selector_model: selectorModel, heuristic_rules: rules };
|
||||
if (thresholdVal) body.complexity_threshold = parseInt(thresholdVal);
|
||||
if (logicLevelVal) body.logic_level = parseInt(logicLevelVal);
|
||||
if (primaryUse) body.primary_use = primaryUse;
|
||||
|
||||
try {
|
||||
if (isEdit) {
|
||||
await api.put('/model-groups/' + encodeURIComponent(id), body);
|
||||
} else {
|
||||
await api.post('/model-groups', body);
|
||||
}
|
||||
document.getElementById('model-group-form').style.display = 'none';
|
||||
await this.loadGroups();
|
||||
} catch (err) {
|
||||
alert('Failed to save: ' + err.message);
|
||||
}
|
||||
}
|
||||
|
||||
async deleteGroup(id) {
|
||||
if (!confirm('Delete model group "' + id + '"? This cannot be undone.')) return;
|
||||
try {
|
||||
await api.delete('/model-groups/' + encodeURIComponent(id));
|
||||
await this.loadGroups();
|
||||
} catch (err) {
|
||||
alert('Failed to delete: ' + err.message);
|
||||
}
|
||||
}
|
||||
|
||||
esc(str) {
|
||||
if (!str) return '';
|
||||
return String(str).replace(/&/g,'&').replace(/</g,'<').replace(/>/g,'>').replace(/"/g,'"');
|
||||
}
|
||||
}
|
||||
|
||||
var modelGroupsPage = new ModelGroupsPage();
|
||||
@@ -392,7 +392,7 @@ class MonitoringPage {
|
||||
</div>
|
||||
<div class="stream-entry-content">
|
||||
<strong>${request.client_id || 'Unknown'}</strong> →
|
||||
${request.provider || 'Unknown'} (${request.model || 'Unknown'})
|
||||
${request.provider || 'Unknown'} (${request.model || 'Unknown'}${request.model_group ? ` via ${request.model_group}` : ''})
|
||||
<div class="stream-entry-details">
|
||||
${request.total_tokens || request.tokens || 0} tokens • ${request.duration_ms || request.duration || 0}ms
|
||||
</div>
|
||||
|
||||
@@ -252,7 +252,7 @@ class OverviewPage {
|
||||
<td>${time}</td>
|
||||
<td><span class="badge-client">${request.client_id}</span></td>
|
||||
<td>${request.provider}</td>
|
||||
<td><code class="code-sm">${request.model}</code></td>
|
||||
<td><code class="code-sm">${request.model}${request.model_group ? ` (via ${request.model_group})` : ''}</code></td>
|
||||
<td>${request.tokens.toLocaleString()}</td>
|
||||
<td>
|
||||
<span class="status-badge ${statusClass}">
|
||||
@@ -313,7 +313,7 @@ class OverviewPage {
|
||||
<td>${time}</td>
|
||||
<td><span class="badge-client">${request.client_id}</span></td>
|
||||
<td>${request.provider}</td>
|
||||
<td><code class="code-sm">${request.model}</code></td>
|
||||
<td><code class="code-sm">${request.model}${request.model_group ? ` (via ${request.model_group})` : ''}</code></td>
|
||||
<td>${(request.total_tokens || request.tokens || 0).toLocaleString()}</td>
|
||||
<td>
|
||||
<span class="status-badge ${statusClass}">
|
||||
|
||||
@@ -309,7 +309,7 @@ class WebSocketManager {
|
||||
<td>${time}</td>
|
||||
<td>${request.client_id || 'Unknown'}</td>
|
||||
<td>${request.provider || 'Unknown'}</td>
|
||||
<td>${request.model || 'Unknown'}</td>
|
||||
<td>${request.model || 'Unknown'}${request.model_group ? ` (via ${request.model_group})` : ''}</td>
|
||||
<td>${(request.total_tokens || request.tokens || 0)}</td>
|
||||
<td>
|
||||
<span class="status-badge ${statusClass}">
|
||||
@@ -358,7 +358,7 @@ class WebSocketManager {
|
||||
</div>
|
||||
<div class="stream-entry-content">
|
||||
<strong>${request.client_id || 'Unknown'}</strong> →
|
||||
${request.provider || 'Unknown'} (${request.model || 'Unknown'})
|
||||
${request.provider || 'Unknown'} (${request.model || 'Unknown'}${request.model_group ? ` via ${request.model_group}` : ''})
|
||||
<div class="stream-entry-details">
|
||||
${(request.total_tokens || request.tokens || 0)} tokens • ${(request.duration_ms || request.duration || 0)}ms
|
||||
</div>
|
||||
|
||||
Reference in New Issue
Block a user