Compare commits
63 Commits
2f6b7deb2c
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 73a82e6175 | |||
| b3354a1bbc | |||
| 1dc5f586b9 | |||
| 40f055cb57 | |||
| 970e778703 | |||
| 477a811999 | |||
| d2b9da89d9 | |||
| b7df3108fa | |||
| 28b8271c1d | |||
| eb585c0001 | |||
| 4aea7a3b4c | |||
| 330eaa57d1 | |||
| 0ae30036f0 | |||
| 3c0b59622e | |||
| 7517307c11 | |||
| 19517b0847 | |||
| a3a6f765e7 | |||
| 79dd122b56 | |||
| 3021e4b2b4 | |||
| 14de7e9ebf | |||
| 4fef201e95 | |||
| bac03de051 | |||
| 37949e560b | |||
| f04cb6b8f2 | |||
| 10262c0e5a | |||
| d345f8c41d | |||
| d1f7a57f58 | |||
| dc9af4d79c | |||
| c009d401fb | |||
| e5ef39f327 | |||
| eb67287b56 | |||
| 4aa17b4fd2 | |||
| 79571c6bdc | |||
| d46a333249 | |||
| 7446f3463d | |||
| b1a72f5a10 | |||
| 5ee539d95c | |||
| 14e26a4323 | |||
| 1c3b1c6fe9 | |||
| 5e0c10db01 | |||
| e598150d90 | |||
| 2fa6f0df62 | |||
| db76858072 | |||
| af2c5b95f7 | |||
| 1f574d8134 | |||
| 8a8d8d1477 | |||
| da074f52b4 | |||
| 9b0aa4dbe8 | |||
| 212ac14a1b | |||
| 2929f51556 | |||
| e12418cc4c | |||
| be4ec3482a | |||
| e67aafdac1 | |||
| 21e5204abd | |||
| 4095c68822 | |||
| ef37dc5af0 | |||
| fdbb068a6c | |||
| dbbf48cb14 | |||
| 1e13b0376b | |||
| 1b5cd2815e | |||
| ba4c4af2f8 | |||
| e56a284415 | |||
| cbc9eeb453 |
@@ -18,6 +18,9 @@ DEEPSEEK_API_KEY=sk-...
|
||||
MOONSHOT_API_KEY=sk-...
|
||||
GROK_API_KEY=xai-...
|
||||
|
||||
# Xiaomi MiMo
|
||||
XIAOMI_API_KEY=sk-...
|
||||
|
||||
# ==============================================================================
|
||||
# Server Configuration
|
||||
# ==============================================================================
|
||||
|
||||
+12
-7
@@ -1,13 +1,18 @@
|
||||
.env
|
||||
.env.*
|
||||
!.env.example
|
||||
/target
|
||||
/llm-proxy
|
||||
/llm-proxy-go
|
||||
/gophergate
|
||||
/data/
|
||||
*.db
|
||||
*.db-shm
|
||||
*.db-wal
|
||||
.env
|
||||
.env.*
|
||||
!.env.example
|
||||
/gophergate
|
||||
/llm-proxy
|
||||
/llm-proxy-go
|
||||
*.log
|
||||
.opencode/
|
||||
.pi-lens/
|
||||
.pi-lens/cache/
|
||||
server.pid
|
||||
/target
|
||||
nohup.out
|
||||
*.bak
|
||||
|
||||
@@ -0,0 +1,919 @@
|
||||
# Automatic Model Routing — Implementation Plan
|
||||
|
||||
> **For Hermes:** Use subagent-driven-development skill to implement this plan task-by-task.
|
||||
|
||||
**Goal:** Add a model-group router that lets clients send `model: "deepseek-auto"` and have gophergate pick the best concrete model based on heuristic rules or an optional classifier LLM.
|
||||
|
||||
**Architecture:** A new `internal/router/` package with heuristic and classifier strategies, backed by a `model_groups` DB table. The router injects into `handleChatCompletions` after provider resolution but before the provider call — zero changes to the Provider interface. Admin CRUD endpoints and a dashboard tab for management.
|
||||
|
||||
**Tech Stack:** Go 1.22+, Gin, sqlx (SQLite), resty, existing OpenAI provider for classifier calls.
|
||||
|
||||
---
|
||||
|
||||
## Task 1: Add `model_groups` DB migration and struct
|
||||
|
||||
**Objective:** Create the `model_groups` table and Go struct.
|
||||
|
||||
**Files:**
|
||||
- Modify: `internal/db/db.go`
|
||||
|
||||
**Step 1: Add CREATE TABLE to migrations**
|
||||
|
||||
In `RunMigrations()`, add to the `queries` slice (after `client_tokens`):
|
||||
|
||||
```go
|
||||
`CREATE TABLE IF NOT EXISTS model_groups (
|
||||
id TEXT PRIMARY KEY,
|
||||
strategy TEXT NOT NULL DEFAULT 'heuristic',
|
||||
selector_model TEXT,
|
||||
targets TEXT NOT NULL DEFAULT '[]',
|
||||
complexity_threshold INTEGER,
|
||||
heuristic_rules TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)`,
|
||||
```
|
||||
|
||||
**Step 2: Add the Go struct**
|
||||
|
||||
After the `ClientToken` struct (around line 264), add:
|
||||
|
||||
```go
|
||||
type ModelGroup struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
Strategy string `db:"strategy" json:"strategy"`
|
||||
SelectorModel *string `db:"selector_model" json:"selector_model"`
|
||||
Targets string `db:"targets" json:"targets"` // JSON array
|
||||
ComplexityThreshold *int `db:"complexity_threshold" json:"complexity_threshold"`
|
||||
HeuristicRules *string `db:"heuristic_rules" json:"heuristic_rules"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
```
|
||||
|
||||
**Step 3: Seed default groups**
|
||||
|
||||
After the "Default client" block in `RunMigrations()`, add:
|
||||
|
||||
```go
|
||||
// Seed default model groups
|
||||
defaultGroups := []struct {
|
||||
id, strategy, targets string
|
||||
}{
|
||||
{"deepseek-auto", "heuristic", `["deepseek-chat","deepseek-reasoner"]`},
|
||||
{"openai-auto", "heuristic", `["gpt-4o-mini","gpt-4o"]`},
|
||||
{"gemini-auto", "heuristic", `["gemini-2.0-flash","gemini-2.5-pro"]`},
|
||||
}
|
||||
for _, g := range defaultGroups {
|
||||
db.Exec(`INSERT OR IGNORE INTO model_groups (id, strategy, targets) VALUES (?, ?, ?)`,
|
||||
g.id, g.strategy, g.targets)
|
||||
}
|
||||
```
|
||||
|
||||
**Step 4: Build and verify**
|
||||
|
||||
```bash
|
||||
cd ~/Documents/projects/web_projects/gophergate && go build ./...
|
||||
```
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add internal/db/db.go
|
||||
git commit -m "feat: add model_groups table and default seed data"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 2: Create router package — interface and heuristic router
|
||||
|
||||
**Objective:** Create `internal/router/` with the Router interface and heuristic implementation.
|
||||
|
||||
**Files:**
|
||||
- Create: `internal/router/router.go`
|
||||
- Create: `internal/router/heuristic.go`
|
||||
|
||||
**Step 1: Create `internal/router/router.go`**
|
||||
|
||||
```go
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
// Decision holds the result of a routing decision.
|
||||
type Decision struct {
|
||||
SelectedModel string `json:"selected_model"`
|
||||
Strategy string `json:"strategy"` // "heuristic" or "classifier"
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// ClassifierFunc is the callback for classifier-based routing.
|
||||
// Takes a system prompt, user message, and selector model.
|
||||
// Returns a complexity rating string (e.g. "3").
|
||||
type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error)
|
||||
|
||||
// Router resolves model groups to concrete models.
|
||||
type Router struct {
|
||||
groups map[string]db.ModelGroup
|
||||
classify ClassifierFunc
|
||||
}
|
||||
|
||||
// New creates a Router. classify may be nil if no classifier groups exist.
|
||||
func New(groups []db.ModelGroup, classify ClassifierFunc) *Router {
|
||||
r := &Router{
|
||||
groups: make(map[string]db.ModelGroup),
|
||||
classify: classify,
|
||||
}
|
||||
for _, g := range groups {
|
||||
r.groups[g.ID] = g
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// IsGroup returns true if the model name is a group ID.
|
||||
func (r *Router) IsGroup(modelID string) bool {
|
||||
_, ok := r.groups[modelID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Route resolves a group to a concrete model.
|
||||
// Extracts the user message from the request body JSON bytes.
|
||||
func (r *Router) Route(ctx context.Context, groupID string, userMessage string) (*Decision, error) {
|
||||
group, ok := r.groups[groupID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown model group: %s", groupID)
|
||||
}
|
||||
|
||||
var targets []string
|
||||
if err := json.Unmarshal([]byte(group.Targets), &targets); err != nil || len(targets) == 0 {
|
||||
return nil, fmt.Errorf("invalid or empty targets for group %s", groupID)
|
||||
}
|
||||
|
||||
switch group.Strategy {
|
||||
case "heuristic":
|
||||
return routeHeuristic(group, targets, userMessage)
|
||||
case "classifier":
|
||||
if r.classify == nil {
|
||||
// Fall back to heuristic if no classifier is available
|
||||
return routeHeuristic(group, targets, userMessage)
|
||||
}
|
||||
return routeClassifier(ctx, r.classify, group, targets, userMessage)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown strategy: %s", group.Strategy)
|
||||
}
|
||||
}
|
||||
|
||||
// Reload replaces the group definitions without recreating the router.
|
||||
func (r *Router) Reload(groups []db.ModelGroup) {
|
||||
r.groups = make(map[string]db.ModelGroup)
|
||||
for _, g := range groups {
|
||||
r.groups[g.ID] = g
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: Create `internal/router/heuristic.go`**
|
||||
|
||||
```go
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
// HeuristicRule defines a pattern-based routing rule.
|
||||
type HeuristicRule struct {
|
||||
Pattern string `json:"pattern"` // substring to match in user message
|
||||
TargetIdx int `json:"target"` // index into targets array (0-based)
|
||||
CaseSensitive bool `json:"case_sensitive,omitempty"`
|
||||
}
|
||||
|
||||
func routeHeuristic(group db.ModelGroup, targets []string, userMessage string) (*Decision, error) {
|
||||
// Default to first target (cheapest/fastest)
|
||||
selected := targets[0]
|
||||
reason := "default (first target)"
|
||||
|
||||
// If heuristic_rules is set, use them
|
||||
if group.HeuristicRules != nil && *group.HeuristicRules != "" {
|
||||
var rules []HeuristicRule
|
||||
if err := json.Unmarshal([]byte(*group.HeuristicRules), &rules); err == nil {
|
||||
searchMsg := userMessage
|
||||
for _, rule := range rules {
|
||||
pattern := rule.Pattern
|
||||
msg := searchMsg
|
||||
if !rule.CaseSensitive {
|
||||
pattern = strings.ToLower(pattern)
|
||||
msg = strings.ToLower(msg)
|
||||
}
|
||||
if strings.Contains(msg, pattern) {
|
||||
if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) {
|
||||
selected = targets[rule.TargetIdx]
|
||||
reason = "matched heuristic rule: " + rule.Pattern
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Built-in fallback heuristics (apply even without custom rules)
|
||||
if reason == "default (first target)" && len(targets) > 1 {
|
||||
msgLower := strings.ToLower(userMessage)
|
||||
// Complex task indicators → last target (usually the smarter model)
|
||||
complexIndicators := []string{
|
||||
"step by step", "explain in detail", "reason through",
|
||||
"think carefully", "analyze", "debug", "write code",
|
||||
"implement", "refactor", "architecture",
|
||||
}
|
||||
for _, indicator := range complexIndicators {
|
||||
if strings.Contains(msgLower, indicator) {
|
||||
selected = targets[len(targets)-1]
|
||||
reason = "complex task indicator: " + indicator
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Decision{
|
||||
SelectedModel: selected,
|
||||
Strategy: "heuristic",
|
||||
Reason: reason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// routeHeuristic exists as a package-level func for direct use.
|
||||
var _ = routeHeuristic // suppress unused warning when classifier is the only caller
|
||||
```
|
||||
|
||||
Hmm, actually let me simplify. The `routeHeuristic` function IS used by `Router.Route()`. Let me not use the blank identifier trick.
|
||||
|
||||
**Step 3: Build**
|
||||
|
||||
```bash
|
||||
cd ~/Documents/projects/web_projects/gophergate && go build ./...
|
||||
```
|
||||
|
||||
Fix any compilation errors (missing imports, etc.).
|
||||
|
||||
**Step 4: Commit**
|
||||
|
||||
```bash
|
||||
git add internal/router/
|
||||
git commit -m "feat: add router package with heuristic strategy"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 3: Add classifier router
|
||||
|
||||
**Objective:** Implement the classifier strategy that uses a cheap LLM to rate task complexity.
|
||||
|
||||
**Files:**
|
||||
- Create: `internal/router/classifier.go`
|
||||
|
||||
**Step 1: Create `internal/router/classifier.go`**
|
||||
|
||||
```go
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
const classifierSystemPrompt = `You are a task complexity classifier. Rate the following user message on a scale of 1 to %d, where:
|
||||
1 = trivial/simple (basic facts, greetings, simple math)
|
||||
%d = highly complex (multi-step reasoning, code generation, architecture design)
|
||||
|
||||
Reply with ONLY the number. No explanation.`
|
||||
|
||||
func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, userMessage string) (*Decision, error) {
|
||||
maxRating := len(targets)
|
||||
if maxRating < 2 {
|
||||
maxRating = 2
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating)
|
||||
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMessage)
|
||||
if err != nil {
|
||||
// Classifier failed — fall back to heuristic
|
||||
return routeHeuristic(group, targets, userMessage)
|
||||
}
|
||||
|
||||
rating, err := strconv.Atoi(strings.TrimSpace(ratingStr))
|
||||
if err != nil || rating < 1 {
|
||||
rating = 1
|
||||
}
|
||||
if rating > maxRating {
|
||||
rating = maxRating
|
||||
}
|
||||
|
||||
idx := rating - 1 // 0-based index into targets
|
||||
return &Decision{
|
||||
SelectedModel: targets[idx],
|
||||
Strategy: "classifier",
|
||||
Reason: fmt.Sprintf("complexity rating: %d/%d", rating, maxRating),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getSelectorModel(group db.ModelGroup, targets []string) string {
|
||||
if group.SelectorModel != nil && *group.SelectorModel != "" {
|
||||
return *group.SelectorModel
|
||||
}
|
||||
// Default: use the first (cheapest) target model as the selector
|
||||
return targets[0]
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: Build**
|
||||
|
||||
```bash
|
||||
cd ~/Documents/projects/web_projects/gophergate && go build ./...
|
||||
```
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
git add internal/router/classifier.go
|
||||
git commit -m "feat: add classifier routing strategy with LLM complexity rating"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 4: Wire router into the server
|
||||
|
||||
**Objective:** Add the Router to the Server struct, initialize it, and inject it into `handleChatCompletions`.
|
||||
|
||||
**Files:**
|
||||
- Modify: `internal/server/server.go`
|
||||
|
||||
**Step 1: Add router field to Server struct**
|
||||
|
||||
In the `Server` struct (around line 23), add after the `registryMu` field:
|
||||
|
||||
```go
|
||||
router *router.Router
|
||||
```
|
||||
|
||||
**Step 2: Add import**
|
||||
|
||||
Add to the imports block:
|
||||
|
||||
```go
|
||||
"gophergate/internal/router"
|
||||
```
|
||||
|
||||
**Step 3: Initialize router in NewServer**
|
||||
|
||||
After `s.setupRoutes()` (line 66), add:
|
||||
|
||||
```go
|
||||
// Initialize model group router
|
||||
s.refreshRouter()
|
||||
```
|
||||
|
||||
**Step 4: Add refreshRouter method**
|
||||
|
||||
Add a new method on Server:
|
||||
|
||||
```go
|
||||
func (s *Server) refreshRouter() {
|
||||
var groups []db.ModelGroup
|
||||
if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil {
|
||||
fmt.Printf("Warning: Failed to load model groups: %v\n", err)
|
||||
groups = nil
|
||||
}
|
||||
|
||||
// Build classifier function using the OpenAI provider
|
||||
var classifyFn router.ClassifierFunc
|
||||
if openaiProvider, ok := s.providers["openai"]; ok {
|
||||
classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) {
|
||||
req := &models.UnifiedRequest{
|
||||
Model: selectorModel,
|
||||
Messages: []models.UnifiedMessage{
|
||||
{Role: "system", Content: []models.ContentPart{{Type: "text", Text: systemPrompt}}},
|
||||
{Role: "user", Content: []models.ContentPart{{Type: "text", Text: userMessage}}},
|
||||
},
|
||||
MaxTokens: uint32Ptr(5),
|
||||
Stream: false,
|
||||
}
|
||||
resp, err := openaiProvider.ChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", fmt.Errorf("no choices in classifier response")
|
||||
}
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
}
|
||||
|
||||
if s.router == nil {
|
||||
s.router = router.New(groups, classifyFn)
|
||||
} else {
|
||||
s.router.Reload(groups)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Step 5: Add uint32Ptr helper (if not already in the codebase)**
|
||||
|
||||
At the bottom of server.go, add:
|
||||
|
||||
```go
|
||||
func uint32Ptr(v uint32) *uint32 { return &v }
|
||||
```
|
||||
|
||||
**Step 6: Inject router into handleChatCompletions**
|
||||
|
||||
In `handleChatCompletions`, after the model prefix stripping block (after line 475) and before building the UnifiedRequest (line 478), add:
|
||||
|
||||
```go
|
||||
// Check if model is a group and route to a concrete model
|
||||
if s.router != nil && s.router.IsGroup(modelID) {
|
||||
userMessage := extractUserMessage(req.Messages)
|
||||
decision, err := s.router.Route(c.Request.Context(), modelID, userMessage)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
|
||||
return
|
||||
}
|
||||
modelID = decision.SelectedModel
|
||||
log.Printf("[ROUTER] %s → %s (%s: %s)", req.Model, modelID, decision.Strategy, decision.Reason)
|
||||
}
|
||||
```
|
||||
|
||||
**Step 7: Add extractUserMessage helper**
|
||||
|
||||
```go
|
||||
func extractUserMessage(messages []models.ChatCompletionMessage) string {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == "user" {
|
||||
if s, ok := messages[i].Content.(string); ok {
|
||||
return s
|
||||
}
|
||||
// It might be a content array — grab text from first part
|
||||
if parts, ok := messages[i].Content.([]interface{}); ok && len(parts) > 0 {
|
||||
if part, ok := parts[0].(map[string]interface{}); ok {
|
||||
if text, ok := part["text"].(string); ok {
|
||||
return text
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
```
|
||||
|
||||
**Step 8: Add router refresh to RefreshProviders**
|
||||
|
||||
At the end of `RefreshProviders()` (before `return nil` at line 171), add:
|
||||
|
||||
```go
|
||||
s.refreshRouter()
|
||||
```
|
||||
|
||||
**Step 9: Build**
|
||||
|
||||
```bash
|
||||
cd ~/Documents/projects/web_projects/gophergate && go build ./...
|
||||
```
|
||||
|
||||
Expect compilation errors — need to check the `ChatCompletionMessage` type. The handler uses `models.ChatCompletionRequest` which has `Messages []ChatCompletionMessage`. Let me verify the type. If it's `[]models.ChatCompletionMessage` with `Content` as a string field, the helper is simpler. Fix as needed.
|
||||
|
||||
**Step 10: Commit**
|
||||
|
||||
```bash
|
||||
git add internal/server/server.go
|
||||
git commit -m "feat: wire model group router into chat completions handler"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 5: Add admin API endpoints for model groups
|
||||
|
||||
**Objective:** CRUD endpoints at `/api/model-groups` for dashboard management.
|
||||
|
||||
**Files:**
|
||||
- Create: `internal/server/model_groups_admin.go`
|
||||
|
||||
**Step 1: Create `internal/server/model_groups_admin.go`**
|
||||
|
||||
```go
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"gophergate/internal/db"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (s *Server) handleGetModelGroups(c *gin.Context) {
|
||||
var groups []db.ModelGroup
|
||||
if err := s.database.Select(&groups, "SELECT * FROM model_groups ORDER BY id"); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if groups == nil {
|
||||
groups = []db.ModelGroup{}
|
||||
}
|
||||
c.JSON(http.StatusOK, groups)
|
||||
}
|
||||
|
||||
func (s *Server) handleCreateModelGroup(c *gin.Context) {
|
||||
var group db.ModelGroup
|
||||
if err := c.ShouldBindJSON(&group); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.database.Exec(`
|
||||
INSERT INTO model_groups (id, strategy, selector_model, targets, complexity_threshold, heuristic_rules)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
group.ID, group.Strategy, group.SelectorModel, group.Targets,
|
||||
group.ComplexityThreshold, group.HeuristicRules)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
c.JSON(http.StatusCreated, group)
|
||||
}
|
||||
|
||||
func (s *Server) handleUpdateModelGroup(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var group db.ModelGroup
|
||||
if err := c.ShouldBindJSON(&group); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.database.Exec(`
|
||||
UPDATE model_groups SET strategy=?, selector_model=?, targets=?, complexity_threshold=?, heuristic_rules=?, updated_at=CURRENT_TIMESTAMP
|
||||
WHERE id=?`,
|
||||
group.Strategy, group.SelectorModel, group.Targets,
|
||||
group.ComplexityThreshold, group.HeuristicRules, id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
c.JSON(http.StatusOK, group)
|
||||
}
|
||||
|
||||
func (s *Server) handleDeleteModelGroup(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
_, err := s.database.Exec("DELETE FROM model_groups WHERE id=?", id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
||||
}
|
||||
```
|
||||
|
||||
**Step 2: Register routes in setupRoutes()**
|
||||
|
||||
In `setupRoutes()`, add under the admin group (after the models endpoints around line 229):
|
||||
|
||||
```go
|
||||
admin.GET("/model-groups", s.handleGetModelGroups)
|
||||
admin.POST("/model-groups", s.handleCreateModelGroup)
|
||||
admin.PUT("/model-groups/:id", s.handleUpdateModelGroup)
|
||||
admin.DELETE("/model-groups/:id", s.handleDeleteModelGroup)
|
||||
```
|
||||
|
||||
**Step 3: Build**
|
||||
|
||||
```bash
|
||||
cd ~/Documents/projects/web_projects/gophergate && go build ./...
|
||||
```
|
||||
|
||||
**Step 4: Commit**
|
||||
|
||||
```bash
|
||||
git add internal/server/model_groups_admin.go internal/server/server.go
|
||||
git commit -m "feat: add model groups CRUD admin API endpoints"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 6: Add dashboard UI — sidebar entry and page module
|
||||
|
||||
**Objective:** Add a "Model Groups" tab to the dashboard sidebar and a page module for CRUD management.
|
||||
|
||||
**Files:**
|
||||
- Modify: `static/index.html`
|
||||
- Create: `static/js/pages/model_groups.js`
|
||||
|
||||
**Step 1: Add sidebar menu item in index.html**
|
||||
|
||||
In the MANAGEMENT section (after line 91, before `</ul>`), add:
|
||||
|
||||
```html
|
||||
<li class="menu-item" data-page="model-groups">
|
||||
<i class="fas fa-code-branch"></i>
|
||||
<span>Model Groups</span>
|
||||
</li>
|
||||
```
|
||||
|
||||
**Step 2: Add script tag in index.html**
|
||||
|
||||
After the users.js script (line 179), add:
|
||||
|
||||
```html
|
||||
<script src="/js/pages/model_groups.js?v=8"></script>
|
||||
```
|
||||
|
||||
**Step 3: Create `static/js/pages/model_groups.js`**
|
||||
|
||||
```javascript
|
||||
// Model Groups Management Page
|
||||
|
||||
class ModelGroupsPage {
|
||||
constructor() {
|
||||
this.container = document.getElementById('page-content');
|
||||
}
|
||||
|
||||
async render() {
|
||||
this.container.innerHTML = `
|
||||
<div class="page-header">
|
||||
<h3>Model Groups</h3>
|
||||
<p class="text-muted">Define auto-routing groups that pick the best model for each request.</p>
|
||||
<button class="btn btn-primary" onclick="modelGroupsPage.showCreateForm()">
|
||||
<i class="fas fa-plus"></i> Add Group
|
||||
</button>
|
||||
</div>
|
||||
<div id="model-groups-list" class="table-container"></div>
|
||||
<div id="model-group-form" class="form-container" style="display:none;"></div>
|
||||
`;
|
||||
await this.loadGroups();
|
||||
}
|
||||
|
||||
async loadGroups() {
|
||||
try {
|
||||
const groups = await api.get('/api/model-groups');
|
||||
const list = document.getElementById('model-groups-list');
|
||||
if (!groups || groups.length === 0) {
|
||||
list.innerHTML = '<div class="empty-state">No model groups defined. Create one to enable auto-routing.</div>';
|
||||
return;
|
||||
}
|
||||
|
||||
let targets;
|
||||
try { targets = JSON.parse(g.targets); } catch { targets = []; }
|
||||
const heuristicRules = g.heuristic_rules ? JSON.parse(g.heuristic_rules) : null;
|
||||
|
||||
let html = '<table class="data-table"><thead><tr>';
|
||||
html += '<th>Group ID</th><th>Strategy</th><th>Targets</th><th>Actions</th>';
|
||||
html += '</tr></thead><tbody>';
|
||||
|
||||
groups.forEach(g => {
|
||||
html += `<tr>
|
||||
<td><code>${this.esc(g.id)}</code></td>
|
||||
<td><span class="badge">${this.esc(g.strategy)}</span></td>
|
||||
<td><code>${this.esc(g.targets)}</code></td>
|
||||
<td>
|
||||
<button class="btn btn-sm" onclick="modelGroupsPage.showEditForm('${this.esc(g.id)}')">Edit</button>
|
||||
<button class="btn btn-sm btn-danger" onclick="modelGroupsPage.deleteGroup('${this.esc(g.id)}')">Delete</button>
|
||||
</td>
|
||||
</tr>`;
|
||||
});
|
||||
|
||||
html += '</tbody></table>';
|
||||
list.innerHTML = html;
|
||||
} catch (err) {
|
||||
document.getElementById('model-groups-list').innerHTML =
|
||||
`<div class="error-message">Failed to load model groups: ${this.esc(err.message)}</div>`;
|
||||
}
|
||||
}
|
||||
|
||||
showCreateForm() {
|
||||
this.renderForm(null);
|
||||
}
|
||||
|
||||
async showEditForm(id) {
|
||||
const groups = await api.get('/api/model-groups');
|
||||
const group = groups.find(g => g.id === id);
|
||||
if (group) this.renderForm(group);
|
||||
}
|
||||
|
||||
renderForm(group) {
|
||||
const isEdit = !!group;
|
||||
const form = document.getElementById('model-group-form');
|
||||
form.style.display = 'block';
|
||||
form.innerHTML = `
|
||||
<h4>${isEdit ? 'Edit' : 'Create'} Model Group</h4>
|
||||
<form onsubmit="modelGroupsPage.saveGroup(event, ${isEdit})">
|
||||
<div class="form-control">
|
||||
<label>Group ID</label>
|
||||
<input type="text" id="mg-id" value="${this.esc(group?.id || '')}" ${isEdit ? 'readonly' : 'required'}
|
||||
placeholder="e.g. deepseek-auto">
|
||||
<small>Clients use this as the model name.</small>
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label>Strategy</label>
|
||||
<select id="mg-strategy">
|
||||
<option value="heuristic" ${group?.strategy === 'heuristic' ? 'selected' : ''}>Heuristic (rules-based)</option>
|
||||
<option value="classifier" ${group?.strategy === 'classifier' ? 'selected' : ''}>Classifier (LLM judge)</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label>Targets (JSON array)</label>
|
||||
<input type="text" id="mg-targets" value='${this.esc(group?.targets || '["cheap-model","smart-model"]')}' required>
|
||||
<small>First target = cheapest/fastest. Last target = smartest/most expensive.</small>
|
||||
</div>
|
||||
<div class="form-control" id="mg-selector-row" ${group?.strategy === 'classifier' ? '' : 'style="display:none"'}>
|
||||
<label>Selector Model</label>
|
||||
<input type="text" id="mg-selector-model" value="${this.esc(group?.selector_model || 'gpt-4o-mini')}"
|
||||
placeholder="Model used to judge task complexity">
|
||||
</div>
|
||||
<div class="form-control" id="mg-threshold-row" ${group?.strategy === 'classifier' ? '' : 'style="display:none"'}>
|
||||
<label>Complexity Threshold</label>
|
||||
<input type="number" id="mg-threshold" value="${group?.complexity_threshold || ''}" min="1"
|
||||
placeholder="Tasks rated >= this go to the smart model">
|
||||
</div>
|
||||
<div class="form-control" id="mg-rules-row" ${group?.strategy === 'heuristic' ? '' : 'style="display:none"'}>
|
||||
<label>Heuristic Rules (JSON array)</label>
|
||||
<textarea id="mg-rules" rows="4" placeholder='[{"pattern":"step by step","target":1}]'>${group?.heuristic_rules || ''}</textarea>
|
||||
<small>Pattern to match in user messages. target = index into targets array.</small>
|
||||
</div>
|
||||
<div class="form-actions">
|
||||
<button type="submit" class="btn btn-primary">Save</button>
|
||||
<button type="button" class="btn" onclick="document.getElementById('model-group-form').style.display='none'">Cancel</button>
|
||||
</div>
|
||||
</form>
|
||||
`;
|
||||
|
||||
// Toggle strategy-specific fields
|
||||
document.getElementById('mg-strategy').onchange = function() {
|
||||
const isClassifier = this.value === 'classifier';
|
||||
document.getElementById('mg-selector-row').style.display = isClassifier ? '' : 'none';
|
||||
document.getElementById('mg-threshold-row').style.display = isClassifier ? '' : 'none';
|
||||
document.getElementById('mg-rules-row').style.display = isClassifier ? 'none' : '';
|
||||
};
|
||||
}
|
||||
|
||||
async saveGroup(event, isEdit) {
|
||||
event.preventDefault();
|
||||
const id = document.getElementById('mg-id').value.trim();
|
||||
const strategy = document.getElementById('mg-strategy').value;
|
||||
const targets = document.getElementById('mg-targets').value;
|
||||
const selectorModel = document.getElementById('mg-selector-model').value.trim() || null;
|
||||
const thresholdVal = document.getElementById('mg-threshold').value;
|
||||
const rules = document.getElementById('mg-rules').value.trim() || null;
|
||||
|
||||
// Validate JSON
|
||||
try { JSON.parse(targets); } catch { alert('Targets must be valid JSON array'); return; }
|
||||
if (rules) { try { JSON.parse(rules); } catch { alert('Heuristic rules must be valid JSON'); return; } }
|
||||
|
||||
const body = { id, strategy, targets, selector_model: selectorModel, heuristic_rules: rules };
|
||||
if (thresholdVal) body.complexity_threshold = parseInt(thresholdVal);
|
||||
|
||||
try {
|
||||
if (isEdit) {
|
||||
await api.put(`/api/model-groups/${encodeURIComponent(id)}`, body);
|
||||
} else {
|
||||
await api.post('/api/model-groups', body);
|
||||
}
|
||||
document.getElementById('model-group-form').style.display = 'none';
|
||||
await this.loadGroups();
|
||||
} catch (err) {
|
||||
alert('Failed to save: ' + err.message);
|
||||
}
|
||||
}
|
||||
|
||||
async deleteGroup(id) {
|
||||
if (!confirm(`Delete model group "${id}"?`)) return;
|
||||
try {
|
||||
await api.delete(`/api/model-groups/${encodeURIComponent(id)}`);
|
||||
await this.loadGroups();
|
||||
} catch (err) {
|
||||
alert('Failed to delete: ' + err.message);
|
||||
}
|
||||
}
|
||||
|
||||
esc(str) {
|
||||
if (!str) return '';
|
||||
return String(str).replace(/&/g,'&').replace(/</g,'<').replace(/>/g,'>').replace(/"/g,'"');
|
||||
}
|
||||
}
|
||||
|
||||
const modelGroupsPage = new ModelGroupsPage();
|
||||
```
|
||||
|
||||
**Step 4: Register page in dashboard.js**
|
||||
|
||||
In `static/js/dashboard.js`, find the page loading logic. The `loadPage` method dynamically imports page modules based on `this.currentPage`. The naming convention uses hyphens in `data-page` attributes (e.g., `data-page="model-groups"`). Check how the existing pages are loaded and ensure "model-groups" maps to the new module.
|
||||
|
||||
Looking at the existing pattern, pages are loaded via script tags in index.html and their constructors handle rendering when the page is navigated to. The dashboard.js `loadPage` method calls page-specific init. Let me check if there's a page registry pattern.
|
||||
|
||||
Actually, based on the index.html, pages are loaded as separate script files and the dashboard dispatches to them. The pattern seems to be: each page script defines a class or object, and the dashboard calls a `render()` or `init()` method on it when that page is selected. Let me add the dispatch logic.
|
||||
|
||||
In `dashboard.js`, find the `loadPage` method and ensure it handles "model-groups":
|
||||
|
||||
```javascript
|
||||
// In the loadPage switch/if-else, add:
|
||||
else if (page === 'model-groups') {
|
||||
if (typeof modelGroupsPage !== 'undefined') {
|
||||
modelGroupsPage.render();
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add static/index.html static/js/pages/model_groups.js static/js/dashboard.js
|
||||
git commit -m "feat: add model groups dashboard page with CRUD UI"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 7: Integration test — build, run, verify
|
||||
|
||||
**Objective:** Ensure everything compiles and the routing works end-to-end.
|
||||
|
||||
**Step 1: Full build**
|
||||
|
||||
```bash
|
||||
cd ~/Documents/projects/web_projects/gophergate && go build -o gophergate ./cmd/gophergate
|
||||
```
|
||||
|
||||
**Step 2: Start server and test**
|
||||
|
||||
```bash
|
||||
# In one terminal:
|
||||
./gophergate
|
||||
|
||||
# In another terminal, test that default groups loaded:
|
||||
curl -s -u admin:admin123 http://localhost:8080/api/model-groups | jq
|
||||
|
||||
# Expected: array with deepseek-auto and openai-auto groups
|
||||
```
|
||||
|
||||
**Step 3: Test routing via API**
|
||||
|
||||
```bash
|
||||
# Send a request using a model group
|
||||
curl -s http://localhost:8080/v1/chat/completions \
|
||||
-H "Authorization: Bearer YOUR_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "openai-auto",
|
||||
"messages": [{"role": "user", "content": "What is 2+2?"}]
|
||||
}' | jq
|
||||
|
||||
# Check server logs for [ROUTER] line showing the decision
|
||||
```
|
||||
|
||||
**Step 4: Commit any fixes**
|
||||
|
||||
If any issues found during testing, fix and commit.
|
||||
|
||||
---
|
||||
|
||||
## Architecture Notes
|
||||
|
||||
### Why this approach
|
||||
|
||||
- **No Provider interface changes** — the router is a pre-processing step in the handler, transparent to providers
|
||||
- **Groups stored in DB** — manageable from the dashboard, no config file sprawl
|
||||
- **Classifier is optional** — heuristic mode works with zero added latency or cost
|
||||
- **Fallback chain** — classifier failure falls back to heuristic; missing router falls back to direct passthrough
|
||||
|
||||
### Edge cases handled
|
||||
|
||||
- No groups defined → router never activates, all models pass through as before
|
||||
- Unknown group ID → returns error to client
|
||||
- Empty targets → returns error
|
||||
- Classifier call fails → falls back to heuristic
|
||||
- Classifier returns garbage → clamped to valid range
|
||||
- OpenAI provider disabled → classifier groups fall back to heuristic mode
|
||||
|
||||
### What's NOT in this plan (future work)
|
||||
|
||||
- Streaming classifier support (the ~300ms classifier call happens before streaming begins — acceptable for now)
|
||||
- responses endpoint routing (`handleResponses` could also use the router but needs a different message extraction)
|
||||
- Per-client group overrides
|
||||
- A/B testing / multi-armed bandit routing
|
||||
- Caching classifier decisions for identical messages
|
||||
@@ -1,566 +0,0 @@
|
||||
# LLM Proxy - Comprehensive Fix Plan
|
||||
|
||||
## Project Overview
|
||||
Rust-based unified LLM proxy gateway (Axum + SQLite + Tokio) exposing an OpenAI-compatible API that routes to OpenAI, Gemini, DeepSeek, Grok, and Ollama. Includes dashboard with WebSocket monitoring. ~4,354 lines of Rust across 25 source files.
|
||||
|
||||
## Design Decisions
|
||||
- **Session management**: In-memory HashMap with expiry (no new dependencies)
|
||||
- **Provider deduplication**: Shared helper functions approach
|
||||
- **Dashboard refactor**: Full split into sub-modules (auth, usage, clients, providers, system, websocket)
|
||||
|
||||
---
|
||||
|
||||
## Phase 1: Fix Compilation & Test Issues
|
||||
|
||||
### 1.1 Fix config_path type mismatch
|
||||
**Files**: `src/config/mod.rs:98`, `src/lib.rs:99`
|
||||
|
||||
The `AppConfig.config_path` field is `PathBuf` but `test_utils::create_test_state` sets it to `None`.
|
||||
|
||||
**Fix**: Change `src/config/mod.rs:98` from `pub config_path: PathBuf` to `pub config_path: Option<PathBuf>`. Update `src/config/mod.rs:177` to wrap in `Some()`:
|
||||
```rust
|
||||
config_path: Some(config_path),
|
||||
```
|
||||
|
||||
### 1.2 Fix streaming test compilation errors
|
||||
**File**: `src/utils/streaming.rs:195-201`
|
||||
|
||||
Three issues in the test:
|
||||
1. Line 195-196: `ProviderStreamChunk` missing `reasoning_content` field
|
||||
2. Line 201: `RequestLogger::new()` called with 1 arg but needs 2 (pool + dashboard_tx)
|
||||
|
||||
**Fix**:
|
||||
```rust
|
||||
// Line 195-196: Add reasoning_content field
|
||||
Ok(ProviderStreamChunk { content: "Hello".to_string(), reasoning_content: None, finish_reason: None, model: "test".to_string() }),
|
||||
Ok(ProviderStreamChunk { content: " World".to_string(), reasoning_content: None, finish_reason: Some("stop".to_string()), model: "test".to_string() }),
|
||||
|
||||
// Line 200-201: Add dashboard_tx argument
|
||||
let (dashboard_tx, _) = tokio::sync::broadcast::channel(16);
|
||||
let logger = Arc::new(RequestLogger::new(pool.clone(), dashboard_tx));
|
||||
```
|
||||
|
||||
### 1.3 Fix multimodal test assertion
|
||||
**File**: `src/multimodal/mod.rs:283`
|
||||
|
||||
Line 283 asserts `!model_supports_multimodal("gemini-pro")` but the function at line 187-189 returns `true` for ALL models starting with "gemini".
|
||||
|
||||
**Fix**: Either:
|
||||
- (a) Update the function to exclude non-vision Gemini models (more correct):
|
||||
```rust
|
||||
if model.starts_with("gemini") {
|
||||
// gemini-pro (text-only) doesn't support multimodal, but gemini-pro-vision and gemini-1.5+ do
|
||||
return model.contains("vision") || model.contains("1.5") || model.contains("2.0") || model.contains("flash") || model.contains("ultra");
|
||||
}
|
||||
```
|
||||
- (b) Or remove the failing assertion if all Gemini models actually support vision now.
|
||||
|
||||
**Recommendation**: Option (b) - remove line 283, since modern Gemini models all support multimodal. Replace with a non-multimodal model test like `assert!(!ImageConverter::model_supports_multimodal("claude-3-opus"))`.
|
||||
|
||||
### 1.4 Clean up empty/stale test files
|
||||
**Files**: `tests/streaming_test.rs`, `tests/integration_tests.rs.bak`
|
||||
|
||||
**Fix**:
|
||||
- Delete `tests/streaming_test.rs` (empty file)
|
||||
- Delete `tests/integration_tests.rs.bak` (stale backup referencing old APIs)
|
||||
|
||||
---
|
||||
|
||||
## Phase 2: Fix Critical Bugs
|
||||
|
||||
### 2.1 Replace `futures::executor::block_on` with async
|
||||
**Files**:
|
||||
- `src/providers/openai.rs:63,151`
|
||||
- `src/providers/deepseek.rs:65`
|
||||
- `src/providers/grok.rs:63,151`
|
||||
- `src/providers/ollama.rs:58`
|
||||
|
||||
`block_on()` inside a Tokio async context will deadlock. The issue is that `image_input.to_base64()` is async but it's called inside a sync `.map()` closure within `serde_json::json!{}`.
|
||||
|
||||
**Fix**: Pre-process messages before building the JSON body. Create a helper function in a new file `src/providers/helpers.rs`:
|
||||
|
||||
```rust
|
||||
use crate::models::{ChatMessage, ContentPart};
|
||||
use crate::errors::AppError;
|
||||
|
||||
/// Convert messages to OpenAI-compatible JSON, resolving images asynchronously
|
||||
pub async fn messages_to_openai_json(messages: &[ChatMessage]) -> Result<Vec<serde_json::Value>, AppError> {
|
||||
let mut result = Vec::new();
|
||||
for m in messages {
|
||||
let mut parts = Vec::new();
|
||||
for p in &m.content {
|
||||
match p {
|
||||
ContentPart::Text { text } => {
|
||||
parts.push(serde_json::json!({ "type": "text", "text": text }));
|
||||
}
|
||||
ContentPart::Image(image_input) => {
|
||||
let (base64_data, mime_type) = image_input.to_base64().await
|
||||
.map_err(|e| AppError::MultimodalError(e.to_string()))?;
|
||||
parts.push(serde_json::json!({
|
||||
"type": "image_url",
|
||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
result.push(serde_json::json!({
|
||||
"role": m.role,
|
||||
"content": parts
|
||||
}));
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
```
|
||||
|
||||
Then update each provider's `chat_completion` and `chat_completion_stream` to call:
|
||||
```rust
|
||||
let messages_json = crate::providers::helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let mut body = serde_json::json!({
|
||||
"model": request.model,
|
||||
"messages": messages_json,
|
||||
"stream": false,
|
||||
});
|
||||
```
|
||||
|
||||
Remove all `futures::executor::block_on` calls.
|
||||
|
||||
### 2.2 Fix broken update_client query builder
|
||||
**File**: `src/client/mod.rs:129-163`
|
||||
|
||||
The `updates` vec collects column name strings like `"name = "` but they are **never used** in the actual query. The `query_builder` receives `.push_bind()` values without corresponding column names, producing malformed SQL.
|
||||
|
||||
**Fix**: Replace the broken pattern with proper QueryBuilder usage:
|
||||
```rust
|
||||
let mut query_builder = sqlx::QueryBuilder::new("UPDATE clients SET ");
|
||||
let mut has_updates = false;
|
||||
|
||||
if let Some(name) = &request.name {
|
||||
if has_updates { query_builder.push(", "); }
|
||||
query_builder.push("name = ");
|
||||
query_builder.push_bind(name);
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
if let Some(description) = &request.description {
|
||||
if has_updates { query_builder.push(", "); }
|
||||
query_builder.push("description = ");
|
||||
query_builder.push_bind(description);
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
if let Some(is_active) = request.is_active {
|
||||
if has_updates { query_builder.push(", "); }
|
||||
query_builder.push("is_active = ");
|
||||
query_builder.push_bind(is_active);
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
if let Some(rate_limit) = request.rate_limit_per_minute {
|
||||
if has_updates { query_builder.push(", "); }
|
||||
query_builder.push("rate_limit_per_minute = ");
|
||||
query_builder.push_bind(rate_limit);
|
||||
has_updates = true;
|
||||
}
|
||||
```
|
||||
|
||||
Remove the `updates` vec entirely - it serves no purpose.
|
||||
|
||||
---
|
||||
|
||||
## Phase 3: Security Hardening
|
||||
|
||||
### 3.1 Implement in-memory session management
|
||||
**New file**: `src/dashboard/sessions.rs`
|
||||
|
||||
```rust
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use chrono::{DateTime, Utc, Duration};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Session {
|
||||
pub username: String,
|
||||
pub role: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SessionManager {
|
||||
sessions: Arc<RwLock<HashMap<String, Session>>>,
|
||||
ttl_hours: i64,
|
||||
}
|
||||
|
||||
impl SessionManager {
|
||||
pub fn new(ttl_hours: i64) -> Self {
|
||||
Self {
|
||||
sessions: Arc::new(RwLock::new(HashMap::new())),
|
||||
ttl_hours,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_session(&self, username: String, role: String) -> String {
|
||||
let token = format!("session-{}", uuid::Uuid::new_v4());
|
||||
let now = Utc::now();
|
||||
let session = Session {
|
||||
username,
|
||||
role,
|
||||
created_at: now,
|
||||
expires_at: now + Duration::hours(self.ttl_hours),
|
||||
};
|
||||
self.sessions.write().await.insert(token.clone(), session);
|
||||
token
|
||||
}
|
||||
|
||||
pub async fn validate_session(&self, token: &str) -> Option<Session> {
|
||||
let sessions = self.sessions.read().await;
|
||||
sessions.get(token).and_then(|s| {
|
||||
if s.expires_at > Utc::now() {
|
||||
Some(s.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn revoke_session(&self, token: &str) {
|
||||
self.sessions.write().await.remove(token);
|
||||
}
|
||||
|
||||
pub async fn cleanup_expired(&self) {
|
||||
let now = Utc::now();
|
||||
self.sessions.write().await.retain(|_, s| s.expires_at > now);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Add `SessionManager` to `DashboardState`. Add it to `AppState` or initialize it in dashboard `router()`.
|
||||
|
||||
### 3.2 Fix handle_auth_status to validate sessions
|
||||
**File**: `src/dashboard/mod.rs:191-199`
|
||||
|
||||
Extract the session token from the `Authorization` header and validate it:
|
||||
|
||||
```rust
|
||||
async fn handle_auth_status(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let token = headers.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
if let Some(token) = token {
|
||||
if let Some(session) = state.session_manager.validate_session(token).await {
|
||||
return Json(ApiResponse::success(serde_json::json!({
|
||||
"authenticated": true,
|
||||
"user": {
|
||||
"username": session.username,
|
||||
"name": "Administrator",
|
||||
"role": session.role
|
||||
}
|
||||
})));
|
||||
}
|
||||
}
|
||||
|
||||
Json(ApiResponse::error("Not authenticated".to_string()))
|
||||
}
|
||||
```
|
||||
|
||||
### 3.3 Add middleware to protect dashboard API routes
|
||||
Create an Axum middleware that validates session tokens on all `/api/` routes except `/api/auth/login`.
|
||||
|
||||
### 3.4 Force password change for default admin
|
||||
**File**: `src/database/mod.rs:138-148`
|
||||
|
||||
Add a `must_change_password` column to the `users` table. Set it to `true` for the default admin. Return `must_change_password: true` in the login response so the frontend can prompt.
|
||||
|
||||
### 3.5 Mask auth tokens in settings API response
|
||||
**File**: `src/dashboard/mod.rs:1048`
|
||||
|
||||
Use the existing `mask_token` function (currently `#[allow(dead_code)]` at line 1066):
|
||||
```rust
|
||||
"auth_tokens": state.app_state.auth_tokens.iter().map(|t| mask_token(t)).collect::<Vec<_>>(),
|
||||
```
|
||||
Remove the `#[allow(dead_code)]` attribute.
|
||||
|
||||
### 3.6 Move Gemini API key from URL to header
|
||||
**File**: `src/providers/gemini.rs:172-176,301-305`
|
||||
|
||||
Change from:
|
||||
```rust
|
||||
let url = format!("{}/models/{}:generateContent?key={}", self.config.base_url, request.model, self.api_key);
|
||||
```
|
||||
To:
|
||||
```rust
|
||||
let url = format!("{}/models/{}:generateContent", self.config.base_url, request.model);
|
||||
// ...
|
||||
let response = self.client.post(&url)
|
||||
.header("x-goog-api-key", &self.api_key)
|
||||
.json(&gemini_request)
|
||||
.send()
|
||||
.await
|
||||
```
|
||||
|
||||
Same for the streaming URL at line 301-305.
|
||||
|
||||
---
|
||||
|
||||
## Phase 4: Implement Stubs & Missing Features
|
||||
|
||||
### 4.1 Implement handle_test_provider
|
||||
**File**: `src/dashboard/mod.rs:840-849`
|
||||
|
||||
Actually test the provider by sending a minimal chat completion:
|
||||
```rust
|
||||
async fn handle_test_provider(
|
||||
State(state): State<DashboardState>,
|
||||
axum::extract::Path(name): axum::extract::Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
if let Some(provider) = state.app_state.provider_manager.get_provider(&name).await {
|
||||
let test_request = UnifiedRequest {
|
||||
model: "test".to_string(), // Provider will use default
|
||||
messages: vec![ChatMessage { role: "user".to_string(), content: vec![ContentPart::Text { text: "Hi".to_string() }] }],
|
||||
temperature: None,
|
||||
max_tokens: Some(5),
|
||||
stream: false,
|
||||
};
|
||||
|
||||
match provider.chat_completion(test_request).await {
|
||||
Ok(_) => {
|
||||
let latency = start.elapsed().as_millis();
|
||||
Json(ApiResponse::success(json!({ "success": true, "latency": latency, "message": "Connection test successful" })))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::error(format!("Provider test failed: {}", e)))
|
||||
}
|
||||
} else {
|
||||
Json(ApiResponse::error(format!("Provider '{}' not found or not enabled", name)))
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 4.2 Implement real system health metrics
|
||||
**File**: `src/dashboard/mod.rs:969-978`
|
||||
|
||||
Read from `/proc/self/status` for memory, calculate from pool stats:
|
||||
```rust
|
||||
// Memory: read RSS from /proc/self/status
|
||||
let memory_kb = std::fs::read_to_string("/proc/self/status")
|
||||
.ok()
|
||||
.and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string()))
|
||||
.and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::<f64>().ok()))
|
||||
.unwrap_or(0.0);
|
||||
let memory_mb = memory_kb / 1024.0;
|
||||
```
|
||||
|
||||
### 4.3 Implement handle_get_client
|
||||
**File**: `src/dashboard/mod.rs:647-651`
|
||||
|
||||
Query client by ID from the `clients` table and return full details.
|
||||
|
||||
### 4.4 Implement handle_client_usage
|
||||
**File**: `src/dashboard/mod.rs:676-680`
|
||||
|
||||
Query `llm_requests` aggregated by the given client_id.
|
||||
|
||||
### 4.5 Implement handle_get_provider
|
||||
**File**: `src/dashboard/mod.rs:776-780`
|
||||
|
||||
Return individual provider details (reuse logic from `handle_get_providers`).
|
||||
|
||||
### 4.6 Implement handle_system_backup
|
||||
**File**: `src/dashboard/mod.rs:1033-1039`
|
||||
|
||||
Use SQLite's backup API via raw SQL:
|
||||
```rust
|
||||
let backup_path = format!("data/backup-{}.db", chrono::Utc::now().timestamp());
|
||||
sqlx::query(&format!("VACUUM INTO '{}'", backup_path))
|
||||
.execute(pool)
|
||||
.await?;
|
||||
```
|
||||
|
||||
### 4.7 Address TODO items
|
||||
- `src/server/mod.rs:211` - Check if request messages contain `ContentPart::Image` to set `has_images: true`
|
||||
- `src/logging/mod.rs:80-81` - Add optional request/response body storage (can remain None for now, just note in code)
|
||||
|
||||
---
|
||||
|
||||
## Phase 5: Code Quality
|
||||
|
||||
### 5.1 Extract shared provider logic
|
||||
**New file**: `src/providers/helpers.rs`
|
||||
|
||||
Create shared helper functions:
|
||||
- `messages_to_openai_json()` (from Phase 2)
|
||||
- `build_openai_compatible_body()` - builds the full JSON body with model, messages, stream, temperature, max_tokens
|
||||
- `parse_openai_response()` - extracts content, reasoning_content, usage from response JSON
|
||||
- `create_openai_stream()` - creates SSE stream with standard parsing
|
||||
- `calculate_cost_with_registry()` - shared cost calculation logic
|
||||
|
||||
Update `openai.rs`, `deepseek.rs`, `grok.rs`, `ollama.rs` to call these helpers. Each provider file should shrink from ~210 lines to ~50-80 lines.
|
||||
|
||||
Add `pub mod helpers;` to `src/providers/mod.rs`.
|
||||
|
||||
### 5.2 Replace wildcard re-exports
|
||||
**File**: `src/lib.rs:22-30`
|
||||
|
||||
Replace:
|
||||
```rust
|
||||
pub use auth::*;
|
||||
pub use client::*;
|
||||
// etc.
|
||||
```
|
||||
With explicit re-exports:
|
||||
```rust
|
||||
pub use auth::AuthenticatedClient;
|
||||
pub use client::ClientManager;
|
||||
pub use config::AppConfig;
|
||||
// etc.
|
||||
```
|
||||
|
||||
### 5.3 Fix all Clippy warnings (19 total)
|
||||
|
||||
1. `src/auth/mod.rs:19` - `manual_async_fn`: Use `async fn` instead of returning a future manually
|
||||
2. `src/database/mod.rs:12` - `collapsible_if`: Merge nested if statements
|
||||
3. `src/dashboard/mod.rs:139` - `collapsible_if`: Merge nested if
|
||||
4. `src/dashboard/mod.rs:616` - `to_string_in_format_args`: Remove redundant `.to_string()`
|
||||
5. `src/multimodal/mod.rs:211,220` - `collapsible_if` x2
|
||||
6. `src/providers/openai.rs:123`, `gemini.rs:225`, `deepseek.rs:125`, `grok.rs:123`, `ollama.rs:117` - `collapsible_if` x5 in calculate_cost (will be fixed by deduplication)
|
||||
7. `src/providers/mod.rs:80` - `new_without_default`: Add `impl Default for ProviderManager`
|
||||
8. `src/providers/mod.rs:193,200` - `redundant_closure` x2: Use `Arc::clone` directly instead of `|p| Arc::clone(p)`
|
||||
9. `src/rate_limiting/mod.rs:180,333,334` - `collapsible_if` x3
|
||||
10. `src/rate_limiting/mod.rs:336` - `manual_strip`: Use `.strip_prefix()` pattern
|
||||
11. `src/utils/streaming.rs:33` - `too_many_arguments`: Wrap params in a config struct
|
||||
|
||||
### 5.4 Replace unwrap() in production paths
|
||||
|
||||
1. `src/database/mod.rs:140` - `bcrypt::hash("admin", 12).unwrap()` → Use `?` with proper error propagation
|
||||
2. `src/dashboard/mod.rs:116` - `serde_json::to_string(&event).unwrap()` → Use `unwrap_or_default()` or log error
|
||||
3. `src/server/mod.rs:168` - `.json_data(response).unwrap()` → Handle error with fallback
|
||||
4. `src/config/mod.rs:139` - `std::env::current_dir().unwrap()` → Use `?` or provide a sensible default
|
||||
|
||||
### 5.5 Remove unused dependencies
|
||||
**File**: `Cargo.toml`
|
||||
|
||||
Remove or comment out:
|
||||
- `governor = "0.6"` - Custom TokenBucket is used instead
|
||||
- `async-openai` - Raw reqwest is used for all providers
|
||||
- `once_cell = "1.19"` - Redundant with Rust 2024 edition's `std::sync::LazyLock`
|
||||
|
||||
Verify each is actually unused by checking imports with `rg 'use governor' src/` etc. before removing.
|
||||
|
||||
### 5.6 Split dashboard/mod.rs into sub-modules
|
||||
**Current**: 1077-line monolith at `src/dashboard/mod.rs`
|
||||
|
||||
**Target structure**:
|
||||
```
|
||||
src/dashboard/
|
||||
├── mod.rs (~80 lines) - Module declarations, router(), DashboardState, ApiResponse
|
||||
├── sessions.rs (~80 lines) - SessionManager (new from Phase 3)
|
||||
├── auth.rs (~80 lines) - handle_login, handle_auth_status, handle_change_password
|
||||
├── usage.rs (~200 lines) - handle_usage_summary, handle_time_series, handle_clients_usage, handle_providers_usage, handle_detailed_usage, handle_analytics_breakdown
|
||||
├── clients.rs (~100 lines) - handle_get_clients, handle_create_client, handle_get_client, handle_delete_client, handle_client_usage
|
||||
├── providers.rs (~150 lines) - handle_get_providers, handle_get_provider, handle_update_provider, handle_test_provider
|
||||
├── models.rs (~100 lines) - handle_get_models, handle_update_model
|
||||
├── system.rs (~120 lines) - handle_system_health, handle_system_logs, handle_system_backup, handle_get_settings, handle_update_settings
|
||||
└── websocket.rs (~60 lines) - handle_websocket, handle_websocket_connection, handle_websocket_message
|
||||
```
|
||||
|
||||
The `mod.rs` will declare sub-modules and re-export the `router()` function. All handlers use `DashboardState` which stays in `mod.rs`.
|
||||
|
||||
---
|
||||
|
||||
## Phase 6: Infrastructure
|
||||
|
||||
### 6.1 Add rustfmt.toml
|
||||
```toml
|
||||
max_width = 120
|
||||
tab_spaces = 4
|
||||
edition = "2024"
|
||||
```
|
||||
|
||||
### 6.2 Add clippy.toml
|
||||
```toml
|
||||
too-many-arguments-threshold = 10
|
||||
```
|
||||
|
||||
### 6.3 Add GitHub Actions CI workflow
|
||||
**New file**: `.github/workflows/ci.yml`
|
||||
|
||||
```yaml
|
||||
name: CI
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: cargo fmt --check
|
||||
- run: cargo clippy -- -D warnings
|
||||
- run: cargo test
|
||||
- run: cargo build --release
|
||||
```
|
||||
|
||||
### 6.4 Fix test_dashboard.sh
|
||||
**File**: `test_dashboard.sh:33`
|
||||
|
||||
Change `"admin123"` to `"admin"` to match the actual default password.
|
||||
|
||||
### 6.5 Add Dockerfile
|
||||
**New file**: `Dockerfile`
|
||||
|
||||
Multi-stage build for minimal image size:
|
||||
```dockerfile
|
||||
FROM rust:1.85-bookworm AS builder
|
||||
WORKDIR /app
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
RUN mkdir src && echo "fn main() {}" > src/main.rs && cargo build --release && rm -rf src
|
||||
COPY . .
|
||||
RUN cargo build --release
|
||||
|
||||
FROM debian:bookworm-slim
|
||||
RUN apt-get update && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/*
|
||||
COPY --from=builder /app/target/release/llm-proxy /usr/local/bin/
|
||||
COPY --from=builder /app/static /app/static
|
||||
WORKDIR /app
|
||||
EXPOSE 8080
|
||||
CMD ["llm-proxy"]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Verification
|
||||
|
||||
After all phases, run:
|
||||
```bash
|
||||
cargo fmt --check
|
||||
cargo clippy -- -D warnings
|
||||
cargo test
|
||||
cargo build --release
|
||||
```
|
||||
|
||||
All must pass with zero warnings and zero errors.
|
||||
|
||||
---
|
||||
|
||||
## Issue Summary
|
||||
|
||||
| Severity | Count | Phase |
|
||||
|----------|-------|-------|
|
||||
| Critical | 7 | 1-3 |
|
||||
| High | 5 | 2-3 |
|
||||
| Medium | 14 | 4-5 |
|
||||
| Low | 4 | 6 |
|
||||
| **Total** | **30** | |
|
||||
|
||||
Estimated effort: ~4-6 hours of focused implementation.
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"gopls": {
|
||||
"choice": "yes",
|
||||
"timestamp": 1775750416837
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
{
|
||||
"version": 1,
|
||||
"files": {
|
||||
"../../../../.npm-packages/lib/node_modules/pi-lens/clients/lsp/index.ts": {
|
||||
"latest": {
|
||||
"commit": "da074f5",
|
||||
"timestamp": "2026-04-26T03:45:14.025Z",
|
||||
"mi": 12.6,
|
||||
"cognitive": 335,
|
||||
"nesting": 6,
|
||||
"lines": 910,
|
||||
"maxCyclomatic": 36,
|
||||
"entropy": 6.97
|
||||
},
|
||||
"history": [
|
||||
{
|
||||
"commit": "da074f5",
|
||||
"timestamp": "2026-04-26T03:45:14.025Z",
|
||||
"mi": 12.6,
|
||||
"cognitive": 335,
|
||||
"nesting": 6,
|
||||
"lines": 910,
|
||||
"maxCyclomatic": 36,
|
||||
"entropy": 6.97
|
||||
}
|
||||
],
|
||||
"trend": "stable"
|
||||
},
|
||||
"../../../../.npm-packages/lib/node_modules/pi-lens/clients/lsp/config.ts": {
|
||||
"latest": {
|
||||
"commit": "da074f5",
|
||||
"timestamp": "2026-04-26T03:45:32.901Z",
|
||||
"mi": 37.7,
|
||||
"cognitive": 49,
|
||||
"nesting": 6,
|
||||
"lines": 173,
|
||||
"maxCyclomatic": 8,
|
||||
"entropy": 6.39
|
||||
},
|
||||
"history": [
|
||||
{
|
||||
"commit": "da074f5",
|
||||
"timestamp": "2026-04-26T03:45:32.901Z",
|
||||
"mi": 37.7,
|
||||
"cognitive": 49,
|
||||
"nesting": 6,
|
||||
"lines": 173,
|
||||
"maxCyclomatic": 8,
|
||||
"entropy": 6.39
|
||||
}
|
||||
],
|
||||
"trend": "stable"
|
||||
},
|
||||
"../../../../.npm-packages/lib/node_modules/pi-lens/clients/lsp/server.ts": {
|
||||
"latest": {
|
||||
"commit": "da074f5",
|
||||
"timestamp": "2026-04-26T03:45:38.756Z",
|
||||
"mi": 3.9,
|
||||
"cognitive": 322,
|
||||
"nesting": 7,
|
||||
"lines": 1506,
|
||||
"maxCyclomatic": 28,
|
||||
"entropy": 7.47
|
||||
},
|
||||
"history": [
|
||||
{
|
||||
"commit": "da074f5",
|
||||
"timestamp": "2026-04-26T03:45:38.756Z",
|
||||
"mi": 3.9,
|
||||
"cognitive": 322,
|
||||
"nesting": 7,
|
||||
"lines": 1506,
|
||||
"maxCyclomatic": 28,
|
||||
"entropy": 7.47
|
||||
}
|
||||
],
|
||||
"trend": "stable"
|
||||
}
|
||||
},
|
||||
"capturedAt": "2026-04-26T03:45:43.756Z"
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"files": {},
|
||||
"turnCycles": 0,
|
||||
"maxCycles": 3,
|
||||
"lastUpdated": "2026-04-27T14:41:46.671Z"
|
||||
}
|
||||
@@ -30,7 +30,7 @@ The GopherGate backend is implemented in Go, focusing on high performance, clear
|
||||
## Key Components
|
||||
|
||||
### 1. Provider Interface (`internal/providers/provider.go`)
|
||||
Standardized interface for all LLM backends. Implementations handle mapping between the unified format and provider-specific APIs (OpenAI, Gemini, DeepSeek, Grok).
|
||||
Standardized interface for all LLM backends. Implementations handle mapping between the unified format and provider-specific APIs (OpenAI, Gemini, DeepSeek, Grok, Moonshot, Ollama).
|
||||
|
||||
### 2. Model Registry & Pricing (`internal/utils/registry.go`)
|
||||
Integrates with `models.dev/api.json` to provide real-time model metadata and pricing.
|
||||
@@ -46,6 +46,9 @@ Implements HMAC-SHA256 signed tokens for dashboard authentication. Tokens secure
|
||||
### 5. WebSocket Hub (`internal/server/websocket.go`)
|
||||
A centralized hub for managing WebSocket connections, allowing real-time broadcast of system events, system metrics, and request logs to the dashboard.
|
||||
|
||||
### 6. Model Group Router (`internal/router/`)
|
||||
Resolves model groups (e.g., `deepseek-auto`, `dustins_stack`) into concrete models. It supports a Classifier strategy (uses a cheap LLM to rate complexity) and an upgraded Heuristic strategy (evaluates custom condition rules like tags, token counts, multimodal inputs, reasoning, and tool calling flags or legacy keyword patterns).
|
||||
|
||||
## Concurrency Model
|
||||
|
||||
Go's goroutines and channels are used extensively:
|
||||
|
||||
@@ -0,0 +1,202 @@
|
||||
# GopherGate — Remediation Plan
|
||||
|
||||
> 3 phases, 6 weeks total. Each phase independently shippable.
|
||||
|
||||
---
|
||||
|
||||
## Phase 1 — Security & Stability (Weeks 1-2)
|
||||
|
||||
**Goal:** Patch auth bypass, data races, debug leaks. No new features.
|
||||
|
||||
### 1.1 Fix auth bypass
|
||||
|
||||
- [ ] `middleware/auth.go`: Return 401 instead of `c.Next()` when no auth header on `/v1/*`
|
||||
- [ ] Add `requireAuth` param to `AuthMiddleware` constructor: `AuthMiddleware(db, requireAuth bool)`
|
||||
- [ ] `/v1/*` routes → `requireAuth=true`, leave `/health` unauthed
|
||||
- [ ] Add tests: curl request without token → 401
|
||||
|
||||
### 1.2 Fix WebSocket origin
|
||||
|
||||
- [ ] `websocket.go`: Replace `return true` with origin check against configured `Server.Host`
|
||||
- [ ] Config option `websocket.allowed_origins []string` (default: same origin)
|
||||
- [ ] Add `xsrf` check on WS upgrade endpoint if behind proxy
|
||||
|
||||
### 1.3 Strip debug prints
|
||||
|
||||
- [ ] `config.go`: Remove `fmt.Printf("Debug Config:...")` and `fmt.Printf("Debug Env:...")`
|
||||
- [ ] `server.go` `logRequest()`: Remove `fmt.Printf("[DEBUG] Request logged:...")`
|
||||
- [ ] `config.go`: Remove `fmt.Printf("[DEBUG] Final Ollama Config:...")`
|
||||
- [ ] `providers/ollama.go`: Remove `fmt.Printf("[Ollama]...")` debug logs or gate behind `LLM_PROXY_DEBUG=1`
|
||||
- [ ] Replace all `fmt.Printf` with structured logger (slog from stdlib)
|
||||
|
||||
### 1.4 Fix registry data race
|
||||
|
||||
- [ ] `server.go`: Add `sync.RWMutex` around `s.registry`
|
||||
- [ ] `handleListModels()`: Lock read
|
||||
- [ ] `logRequest()`: Lock read
|
||||
- [ ] Background refresh goroutines: Lock write
|
||||
- [ ] Verify with `go run -race`
|
||||
|
||||
### 1.5 Session cleanup
|
||||
|
||||
- [ ] `sessions.go`: Add periodic cleanup goroutine for expired sessions
|
||||
- [ ] Cleanup interval: every 15 minutes
|
||||
- [ ] `RevokeSession`: Return error instead of silent no-op
|
||||
|
||||
---
|
||||
|
||||
## Phase 2 — Reliability & Observability (Weeks 3-4)
|
||||
|
||||
**Goal:** Error handling, timeouts, logging maturity, concurrency hardening.
|
||||
|
||||
### 2.1 Provider HTTP timeouts
|
||||
|
||||
- [ ] Each provider `New*Provider()`: Set `client.SetTimeout(30 * time.Second)` for non-stream
|
||||
- [ ] Streaming: No timeout, but add `context.Context` cancellation from request
|
||||
- [ ] `circuit_breaker.go`: Configure real thresholds
|
||||
- `MaxRequests: 5`
|
||||
- `Interval: 60 * time.Second`
|
||||
- `Timeout: 30 * time.Second`
|
||||
- `ReadyToTrip: func(counts) bool { return counts.ConsecutiveFailures > 3 }`
|
||||
- [ ] Test: Stop Ollama, hit endpoint → circuit opens after 3 failures → auto-recovers after 30s
|
||||
|
||||
### 2.2 Structured logging (slog)
|
||||
|
||||
- [ ] Create `internal/logger/logger.go` — `slog.NewJSONHandler`
|
||||
- [ ] Log levels: error/warn/info/debug
|
||||
- [ ] Replace all `fmt.Printf` in: server, providers, config, logging
|
||||
- [ ] `RequestLogger`: Use slog structured fields, remove manual JSON building
|
||||
- [ ] Log channel: increase buffer from 100 to 10000 or use batch insert every 5s
|
||||
|
||||
### 2.3 Stream error propagation
|
||||
|
||||
- [ ] `ChatCompletionStream`: Send error chunks as SSE events, not just `fmt.Printf`
|
||||
- [ ] Format: `data: {"error":"..."}\n\n`
|
||||
- [ ] Client sees full error in stream instead of silent truncation
|
||||
|
||||
### 2.4 Registry fetch retry
|
||||
|
||||
- [ ] `FetchRegistry()`: Add retry with backoff (3 tries, 1s/2s/4s)
|
||||
- [ ] Cache last-known-good registry so startup works offline
|
||||
|
||||
### 2.5 Token truncation safety
|
||||
|
||||
- [ ] `helpers.go`: Deep-copy ToolCall before truncation, don't mutate original
|
||||
- [ ] Same pattern across all providers that sanitize IDs
|
||||
|
||||
### 2.6 RevokeSession error handling
|
||||
|
||||
- [ ] `RevokeSession(token)` → `RevokeSession(token) error`
|
||||
- [ ] Update all callers to handle error
|
||||
|
||||
---
|
||||
|
||||
## Phase 3 — Architecture & Maintainability (Weeks 5-6)
|
||||
|
||||
**Goal:** Code splitting, test coverage, billing integrity.
|
||||
|
||||
### 3.1 Split dashboard.go
|
||||
|
||||
- [ ] Create `internal/server/clients.go` — client CRUD handlers
|
||||
- [ ] Create `internal/server/providers.go` — provider handlers
|
||||
- [ ] Create `internal/server/users.go` — user handlers
|
||||
- [ ] Create `internal/server/analytics.go` — usage/analytics handlers
|
||||
- [ ] Create `internal/server/system.go` — health, metrics, logs, backup
|
||||
- [ ] `dashboard.go` shrinks to imports + route wiring only
|
||||
|
||||
### 3.2 Provider routing via config
|
||||
|
||||
- [ ] Replace `strings.Contains` routing table with config-driven model→provider map
|
||||
- [ ] `config.go`: Add `server.model_routing` map (e.g. `"llama-*": "ollama"`)
|
||||
- [ ] Fallback chain: explicit match → prefix match → glob match → default
|
||||
- [ ] Backward-compat: keep old prefix logic as fallback
|
||||
|
||||
### 3.3 Billing integrity
|
||||
|
||||
- [ ] `logging.go`: Add idempotency key to log entries (unique request ID)
|
||||
- [ ] Before deducting balance, check if `request_id` already processed
|
||||
- [ ] `processLog`: Wrap in retry on serialization failure (SQLite busy)
|
||||
- [ ] Credit deduction: move to separate async worker with replay protection
|
||||
|
||||
### 3.4 Add tests
|
||||
|
||||
- [ ] `internal/models/`: Unit tests for `FindModel()`, message conversion
|
||||
- [ ] `internal/providers/helpers_test.go`: Unit tests for `MessagesToOpenAIJSON`, `ParseOpenAIResponse`
|
||||
- [ ] `internal/utils/`: Tests for `Encrypt`/`Decrypt`, `CalculateCost`
|
||||
- [ ] `internal/server/`: Integration test for auth flow (token → chat completion)
|
||||
- [ ] `internal/middleware/`: Test auth bypass fix
|
||||
- [ ] Goal: ≥40% coverage on non-UI packages
|
||||
|
||||
### 3.5 go.mod hygiene
|
||||
|
||||
- [ ] `go mod tidy` (done)
|
||||
- [ ] Add `go vet ./...` to CI/pre-commit hook
|
||||
- [ ] Pin dependencies with `go mod verify`
|
||||
|
||||
---
|
||||
|
||||
## Dependency Map
|
||||
|
||||
```
|
||||
Phase 1 ──────────────────────────▶ Phase 2 ──────────────────────────▶ Phase 3
|
||||
│ │ │
|
||||
├─ 1.1 Auth bypass ──────────▶ 2.3 Stream errors (depends on auth) │
|
||||
├─ 1.2 WS origin │ │
|
||||
├─ 1.3 Debug prints │ │
|
||||
├─ 1.4 Registry race │ │
|
||||
├─ 1.5 Session cleanup │ │
|
||||
│ ├─ 2.1 HTTP timeouts │
|
||||
│ ├─ 2.2 Structured logging ───────────▶ 3.3 Billing (depends on good logs)
|
||||
│ ├─ 2.4 Registry retry │
|
||||
│ ├─ 2.5 Token truncation │
|
||||
│ ├─ 2.6 RevokeSession errors │
|
||||
│ │
|
||||
│ ├─ 3.1 Split dashboard.go
|
||||
│ ├─ 3.2 Config routing
|
||||
│ ├─ 3.4 Tests
|
||||
│ ├─ 3.5 go.mod hygiene
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Mermaid Gantt
|
||||
|
||||
```mermaid
|
||||
gantt
|
||||
title GopherGate Remediation
|
||||
dateFormat YYYY-MM-DD
|
||||
axisFormat %b %d
|
||||
|
||||
section Phase 1 — Security
|
||||
Auth bypass fix :p1a, 2026-05-04, 2d
|
||||
WS origin lock :p1b, after p1a, 1d
|
||||
Strip debug prints :p1c, 2026-05-04, 2d
|
||||
Registry race fix :p1d, after p1c, 1d
|
||||
Session cleanup :p1e, after p1d, 2d
|
||||
|
||||
section Phase 2 — Reliability
|
||||
HTTP timeouts + CB :p2a, 2026-05-11, 3d
|
||||
Structured logging :p2b, 2026-05-11, 3d
|
||||
Stream error propagation :p2c, after p2a, 1d
|
||||
Registry retry :p2d, after p2b, 1d
|
||||
Token truncation fix :p2e, after p2a, 1d
|
||||
RevokeSession errors :p2f, after p2b, 1d
|
||||
|
||||
section Phase 3 — Architecture
|
||||
Split dashboard.go :p3a, 2026-05-25, 4d
|
||||
Config-driven routing :p3b, 2026-05-25, 3d
|
||||
Billing integrity :p3c, after p3a, 3d
|
||||
Add tests :p3d, 2026-06-01, 5d
|
||||
go.mod hygiene :p3e, after p3d, 1d
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Immediate Next Action
|
||||
|
||||
**Start 1.1 — Fix auth bypass:**
|
||||
|
||||
- Edit `middleware/auth.go` → change `c.Next()` to `c.AbortWithStatusJSON(401, ...)` when no header
|
||||
- Add `RequireAuth` bool param
|
||||
- Update `server.go` `setupRoutes()` to pass `requireAuth=true` for `/v1/*`
|
||||
- `curl localhost:8080/v1/chat/completions -d '{}'` → 401
|
||||
@@ -1,16 +1,17 @@
|
||||
# GopherGate
|
||||
|
||||
A unified, high-performance LLM proxy gateway built in Go. It provides a single OpenAI-compatible API to access multiple providers (OpenAI, Gemini, DeepSeek, Moonshot, Grok, Ollama) with built-in token tracking, real-time cost calculation, multi-user authentication, and a management dashboard.
|
||||
A unified, high-performance LLM proxy gateway built in Go. It provides OpenAI-compatible `/v1/chat/completions`, `/v1/images/generations`, `/v1/responses`, and `/v1/models` endpoints to access multiple providers (OpenAI, Gemini, DeepSeek, Moonshot, Grok, Ollama) with built-in token tracking, real-time cost calculation, multi-user authentication, and a management dashboard.
|
||||
|
||||
## Features
|
||||
|
||||
- **Unified API:** OpenAI-compatible `/v1/chat/completions` and `/v1/models` endpoints.
|
||||
- **Unified API:** OpenAI-compatible `/v1/chat/completions`, `/v1/images/generations`, `/v1/responses`, and `/v1/models` endpoints.
|
||||
- The `/v1/responses` endpoint (OpenAI Responses API) is currently supported for OpenAI models only. Non-OpenAI providers (Gemini, DeepSeek, Moonshot, Grok, Ollama) return a "not supported" response.
|
||||
- **Multi-Provider Support:**
|
||||
- **OpenAI:** GPT-4o, GPT-4o Mini, o1, o3 reasoning models.
|
||||
- **Google Gemini:** Gemini 2.0 Flash, Pro, and vision models (with native CoT support).
|
||||
- **DeepSeek:** DeepSeek Chat and Reasoner (R1) models.
|
||||
- **Moonshot:** Kimi K2.5 and other Kimi models.
|
||||
- **xAI Grok:** Grok-4 models.
|
||||
- **OpenAI:** GPT-4o, GPT-4o Mini, GPT-5, GPT-5.4, o1/o3/o4 reasoning models, DALL-E 2/3 image generation.
|
||||
- **Google Gemini:** Gemini 2.5 Flash/Pro, Gemini 3 Flash/Pro previews, Imagen 3 image generation.
|
||||
- **DeepSeek:** DeepSeek Chat, Reasoner, V4 Flash, V4 Pro.
|
||||
- **Moonshot:** Kimi K2.5, K2.6 reasoning models.
|
||||
- **xAI Grok:** Grok-3, Grok-4, Grok-4.3 reasoning models.
|
||||
- **Ollama:** Local LLMs running on your network.
|
||||
- **Observability & Tracking:**
|
||||
- **Asynchronous Logging:** Non-blocking request logging to SQLite using background workers.
|
||||
@@ -18,13 +19,25 @@ A unified, high-performance LLM proxy gateway built in Go. It provides a single
|
||||
- **Database Persistence:** Every request logged to SQLite for historical analysis and dashboard analytics.
|
||||
- **Streaming Support:** Full SSE (Server-Sent Events) support for all providers.
|
||||
- **Multimodal (Vision):** Image processing (Base64 and remote URLs) across compatible providers.
|
||||
- **Image Generation:** DALL-E 2/3 (OpenAI) and Imagen 3 (Gemini) via OpenAI-compatible `/v1/images/generations` endpoint.
|
||||
- **Automatic Model Routing:**
|
||||
- **Hierarchical Routing:** Groups can target other groups, cascading through multiple levels until a concrete model is reached. Cycle detection and depth limiting (max 10) prevent infinite loops.
|
||||
- **Heuristic strategy:** Free, zero-latency routing supporting both keyword matching (regex/substrings) and condition-based checks (evaluating tags, token limits, multimodal inputs, reasoning, and tool calling requirements).
|
||||
- **Classifier strategy:** Uses a cheap LLM to rate task complexity on a configurable scale (1-10), then selects the appropriate model. Bucket mapping distributes ratings proportionally across targets.
|
||||
- **Two-Level Dispatch:** A `dispatcher` group (classifier, threshold=10) auto-routes to tier groups by complexity score, which then apply their own internal strategies.
|
||||
- **Metadata:** Groups support `logic_level` (1-10 complexity scale) and `primary_use` (description) fields for organizational clarity.
|
||||
- Pre-seeded with provider groups, tier groups (heavy-logic / standard-pro / fast-flow), and a dispatcher. Model groups are exposed in `/v1/models` so clients can discover them.
|
||||
- **Multi-User Access Control:**
|
||||
- **Admin Role:** Full access to all dashboard features, user management, and system configuration.
|
||||
- **Viewer Role:** Read-only access to usage analytics, costs, and monitoring.
|
||||
- **Client API Keys:** Create and manage multiple client tokens for external integrations.
|
||||
- **Reliability:**
|
||||
- **Circuit Breaking:** Automatically protects when providers are down (coming soon).
|
||||
- **Rate Limiting:** Per-client and global rate limits (coming soon).
|
||||
- **Circuit Breaking:** Protects providers when they are down, auto-recovers after timeout.
|
||||
- **Provider-Aware Classification:** Classifier selector models are routed to the correct provider automatically.
|
||||
|
||||
## DeepSeek Language Note
|
||||
|
||||
DeepSeek models default to Chinese for some prompts. GopherGate automatically injects an English system prompt ("Always respond in English.") when no system message is present. If the client provides its own system prompt, it is left untouched.
|
||||
|
||||
## Security
|
||||
|
||||
@@ -54,6 +67,7 @@ GopherGate is designed with security in mind:
|
||||
### Quick Start
|
||||
|
||||
1. Clone and build:
|
||||
|
||||
```bash
|
||||
git clone <repository-url>
|
||||
cd gophergate
|
||||
@@ -61,13 +75,20 @@ GopherGate is designed with security in mind:
|
||||
```
|
||||
|
||||
2. Configure environment:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Edit .env and add your configuration:
|
||||
# LLM_PROXY__ENCRYPTION_KEY=... (32-byte hex or base64 string)
|
||||
# OPENAI_API_KEY=sk-...
|
||||
# GEMINI_API_KEY=AIza...
|
||||
# DEEPSEEK_API_KEY=sk-...
|
||||
# MOONSHOT_API_KEY=...
|
||||
# GROK_API_KEY=xai-...
|
||||
# For Ollama (optional): Set base URL and enable
|
||||
# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://localhost:11434/v1
|
||||
# LLM_PROXY__PROVIDERS__OLLAMA__ENABLED=true
|
||||
# LLM_PROXY__PROVIDERS__OLLAMA__MODELS=llama3,gemma2,mistral
|
||||
```
|
||||
|
||||
3. Run the proxy:
|
||||
@@ -75,7 +96,16 @@ GopherGate is designed with security in mind:
|
||||
./gophergate
|
||||
```
|
||||
|
||||
The server starts on `http://0.0.0.0:8080` by default.
|
||||
The server starts on `http://0.0.0.0:8080` by default. Configure `LLM_PROXY__SERVER__PORT` in `.env` to change it.
|
||||
|
||||
### Quick Deploy Script
|
||||
|
||||
A `deploy.sh` script is included for production restarts:
|
||||
|
||||
```bash
|
||||
./deploy.sh
|
||||
# git pull -> go build -> stop old process -> start new process
|
||||
```
|
||||
|
||||
### Deployment (Docker)
|
||||
|
||||
@@ -98,6 +128,8 @@ Access the dashboard at `http://localhost:8080`.
|
||||
- **Usage:** Summary stats, time-series analytics, and provider breakdown.
|
||||
- **Clients:** API key management and per-client usage tracking.
|
||||
- **Providers:** Provider configuration and status monitoring.
|
||||
- **Model Groups:** Define auto-routing groups with heuristic or classifier strategies. Supports logic level and primary use metadata.
|
||||
- **Models:** Model enable/disable and cost configuration.
|
||||
- **Users:** Admin-only user management for dashboard access.
|
||||
- **Monitoring:** Live request stream via WebSocket.
|
||||
|
||||
@@ -108,6 +140,7 @@ Access the dashboard at `http://localhost:8080`.
|
||||
|
||||
**Forgot Password?**
|
||||
You can reset the admin password to default by running:
|
||||
|
||||
```bash
|
||||
./gophergate -reset-admin
|
||||
```
|
||||
@@ -116,11 +149,8 @@ You can reset the admin password to default by running:
|
||||
|
||||
The proxy is a drop-in replacement for OpenAI. Configure your client:
|
||||
|
||||
Moonshot models are available through the same OpenAI-compatible endpoint. For
|
||||
example, use `kimi-k2.5` as the model name after setting `MOONSHOT_API_KEY` in
|
||||
your environment.
|
||||
|
||||
### Python
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -135,6 +165,111 @@ response = client.chat.completions.create(
|
||||
)
|
||||
```
|
||||
|
||||
### Responses API
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1",
|
||||
api_key="YOUR_CLIENT_API_KEY"
|
||||
)
|
||||
|
||||
# OpenAI Responses API (supported for OpenAI models only)
|
||||
response = client.responses.create(
|
||||
model="gpt-4o",
|
||||
input="Explain quantum computing in one paragraph.",
|
||||
instructions="You are a helpful assistant.",
|
||||
temperature=0.7,
|
||||
max_output_tokens=500
|
||||
)
|
||||
print(response.output_text)
|
||||
```
|
||||
|
||||
**Note:** The `/v1/responses` endpoint is currently supported for OpenAI models only.
|
||||
|
||||
### Automatic Model Routing
|
||||
|
||||
Use a model group name to let gophergate pick the best model automatically:
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1",
|
||||
api_key="YOUR_CLIENT_API_KEY"
|
||||
)
|
||||
|
||||
# Simple query -- routes to the cheap/fast model
|
||||
response = client.chat.completions.create(
|
||||
model="fast-flow",
|
||||
messages=[{"role": "user", "content": "What is 2+2?"}]
|
||||
)
|
||||
|
||||
# Complex query -- routes to the reasoning model automatically
|
||||
response = client.chat.completions.create(
|
||||
model="heavy-logic",
|
||||
messages=[{"role": "user", "content": "Write a Python red-black tree implementation."}]
|
||||
)
|
||||
```
|
||||
|
||||
### Two-Level Dispatch
|
||||
|
||||
The `dispatcher` group uses a classifier to score prompts 1-10, then routes to the appropriate tier group:
|
||||
|
||||
```python
|
||||
# Automatically routed based on complexity:
|
||||
# 1-3 -> fast-flow (classification, basic Q&A)
|
||||
# 4-7 -> standard-pro (general assistant, long docs)
|
||||
# 8-10 -> heavy-logic (complex coding, logic, agents)
|
||||
response = client.chat.completions.create(
|
||||
model="dispatcher",
|
||||
messages=[{"role": "user", "content": "Debug this race condition in my Go code."}]
|
||||
)
|
||||
# This goes: dispatcher -> heavy-logic -> deepseek-v4-pro
|
||||
```
|
||||
|
||||
Pre-seeded groups:
|
||||
|
||||
| Group | Level | Strategy | Targets | Primary Use |
|
||||
|-------|-------|----------|---------|-------------|
|
||||
| `fast-flow` | 2 | heuristic | deepseek-v4-flash, gpt-5.4-nano | Classification, JSON, Basic Q&A |
|
||||
| `standard-pro` | 5 | heuristic | gpt-5.4-mini, gemini-3-flash-preview | General Assistant, Long Docs |
|
||||
| `heavy-logic` | 9 | heuristic | grok-4.3, kimi-k2.6, deepseek-v4-pro | Complex Coding, Logic, Agents |
|
||||
| `dispatcher` | - | classifier | fast-flow, standard-pro, heavy-logic | Auto-dispatches by complexity |
|
||||
| `deepseek-auto` | - | heuristic | deepseek-chat, deepseek-reasoner | Legacy provider group |
|
||||
| `openai-auto` | - | heuristic | gpt-4o-mini, gpt-4o | Legacy provider group |
|
||||
| `gemini-auto` | - | heuristic | gemini-2.0-flash, gemini-2.5-pro | Legacy provider group |
|
||||
|
||||
### Image Generation (DALL-E / Imagen)
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1",
|
||||
api_key="YOUR_CLIENT_API_KEY"
|
||||
)
|
||||
|
||||
# DALL-E 3 (OpenAI)
|
||||
resp = client.images.generate(
|
||||
model="dall-e-3",
|
||||
prompt="A cute gopher wearing a top hat",
|
||||
n=1,
|
||||
size="1024x1024"
|
||||
)
|
||||
print(resp.data[0].url)
|
||||
|
||||
# Imagen 3 (Gemini) -- uses same endpoint
|
||||
resp = client.images.generate(
|
||||
model="imagen-3.0-generate-001",
|
||||
prompt="A gopher coding in Go",
|
||||
n=1,
|
||||
size="1024x1024"
|
||||
)
|
||||
print(resp.data[0].url) # Returns data URI (Gemini returns base64)
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
- [x] Database schema & migrations (hardcoded in `db.go`)
|
||||
- [x] Configuration loader (Viper)
|
||||
- [x] Auth Middleware (scoped to `/v1`)
|
||||
- [x] Basic Provider implementations (OpenAI, Gemini, DeepSeek, Grok)
|
||||
- [x] Basic Provider implementations (OpenAI, Gemini, DeepSeek, Grok, Ollama)
|
||||
- [x] Streaming Support (SSE & Gemini custom streaming)
|
||||
- [x] Archive Rust files to `rust` branch
|
||||
- [x] Clean root and set Go version as `main`
|
||||
@@ -15,20 +15,41 @@
|
||||
- [x] Dashboard Analytics & Usage Summary (Fixed SQL robustness)
|
||||
- [x] WebSocket for real-time dashboard updates (Hub with client counting)
|
||||
- [x] Asynchronous Request Logging to SQLite
|
||||
- [x] Update documentation (README, deployment, architecture)
|
||||
- [x] Cost Tracking accuracy (Registry integration with `models.dev`)
|
||||
- [x] Model Listing endpoint (`/v1/models`) with provider filtering
|
||||
- [x] System Metrics endpoint (`/api/system/metrics` using `gopsutil`)
|
||||
- [x] Fixed dashboard 404s and 500s
|
||||
- [x] Model groups with heuristic and classifier routing strategies
|
||||
- [x] Hierarchical routing — groups can target other groups with cycle detection
|
||||
- [x] Classifier bucket mapping via complexity_threshold (1-10 scale -> N targets)
|
||||
- [x] Two-level dispatch — classifier router delegates to tier groups
|
||||
- [x] Model groups exposed in /v1/models endpoint (owned_by: gophergate)
|
||||
- [x] logic_level and primary_use metadata on model groups
|
||||
- [x] Model group CRUD dashboard page
|
||||
- [x] dispatcher, heavy-logic, standard-pro, fast-flow seed groups
|
||||
- [x] Provider selection moved after routing resolution (fixes group routing)
|
||||
- [x] Classifier selector model routed to correct provider (selectProvider)
|
||||
- [x] DeepSeek English system prompt injection (ensureEnglish)
|
||||
- [x] Deploy script (deploy.sh)
|
||||
- [x] Recent Activity pane shows resolved model + group annotation
|
||||
- [x] Model names aligned with models.dev registry
|
||||
|
||||
## Feature Parity Checklist (High Priority)
|
||||
## Planned Resolutions (High Priority)
|
||||
|
||||
### Security Fixes
|
||||
- [x] **Critical:** Fix `AuthMiddleware` to reject invalid tokens instead of falling back to insecure prefix derivation.
|
||||
|
||||
### Feature Parity Checklist (High Priority)
|
||||
|
||||
### OpenAI Provider
|
||||
- [x] Tool Calling
|
||||
- [x] Multimodal (Images) support
|
||||
- [x] Accurate usage parsing (cached & reasoning tokens)
|
||||
- [ ] Reasoning Content (CoT) support for `o1`, `o3` (need to ensure it's parsed in responses)
|
||||
- [ ] Support for `/v1/responses` API (required for some gpt-5/o1 models)
|
||||
### Feature Parity: OpenAI Provider Enhancements
|
||||
- [x] **Reasoning Content (CoT) Support (`o1`/`o3`):**
|
||||
- [x] Infrastructure verified. `reasoning_content` is mapped in request/response structures.
|
||||
- [x] **Support for `/v1/responses` API:**
|
||||
- [x] Implemented new route in `internal/server/server.go`.
|
||||
|
||||
### Gemini Provider
|
||||
- [x] Tool Calling (mapping to Gemini format)
|
||||
@@ -47,9 +68,15 @@
|
||||
- [x] Multimodal support
|
||||
- [x] Accurate usage parsing (via OpenAI helper)
|
||||
|
||||
### Ollama Provider
|
||||
- [x] OpenAI-compatible API integration
|
||||
- [x] Streaming support
|
||||
- [x] Model pattern detection for routing
|
||||
- [x] Zero cost calculation (local/free models)
|
||||
|
||||
## Infrastructure & Middleware
|
||||
- [ ] Implement Rate Limiting (`golang.org/x/time/rate`)
|
||||
- [ ] Implement Circuit Breaker (`github.com/sony/gobreaker`)
|
||||
- [x] Implement Circuit Breaker (`github.com/sony/gobreaker`)
|
||||
|
||||
## Verification
|
||||
- [ ] Unit tests for feature-specific mapping (CoT, Tools, Images)
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
type MyNullTime struct {
|
||||
Time interface{}
|
||||
Type string
|
||||
}
|
||||
|
||||
func (n *MyNullTime) Scan(value interface{}) error {
|
||||
n.Time = value
|
||||
n.Type = fmt.Sprintf("%T", value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
db, err := sqlx.Connect("sqlite", "/home/newkirk/Documents/projects/web_projects/gophergate/data/backups/llm_proxy.db.20260303T205057Z")
|
||||
if err != nil {
|
||||
fmt.Println("connect err:", err)
|
||||
return
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Test 1: Direct column scan type
|
||||
var d MyNullTime
|
||||
db.Get(&d, "SELECT last_used_at FROM client_tokens WHERE client_id = ? LIMIT 1", "sk-opencode")
|
||||
fmt.Printf("direct SELECT: GoType=%s value=%v\n", d.Type, d.Time)
|
||||
|
||||
// Test 2: MAX aggregate scan type
|
||||
var m MyNullTime
|
||||
db.Get(&m, "SELECT MAX(last_used_at) FROM client_tokens WHERE client_id = ?", "sk-opencode")
|
||||
fmt.Printf("MAX SELECT: GoType=%s value=%v\n", m.Type, m.Time)
|
||||
|
||||
// Test 3: peek at the raw driver types
|
||||
row := db.QueryRow("SELECT last_used_at, MAX(last_used_at) FROM client_tokens WHERE client_id = ? LIMIT 1", "sk-opencode")
|
||||
var a, b interface{}
|
||||
row.Scan(&a, &b)
|
||||
fmt.Printf("\nRaw Scan:\n")
|
||||
fmt.Printf(" last_used_at: type=%T val=%v\n", a, a)
|
||||
fmt.Printf(" MAX(last_used_at): type=%T val=%v\n", b, b)
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Define the service name/path for easy updates
|
||||
BINARY_NAME="gophergate"
|
||||
SOURCE_PATH="./cmd/gophergate/main.go"
|
||||
|
||||
echo "Stopping existing $BINARY_NAME processes..."
|
||||
# Using pkill; || true ensures the script continues even if no process was found
|
||||
pkill -9 "$BINARY_NAME" || echo "No running process found."
|
||||
|
||||
echo "Pulling latest changes from git..."
|
||||
git pull
|
||||
|
||||
echo "Building the application..."
|
||||
if go build -o "$BINARY_NAME" "$SOURCE_PATH"; then
|
||||
echo "Build successful. Starting $BINARY_NAME in the background..."
|
||||
# Launch with nohup and redirect output to a log file
|
||||
nohup "./$BINARY_NAME" > gophergate.log 2>&1 &
|
||||
echo "Service started. PID: $!"
|
||||
else
|
||||
echo "Build failed! Keeping the previous state."
|
||||
exit 1
|
||||
fi
|
||||
@@ -26,6 +26,22 @@ go build -o gophergate ./cmd/gophergate
|
||||
./gophergate
|
||||
```
|
||||
|
||||
### Quick Deploy Script
|
||||
|
||||
A `deploy.sh` script is provided for production restarts:
|
||||
|
||||
```bash
|
||||
./deploy.sh
|
||||
```
|
||||
|
||||
This script will:
|
||||
1. Stop any running gophergate process
|
||||
2. Pull latest changes from git
|
||||
3. Build the application
|
||||
4. Start it in the background (logs to `gophergate.log`)
|
||||
|
||||
If the build fails, the previous binary is left untouched and the script exits.
|
||||
|
||||
## Docker Deployment
|
||||
|
||||
The project includes a multi-stage `Dockerfile` for minimal image size.
|
||||
@@ -50,3 +66,4 @@ docker run -d \
|
||||
- **SSL/TLS:** It is recommended to run the proxy behind a reverse proxy like Nginx or Caddy for SSL termination.
|
||||
- **Backups:** Regularly backup the `data/llm_proxy.db` file.
|
||||
- **Monitoring:** Monitor the `/health` endpoint for system status.
|
||||
- **Logs:** When started with `deploy.sh` or `nohup`, logs are written to `gophergate.log`.
|
||||
|
||||
@@ -10,6 +10,7 @@ require (
|
||||
github.com/jmoiron/sqlx v1.4.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/shirou/gopsutil/v3 v3.24.5
|
||||
github.com/sony/gobreaker v1.0.0
|
||||
github.com/spf13/viper v1.21.0
|
||||
golang.org/x/crypto v0.48.0
|
||||
modernc.org/sqlite v1.47.0
|
||||
|
||||
@@ -106,6 +106,8 @@ github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFt
|
||||
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
|
||||
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
|
||||
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
|
||||
github.com/sony/gobreaker v1.0.0 h1:feX5fGGXSl3dYd4aHZItw+FpHLvvoaqkawKjVNiFMNQ=
|
||||
github.com/sony/gobreaker v1.0.0/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
|
||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||
|
||||
+39
-12
@@ -11,17 +11,18 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Providers ProviderConfig `mapstructure:"providers"`
|
||||
EncryptionKey string `mapstructure:"encryption_key"`
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Providers ProviderConfig `mapstructure:"providers"`
|
||||
EncryptionKey string `mapstructure:"encryption_key"`
|
||||
KeyBytes []byte
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Port int `mapstructure:"port"`
|
||||
Host string `mapstructure:"host"`
|
||||
AuthTokens []string `mapstructure:"auth_tokens"`
|
||||
Port int `mapstructure:"port"`
|
||||
Host string `mapstructure:"host"`
|
||||
AuthTokens []string `mapstructure:"auth_tokens"`
|
||||
WSAllowedOrigin string `mapstructure:"ws_allowed_origin"`
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
@@ -36,6 +37,7 @@ type ProviderConfig struct {
|
||||
Moonshot MoonshotConfig `mapstructure:"moonshot"`
|
||||
Grok GrokConfig `mapstructure:"grok"`
|
||||
Ollama OllamaConfig `mapstructure:"ollama"`
|
||||
Xiaomi XiaomiConfig `mapstructure:"xiaomi"`
|
||||
}
|
||||
|
||||
type OpenAIConfig struct {
|
||||
@@ -80,6 +82,13 @@ type OllamaConfig struct {
|
||||
Models []string `mapstructure:"models"`
|
||||
}
|
||||
|
||||
type XiaomiConfig struct {
|
||||
APIKeyEnv string `mapstructure:"api_key_env"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
DefaultModel string `mapstructure:"default_model"`
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
v := viper.New()
|
||||
|
||||
@@ -119,6 +128,11 @@ func Load() (*Config, error) {
|
||||
v.SetDefault("providers.ollama.enabled", false)
|
||||
v.SetDefault("providers.ollama.models", []string{})
|
||||
|
||||
v.SetDefault("providers.xiaomi.api_key_env", "XIAOMI_API_KEY")
|
||||
v.SetDefault("providers.xiaomi.base_url", "https://api.xiaomimimo.com/v1")
|
||||
v.SetDefault("providers.xiaomi.default_model", "mimo-v2.5")
|
||||
v.SetDefault("providers.xiaomi.enabled", true)
|
||||
|
||||
// Environment variables
|
||||
v.SetEnvPrefix("LLM_PROXY")
|
||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "__"))
|
||||
@@ -128,6 +142,9 @@ func Load() (*Config, error) {
|
||||
v.BindEnv("encryption_key", "LLM_PROXY__ENCRYPTION_KEY")
|
||||
v.BindEnv("server.port", "LLM_PROXY__SERVER__PORT")
|
||||
v.BindEnv("server.host", "LLM_PROXY__SERVER__HOST")
|
||||
v.BindEnv("providers.ollama.enabled", "LLM_PROXY__PROVIDERS__OLLAMA__ENABLED")
|
||||
v.BindEnv("providers.ollama.base_url", "LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL")
|
||||
v.BindEnv("providers.ollama.models", "LLM_PROXY__PROVIDERS__OLLAMA__MODELS")
|
||||
|
||||
// Config file
|
||||
v.SetConfigName("config")
|
||||
@@ -148,17 +165,25 @@ func Load() (*Config, error) {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Debug Config: port from viper=%d, host from viper=%s\n", cfg.Server.Port, cfg.Server.Host)
|
||||
fmt.Printf("Debug Env: LLM_PROXY__SERVER__PORT=%s, LLM_PROXY__SERVER__HOST=%s\n", os.Getenv("LLM_PROXY__SERVER__PORT"), os.Getenv("LLM_PROXY__SERVER__HOST"))
|
||||
|
||||
// Manual overrides for nested keys which Viper doesn't always bind correctly with AutomaticEnv + SetEnvPrefix
|
||||
if port := os.Getenv("LLM_PROXY__SERVER__PORT"); port != "" {
|
||||
fmt.Sscanf(port, "%d", &cfg.Server.Port)
|
||||
fmt.Printf("Overriding port to %d from env\n", cfg.Server.Port)
|
||||
|
||||
}
|
||||
if host := os.Getenv("LLM_PROXY__SERVER__HOST"); host != "" {
|
||||
cfg.Server.Host = host
|
||||
fmt.Printf("Overriding host to %s from env\n", cfg.Server.Host)
|
||||
|
||||
}
|
||||
|
||||
// Ollama overrides
|
||||
if enabled := os.Getenv("LLM_PROXY__PROVIDERS__OLLAMA__ENABLED"); enabled != "" {
|
||||
cfg.Providers.Ollama.Enabled = enabled == "true"
|
||||
}
|
||||
if baseURL := os.Getenv("LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL"); baseURL != "" {
|
||||
cfg.Providers.Ollama.BaseURL = baseURL
|
||||
}
|
||||
if models := os.Getenv("LLM_PROXY__PROVIDERS__OLLAMA__MODELS"); models != "" {
|
||||
cfg.Providers.Ollama.Models = strings.Split(models, ",")
|
||||
}
|
||||
|
||||
// Validate encryption key
|
||||
@@ -198,6 +223,8 @@ func (c *Config) GetAPIKey(provider string) (string, error) {
|
||||
case "ollama":
|
||||
// Ollama doesn't require an API key
|
||||
return "", nil
|
||||
case "xiaomi":
|
||||
envVar = c.Providers.Xiaomi.APIKeyEnv
|
||||
default:
|
||||
return "", fmt.Errorf("unknown provider: %s", provider)
|
||||
}
|
||||
|
||||
@@ -32,6 +32,14 @@ func Init(path string) (*DB, error) {
|
||||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
// Enable Write-Ahead Logging (WAL) and set a busy timeout to handle concurrent access
|
||||
if _, err := db.Exec("PRAGMA journal_mode=WAL;"); err != nil {
|
||||
log.Printf("failed to enable WAL mode: %v", err)
|
||||
}
|
||||
if _, err := db.Exec("PRAGMA busy_timeout=5000;"); err != nil {
|
||||
log.Printf("failed to set busy timeout: %v", err)
|
||||
}
|
||||
|
||||
instance := &DB{db}
|
||||
|
||||
// Run migrations
|
||||
@@ -122,6 +130,18 @@ func (db *DB) RunMigrations() error {
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
last_used_at DATETIME,
|
||||
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE CASCADE
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS model_groups (
|
||||
id TEXT PRIMARY KEY,
|
||||
strategy TEXT NOT NULL DEFAULT 'heuristic',
|
||||
selector_model TEXT,
|
||||
targets TEXT NOT NULL DEFAULT '[]',
|
||||
complexity_threshold INTEGER,
|
||||
heuristic_rules TEXT,
|
||||
logic_level INTEGER,
|
||||
primary_use TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)`,
|
||||
}
|
||||
|
||||
@@ -152,6 +172,10 @@ func (db *DB) RunMigrations() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Add columns to existing model_groups tables (safe — SQLite ignores duplicates on error)
|
||||
db.Exec("ALTER TABLE model_groups ADD COLUMN logic_level INTEGER")
|
||||
db.Exec("ALTER TABLE model_groups ADD COLUMN primary_use TEXT")
|
||||
|
||||
// Default admin user
|
||||
var count int
|
||||
if err := db.Get(&count, "SELECT COUNT(*) FROM users"); err != nil {
|
||||
@@ -177,6 +201,25 @@ func (db *DB) RunMigrations() error {
|
||||
return fmt.Errorf("failed to insert default client: %w", err)
|
||||
}
|
||||
|
||||
// Seed default model groups
|
||||
defaultGroups := []struct {
|
||||
id, strategy, targets, selectorModel string
|
||||
complexityThreshold, logicLevel *int
|
||||
primaryUse *string
|
||||
}{
|
||||
{"deepseek-auto", "heuristic", `["deepseek-chat","deepseek-reasoner"]`, "", nil, nil, nil},
|
||||
{"openai-auto", "heuristic", `["gpt-4o-mini","gpt-4o"]`, "", nil, nil, nil},
|
||||
{"gemini-auto", "heuristic", `["gemini-2.0-flash","gemini-2.5-pro"]`, "", nil, nil, nil},
|
||||
{"heavy-logic", "heuristic", `["grok-4.3","kimi-k2.6","deepseek-v4-pro"]`, "", nil, intPtr(9), strPtr("Complex Coding, Logic, Agents.")},
|
||||
{"standard-pro", "heuristic", `["gpt-5.4-mini","gemini-3-flash-preview"]`, "", nil, intPtr(5), strPtr("General Assistant, Long Docs.")},
|
||||
{"fast-flow", "heuristic", `["deepseek-v4-flash","gpt-5.4-nano"]`, "", nil, intPtr(2), strPtr("Classification, JSON, Basic Q&A.")},
|
||||
{"dispatcher", "classifier", `["fast-flow","standard-pro","heavy-logic"]`, "gpt-5.4-nano", intPtr(10), nil, strPtr("Auto-dispatches to tier groups by complexity.")},
|
||||
}
|
||||
for _, g := range defaultGroups {
|
||||
db.Exec(`INSERT OR IGNORE INTO model_groups (id, strategy, targets, selector_model, complexity_threshold, logic_level, primary_use) VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
||||
g.id, g.strategy, g.targets, nilStr(g.selectorModel), g.complexityThreshold, g.logicLevel, g.primaryUse)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -262,3 +305,27 @@ type ClientToken struct {
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
LastUsedAt *time.Time `db:"last_used_at"`
|
||||
}
|
||||
|
||||
type ModelGroup struct {
|
||||
ID string `db:"id" json:"id"`
|
||||
Strategy string `db:"strategy" json:"strategy"`
|
||||
SelectorModel *string `db:"selector_model" json:"selector_model"`
|
||||
Targets string `db:"targets" json:"targets"` // JSON array
|
||||
ComplexityThreshold *int `db:"complexity_threshold" json:"complexity_threshold"`
|
||||
HeuristicRules *string `db:"heuristic_rules" json:"heuristic_rules"`
|
||||
LogicLevel *int `db:"logic_level" json:"logic_level"`
|
||||
PrimaryUse *string `db:"primary_use" json:"primary_use"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
func intPtr(v int) *int { return &v }
|
||||
func strPtr(v string) *string { return &v }
|
||||
|
||||
// nilStr returns a *string for non-empty strings, nil for empty.
|
||||
func nilStr(v string) *string {
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
return &v
|
||||
}
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var level = slog.LevelInfo
|
||||
|
||||
func init() {
|
||||
env := os.Getenv("LLM_PROXY_LOG_LEVEL")
|
||||
switch strings.ToLower(env) {
|
||||
case "debug":
|
||||
level = slog.LevelDebug
|
||||
case "warn":
|
||||
level = slog.LevelWarn
|
||||
case "error":
|
||||
level = slog.LevelError
|
||||
}
|
||||
|
||||
h := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
})
|
||||
slog.SetDefault(slog.New(h))
|
||||
}
|
||||
|
||||
// Warn is a helper to emit structured warnings.
|
||||
func Warn(msg string, args ...any) {
|
||||
slog.Warn(msg, args...)
|
||||
}
|
||||
|
||||
// Error is a helper to emit structured errors.
|
||||
func Error(msg string, args ...any) {
|
||||
slog.Error(msg, args...)
|
||||
}
|
||||
|
||||
// Debug is a helper to emit structured debug messages.
|
||||
func Debug(msg string, args ...any) {
|
||||
slog.Debug(msg, args...)
|
||||
}
|
||||
|
||||
// Ctx wraps slog with context.
|
||||
func Ctx(ctx context.Context) *slog.Logger {
|
||||
return slog.Default()
|
||||
}
|
||||
+52
-18
@@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
@@ -10,43 +11,76 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func AuthMiddleware(database *db.DB) gin.HandlerFunc {
|
||||
func AuthMiddleware(database *db.DB, requireAuth bool) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
// Fallback to checking "Authentication" header in case the client library used the wrong name
|
||||
authHeader = c.GetHeader("Authentication")
|
||||
}
|
||||
|
||||
if authHeader == "" {
|
||||
if requireAuth {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Missing Authorization or Authentication header.",
|
||||
"type": "invalid_request_error",
|
||||
"param": nil,
|
||||
"code": "401",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if token == authHeader { // No "Bearer " prefix
|
||||
if requireAuth {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Invalid authorization header format. Bearer token required.",
|
||||
"type": "invalid_request_error",
|
||||
"param": nil,
|
||||
"code": "401",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Try to resolve client from database
|
||||
// Try to resolve client from database with a read-only SELECT
|
||||
var clientID string
|
||||
err := database.Get(&clientID, "UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = 1 RETURNING client_id", token)
|
||||
|
||||
err := database.Get(&clientID, "SELECT client_id FROM client_tokens WHERE token = ? AND is_active = 1", token)
|
||||
|
||||
if err == nil {
|
||||
c.Set("auth", models.AuthInfo{
|
||||
Token: token,
|
||||
ClientID: clientID,
|
||||
})
|
||||
} else {
|
||||
// Fallback to token-prefix derivation (matches Rust behavior)
|
||||
prefixLen := len(token)
|
||||
if prefixLen > 8 {
|
||||
prefixLen = 8
|
||||
}
|
||||
clientID = "client_" + token[:prefixLen]
|
||||
c.Set("auth", models.AuthInfo{
|
||||
Token: token,
|
||||
ClientID: clientID,
|
||||
})
|
||||
log.Printf("Token not found in DB, using fallback client ID: %s", clientID)
|
||||
}
|
||||
|
||||
// Update last_used_at asynchronously so that database locks or write delays
|
||||
// do not block or fail the client's request authentication.
|
||||
go func(t string) {
|
||||
if _, updateErr := database.Exec("UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ?", t); updateErr != nil {
|
||||
log.Printf("Warning: failed to update client token last_used_at: %v", updateErr)
|
||||
}
|
||||
}(token)
|
||||
|
||||
c.Next()
|
||||
c.Next()
|
||||
} else {
|
||||
log.Printf("Token not found, inactive or error in DB: %s (err: %v)", token, err)
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": gin.H{
|
||||
"message": "Invalid or inactive client token.",
|
||||
"type": "invalid_request_error",
|
||||
"param": nil,
|
||||
"code": "401",
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,12 +26,12 @@ type ChatCompletionRequest struct {
|
||||
}
|
||||
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"` // "system", "user", "assistant", "tool"
|
||||
Content interface{} `json:"content"`
|
||||
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
ToolCallID *string `json:"tool_call_id,omitempty"`
|
||||
Role string `json:"role"` // "system", "user", "assistant", "tool"
|
||||
Content interface{} `json:"content"`
|
||||
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
ToolCallID *string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type ContentPart struct {
|
||||
@@ -53,9 +53,9 @@ type Tool struct {
|
||||
}
|
||||
|
||||
type FunctionDef struct {
|
||||
Name string `json:"name"`
|
||||
Name string `json:"name"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
@@ -116,6 +116,7 @@ type ChatCompletionStreamResponse struct {
|
||||
Model string `json:"model"`
|
||||
Choices []ChatStreamChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
Error *string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type ChatStreamChoice struct {
|
||||
@@ -209,6 +210,30 @@ func (i *ImageInput) ToBase64() (string, string, error) {
|
||||
return "", "", fmt.Errorf("empty image input")
|
||||
}
|
||||
|
||||
// Image Generation (DALL-E, Imagen)
|
||||
|
||||
type ImageGenerationRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N *uint32 `json:"n,omitempty"`
|
||||
Quality *string `json:"quality,omitempty"`
|
||||
ResponseFormat *string `json:"response_format,omitempty"`
|
||||
Size *string `json:"size,omitempty"`
|
||||
Style *string `json:"style,omitempty"`
|
||||
User *string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
type ImageGenerationResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []ImageData `json:"data"`
|
||||
}
|
||||
|
||||
type ImageData struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
B64JSON string `json:"b64_json,omitempty"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
}
|
||||
|
||||
// AuthInfo for context
|
||||
type AuthInfo struct {
|
||||
Token string
|
||||
|
||||
+159
-17
@@ -2,6 +2,25 @@ package models
|
||||
|
||||
import "strings"
|
||||
|
||||
// CanonicalProviders lists the original model creators in priority order.
|
||||
// When a model name exists in multiple providers (e.g. deepseek-v4-pro in
|
||||
// deepseek, ollama-cloud, openrouter, etc.), these providers take precedence
|
||||
// so the proxy uses authoritative metadata (pricing, limits) rather than a
|
||||
// reseller's values.
|
||||
var CanonicalProviders = []string{
|
||||
"openai",
|
||||
"google",
|
||||
"deepseek",
|
||||
"xai",
|
||||
"moonshotai",
|
||||
"moonshotai-cn",
|
||||
"anthropic",
|
||||
"mistral",
|
||||
"cohere",
|
||||
"minimax",
|
||||
"xiaomi",
|
||||
}
|
||||
|
||||
type ModelRegistry struct {
|
||||
Providers map[string]ProviderInfo `json:"-"`
|
||||
}
|
||||
@@ -39,31 +58,154 @@ type ModelModalities struct {
|
||||
Output []string `json:"output"`
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) FindModel(modelID string) *ModelMetadata {
|
||||
// First try exact match in models map
|
||||
for _, provider := range r.Providers {
|
||||
if model, ok := provider.Models[modelID]; ok {
|
||||
return &model
|
||||
}
|
||||
}
|
||||
|
||||
// Try searching by ID in metadata
|
||||
for _, provider := range r.Providers {
|
||||
for _, model := range provider.Models {
|
||||
if model.ID == modelID {
|
||||
return &model
|
||||
// findInCanonical searches the canonical providers in order for an exact model
|
||||
// key match. Returns the metadata and true if found.
|
||||
func (r *ModelRegistry) findInCanonical(modelID string) (*ModelMetadata, bool) {
|
||||
for _, key := range CanonicalProviders {
|
||||
if p, ok := r.Providers[key]; ok {
|
||||
if m, ok := p.Models[modelID]; ok {
|
||||
return &m, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Try fuzzy matching (e.g. gpt-4o-2024-05-13 matching gpt-4o)
|
||||
for _, provider := range r.Providers {
|
||||
for id, model := range provider.Models {
|
||||
// findInAll searches all providers (map iteration, random order) for an exact
|
||||
// model key match. Used as fallback when canonical search fails.
|
||||
func (r *ModelRegistry) findInAll(modelID string) (*ModelMetadata, bool) {
|
||||
for _, p := range r.Providers {
|
||||
if m, ok := p.Models[modelID]; ok {
|
||||
return &m, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findInCanonicalByID searches canonical providers for a model whose metadata
|
||||
// ID field matches modelID.
|
||||
func (r *ModelRegistry) findInCanonicalByID(modelID string) (*ModelMetadata, bool) {
|
||||
for _, key := range CanonicalProviders {
|
||||
if p, ok := r.Providers[key]; ok {
|
||||
for _, m := range p.Models {
|
||||
if m.ID == modelID {
|
||||
return &m, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findInAllByID searches all providers for a model whose metadata ID field
|
||||
// matches modelID.
|
||||
func (r *ModelRegistry) findInAllByID(modelID string) (*ModelMetadata, bool) {
|
||||
for _, p := range r.Providers {
|
||||
for _, m := range p.Models {
|
||||
if m.ID == modelID {
|
||||
return &m, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findCanonicalReverseFuzzy searches canonical providers for any model whose
|
||||
// key starts with modelID.
|
||||
func (r *ModelRegistry) findCanonicalReverseFuzzy(modelID string) (*ModelMetadata, bool) {
|
||||
for _, key := range CanonicalProviders {
|
||||
if p, ok := r.Providers[key]; ok {
|
||||
for id, m := range p.Models {
|
||||
if strings.HasPrefix(id, modelID) {
|
||||
return &m, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findAllReverseFuzzy searches all providers for any model whose key starts
|
||||
// with modelID.
|
||||
func (r *ModelRegistry) findAllReverseFuzzy(modelID string) (*ModelMetadata, bool) {
|
||||
for _, p := range r.Providers {
|
||||
for id, m := range p.Models {
|
||||
if strings.HasPrefix(id, modelID) {
|
||||
return &m, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findCanonicalForwardFuzzy searches canonical providers for any model whose
|
||||
// key is a prefix of modelID.
|
||||
func (r *ModelRegistry) findCanonicalForwardFuzzy(modelID string) (*ModelMetadata, bool) {
|
||||
for _, key := range CanonicalProviders {
|
||||
if p, ok := r.Providers[key]; ok {
|
||||
for id, m := range p.Models {
|
||||
if strings.HasPrefix(modelID, id) {
|
||||
return &m, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findAllForwardFuzzy searches all providers for any model whose key is a
|
||||
// prefix of modelID.
|
||||
func (r *ModelRegistry) findAllForwardFuzzy(modelID string) (*ModelMetadata, bool) {
|
||||
for _, p := range r.Providers {
|
||||
for id, m := range p.Models {
|
||||
if strings.HasPrefix(modelID, id) {
|
||||
return &model
|
||||
return &m, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// FindModel looks up model metadata by ID. It searches canonical providers
|
||||
// first at each strategy level (exact key, metadata ID, reverse fuzzy,
|
||||
// forward fuzzy) and falls back to all providers only when canonical search
|
||||
// yields no result. This prevents reseller entries (ollama-cloud, openrouter,
|
||||
// etc.) from overriding the original provider's authoritative pricing and
|
||||
// limits.
|
||||
func (r *ModelRegistry) FindModel(modelID string) *ModelMetadata {
|
||||
// 1. Exact key match — canonical first, then all
|
||||
if m, ok := r.findInCanonical(modelID); ok {
|
||||
return m
|
||||
}
|
||||
if m, ok := r.findInAll(modelID); ok {
|
||||
return m
|
||||
}
|
||||
|
||||
// 2. Match by metadata ID field — canonical first, then all
|
||||
if m, ok := r.findInCanonicalByID(modelID); ok {
|
||||
return m
|
||||
}
|
||||
if m, ok := r.findInAllByID(modelID); ok {
|
||||
return m
|
||||
}
|
||||
|
||||
// 3. Reverse fuzzy: model key starts with modelID
|
||||
// e.g. 'gpt-5.4-mini' matching 'gpt-5.4-mini-2026-04-01'
|
||||
if m, ok := r.findCanonicalReverseFuzzy(modelID); ok {
|
||||
return m
|
||||
}
|
||||
if m, ok := r.findAllReverseFuzzy(modelID); ok {
|
||||
return m
|
||||
}
|
||||
|
||||
// 4. Forward fuzzy: modelID starts with model key
|
||||
// e.g. 'gpt-4o-2024-05-13' matching 'gpt-4o'
|
||||
if m, ok := r.findCanonicalForwardFuzzy(modelID); ok {
|
||||
return m
|
||||
}
|
||||
if m, ok := r.findAllForwardFuzzy(modelID); ok {
|
||||
return m
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestModelRegistry_FindModel_Exact(t *testing.T) {
|
||||
r := &ModelRegistry{
|
||||
Providers: map[string]ProviderInfo{
|
||||
"openai": {
|
||||
Models: map[string]ModelMetadata{
|
||||
"gpt-4o": {ID: "gpt-4o", Name: "GPT-4o"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
m := r.FindModel("gpt-4o")
|
||||
if m == nil {
|
||||
t.Fatal("expected to find gpt-4o")
|
||||
}
|
||||
if m.Name != "GPT-4o" {
|
||||
t.Fatalf("expected GPT-4o, got %s", m.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRegistry_FindModel_Fuzzy(t *testing.T) {
|
||||
r := &ModelRegistry{
|
||||
Providers: map[string]ProviderInfo{
|
||||
"openai": {
|
||||
Models: map[string]ModelMetadata{
|
||||
"gpt-4o": {ID: "gpt-4o", Name: "GPT-4o"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
// Fuzzy: "gpt-4o-2024-05-13" should match "gpt-4o"
|
||||
m := r.FindModel("gpt-4o-2024-05-13")
|
||||
if m == nil {
|
||||
t.Fatal("expected fuzzy match")
|
||||
}
|
||||
if m.Name != "GPT-4o" {
|
||||
t.Fatalf("expected GPT-4o, got %s", m.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRegistry_FindModel_NotFound(t *testing.T) {
|
||||
r := &ModelRegistry{
|
||||
Providers: map[string]ProviderInfo{
|
||||
"openai": {
|
||||
Models: map[string]ModelMetadata{
|
||||
"gpt-4o": {ID: "gpt-4o", Name: "GPT-4o"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
m := r.FindModel("nonexistent-model")
|
||||
if m != nil {
|
||||
t.Fatal("expected nil for nonexistent model")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRegistry_FindModel_CanonicalPriority(t *testing.T) {
|
||||
// Same model name in canonical (deepseek) and reseller (ollama-cloud).
|
||||
// Canonical must win so the proxy uses authoritative limits.
|
||||
r := &ModelRegistry{
|
||||
Providers: map[string]ProviderInfo{
|
||||
"ollama-cloud": {
|
||||
Models: map[string]ModelMetadata{
|
||||
"deepseek-v4-pro": {ID: "deepseek-v4-pro", Name: "DSv4 Pro (Ollama Cloud)", Limit: &ModelLimit{Context: 1048576, Output: 1048576}},
|
||||
},
|
||||
},
|
||||
"deepseek": {
|
||||
Models: map[string]ModelMetadata{
|
||||
"deepseek-v4-pro": {ID: "deepseek-v4-pro", Name: "DeepSeek v4 Pro", Limit: &ModelLimit{Context: 1000000, Output: 384000}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
m := r.FindModel("deepseek-v4-pro")
|
||||
if m == nil {
|
||||
t.Fatal("expected to find deepseek-v4-pro")
|
||||
}
|
||||
if m.Name != "DeepSeek v4 Pro" {
|
||||
t.Fatalf("expected DeepSeek v4 Pro (canonical), got %s", m.Name)
|
||||
}
|
||||
if m.Limit.Output != 384000 {
|
||||
t.Fatalf("expected output limit 384000 (canonical), got %d", m.Limit.Output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRegistry_FindModel_ReverseFuzzy(t *testing.T) {
|
||||
r := &ModelRegistry{
|
||||
Providers: map[string]ProviderInfo{
|
||||
"openai": {
|
||||
Models: map[string]ModelMetadata{
|
||||
"gpt-5.4-mini-2026-04-01": {ID: "gpt-5.4-mini-2026-04-01", Name: "GPT-5.4 Mini"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
// Reverse fuzzy: "gpt-5.4-mini" should match "gpt-5.4-mini-2026-04-01"
|
||||
m := r.FindModel("gpt-5.4-mini")
|
||||
if m == nil {
|
||||
t.Fatal("expected reverse fuzzy match")
|
||||
}
|
||||
if m.Name != "GPT-5.4 Mini" {
|
||||
t.Fatalf("expected GPT-5.4 Mini, got %s", m.Name)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
package models
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// Responses API request types
|
||||
|
||||
// ResponsesRequest maps to POST /v1/responses body (OpenAI Responses API format).
|
||||
// The `input` field can be a string or an array of message objects.
|
||||
type ResponsesRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input json.RawMessage `json:"input"` // string or []ResponseInputMessage
|
||||
Instructions string `json:"instructions,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
MaxOutputTokens *uint32 `json:"max_output_tokens,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Store *bool `json:"store,omitempty"`
|
||||
}
|
||||
|
||||
// ResponseInputMessage represents a single message in the input array.
|
||||
type ResponseInputMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"` // string or []ContentPart
|
||||
}
|
||||
|
||||
// Responses API response types
|
||||
|
||||
// ResponsesResponse maps to OpenAI /v1/responses response.
|
||||
type ResponsesResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Model string `json:"model"`
|
||||
Output []ResponsesOutputItem `json:"output"`
|
||||
Usage *ResponsesUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesOutputItem represents an item in the output array.
|
||||
// For messages: type="message", role, content[].
|
||||
// For function calls: type="function_call", id, name, arguments, status.
|
||||
type ResponsesOutputItem struct {
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content []ResponsesOutputContent `json:"content,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesOutputContent represents content parts within an output message.
|
||||
type ResponsesOutputContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Annotations []json.RawMessage `json:"annotations,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesUsage maps to the usage block in Responses API.
|
||||
type ResponsesUsage struct {
|
||||
InputTokens uint32 `json:"input_tokens"`
|
||||
OutputTokens uint32 `json:"output_tokens"`
|
||||
TotalTokens uint32 `json:"total_tokens"`
|
||||
InputTokensDetails *ResponsesInputTokensDetails `json:"input_tokens_details,omitempty"`
|
||||
OutputTokensDetails *ResponsesOutputTokensDetails `json:"output_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesInputTokensDetails maps input token details.
|
||||
type ResponsesInputTokensDetails struct {
|
||||
CachedTokens uint32 `json:"cached_tokens"`
|
||||
}
|
||||
|
||||
// ResponsesOutputTokensDetails maps output token details.
|
||||
type ResponsesOutputTokensDetails struct {
|
||||
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||
}
|
||||
|
||||
// ToUsage converts ResponsesUsage to the unified Usage model.
|
||||
func (u *ResponsesUsage) ToUsage() *Usage {
|
||||
usage := &Usage{
|
||||
PromptTokens: u.InputTokens,
|
||||
CompletionTokens: u.OutputTokens,
|
||||
TotalTokens: u.TotalTokens,
|
||||
}
|
||||
if u.InputTokensDetails != nil && u.InputTokensDetails.CachedTokens > 0 {
|
||||
usage.CacheReadTokens = &u.InputTokensDetails.CachedTokens
|
||||
}
|
||||
if u.OutputTokensDetails != nil && u.OutputTokensDetails.ReasoningTokens > 0 {
|
||||
usage.ReasoningTokens = &u.OutputTokensDetails.ReasoningTokens
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
// ResponsesStreamChunk represents an SSE chunk from the Responses streaming endpoint.
|
||||
type ResponsesStreamChunk struct {
|
||||
Type string `json:"type"`
|
||||
Response *ResponsesStreamPayload `json:"response,omitempty"`
|
||||
Item *ResponsesStreamPayloadItem `json:"item,omitempty"`
|
||||
Delta *ResponsesStreamDelta `json:"delta,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesStreamPayload represents the "response" field in some SSE chunks.
|
||||
type ResponsesStreamPayload struct {
|
||||
Object string `json:"object"`
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Usage *ResponsesUsage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesStreamPayloadItem represents the "item" field in SSE chunks.
|
||||
type ResponsesStreamPayloadItem struct {
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content []ResponsesOutputContent `json:"content,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
// ResponsesStreamDelta represents a content delta in streaming.
|
||||
type ResponsesStreamDelta struct {
|
||||
ContentIndex int `json:"content_index"`
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
// UnifiedResponsesRequest is the internal unified format for Responses API.
|
||||
type UnifiedResponsesRequest struct {
|
||||
ClientID string
|
||||
Model string
|
||||
Input string // normalized input text
|
||||
InputMessages []ResponseInputMessage // structured input messages (if provided as array)
|
||||
Instructions string
|
||||
Temperature *float64
|
||||
MaxOutputTokens *uint32
|
||||
TopP *float64
|
||||
Stream bool
|
||||
Tools json.RawMessage
|
||||
ToolChoice json.RawMessage
|
||||
Store bool
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/sony/gobreaker"
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
type CircuitBreakerProvider struct {
|
||||
provider Provider
|
||||
cb *gobreaker.CircuitBreaker
|
||||
}
|
||||
|
||||
func NewCircuitBreakerProvider(p Provider) Provider {
|
||||
name := p.Name()
|
||||
var maxRequests uint32 = 5
|
||||
var interval = 60 * time.Second
|
||||
var timeout = 5 * time.Minute
|
||||
|
||||
settings := gobreaker.Settings{
|
||||
Name: name,
|
||||
MaxRequests: maxRequests,
|
||||
Interval: interval,
|
||||
Timeout: timeout,
|
||||
ReadyToTrip: func(counts gobreaker.Counts) bool {
|
||||
// Trip after 3 consecutive failures
|
||||
return counts.ConsecutiveFailures > 3
|
||||
},
|
||||
}
|
||||
return &CircuitBreakerProvider{
|
||||
provider: p,
|
||||
cb: gobreaker.NewCircuitBreaker(settings),
|
||||
}
|
||||
}
|
||||
|
||||
func (cbp *CircuitBreakerProvider) Name() string {
|
||||
return cbp.provider.Name()
|
||||
}
|
||||
|
||||
func (cbp *CircuitBreakerProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
result, err := cbp.cb.Execute(func() (interface{}, error) {
|
||||
return cbp.provider.ChatCompletion(ctx, req)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.(*models.ChatCompletionResponse), nil
|
||||
}
|
||||
|
||||
func (cbp *CircuitBreakerProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
// Circuit breaker for streaming is tricky. We'll just call the provider directly.
|
||||
// Future: Implement a way to track stream failures in the circuit breaker.
|
||||
return cbp.provider.ChatCompletionStream(ctx, req)
|
||||
}
|
||||
|
||||
func (cbp *CircuitBreakerProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
|
||||
result, err := cbp.cb.Execute(func() (interface{}, error) {
|
||||
return cbp.provider.ImageGeneration(ctx, req)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.(*models.ImageGenerationResponse), nil
|
||||
}
|
||||
|
||||
func (cbp *CircuitBreakerProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
|
||||
result, err := cbp.cb.Execute(func() (interface{}, error) {
|
||||
return cbp.provider.Responses(ctx, req)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.(*models.ResponsesResponse), nil
|
||||
}
|
||||
|
||||
func (cbp *CircuitBreakerProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
|
||||
// Circuit breaker passthrough for streaming (same pattern as ChatCompletionStream)
|
||||
return cbp.provider.ResponsesStream(ctx, req)
|
||||
}
|
||||
@@ -7,10 +7,11 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type DeepSeekProvider struct {
|
||||
@@ -21,7 +22,7 @@ type DeepSeekProvider struct {
|
||||
|
||||
func NewDeepSeekProvider(cfg config.DeepSeekConfig, apiKey string) *DeepSeekProvider {
|
||||
return &DeepSeekProvider{
|
||||
client: resty.New(),
|
||||
client: resty.New().SetTimeout(10 * time.Minute),
|
||||
config: cfg,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
@@ -32,11 +33,11 @@ func (p *DeepSeekProvider) Name() string {
|
||||
}
|
||||
|
||||
type deepSeekUsage struct {
|
||||
PromptTokens uint32 `json:"prompt_tokens"`
|
||||
CompletionTokens uint32 `json:"completion_tokens"`
|
||||
TotalTokens uint32 `json:"total_tokens"`
|
||||
PromptCacheHitTokens uint32 `json:"prompt_cache_hit_tokens"`
|
||||
PromptCacheMissTokens uint32 `json:"prompt_cache_miss_tokens"`
|
||||
PromptTokens uint32 `json:"prompt_tokens"`
|
||||
CompletionTokens uint32 `json:"completion_tokens"`
|
||||
TotalTokens uint32 `json:"total_tokens"`
|
||||
PromptCacheHitTokens uint32 `json:"prompt_cache_hit_tokens"`
|
||||
PromptCacheMissTokens uint32 `json:"prompt_cache_miss_tokens"`
|
||||
CompletionTokensDetails *struct {
|
||||
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||
} `json:"completion_tokens_details"`
|
||||
@@ -61,6 +62,9 @@ func (u *deepSeekUsage) ToUnified() *models.Usage {
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
// Ensure English responses — DeepSeek defaults to Chinese for some prompts
|
||||
ensureEnglish(req)
|
||||
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
@@ -68,19 +72,26 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||
|
||||
// Sanitize for deepseek-reasoner
|
||||
if req.Model == "deepseek-reasoner" {
|
||||
delete(body, "temperature")
|
||||
delete(body, "top_p")
|
||||
delete(body, "presence_penalty")
|
||||
delete(body, "frequency_penalty")
|
||||
|
||||
// Sanitize for models that support reasoning/thinking mode
|
||||
isReasoner := strings.Contains(req.Model, "reasoner") || strings.Contains(req.Model, "v4") || strings.Contains(req.Model, "r1")
|
||||
|
||||
if isReasoner {
|
||||
// deepseek-reasoner (R1) does not support these parameters
|
||||
if req.Model == "deepseek-reasoner" || strings.HasPrefix(req.Model, "deepseek-r1") {
|
||||
delete(body, "temperature")
|
||||
delete(body, "top_p")
|
||||
delete(body, "presence_penalty")
|
||||
delete(body, "frequency_penalty")
|
||||
}
|
||||
|
||||
if msgs, ok := body["messages"].([]interface{}); ok {
|
||||
for _, m := range msgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
if msg["role"] == "assistant" {
|
||||
// DeepSeek requires reasoning_content to be passed back in history
|
||||
// if the model is in thinking mode.
|
||||
if msg["reasoning_content"] == nil {
|
||||
msg["reasoning_content"] = " "
|
||||
msg["reasoning_content"] = ""
|
||||
}
|
||||
if msg["content"] == nil || msg["content"] == "" {
|
||||
msg["content"] = ""
|
||||
@@ -102,7 +113,15 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
var msg string
|
||||
if resp.RawBody() != nil {
|
||||
bodyBytes, _ := io.ReadAll(resp.RawBody())
|
||||
msg = string(bodyBytes)
|
||||
}
|
||||
if msg == "" {
|
||||
msg = resp.String()
|
||||
}
|
||||
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
@@ -128,6 +147,8 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
ensureEnglish(req)
|
||||
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
@@ -135,19 +156,26 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||
|
||||
// Sanitize for deepseek-reasoner
|
||||
if req.Model == "deepseek-reasoner" {
|
||||
delete(body, "temperature")
|
||||
delete(body, "top_p")
|
||||
delete(body, "presence_penalty")
|
||||
delete(body, "frequency_penalty")
|
||||
|
||||
// Sanitize for models that support reasoning/thinking mode
|
||||
isReasoner := strings.Contains(req.Model, "reasoner") || strings.Contains(req.Model, "v4") || strings.Contains(req.Model, "r1")
|
||||
|
||||
if isReasoner {
|
||||
// deepseek-reasoner (R1) does not support these parameters
|
||||
if req.Model == "deepseek-reasoner" || strings.HasPrefix(req.Model, "deepseek-r1") {
|
||||
delete(body, "temperature")
|
||||
delete(body, "top_p")
|
||||
delete(body, "presence_penalty")
|
||||
delete(body, "frequency_penalty")
|
||||
}
|
||||
|
||||
if msgs, ok := body["messages"].([]interface{}); ok {
|
||||
for _, m := range msgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
if msg["role"] == "assistant" {
|
||||
// DeepSeek requires reasoning_content to be passed back in history
|
||||
// if the model is in thinking mode.
|
||||
if msg["reasoning_content"] == nil {
|
||||
msg["reasoning_content"] = " "
|
||||
msg["reasoning_content"] = ""
|
||||
}
|
||||
if msg["content"] == nil || msg["content"] == "" {
|
||||
msg["content"] = ""
|
||||
@@ -170,11 +198,19 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
var msg string
|
||||
if resp.RawBody() != nil {
|
||||
bodyBytes, _ := io.ReadAll(resp.RawBody())
|
||||
msg = string(bodyBytes)
|
||||
}
|
||||
if msg == "" {
|
||||
msg = resp.String()
|
||||
}
|
||||
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
// Custom scanner loop to handle DeepSeek specific usage in chunks
|
||||
@@ -218,3 +254,15 @@ func StreamDeepSeek(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRes
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
|
||||
return nil, fmt.Errorf("deepseek does not support image generation")
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by deepseek")
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by deepseek")
|
||||
}
|
||||
|
||||
+485
-95
@@ -4,10 +4,13 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type GeminiProvider struct {
|
||||
@@ -18,7 +21,7 @@ type GeminiProvider struct {
|
||||
|
||||
func NewGeminiProvider(cfg config.GeminiConfig, apiKey string) *GeminiProvider {
|
||||
return &GeminiProvider{
|
||||
client: resty.New(),
|
||||
client: resty.New().SetTimeout(10 * time.Minute),
|
||||
config: cfg,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
@@ -29,7 +32,21 @@ func (p *GeminiProvider) Name() string {
|
||||
}
|
||||
|
||||
type GeminiRequest struct {
|
||||
Contents []GeminiContent `json:"contents"`
|
||||
Contents []GeminiContent `json:"contents"`
|
||||
Tools []GeminiTool `json:"tools,omitempty"`
|
||||
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiTool struct {
|
||||
FunctionDeclarations []models.FunctionDef `json:"functionDeclarations"`
|
||||
}
|
||||
|
||||
type GeminiGenerationConfig struct {
|
||||
Temperature *float32 `json:"temperature,omitempty"`
|
||||
TopP *float32 `json:"topP,omitempty"`
|
||||
TopK *int `json:"topK,omitempty"`
|
||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiContent struct {
|
||||
@@ -38,10 +55,10 @@ type GeminiContent struct {
|
||||
}
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiInlineData struct {
|
||||
@@ -59,77 +76,293 @@ type GeminiFunctionResponse struct {
|
||||
Response json.RawMessage `json:"response"`
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
|
||||
// Gemini Imagen API: POST https://generativelanguage.googleapis.com/v1beta/models/{model}:predict
|
||||
// Map OpenAI-style params to Gemini Imagen params
|
||||
|
||||
n := uint32(1)
|
||||
if req.N != nil && *req.N > 0 {
|
||||
n = *req.N
|
||||
}
|
||||
|
||||
aspectRatio := "1:1"
|
||||
if req.Size != nil {
|
||||
aspectRatio = sizeToGeminiAspectRatio(*req.Size)
|
||||
}
|
||||
|
||||
// Build Imagen request
|
||||
imagenReq := map[string]interface{}{
|
||||
"instances": []map[string]interface{}{
|
||||
{"prompt": req.Prompt},
|
||||
},
|
||||
"parameters": map[string]interface{}{
|
||||
"sampleCount": n,
|
||||
"aspectRatio": aspectRatio,
|
||||
},
|
||||
}
|
||||
|
||||
// Model defaults to imagen-3.0-generate-001 if empty
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = "imagen-3.0-generate-001"
|
||||
}
|
||||
|
||||
// Use v1beta for Imagen
|
||||
baseURL := p.config.BaseURL
|
||||
if !strings.Contains(baseURL, "v1beta") {
|
||||
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/models/%s:predict?key=%s", baseURL, model, p.apiKey)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(imagenReq).
|
||||
Post(url)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gemini imagen request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Gemini Imagen API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
// Parse Imagen response
|
||||
var imagenResp struct {
|
||||
Predictions []struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||
} `json:"predictions"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(resp.Body(), &imagenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse Imagen response: %w", err)
|
||||
}
|
||||
|
||||
respFormat := "url"
|
||||
if req.ResponseFormat != nil && *req.ResponseFormat == "b64_json" {
|
||||
respFormat = "b64_json"
|
||||
}
|
||||
|
||||
var data []models.ImageData
|
||||
for _, pred := range imagenResp.Predictions {
|
||||
imgData := models.ImageData{}
|
||||
if respFormat == "b64_json" {
|
||||
imgData.B64JSON = pred.BytesBase64Encoded
|
||||
} else {
|
||||
// Build a data URI since Gemini returns base64, not a URL
|
||||
mime := pred.MimeType
|
||||
if mime == "" {
|
||||
mime = "image/png"
|
||||
}
|
||||
imgData.URL = fmt.Sprintf("data:%s;base64,%s", mime, pred.BytesBase64Encoded)
|
||||
}
|
||||
data = append(data, imgData)
|
||||
}
|
||||
|
||||
result := &models.ImageGenerationResponse{
|
||||
Created: time.Now().Unix(),
|
||||
Data: data,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// sizeToGeminiAspectRatio converts OpenAI size format (e.g. "1024x1024") to Gemini aspect ratio (e.g. "1:1")
|
||||
func sizeToGeminiAspectRatio(size string) string {
|
||||
switch size {
|
||||
case "1024x1024":
|
||||
return "1:1"
|
||||
case "1024x1792":
|
||||
return "9:16"
|
||||
case "1792x1024":
|
||||
return "16:9"
|
||||
case "256x256", "512x512":
|
||||
return "1:1"
|
||||
default:
|
||||
return "1:1"
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by gemini")
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by gemini")
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
// Gemini mapping
|
||||
var contents []GeminiContent
|
||||
for _, msg := range req.Messages {
|
||||
|
||||
for i := 0; i < len(req.Messages); i++ {
|
||||
msg := req.Messages[i]
|
||||
|
||||
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||
// 1. Add the assistant (model) message with tool calls
|
||||
parts := []GeminiPart{}
|
||||
for _, cp := range msg.Content {
|
||||
if cp.Type == "text" && cp.Text != "" {
|
||||
parts = append(parts, GeminiPart{Text: cp.Text})
|
||||
}
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
parts = append(parts, GeminiPart{
|
||||
FunctionCall: &GeminiFunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Args: json.RawMessage(tc.Function.Arguments),
|
||||
},
|
||||
})
|
||||
}
|
||||
contents = append(contents, GeminiContent{Role: "model", Parts: parts})
|
||||
|
||||
// 2. The VERY NEXT message MUST be the "function" results for THESE EXACT calls.
|
||||
// Look ahead for tool messages.
|
||||
var functionParts []GeminiPart
|
||||
toolCallIDs := make(map[string]bool)
|
||||
for _, tc := range msg.ToolCalls {
|
||||
toolCallIDs[tc.ID] = true
|
||||
}
|
||||
|
||||
// We need to find tool messages that correspond to these calls.
|
||||
// In many patterns, they follow immediately.
|
||||
j := i + 1
|
||||
foundAny := false
|
||||
for j < len(req.Messages) && req.Messages[j].Role == "tool" {
|
||||
m := req.Messages[j]
|
||||
|
||||
// Try to match by ID or just take them in order if IDs are missing/mismatched
|
||||
// Gemini is strict: you must respond to EVERY call in the previous message.
|
||||
text := ""
|
||||
if len(m.Content) > 0 {
|
||||
text = m.Content[0].Text
|
||||
}
|
||||
name := "unknown_function"
|
||||
if m.Name != nil {
|
||||
name = *m.Name
|
||||
}
|
||||
|
||||
var responseObj interface{}
|
||||
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
|
||||
responseObj = map[string]interface{}{"result": text}
|
||||
}
|
||||
respBytes, _ := json.Marshal(responseObj)
|
||||
|
||||
functionParts = append(functionParts, GeminiPart{
|
||||
FunctionResponse: &GeminiFunctionResponse{
|
||||
Name: name,
|
||||
Response: json.RawMessage(respBytes),
|
||||
},
|
||||
})
|
||||
foundAny = true
|
||||
j++
|
||||
}
|
||||
|
||||
if foundAny {
|
||||
contents = append(contents, GeminiContent{Role: "function", Parts: functionParts})
|
||||
i = j - 1 // Advance outer loop past the tool messages we consumed
|
||||
} else {
|
||||
// If no tool results found but assistant made calls, Gemini WILL error.
|
||||
// We should probably skip the calls or provide dummy results,
|
||||
// but usually this means the conversation is incomplete.
|
||||
// For now, don't add a "function" message if none found.
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Standard message handling (System/User/Assistant without tools)
|
||||
role := "user"
|
||||
if msg.Role == "assistant" {
|
||||
role = "model"
|
||||
} else if msg.Role == "system" {
|
||||
role = "user" // Gemini uses 'user' for system prompts in some versions, or handles it via systemInstruction
|
||||
} else if msg.Role == "tool" {
|
||||
role = "user" // Tool results are user-side in Gemini
|
||||
// Orphaned tool message (not following an assistant call) - Gemini doesn't like this.
|
||||
// Skip or map to user? Skipping is safer for API stability.
|
||||
continue
|
||||
}
|
||||
|
||||
var parts []GeminiPart
|
||||
|
||||
// Handle tool responses
|
||||
if msg.Role == "tool" {
|
||||
text := ""
|
||||
if len(msg.Content) > 0 {
|
||||
text = msg.Content[0].Text
|
||||
}
|
||||
|
||||
// Gemini expects functionResponse to be an object
|
||||
name := "unknown_function"
|
||||
if msg.Name != nil {
|
||||
name = *msg.Name
|
||||
}
|
||||
|
||||
parts = append(parts, GeminiPart{
|
||||
FunctionResponse: &GeminiFunctionResponse{
|
||||
Name: name,
|
||||
Response: json.RawMessage(text),
|
||||
},
|
||||
})
|
||||
} else {
|
||||
for _, cp := range msg.Content {
|
||||
if cp.Type == "text" {
|
||||
parts = append(parts, GeminiPart{Text: cp.Text})
|
||||
} else if cp.Image != nil {
|
||||
base64Data, mimeType, _ := cp.Image.ToBase64()
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: base64Data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Handle assistant tool calls
|
||||
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
parts = append(parts, GeminiPart{
|
||||
FunctionCall: &GeminiFunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Args: json.RawMessage(tc.Function.Arguments),
|
||||
},
|
||||
})
|
||||
}
|
||||
var parts []GeminiPart
|
||||
for _, cp := range msg.Content {
|
||||
if cp.Type == "text" && cp.Text != "" {
|
||||
parts = append(parts, GeminiPart{Text: cp.Text})
|
||||
} else if cp.Image != nil {
|
||||
base64Data, mimeType, _ := cp.Image.ToBase64()
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: base64Data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
contents = append(contents, GeminiContent{
|
||||
Role: role,
|
||||
Parts: parts,
|
||||
})
|
||||
|
||||
if len(parts) > 0 {
|
||||
contents = append(contents, GeminiContent{Role: role, Parts: parts})
|
||||
}
|
||||
}
|
||||
|
||||
genConfig := &GeminiGenerationConfig{}
|
||||
if req.Temperature != nil {
|
||||
t := float32(*req.Temperature)
|
||||
genConfig.Temperature = &t
|
||||
}
|
||||
if req.TopP != nil {
|
||||
tp := float32(*req.TopP)
|
||||
genConfig.TopP = &tp
|
||||
}
|
||||
if req.TopK != nil {
|
||||
tk := int(*req.TopK)
|
||||
genConfig.TopK = &tk
|
||||
}
|
||||
if req.MaxTokens != nil {
|
||||
mt := int(*req.MaxTokens)
|
||||
genConfig.MaxOutputTokens = &mt
|
||||
}
|
||||
if len(req.Stop) > 0 {
|
||||
genConfig.StopSequences = req.Stop
|
||||
}
|
||||
|
||||
body := GeminiRequest{
|
||||
Contents: contents,
|
||||
Contents: contents,
|
||||
GenerationConfig: genConfig,
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
|
||||
// Map Tools
|
||||
hasMappedTools := false
|
||||
if len(req.Tools) > 0 {
|
||||
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
|
||||
for _, t := range req.Tools {
|
||||
if t.Type == "function" {
|
||||
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
|
||||
}
|
||||
}
|
||||
if len(geminiTool.FunctionDeclarations) > 0 {
|
||||
body.Tools = []GeminiTool{geminiTool}
|
||||
hasMappedTools = true
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := p.config.BaseURL
|
||||
lowerModel := strings.ToLower(req.Model)
|
||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools {
|
||||
// Use v1beta for preview, newer models, or when using tools
|
||||
if !strings.Contains(baseURL, "v1beta") {
|
||||
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
|
||||
}
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", baseURL, req.Model, p.apiKey)
|
||||
fmt.Printf("[Gemini] POST %s\n", url)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
@@ -141,23 +374,36 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
fmt.Printf("[Gemini] API Error %d: %s\n", resp.StatusCode(), msg)
|
||||
// Also log the request body for debugging (careful with API keys if logged elsewhere)
|
||||
reqJSON, _ := json.Marshal(body)
|
||||
fmt.Printf("[Gemini] Request Body: %s\n", string(reqJSON))
|
||||
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
// Parse Gemini response and convert to OpenAI format
|
||||
var geminiResp struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Role string `json:"role"`
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
Text string `json:"text"`
|
||||
FunctionCall *GeminiFunctionCall `json:"functionCall"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata struct {
|
||||
PromptTokenCount uint32 `json:"promptTokenCount"`
|
||||
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
||||
TotalTokenCount uint32 `json:"totalTokenCount"`
|
||||
PromptTokenCount uint32 `json:"promptTokenCount"`
|
||||
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
||||
TotalTokenCount uint32 `json:"totalTokenCount"`
|
||||
CachedContentTokenCount uint32 `json:"cachedContentTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
@@ -170,29 +416,51 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
}
|
||||
|
||||
content := ""
|
||||
for _, p := range geminiResp.Candidates[0].Content.Parts {
|
||||
content += p.Text
|
||||
var toolCalls []models.ToolCall
|
||||
for _, part := range geminiResp.Candidates[0].Content.Parts {
|
||||
if part.Text != "" {
|
||||
content += part.Text
|
||||
}
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, models.ToolCall{
|
||||
ID: fmt.Sprintf("call_%s", part.FunctionCall.Name), // Gemini doesn't have call IDs
|
||||
Type: "function",
|
||||
Function: models.FunctionCall{
|
||||
Name: part.FunctionCall.Name,
|
||||
Arguments: string(part.FunctionCall.Args),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
finishReason := strings.ToLower(geminiResp.Candidates[0].FinishReason)
|
||||
if finishReason == "stop" {
|
||||
finishReason = "stop"
|
||||
} else if len(toolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
|
||||
openAIResp := &models.ChatCompletionResponse{
|
||||
ID: "gemini-" + req.Model,
|
||||
Object: "chat.completion",
|
||||
Created: 0, // Should be current timestamp
|
||||
Created: 0,
|
||||
Model: req.Model,
|
||||
Choices: []models.ChatChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: models.ChatMessage{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
},
|
||||
FinishReason: &geminiResp.Candidates[0].FinishReason,
|
||||
FinishReason: &finishReason,
|
||||
},
|
||||
},
|
||||
Usage: &models.Usage{
|
||||
PromptTokens: geminiResp.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResp.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiResp.UsageMetadata.TotalTokenCount,
|
||||
CacheReadTokens: uint32Ptr(geminiResp.UsageMetadata.CachedContentTokenCount),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -202,29 +470,144 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
// Simplified Gemini mapping
|
||||
var contents []GeminiContent
|
||||
for _, msg := range req.Messages {
|
||||
for i := 0; i < len(req.Messages); i++ {
|
||||
msg := req.Messages[i]
|
||||
|
||||
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||
parts := []GeminiPart{}
|
||||
for _, cp := range msg.Content {
|
||||
if cp.Type == "text" && cp.Text != "" {
|
||||
parts = append(parts, GeminiPart{Text: cp.Text})
|
||||
}
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
parts = append(parts, GeminiPart{
|
||||
FunctionCall: &GeminiFunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Args: json.RawMessage(tc.Function.Arguments),
|
||||
},
|
||||
})
|
||||
}
|
||||
contents = append(contents, GeminiContent{Role: "model", Parts: parts})
|
||||
|
||||
var functionParts []GeminiPart
|
||||
j := i + 1
|
||||
foundAny := false
|
||||
for j < len(req.Messages) && req.Messages[j].Role == "tool" {
|
||||
m := req.Messages[j]
|
||||
text := ""
|
||||
if len(m.Content) > 0 {
|
||||
text = m.Content[0].Text
|
||||
}
|
||||
name := "unknown_function"
|
||||
if m.Name != nil {
|
||||
name = *m.Name
|
||||
}
|
||||
|
||||
var responseObj interface{}
|
||||
if err := json.Unmarshal([]byte(text), &responseObj); err != nil {
|
||||
responseObj = map[string]interface{}{"result": text}
|
||||
}
|
||||
respBytes, _ := json.Marshal(responseObj)
|
||||
|
||||
functionParts = append(functionParts, GeminiPart{
|
||||
FunctionResponse: &GeminiFunctionResponse{
|
||||
Name: name,
|
||||
Response: json.RawMessage(respBytes),
|
||||
},
|
||||
})
|
||||
foundAny = true
|
||||
j++
|
||||
}
|
||||
|
||||
if foundAny {
|
||||
contents = append(contents, GeminiContent{Role: "function", Parts: functionParts})
|
||||
i = j - 1
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
role := "user"
|
||||
if msg.Role == "assistant" {
|
||||
role = "model"
|
||||
} else if msg.Role == "system" {
|
||||
role = "user"
|
||||
} else if msg.Role == "tool" {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
var parts []GeminiPart
|
||||
for _, p := range msg.Content {
|
||||
parts = append(parts, GeminiPart{Text: p.Text})
|
||||
for _, cp := range msg.Content {
|
||||
if cp.Type == "text" && cp.Text != "" {
|
||||
parts = append(parts, GeminiPart{Text: cp.Text})
|
||||
} else if cp.Image != nil {
|
||||
base64Data, mimeType, _ := cp.Image.ToBase64()
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: base64Data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
contents = append(contents, GeminiContent{
|
||||
Role: role,
|
||||
Parts: parts,
|
||||
})
|
||||
|
||||
if len(parts) > 0 {
|
||||
contents = append(contents, GeminiContent{Role: role, Parts: parts})
|
||||
}
|
||||
}
|
||||
|
||||
genConfig := &GeminiGenerationConfig{}
|
||||
if req.Temperature != nil {
|
||||
t := float32(*req.Temperature)
|
||||
genConfig.Temperature = &t
|
||||
}
|
||||
if req.TopP != nil {
|
||||
tp := float32(*req.TopP)
|
||||
genConfig.TopP = &tp
|
||||
}
|
||||
if req.TopK != nil {
|
||||
tk := int(*req.TopK)
|
||||
genConfig.TopK = &tk
|
||||
}
|
||||
if req.MaxTokens != nil {
|
||||
mt := int(*req.MaxTokens)
|
||||
genConfig.MaxOutputTokens = &mt
|
||||
}
|
||||
if len(req.Stop) > 0 {
|
||||
genConfig.StopSequences = req.Stop
|
||||
}
|
||||
|
||||
body := GeminiRequest{
|
||||
Contents: contents,
|
||||
Contents: contents,
|
||||
GenerationConfig: genConfig,
|
||||
}
|
||||
|
||||
hasMappedTools := false
|
||||
if len(req.Tools) > 0 {
|
||||
geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}}
|
||||
for _, t := range req.Tools {
|
||||
if t.Type == "function" {
|
||||
geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function)
|
||||
}
|
||||
}
|
||||
if len(geminiTool.FunctionDeclarations) > 0 {
|
||||
body.Tools = []GeminiTool{geminiTool}
|
||||
hasMappedTools = true
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := p.config.BaseURL
|
||||
lowerModel := strings.ToLower(req.Model)
|
||||
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") || hasMappedTools {
|
||||
// Use v1beta for preview, newer models, or when using tools
|
||||
if !strings.Contains(baseURL, "v1beta") {
|
||||
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
|
||||
}
|
||||
}
|
||||
|
||||
// Use streamGenerateContent for streaming
|
||||
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
|
||||
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", baseURL, req.Model, p.apiKey)
|
||||
fmt.Printf("[Gemini-Stream] POST %s\n", url)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
@@ -237,18 +620,25 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamGemini(resp.RawBody(), ch, req.Model)
|
||||
if err != nil {
|
||||
fmt.Printf("Gemini Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ch, err := StreamGemini(resp.RawBody(), req.Model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gemini stream init error: %w", err)
|
||||
}
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func uint32Ptr(v uint32) *uint32 {
|
||||
if v > 0 {
|
||||
return &v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,10 +4,12 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type GrokProvider struct {
|
||||
@@ -18,7 +20,7 @@ type GrokProvider struct {
|
||||
|
||||
func NewGrokProvider(cfg config.GrokConfig, apiKey string) *GrokProvider {
|
||||
return &GrokProvider{
|
||||
client: resty.New(),
|
||||
client: resty.New().SetTimeout(10 * time.Minute),
|
||||
config: cfg,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
@@ -47,7 +49,13 @@ func (p *GrokProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRe
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
@@ -78,11 +86,17 @@ func (p *GrokProvider) ChatCompletionStream(ctx context.Context, req *models.Uni
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamOpenAI(resp.RawBody(), ch)
|
||||
@@ -93,3 +107,15 @@ func (p *GrokProvider) ChatCompletionStream(ctx context.Context, req *models.Uni
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (p *GrokProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
|
||||
return nil, fmt.Errorf("grok does not support image generation")
|
||||
}
|
||||
|
||||
func (p *GrokProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by grok")
|
||||
}
|
||||
|
||||
func (p *GrokProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by grok")
|
||||
}
|
||||
|
||||
+390
-98
@@ -10,11 +10,32 @@ import (
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
func sanitizeFunctionName(name string) string {
|
||||
var sb strings.Builder
|
||||
for _, ch := range name {
|
||||
if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-' {
|
||||
sb.WriteRune(ch)
|
||||
} else {
|
||||
sb.WriteRune('_')
|
||||
}
|
||||
}
|
||||
res := sb.String()
|
||||
if res == "" {
|
||||
return "function"
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images.
|
||||
func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) {
|
||||
var result []interface{}
|
||||
for _, m := range messages {
|
||||
if m.Role == "tool" {
|
||||
role := strings.ToLower(m.Role)
|
||||
if role == "model" {
|
||||
role = "assistant"
|
||||
}
|
||||
|
||||
if role == "tool" || role == "function" {
|
||||
text := ""
|
||||
if len(m.Content) > 0 {
|
||||
text = m.Content[0].Text
|
||||
@@ -23,15 +44,14 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
|
||||
"role": "tool",
|
||||
"content": text,
|
||||
}
|
||||
id := "unknown"
|
||||
if m.ToolCallID != nil {
|
||||
id := *m.ToolCallID
|
||||
if len(id) > 40 {
|
||||
id = id[:40]
|
||||
}
|
||||
msg["tool_call_id"] = id
|
||||
id = *m.ToolCallID
|
||||
}
|
||||
msg["tool_call_id"] = id
|
||||
|
||||
if m.Name != nil {
|
||||
msg["name"] = *m.Name
|
||||
msg["name"] = sanitizeFunctionName(*m.Name)
|
||||
}
|
||||
result = append(result, msg)
|
||||
continue
|
||||
@@ -59,7 +79,9 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
|
||||
}
|
||||
|
||||
var finalContent interface{}
|
||||
if len(parts) == 1 {
|
||||
if len(parts) == 0 {
|
||||
finalContent = nil
|
||||
} else if len(parts) == 1 {
|
||||
if p, ok := parts[0].(map[string]interface{}); ok && p["type"] == "text" {
|
||||
finalContent = p["text"]
|
||||
} else {
|
||||
@@ -70,7 +92,7 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
|
||||
}
|
||||
|
||||
msg := map[string]interface{}{
|
||||
"role": m.Role,
|
||||
"role": role,
|
||||
"content": finalContent,
|
||||
}
|
||||
|
||||
@@ -82,20 +104,18 @@ func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, erro
|
||||
sanitizedCalls := make([]models.ToolCall, len(m.ToolCalls))
|
||||
copy(sanitizedCalls, m.ToolCalls)
|
||||
for i := range sanitizedCalls {
|
||||
if len(sanitizedCalls[i].ID) > 40 {
|
||||
sanitizedCalls[i].ID = sanitizedCalls[i].ID[:40]
|
||||
if sanitizedCalls[i].Type == "" {
|
||||
sanitizedCalls[i].Type = "function"
|
||||
}
|
||||
sanitizedCalls[i].Function.Name = sanitizeFunctionName(sanitizedCalls[i].Function.Name)
|
||||
}
|
||||
msg["tool_calls"] = sanitizedCalls
|
||||
if len(parts) == 0 {
|
||||
msg["content"] = ""
|
||||
}
|
||||
msg["content"] = "" // OpenAI requirement: content must be string if tool_calls present
|
||||
}
|
||||
|
||||
if m.Name != nil {
|
||||
msg["name"] = *m.Name
|
||||
}
|
||||
|
||||
result = append(result, msg)
|
||||
}
|
||||
return result, nil
|
||||
@@ -121,11 +141,25 @@ func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{},
|
||||
body["max_tokens"] = *request.MaxTokens
|
||||
}
|
||||
if len(request.Tools) > 0 {
|
||||
body["tools"] = request.Tools
|
||||
sanitizedTools := make([]models.Tool, len(request.Tools))
|
||||
copy(sanitizedTools, request.Tools)
|
||||
for i := range sanitizedTools {
|
||||
if sanitizedTools[i].Type == "function" {
|
||||
sanitizedTools[i].Function.Name = sanitizeFunctionName(sanitizedTools[i].Function.Name)
|
||||
}
|
||||
}
|
||||
body["tools"] = sanitizedTools
|
||||
}
|
||||
if request.ToolChoice != nil {
|
||||
var toolChoice interface{}
|
||||
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
|
||||
if tcMap, ok := toolChoice.(map[string]interface{}); ok {
|
||||
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
|
||||
if name, ok := funcMap["name"].(string); ok {
|
||||
funcMap["name"] = sanitizeFunctionName(name)
|
||||
}
|
||||
}
|
||||
}
|
||||
body["tool_choice"] = toolChoice
|
||||
}
|
||||
}
|
||||
@@ -133,11 +167,138 @@ func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{},
|
||||
return body
|
||||
}
|
||||
|
||||
// BuildOpenAIResponsesBody builds the request body for the Responses API endpoint.
|
||||
func BuildOpenAIResponsesBody(req *models.ResponsesRequest, stream bool) map[string]interface{} {
|
||||
body := map[string]interface{}{
|
||||
"model": req.Model,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
// The input field can be a string or a structured array.
|
||||
// Try to preserve the original format.
|
||||
if req.Input != nil {
|
||||
// Try as string first
|
||||
var inputStr string
|
||||
if err := json.Unmarshal(req.Input, &inputStr); err == nil {
|
||||
body["input"] = inputStr
|
||||
} else {
|
||||
// Try as array of messages
|
||||
var inputArr []interface{}
|
||||
if err := json.Unmarshal(req.Input, &inputArr); err == nil {
|
||||
body["input"] = inputArr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if req.Instructions != "" {
|
||||
body["instructions"] = req.Instructions
|
||||
}
|
||||
if req.Temperature != nil {
|
||||
body["temperature"] = *req.Temperature
|
||||
}
|
||||
if req.MaxOutputTokens != nil {
|
||||
body["max_output_tokens"] = *req.MaxOutputTokens
|
||||
}
|
||||
if req.TopP != nil {
|
||||
body["top_p"] = *req.TopP
|
||||
}
|
||||
if req.Tools != nil {
|
||||
var tools interface{}
|
||||
if err := json.Unmarshal(req.Tools, &tools); err == nil {
|
||||
body["tools"] = tools
|
||||
}
|
||||
}
|
||||
if req.ToolChoice != nil {
|
||||
var toolChoice interface{}
|
||||
if err := json.Unmarshal(req.ToolChoice, &toolChoice); err == nil {
|
||||
body["tool_choice"] = toolChoice
|
||||
}
|
||||
}
|
||||
if req.Store != nil {
|
||||
body["store"] = *req.Store
|
||||
}
|
||||
|
||||
if stream {
|
||||
body["stream_options"] = map[string]interface{}{
|
||||
"include_usage": true,
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// ParseOpenAIResponsesResponse parses a raw JSON map into a ResponsesResponse.
|
||||
func ParseOpenAIResponsesResponse(respJSON map[string]interface{}, model string) (*models.ResponsesResponse, error) {
|
||||
data, err := json.Marshal(respJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp models.ResponsesResponse
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Re-parse usage with the detailed tokens
|
||||
if usageData, ok := respJSON["usage"]; ok {
|
||||
var responsesUsage models.ResponsesUsage
|
||||
usageBytes, _ := json.Marshal(usageData)
|
||||
if err := json.Unmarshal(usageBytes, &responsesUsage); err == nil {
|
||||
resp.Usage = &responsesUsage
|
||||
}
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ParseOpenAIResponsesStreamChunk parses a single SSE line into a ResponsesStreamChunk.
|
||||
// Returns the chunk, whether this is the [DONE] signal, and any error.
|
||||
func ParseOpenAIResponsesStreamChunk(line string) (*models.ResponsesStreamChunk, bool, error) {
|
||||
if line == "" {
|
||||
return nil, false, nil
|
||||
}
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
var chunk models.ResponsesStreamChunk
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
return nil, false, fmt.Errorf("failed to unmarshal responses stream chunk: %w", err)
|
||||
}
|
||||
|
||||
return &chunk, false, nil
|
||||
}
|
||||
|
||||
// StreamOpenAIResponses reads SSE chunks from the body and sends them to the channel.
|
||||
func StreamOpenAIResponses(ctx io.ReadCloser, ch chan<- *models.ResponsesStreamChunk) error {
|
||||
defer ctx.Close()
|
||||
scanner := bufio.NewScanner(ctx)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
chunk, done, err := ParseOpenAIResponsesStreamChunk(line)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
if chunk != nil {
|
||||
ch <- chunk
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
type openAIUsage struct {
|
||||
PromptTokens uint32 `json:"prompt_tokens"`
|
||||
CompletionTokens uint32 `json:"completion_tokens"`
|
||||
TotalTokens uint32 `json:"total_tokens"`
|
||||
PromptTokensDetails *struct {
|
||||
PromptTokens uint32 `json:"prompt_tokens"`
|
||||
CompletionTokens uint32 `json:"completion_tokens"`
|
||||
TotalTokens uint32 `json:"total_tokens"`
|
||||
PromptTokensDetails *struct {
|
||||
CachedTokens uint32 `json:"cached_tokens"`
|
||||
} `json:"prompt_tokens_details"`
|
||||
CompletionTokensDetails *struct {
|
||||
@@ -165,7 +326,7 @@ func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
var resp models.ChatCompletionResponse
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return nil, err
|
||||
@@ -180,7 +341,7 @@ func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models
|
||||
resp.Usage = oUsage.ToUnified()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
@@ -234,85 +395,216 @@ func StreamOpenAI(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error {
|
||||
defer ctx.Close()
|
||||
|
||||
dec := json.NewDecoder(ctx)
|
||||
|
||||
t, err := dec.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if delim, ok := t.(json.Delim); ok && delim == '[' {
|
||||
for dec.More() {
|
||||
var geminiChunk struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought string `json:"thought,omitempty"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata struct {
|
||||
PromptTokenCount uint32 `json:"promptTokenCount"`
|
||||
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
||||
TotalTokenCount uint32 `json:"totalTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
if err := dec.Decode(&geminiChunk); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(geminiChunk.Candidates) > 0 || geminiChunk.UsageMetadata.TotalTokenCount > 0 {
|
||||
content := ""
|
||||
var reasoning *string
|
||||
if len(geminiChunk.Candidates) > 0 {
|
||||
for _, p := range geminiChunk.Candidates[0].Content.Parts {
|
||||
if p.Text != "" {
|
||||
content += p.Text
|
||||
}
|
||||
if p.Thought != "" {
|
||||
if reasoning == nil {
|
||||
reasoning = new(string)
|
||||
}
|
||||
*reasoning += p.Thought
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var finishReason *string
|
||||
if len(geminiChunk.Candidates) > 0 {
|
||||
fr := strings.ToLower(geminiChunk.Candidates[0].FinishReason)
|
||||
finishReason = &fr
|
||||
}
|
||||
// geminiStreamChunk is the shared data structure for parsing Gemini streaming responses.
|
||||
type geminiStreamChunk struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought string `json:"thought,omitempty"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata struct {
|
||||
PromptTokenCount uint32 `json:"promptTokenCount"`
|
||||
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
||||
TotalTokenCount uint32 `json:"totalTokenCount"`
|
||||
CachedContentTokenCount uint32 `json:"cachedContentTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
ch <- &models.ChatCompletionStreamResponse{
|
||||
ID: "gemini-stream",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 0,
|
||||
Model: model,
|
||||
Choices: []models.ChatStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: models.ChatStreamDelta{
|
||||
Content: &content,
|
||||
ReasoningContent: reasoning,
|
||||
},
|
||||
FinishReason: finishReason,
|
||||
},
|
||||
},
|
||||
Usage: &models.Usage{
|
||||
PromptTokens: geminiChunk.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiChunk.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiChunk.UsageMetadata.TotalTokenCount,
|
||||
},
|
||||
// emitGeminiChunk builds a ChatCompletionStreamResponse from a parsed geminiStreamChunk
|
||||
// and sends it to the channel. Returns true if anything was emitted.
|
||||
func emitGeminiChunk(ch chan<- *models.ChatCompletionStreamResponse, chunk *geminiStreamChunk, model string) bool {
|
||||
if len(chunk.Candidates) == 0 && chunk.UsageMetadata.TotalTokenCount == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
content := ""
|
||||
var reasoning *string
|
||||
var finishReason *string
|
||||
if len(chunk.Candidates) > 0 {
|
||||
for _, p := range chunk.Candidates[0].Content.Parts {
|
||||
if p.Text != "" {
|
||||
content += p.Text
|
||||
}
|
||||
if p.Thought != "" {
|
||||
if reasoning == nil {
|
||||
reasoning = new(string)
|
||||
}
|
||||
*reasoning += p.Thought
|
||||
}
|
||||
}
|
||||
fr := strings.ToLower(chunk.Candidates[0].FinishReason)
|
||||
finishReason = &fr
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
ch <- &models.ChatCompletionStreamResponse{
|
||||
ID: "gemini-stream",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 0,
|
||||
Model: model,
|
||||
Choices: []models.ChatStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: models.ChatStreamDelta{
|
||||
Content: &content,
|
||||
ReasoningContent: reasoning,
|
||||
},
|
||||
FinishReason: finishReason,
|
||||
},
|
||||
},
|
||||
Usage: &models.Usage{
|
||||
PromptTokens: chunk.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: chunk.UsageMetadata.TotalTokenCount,
|
||||
CacheReadTokens: uint32Ptr(chunk.UsageMetadata.CachedContentTokenCount),
|
||||
},
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// StreamGemini handles Gemini streaming responses in two formats:
|
||||
// 1. SSE format (newer models): each line is "data: {...}"
|
||||
// 2. JSON array format (older models): response body is [ {...}, {...} ]
|
||||
//
|
||||
// Usage metadata is only present in the final chunk, which we accumulate
|
||||
// and emit so the server can log it on stream end.
|
||||
func StreamGemini(ctx io.ReadCloser, model string) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
_ = ctx.Close()
|
||||
}()
|
||||
defer close(ch)
|
||||
|
||||
// Peek at the first byte to detect format
|
||||
peek := make([]byte, 6)
|
||||
n, _ := io.ReadAtLeast(ctx, peek, 1)
|
||||
if n == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
first := string(peek[:n])
|
||||
|
||||
if first[0] == '[' {
|
||||
// JSON array format
|
||||
rest, _ := io.ReadAll(ctx)
|
||||
streamGeminiJSONArray(append([]byte(first), rest...), ch, model)
|
||||
return
|
||||
} else if strings.HasPrefix(first, "data:") || strings.HasPrefix(first, "data: ") {
|
||||
// SSE format — pre-pend the peeked bytes then run SSE scanner
|
||||
combined := io.MultiReader(
|
||||
strings.NewReader(string(peek[:n])),
|
||||
ctx,
|
||||
)
|
||||
streamGeminiSSE(combined, ch, model)
|
||||
} else {
|
||||
// Unknown format — might still be SSE starting after a peek char
|
||||
// Pre-pend peeked bytes and try SSE
|
||||
combined := io.MultiReader(
|
||||
strings.NewReader(string(peek[:n])),
|
||||
ctx,
|
||||
)
|
||||
streamGeminiSSE(combined, ch, model)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// readAll reads remaining bytes from a reader (keeps the function signature simple
|
||||
// for the JSON array fallback path).
|
||||
func readAll(r io.Reader) []byte {
|
||||
b, _ := io.ReadAll(r)
|
||||
return b
|
||||
}
|
||||
|
||||
func streamGeminiJSONArray(data []byte, ch chan<- *models.ChatCompletionStreamResponse, model string) {
|
||||
var chunks []geminiStreamChunk
|
||||
if err := json.Unmarshal(data, &chunks); err != nil {
|
||||
fmt.Printf("[Gemini-Stream] JSON array parse error: %v\n", err)
|
||||
return
|
||||
}
|
||||
// Track the last chunk with usage for the final emission
|
||||
var lastUsage *geminiStreamChunk
|
||||
for i := range chunks {
|
||||
if chunks[i].UsageMetadata.TotalTokenCount > 0 {
|
||||
lastUsage = &chunks[i]
|
||||
}
|
||||
}
|
||||
if lastUsage != nil {
|
||||
// Emit a synthetic final chunk with usage data
|
||||
if len(lastUsage.Candidates) == 0 && lastUsage.UsageMetadata.TotalTokenCount > 0 {
|
||||
emitGeminiChunk(ch, lastUsage, model)
|
||||
}
|
||||
}
|
||||
// Also emit each content-bearing chunk
|
||||
for i := range chunks {
|
||||
emitGeminiChunk(ch, &chunks[i], model)
|
||||
}
|
||||
}
|
||||
|
||||
func streamGeminiSSE(r io.Reader, ch chan<- *models.ChatCompletionStreamResponse, model string) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
// Track the last seen usage for emission at end of stream
|
||||
var lastUsage geminiStreamChunk
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
// Emit final usage if we have one
|
||||
if lastUsage.UsageMetadata.TotalTokenCount > 0 {
|
||||
emitGeminiChunk(ch, &lastUsage, model)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
var chunk geminiStreamChunk
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Capture usage from any chunk (Gemini puts it in the final response)
|
||||
if chunk.UsageMetadata.TotalTokenCount > 0 {
|
||||
lastUsage = chunk
|
||||
}
|
||||
|
||||
// Emit content chunks as they arrive
|
||||
if len(chunk.Candidates) > 0 {
|
||||
emitGeminiChunk(ch, &chunk, model)
|
||||
}
|
||||
}
|
||||
|
||||
// If stream ended without [DONE] marker but we collected usage, emit it
|
||||
if lastUsage.UsageMetadata.TotalTokenCount > 0 {
|
||||
emitGeminiChunk(ch, &lastUsage, model)
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
fmt.Printf("[Gemini-Stream] SSE scan error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ensureEnglish injects a system message instructing the model to respond in
|
||||
// English when no system prompt is already present. Some providers (e.g. DeepSeek)
|
||||
// default to Chinese for certain prompts.
|
||||
func ensureEnglish(req *models.UnifiedRequest) {
|
||||
if len(req.Messages) > 0 && req.Messages[0].Role == "system" {
|
||||
return // already has a system prompt, don't interfere
|
||||
}
|
||||
enMsg := models.UnifiedMessage{
|
||||
Role: "system",
|
||||
Content: []models.UnifiedContentPart{
|
||||
{Type: "text", Text: "You are a helpful assistant. Always respond in English."},
|
||||
},
|
||||
}
|
||||
req.Messages = append([]models.UnifiedMessage{enMsg}, req.Messages...)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
func TestSanitizeFunctionName(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"google-search", "google-search"},
|
||||
{"google.search", "google_search"},
|
||||
{"google search", "google_search"},
|
||||
{"web_search(query)", "web_search_query_"},
|
||||
{"", "function"},
|
||||
{"123_abc-XYZ", "123_abc-XYZ"},
|
||||
{"invalid.name.with.dots", "invalid_name_with_dots"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
actual := sanitizeFunctionName(tc.input)
|
||||
if actual != tc.expected {
|
||||
t.Errorf("sanitizeFunctionName(%q) = %q; expected %q", tc.input, actual, tc.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesToOpenAIJSON_SanitizeToolCalls(t *testing.T) {
|
||||
messages := []models.UnifiedMessage{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: []models.UnifiedContentPart{
|
||||
{Type: "text", Text: "I will use search."},
|
||||
},
|
||||
ToolCalls: []models.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: models.FunctionCall{
|
||||
Name: "google.search",
|
||||
Arguments: `{"query": "hello"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: []models.UnifiedContentPart{
|
||||
{Type: "text", Text: `{"result": "success"}`},
|
||||
},
|
||||
ToolCallID: stringPtr("call_1"),
|
||||
Name: stringPtr("google.search"),
|
||||
},
|
||||
}
|
||||
|
||||
res, err := MessagesToOpenAIJSON(messages)
|
||||
if err != nil {
|
||||
t.Fatalf("MessagesToOpenAIJSON failed: %v", err)
|
||||
}
|
||||
|
||||
if len(res) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(res))
|
||||
}
|
||||
|
||||
// Verify assistant message
|
||||
msg1 := res[0].(map[string]interface{})
|
||||
if msg1["role"] != "assistant" {
|
||||
t.Errorf("expected role assistant, got %v", msg1["role"])
|
||||
}
|
||||
calls := msg1["tool_calls"].([]models.ToolCall)
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(calls))
|
||||
}
|
||||
if calls[0].Function.Name != "google_search" {
|
||||
t.Errorf("expected function name google_search, got %q", calls[0].Function.Name)
|
||||
}
|
||||
|
||||
// Verify tool response message
|
||||
msg2 := res[1].(map[string]interface{})
|
||||
if msg2["role"] != "tool" {
|
||||
t.Errorf("expected role tool, got %v", msg2["role"])
|
||||
}
|
||||
if msg2["name"] != "google_search" {
|
||||
t.Errorf("expected tool name google_search, got %v", msg2["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIBody_SanitizeToolsAndChoice(t *testing.T) {
|
||||
req := &models.UnifiedRequest{
|
||||
Model: "gpt-4o",
|
||||
Tools: []models.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: models.FunctionDef{
|
||||
Name: "google.search",
|
||||
},
|
||||
},
|
||||
},
|
||||
ToolChoice: json.RawMessage(`{"type": "function", "function": {"name": "google.search"}}`),
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, nil, false)
|
||||
|
||||
// Verify tools
|
||||
tools := body["tools"].([]models.Tool)
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(tools))
|
||||
}
|
||||
if tools[0].Function.Name != "google_search" {
|
||||
t.Errorf("expected tool function name google_search, got %q", tools[0].Function.Name)
|
||||
}
|
||||
|
||||
// Verify tool_choice
|
||||
toolChoice := body["tool_choice"].(map[string]interface{})
|
||||
funcObj := toolChoice["function"].(map[string]interface{})
|
||||
if funcObj["name"] != "google_search" {
|
||||
t.Errorf("expected tool_choice function name google_search, got %q", funcObj["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
@@ -4,11 +4,13 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type MoonshotProvider struct {
|
||||
@@ -19,7 +21,7 @@ type MoonshotProvider struct {
|
||||
|
||||
func NewMoonshotProvider(cfg config.MoonshotConfig, apiKey string) *MoonshotProvider {
|
||||
return &MoonshotProvider{
|
||||
client: resty.New(),
|
||||
client: resty.New().SetTimeout(10 * time.Minute),
|
||||
config: cfg,
|
||||
apiKey: strings.TrimSpace(apiKey),
|
||||
}
|
||||
@@ -58,7 +60,13 @@ func (p *MoonshotProvider) ChatCompletion(ctx context.Context, req *models.Unifi
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
@@ -99,7 +107,13 @@ func (p *MoonshotProvider) ChatCompletionStream(ctx context.Context, req *models
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
@@ -112,3 +126,15 @@ func (p *MoonshotProvider) ChatCompletionStream(ctx context.Context, req *models
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (p *MoonshotProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
|
||||
return nil, fmt.Errorf("moonshot does not support image generation")
|
||||
}
|
||||
|
||||
func (p *MoonshotProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by moonshot")
|
||||
}
|
||||
|
||||
func (p *MoonshotProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by moonshot")
|
||||
}
|
||||
|
||||
@@ -7,10 +7,11 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type OllamaProvider struct {
|
||||
@@ -19,8 +20,15 @@ type OllamaProvider struct {
|
||||
}
|
||||
|
||||
func NewOllamaProvider(cfg config.OllamaConfig) *OllamaProvider {
|
||||
client := resty.New()
|
||||
// Set reasonable timeouts for local Ollama server (longer for larger models)
|
||||
// For streaming, we want a very long timeout or none at all to handle generation time
|
||||
client.SetTimeout(15 * time.Minute)
|
||||
client.SetRetryCount(2)
|
||||
client.SetRetryWaitTime(1 * time.Second)
|
||||
|
||||
return &OllamaProvider{
|
||||
client: resty.New(),
|
||||
client: client,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
@@ -36,18 +44,25 @@ func (p *OllamaProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
}
|
||||
|
||||
body := BuildOllamaBody(req, messagesJSON, false)
|
||||
url := fmt.Sprintf("%s/chat/completions", p.config.BaseURL)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetBody(body).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
Post(url)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
@@ -77,16 +92,21 @@ func (p *OllamaProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamOllama(resp.RawBody(), ch, req.Model)
|
||||
if err != nil {
|
||||
fmt.Printf("Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -100,23 +120,63 @@ func BuildOllamaBody(request *models.UnifiedRequest, messagesJSON []interface{},
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
options := make(map[string]interface{})
|
||||
modelLower := strings.ToLower(request.Model)
|
||||
|
||||
// Context window size (default 8k for all, 32k+ for modern large-context models)
|
||||
ctxSize := 8192
|
||||
if strings.Contains(modelLower, "llama") ||
|
||||
strings.Contains(modelLower, "gemma") ||
|
||||
strings.Contains(modelLower, "mistral") ||
|
||||
strings.Contains(modelLower, "mixtral") ||
|
||||
strings.Contains(modelLower, "qwen") ||
|
||||
strings.Contains(modelLower, "deepseek") ||
|
||||
strings.Contains(modelLower, "command-r") ||
|
||||
strings.Contains(modelLower, "phi") {
|
||||
ctxSize = 32768
|
||||
}
|
||||
options["num_ctx"] = ctxSize
|
||||
|
||||
if request.Temperature != nil {
|
||||
body["temperature"] = *request.Temperature
|
||||
options["temperature"] = *request.Temperature
|
||||
}
|
||||
|
||||
if request.MaxTokens != nil {
|
||||
body["max_tokens"] = *request.MaxTokens
|
||||
options["num_predict"] = *request.MaxTokens
|
||||
} else {
|
||||
// Default to 8192 for all Ollama models if not specified,
|
||||
// as Ollama's compatibility layer defaults to 128 if neither
|
||||
// max_tokens nor num_predict are provided.
|
||||
body["max_tokens"] = 8192
|
||||
options["num_predict"] = 8192
|
||||
}
|
||||
|
||||
if request.TopP != nil {
|
||||
body["top_p"] = *request.TopP
|
||||
options["top_p"] = *request.TopP
|
||||
}
|
||||
if request.TopK != nil {
|
||||
body["top_k"] = *request.TopK
|
||||
options["top_k"] = *request.TopK
|
||||
}
|
||||
|
||||
if len(request.Stop) > 0 {
|
||||
body["stop"] = request.Stop
|
||||
options["stop"] = request.Stop
|
||||
}
|
||||
|
||||
if len(options) > 0 {
|
||||
body["options"] = options
|
||||
}
|
||||
|
||||
if len(request.Tools) > 0 {
|
||||
body["tools"] = request.Tools
|
||||
// Explicitly set tool_choice to auto if tools are present but choice is not specified
|
||||
if request.ToolChoice == nil {
|
||||
body["tool_choice"] = "auto"
|
||||
}
|
||||
}
|
||||
if request.ToolChoice != nil {
|
||||
var toolChoice interface{}
|
||||
@@ -133,7 +193,7 @@ func ParseOllamaResponse(respJSON map[string]interface{}, model string) (*models
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
var resp models.ChatCompletionResponse
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return nil, err
|
||||
@@ -146,7 +206,7 @@ func ParseOllamaResponse(respJSON map[string]interface{}, model string) (*models
|
||||
resp.Usage = &usage
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
@@ -181,6 +241,11 @@ func ParseOllamaStreamChunk(line string) (*models.ChatCompletionStreamResponse,
|
||||
func StreamOllama(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error {
|
||||
defer ctx.Close()
|
||||
scanner := bufio.NewScanner(ctx)
|
||||
// Set a larger buffer for scanning to handle large chunks if they occur
|
||||
const maxCapacity = 10 * 1024 * 1024 // 10MB
|
||||
buf := make([]byte, 64*1024)
|
||||
scanner.Buffer(buf, maxCapacity)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
chunk, done, err := ParseOllamaStreamChunk(line)
|
||||
@@ -195,4 +260,16 @@ func StreamOllama(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OllamaProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
|
||||
return nil, fmt.Errorf("ollama does not support image generation")
|
||||
}
|
||||
|
||||
func (p *OllamaProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by ollama")
|
||||
}
|
||||
|
||||
func (p *OllamaProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by ollama")
|
||||
}
|
||||
|
||||
@@ -4,23 +4,26 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type OpenAIProvider struct {
|
||||
client *resty.Client
|
||||
config config.OpenAIConfig
|
||||
apiKey string
|
||||
client *resty.Client
|
||||
config config.OpenAIConfig
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewOpenAIProvider(cfg config.OpenAIConfig, apiKey string) *OpenAIProvider {
|
||||
return &OpenAIProvider{
|
||||
client: resty.New(),
|
||||
config: cfg,
|
||||
client: resty.New().SetTimeout(10 * time.Minute),
|
||||
config: cfg,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
}
|
||||
@@ -37,6 +40,17 @@ func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||
|
||||
// Debug message sequence
|
||||
for i, m := range messagesJSON {
|
||||
mMap, _ := m.(map[string]interface{})
|
||||
role, _ := mMap["role"].(string)
|
||||
hasToolCalls := false
|
||||
if tc, ok := mMap["tool_calls"]; ok && tc != nil {
|
||||
hasToolCalls = true
|
||||
}
|
||||
log.Printf("[DEBUG] OpenAI Msg[%d]: role=%s, hasToolCalls=%v", i, role, hasToolCalls)
|
||||
}
|
||||
|
||||
// Transition: Newer models require max_completion_tokens
|
||||
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
|
||||
if maxTokens, ok := body["max_tokens"]; ok {
|
||||
@@ -56,7 +70,17 @@ func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if b := resp.Body(); len(b) > 0 {
|
||||
msg = string(b)
|
||||
}
|
||||
}
|
||||
// Log the request body for debugging
|
||||
reqJSON, _ := json.Marshal(body)
|
||||
log.Printf("OpenAI API Error (%d): %s", resp.StatusCode(), msg)
|
||||
log.Printf("OpenAI request body: %s", string(reqJSON))
|
||||
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
@@ -67,6 +91,59 @@ func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
||||
return ParseOpenAIResponse(respJSON, req.Model)
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
|
||||
body := map[string]interface{}{
|
||||
"prompt": req.Prompt,
|
||||
"model": req.Model,
|
||||
}
|
||||
|
||||
if req.N != nil {
|
||||
body["n"] = *req.N
|
||||
}
|
||||
if req.Quality != nil {
|
||||
body["quality"] = *req.Quality
|
||||
}
|
||||
if req.ResponseFormat != nil {
|
||||
body["response_format"] = *req.ResponseFormat
|
||||
}
|
||||
if req.Size != nil {
|
||||
body["size"] = *req.Size
|
||||
}
|
||||
if req.Style != nil {
|
||||
body["style"] = *req.Style
|
||||
}
|
||||
if req.User != nil {
|
||||
body["user"] = *req.User
|
||||
}
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
Post(fmt.Sprintf("%s/images/generations", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("OpenAI image API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
var result models.ImageGenerationResponse
|
||||
if err := json.Unmarshal(resp.Body(), &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
@@ -75,6 +152,17 @@ func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||
|
||||
// Debug message sequence
|
||||
for i, m := range messagesJSON {
|
||||
mMap, _ := m.(map[string]interface{})
|
||||
role, _ := mMap["role"].(string)
|
||||
hasToolCalls := false
|
||||
if tc, ok := mMap["tool_calls"]; ok && tc != nil {
|
||||
hasToolCalls = true
|
||||
}
|
||||
log.Printf("[DEBUG] OpenAI Stream Msg[%d]: role=%s, hasToolCalls=%v", i, role, hasToolCalls)
|
||||
}
|
||||
|
||||
// Transition: Newer models require max_completion_tokens
|
||||
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
|
||||
if maxTokens, ok := body["max_tokens"]; ok {
|
||||
@@ -95,11 +183,25 @@ func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if b := resp.Body(); len(b) > 0 {
|
||||
msg = string(b)
|
||||
}
|
||||
if msg == "" {
|
||||
if b, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(b)
|
||||
}
|
||||
}
|
||||
}
|
||||
reqJSON, _ := json.Marshal(body)
|
||||
log.Printf("OpenAI API Error (%d): %s", resp.StatusCode(), msg)
|
||||
log.Printf("OpenAI request body: %s", string(reqJSON))
|
||||
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamOpenAI(resp.RawBody(), ch)
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
// Responses sends a non-streaming request to OpenAI's /v1/responses endpoint.
|
||||
func (p *OpenAIProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
|
||||
// Determine if streaming was requested
|
||||
stream := req.Stream != nil && *req.Stream
|
||||
|
||||
body := BuildOpenAIResponsesBody(req, stream)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
Post(fmt.Sprintf("%s/responses", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("responses request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("OpenAI Responses API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse responses response: %w", err)
|
||||
}
|
||||
|
||||
return ParseOpenAIResponsesResponse(respJSON, req.Model)
|
||||
}
|
||||
|
||||
// ResponsesStream sends a streaming request to OpenAI's /v1/responses endpoint.
|
||||
func (p *OpenAIProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
|
||||
body := BuildOpenAIResponsesBody(req, true)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(fmt.Sprintf("%s/responses", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("responses stream request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("OpenAI Responses API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
ch := make(chan *models.ResponsesStreamChunk)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamOpenAIResponses(resp.RawBody(), ch)
|
||||
if err != nil {
|
||||
fmt.Printf("Responses stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
@@ -10,4 +10,7 @@ type Provider interface {
|
||||
Name() string
|
||||
ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error)
|
||||
ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error)
|
||||
ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error)
|
||||
Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error)
|
||||
ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-resty/resty/v2"
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
type XiaomiProvider struct {
|
||||
client *resty.Client
|
||||
config config.XiaomiConfig
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewXiaomiProvider(cfg config.XiaomiConfig, apiKey string) *XiaomiProvider {
|
||||
return &XiaomiProvider{
|
||||
client: resty.New().SetTimeout(10 * time.Minute),
|
||||
config: cfg,
|
||||
apiKey: strings.TrimSpace(apiKey),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *XiaomiProvider) Name() string {
|
||||
return "xiaomi"
|
||||
}
|
||||
|
||||
func (p *XiaomiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||
|
||||
baseURL := strings.TrimRight(p.config.BaseURL, "/")
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("Accept", "application/json").
|
||||
SetBody(body).
|
||||
Post(fmt.Sprintf("%s/chat/completions", baseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if b := resp.Body(); len(b) > 0 {
|
||||
msg = string(b)
|
||||
}
|
||||
}
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Xiaomi API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return ParseOpenAIResponse(respJSON, req.Model)
|
||||
}
|
||||
|
||||
func (p *XiaomiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||
|
||||
baseURL := strings.TrimRight(p.config.BaseURL, "/")
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("Accept", "text/event-stream").
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(fmt.Sprintf("%s/chat/completions", baseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
msg := resp.String()
|
||||
if msg == "" {
|
||||
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
||||
msg = string(body)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Xiaomi API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
if err := StreamOpenAI(resp.RawBody(), ch); err != nil {
|
||||
fmt.Printf("Xiaomi Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (p *XiaomiProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
|
||||
return nil, fmt.Errorf("xiaomi does not support image generation")
|
||||
}
|
||||
|
||||
func (p *XiaomiProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by xiaomi")
|
||||
}
|
||||
|
||||
func (p *XiaomiProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
|
||||
return nil, fmt.Errorf("responses API not supported by xiaomi")
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
const classifierSystemPrompt = `You are a task complexity classifier. Rate the following user message on a scale of 1 to %d, where:
|
||||
1 = trivial/simple (basic facts, greetings, simple math)
|
||||
%d = highly complex (multi-step reasoning, code generation, architecture design)
|
||||
|
||||
Reply with ONLY the number. No explanation.`
|
||||
|
||||
func routeClassifier(ctx context.Context, classify ClassifierFunc, group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
|
||||
// Determine the rating scale
|
||||
maxRating := len(targets)
|
||||
if maxRating < 2 {
|
||||
maxRating = 2
|
||||
}
|
||||
|
||||
// When complexity_threshold is set, use it as a wider scale (e.g., 1-10)
|
||||
// and map ratings proportionally to target buckets.
|
||||
bucketMode := group.ComplexityThreshold != nil && *group.ComplexityThreshold > 0
|
||||
if bucketMode {
|
||||
maxRating = *group.ComplexityThreshold
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(classifierSystemPrompt, maxRating, maxRating)
|
||||
userMsg := ""
|
||||
if routeCtx != nil {
|
||||
userMsg = routeCtx.UserMessage
|
||||
}
|
||||
ratingStr, err := classify(ctx, getSelectorModel(group, targets), prompt, userMsg)
|
||||
if err != nil {
|
||||
// Classifier failed — fall back to heuristic
|
||||
return routeHeuristic(group, targets, routeCtx)
|
||||
}
|
||||
|
||||
rating, err := strconv.Atoi(strings.TrimSpace(ratingStr))
|
||||
if err != nil || rating < 1 {
|
||||
rating = 1
|
||||
}
|
||||
if rating > maxRating {
|
||||
rating = maxRating
|
||||
}
|
||||
|
||||
var idx int
|
||||
if bucketMode {
|
||||
// Proportional mapping: wider scale → N target buckets
|
||||
// e.g., threshold=10, 3 targets: 1-3→0, 4-7→1, 8-10→2
|
||||
idx = rating * len(targets) / (maxRating + 1)
|
||||
if idx >= len(targets) {
|
||||
idx = len(targets) - 1
|
||||
}
|
||||
} else {
|
||||
idx = rating - 1 // 1:1 mapping
|
||||
}
|
||||
|
||||
return &Decision{
|
||||
SelectedModel: targets[idx],
|
||||
Strategy: "classifier",
|
||||
Reason: fmt.Sprintf("complexity rating: %d/%d", rating, maxRating),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getSelectorModel(group db.ModelGroup, targets []string) string {
|
||||
if group.SelectorModel != nil && *group.SelectorModel != "" {
|
||||
return *group.SelectorModel
|
||||
}
|
||||
// Default: use the first (cheapest) target model as the selector
|
||||
return targets[0]
|
||||
}
|
||||
@@ -0,0 +1,219 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
// HeuristicRule defines a pattern-based routing rule (legacy format).
|
||||
type HeuristicRule struct {
|
||||
Pattern string `json:"pattern"`
|
||||
TargetIdx int `json:"target"`
|
||||
CaseSensitive bool `json:"case_sensitive,omitempty"`
|
||||
}
|
||||
|
||||
// ConditionRule defines a condition-based routing rule (new format).
|
||||
type ConditionRule struct {
|
||||
RuleID string `json:"rule_id"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Conditions Conditions `json:"conditions"`
|
||||
PrimaryModel string `json:"primary_model"`
|
||||
FallbackModel string `json:"fallback_model,omitempty"`
|
||||
}
|
||||
|
||||
// Conditions defines the matching parameters for a rule.
|
||||
type Conditions struct {
|
||||
AnyOfTags []string `json:"any_of_tags,omitempty"`
|
||||
MaxInputTokensLt *int `json:"max_input_tokens_lt,omitempty"`
|
||||
RequiresReasoning *bool `json:"requires_reasoning,omitempty"`
|
||||
RequiresToolCalling *bool `json:"requires_tool_calling,omitempty"`
|
||||
HasMultimodalInput *bool `json:"has_multimodal_input,omitempty"`
|
||||
IsDefaultFallback *bool `json:"is_default_fallback,omitempty"`
|
||||
}
|
||||
|
||||
func routeHeuristic(group db.ModelGroup, targets []string, routeCtx *RouteContext) (*Decision, error) {
|
||||
if routeCtx == nil {
|
||||
routeCtx = &RouteContext{}
|
||||
}
|
||||
|
||||
selected := targets[0]
|
||||
reason := "default (first target)"
|
||||
|
||||
// If heuristic_rules is set, determine format and parse
|
||||
if group.HeuristicRules != nil && *group.HeuristicRules != "" {
|
||||
rulesJSON := *group.HeuristicRules
|
||||
|
||||
if isConditionBasedRules(rulesJSON) {
|
||||
var condRules []ConditionRule
|
||||
if err := json.Unmarshal([]byte(rulesJSON), &condRules); err == nil {
|
||||
for _, rule := range condRules {
|
||||
if matchConditions(rule.Conditions, routeCtx) {
|
||||
// Resolve primary/fallback to concrete models in target list
|
||||
targetModel := ""
|
||||
if rule.PrimaryModel != "" {
|
||||
targetModel = getModelInTargets(rule.PrimaryModel, targets)
|
||||
}
|
||||
if targetModel == "" && rule.FallbackModel != "" {
|
||||
targetModel = getModelInTargets(rule.FallbackModel, targets)
|
||||
}
|
||||
|
||||
if targetModel != "" {
|
||||
selected = targetModel
|
||||
reason = "matched condition rule: " + rule.RuleID
|
||||
if rule.Description != "" {
|
||||
reason += " (" + rule.Description + ")"
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback to legacy pattern-based rules
|
||||
var legacyRules []HeuristicRule
|
||||
if err := json.Unmarshal([]byte(rulesJSON), &legacyRules); err == nil {
|
||||
searchMsg := routeCtx.UserMessage
|
||||
for _, rule := range legacyRules {
|
||||
pattern := rule.Pattern
|
||||
if pattern == "" {
|
||||
continue // Avoid infinite matches with empty patterns
|
||||
}
|
||||
msg := searchMsg
|
||||
if !rule.CaseSensitive {
|
||||
pattern = strings.ToLower(pattern)
|
||||
msg = strings.ToLower(msg)
|
||||
}
|
||||
|
||||
// Support both regex matching (if pattern is valid regex) and literal contains
|
||||
matched := false
|
||||
if strings.Contains(rule.Pattern, "(") || strings.Contains(rule.Pattern, "\\b") {
|
||||
var re *regexp.Regexp
|
||||
var err error
|
||||
if !rule.CaseSensitive {
|
||||
re, err = regexp.Compile("(?i)" + rule.Pattern)
|
||||
} else {
|
||||
re, err = regexp.Compile(rule.Pattern)
|
||||
}
|
||||
if err == nil {
|
||||
matched = re.MatchString(routeCtx.UserMessage)
|
||||
}
|
||||
}
|
||||
|
||||
if !matched && strings.Contains(msg, pattern) {
|
||||
matched = true
|
||||
}
|
||||
|
||||
if matched {
|
||||
if rule.TargetIdx >= 0 && rule.TargetIdx < len(targets) {
|
||||
selected = targets[rule.TargetIdx]
|
||||
reason = "matched heuristic rule: " + rule.Pattern
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Built-in fallback heuristics (if no custom rule matched)
|
||||
if reason == "default (first target)" && len(targets) > 1 {
|
||||
msgLower := strings.ToLower(routeCtx.UserMessage)
|
||||
complexIndicators := []string{
|
||||
"step by step", "explain in detail", "reason through",
|
||||
"think carefully", "analyze", "debug", "write code",
|
||||
"implement", "refactor", "architecture",
|
||||
}
|
||||
for _, indicator := range complexIndicators {
|
||||
if strings.Contains(msgLower, indicator) {
|
||||
selected = targets[len(targets)-1]
|
||||
reason = "complex task indicator: " + indicator
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Decision{
|
||||
SelectedModel: selected,
|
||||
Strategy: "heuristic",
|
||||
Reason: reason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// isConditionBasedRules returns true if the JSON represents condition-based rules.
|
||||
func isConditionBasedRules(rulesJSON string) bool {
|
||||
var rules []ConditionRule
|
||||
if err := json.Unmarshal([]byte(rulesJSON), &rules); err == nil && len(rules) > 0 {
|
||||
// If the rule has either conditions or primary_model/rule_id, treat it as condition-based
|
||||
return rules[0].PrimaryModel != "" || rules[0].RuleID != ""
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchConditions evaluates whether the given conditions match the RouteContext.
|
||||
func matchConditions(cond Conditions, routeCtx *RouteContext) bool {
|
||||
if cond.IsDefaultFallback != nil && *cond.IsDefaultFallback {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check tags: must match any_of_tags if specified
|
||||
if len(cond.AnyOfTags) > 0 {
|
||||
tagMatched := false
|
||||
for _, ruleTag := range cond.AnyOfTags {
|
||||
for _, ctxTag := range routeCtx.Tags {
|
||||
if strings.EqualFold(ruleTag, ctxTag) {
|
||||
tagMatched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if tagMatched {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !tagMatched {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check max input tokens
|
||||
if cond.MaxInputTokensLt != nil {
|
||||
if routeCtx.InputTokens >= *cond.MaxInputTokensLt {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check reasoning flag
|
||||
if cond.RequiresReasoning != nil {
|
||||
if routeCtx.RequiresReasoning != *cond.RequiresReasoning {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check tool calling flag
|
||||
if cond.RequiresToolCalling != nil {
|
||||
if routeCtx.RequiresToolCalling != *cond.RequiresToolCalling {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check multimodal flag
|
||||
if cond.HasMultimodalInput != nil {
|
||||
if routeCtx.HasMultimodalInput != *cond.HasMultimodalInput {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// getModelInTargets returns the model name if it exists in targets, or empty string.
|
||||
func getModelInTargets(modelName string, targets []string) string {
|
||||
for _, t := range targets {
|
||||
if strings.EqualFold(t, modelName) {
|
||||
return t
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
func TestRouteHeuristic_ConditionRules(t *testing.T) {
|
||||
targets := []string{
|
||||
"deepseek-v4-flash", // index 0
|
||||
"gemini-3-flash", // index 1
|
||||
"grok-build-0.1", // index 2
|
||||
"kimi-k2.6", // index 3
|
||||
"mimo-v2.5-pro", // index 4
|
||||
"grok-4.3", // index 5
|
||||
"deepseek-v4-pro", // index 6
|
||||
}
|
||||
|
||||
rulesJSON := `[
|
||||
{
|
||||
"rule_id": "fast_flow_extraction",
|
||||
"conditions": {
|
||||
"any_of_tags": ["fast-flow", "classification"],
|
||||
"max_input_tokens_lt": 8000,
|
||||
"requires_reasoning": false
|
||||
},
|
||||
"primary_model": "deepseek-v4-flash",
|
||||
"fallback_model": "grok-build-0.1"
|
||||
},
|
||||
{
|
||||
"rule_id": "multimodal_long_context",
|
||||
"conditions": {
|
||||
"any_of_tags": ["standard-pro", "long-doc"],
|
||||
"has_multimodal_input": true
|
||||
},
|
||||
"primary_model": "gemini-3-flash",
|
||||
"fallback_model": "mimo-v2.5-pro"
|
||||
},
|
||||
{
|
||||
"rule_id": "regional_fallback_general",
|
||||
"conditions": {
|
||||
"is_default_fallback": true
|
||||
},
|
||||
"primary_model": "kimi-k2.6"
|
||||
}
|
||||
]`
|
||||
|
||||
group := db.ModelGroup{
|
||||
ID: "dustins_stack",
|
||||
Strategy: "heuristic",
|
||||
HeuristicRules: &rulesJSON,
|
||||
}
|
||||
|
||||
// 1. Test Match Fast Flow (condition success)
|
||||
ctx1 := &RouteContext{
|
||||
UserMessage: "classify this JSON",
|
||||
InputTokens: 500,
|
||||
HasMultimodalInput: false,
|
||||
RequiresReasoning: false,
|
||||
Tags: []string{"fast-flow", "classification"},
|
||||
}
|
||||
dec1, err := routeHeuristic(group, targets, ctx1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec1.SelectedModel != "deepseek-v4-flash" {
|
||||
t.Fatalf("expected deepseek-v4-flash, got %s", dec1.SelectedModel)
|
||||
}
|
||||
|
||||
// 2. Test Multimodal Long Context (condition success)
|
||||
ctx2 := &RouteContext{
|
||||
UserMessage: "explain this video",
|
||||
InputTokens: 15000,
|
||||
HasMultimodalInput: true,
|
||||
RequiresReasoning: false,
|
||||
Tags: []string{"standard-pro", "video-analysis"},
|
||||
}
|
||||
dec2, err := routeHeuristic(group, targets, ctx2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec2.SelectedModel != "gemini-3-flash" {
|
||||
t.Fatalf("expected gemini-3-flash, got %s", dec2.SelectedModel)
|
||||
}
|
||||
|
||||
// 3. Test Fallback general rule
|
||||
ctx3 := &RouteContext{
|
||||
UserMessage: "hello there",
|
||||
InputTokens: 100,
|
||||
HasMultimodalInput: false,
|
||||
RequiresReasoning: false,
|
||||
Tags: []string{"general"},
|
||||
}
|
||||
dec3, err := routeHeuristic(group, targets, ctx3)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec3.SelectedModel != "kimi-k2.6" {
|
||||
t.Fatalf("expected kimi-k2.6, got %s", dec3.SelectedModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteHeuristic_LegacyRules(t *testing.T) {
|
||||
targets := []string{"gpt-4o-mini", "deepseek-v4-pro", "kimi-k2.6"}
|
||||
|
||||
// Legacy pattern-based rule with regex
|
||||
rulesJSON := `[
|
||||
{"pattern": "\\b(agent|agents|tool use)\\b", "target": 1},
|
||||
{"pattern": "summarize", "target": 2}
|
||||
]`
|
||||
|
||||
group := db.ModelGroup{
|
||||
ID: "heavy-logic",
|
||||
Strategy: "heuristic",
|
||||
HeuristicRules: &rulesJSON,
|
||||
}
|
||||
|
||||
// 1. Test regex match
|
||||
ctx1 := &RouteContext{
|
||||
UserMessage: "We need an agent to do tool use",
|
||||
}
|
||||
dec1, err := routeHeuristic(group, targets, ctx1)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec1.SelectedModel != "deepseek-v4-pro" {
|
||||
t.Fatalf("expected deepseek-v4-pro, got %s", dec1.SelectedModel)
|
||||
}
|
||||
|
||||
// 2. Test literal match
|
||||
ctx2 := &RouteContext{
|
||||
UserMessage: "Please summarize this text",
|
||||
}
|
||||
dec2, err := routeHeuristic(group, targets, ctx2)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if dec2.SelectedModel != "kimi-k2.6" {
|
||||
t.Fatalf("expected kimi-k2.6, got %s", dec2.SelectedModel)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
// Decision holds the result of a routing decision.
|
||||
type Decision struct {
|
||||
SelectedModel string `json:"selected_model"`
|
||||
Strategy string `json:"strategy"` // "heuristic" or "classifier"
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// RouteContext holds metadata of the request to evaluate condition rules.
|
||||
type RouteContext struct {
|
||||
UserMessage string `json:"user_message"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
HasMultimodalInput bool `json:"has_multimodal_input"`
|
||||
RequiresToolCalling bool `json:"requires_tool_calling"`
|
||||
RequiresReasoning bool `json:"requires_reasoning"`
|
||||
Tags []string `json:"tags"`
|
||||
}
|
||||
|
||||
// ClassifierFunc is the callback for classifier-based routing.
|
||||
type ClassifierFunc func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error)
|
||||
|
||||
// Router resolves model groups to concrete models.
|
||||
type Router struct {
|
||||
groups map[string]db.ModelGroup
|
||||
classify ClassifierFunc
|
||||
}
|
||||
|
||||
// New creates a Router. classify may be nil if no classifier groups exist.
|
||||
func New(groups []db.ModelGroup, classify ClassifierFunc) *Router {
|
||||
r := &Router{
|
||||
groups: make(map[string]db.ModelGroup),
|
||||
classify: classify,
|
||||
}
|
||||
for _, g := range groups {
|
||||
r.groups[g.ID] = g
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Groups returns all registered model group IDs.
|
||||
func (r *Router) Groups() []string {
|
||||
ids := make([]string, 0, len(r.groups))
|
||||
for id := range r.groups {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// IsGroup returns true if the model name is a group ID.
|
||||
func (r *Router) IsGroup(modelID string) bool {
|
||||
_, ok := r.groups[modelID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// Route resolves a group to a concrete model.
|
||||
func (r *Router) Route(ctx context.Context, groupID string, routeCtx *RouteContext) (*Decision, error) {
|
||||
group, ok := r.groups[groupID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown model group: %s", groupID)
|
||||
}
|
||||
|
||||
var targets []string
|
||||
if err := json.Unmarshal([]byte(group.Targets), &targets); err != nil || len(targets) == 0 {
|
||||
return nil, fmt.Errorf("invalid or empty targets for group %s", groupID)
|
||||
}
|
||||
|
||||
switch group.Strategy {
|
||||
case "heuristic":
|
||||
return routeHeuristic(group, targets, routeCtx)
|
||||
case "classifier":
|
||||
if r.classify == nil {
|
||||
return routeHeuristic(group, targets, routeCtx)
|
||||
}
|
||||
return routeClassifier(ctx, r.classify, group, targets, routeCtx)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown strategy: %s", group.Strategy)
|
||||
}
|
||||
}
|
||||
|
||||
// RouteToConcrete resolves a model name to a concrete model, following group
|
||||
// chains recursively until a non-group target is reached. Returns the original
|
||||
// name unchanged if it is not a group.
|
||||
func (r *Router) RouteToConcrete(ctx context.Context, modelID string, routeCtx *RouteContext) (*Decision, error) {
|
||||
const maxDepth = 10
|
||||
visited := make(map[string]bool)
|
||||
current := modelID
|
||||
var chain []*Decision
|
||||
|
||||
for depth := 0; depth < maxDepth; depth++ {
|
||||
if !r.IsGroup(current) {
|
||||
// Build a composite reason showing the chain traversed
|
||||
reason := "direct"
|
||||
if len(chain) > 0 {
|
||||
parts := make([]string, len(chain))
|
||||
for i, d := range chain {
|
||||
parts[i] = d.SelectedModel + " (" + d.Reason + ")"
|
||||
}
|
||||
reason = strings.Join(parts, " -> ")
|
||||
}
|
||||
return &Decision{
|
||||
SelectedModel: current,
|
||||
Strategy: "hierarchical",
|
||||
Reason: reason,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if visited[current] {
|
||||
return nil, fmt.Errorf("routing cycle detected: group %s already visited", current)
|
||||
}
|
||||
visited[current] = true
|
||||
|
||||
decision, err := r.Route(ctx, current, routeCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chain = append(chain, decision)
|
||||
current = decision.SelectedModel
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("routing depth exceeded: reached max depth of %d", maxDepth)
|
||||
}
|
||||
|
||||
// Reload replaces the group definitions without recreating the router.
|
||||
func (r *Router) Reload(groups []db.ModelGroup) {
|
||||
r.groups = make(map[string]db.ModelGroup)
|
||||
for _, g := range groups {
|
||||
r.groups[g.ID] = g
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,372 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type UsagePeriodFilter struct {
|
||||
Period string `form:"period"`
|
||||
From string `form:"from"`
|
||||
To string `form:"to"`
|
||||
}
|
||||
|
||||
func (f *UsagePeriodFilter) ToSQL() (string, []interface{}) {
|
||||
period := f.Period
|
||||
if period == "" {
|
||||
period = "all"
|
||||
}
|
||||
|
||||
if period == "custom" {
|
||||
var clauses []string
|
||||
var binds []interface{}
|
||||
if f.From != "" {
|
||||
clauses = append(clauses, "timestamp >= ?")
|
||||
binds = append(binds, f.From)
|
||||
}
|
||||
if f.To != "" {
|
||||
clauses = append(clauses, "timestamp <= ?")
|
||||
binds = append(binds, f.To)
|
||||
}
|
||||
if len(clauses) > 0 {
|
||||
return " AND " + strings.Join(clauses, " AND "), binds
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
var cutoff time.Time
|
||||
switch period {
|
||||
case "today":
|
||||
cutoff = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
|
||||
case "24h":
|
||||
cutoff = now.Add(-24 * time.Hour)
|
||||
case "7d":
|
||||
cutoff = now.Add(-7 * 24 * time.Hour)
|
||||
case "30d":
|
||||
cutoff = now.Add(-30 * 24 * time.Hour)
|
||||
default:
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return " AND timestamp >= ?", []interface{}{cutoff.Format(time.RFC3339)}
|
||||
}
|
||||
|
||||
func (s *Server) handleUsageSummary(c *gin.Context) {
|
||||
var filter UsagePeriodFilter
|
||||
if err := c.ShouldBindQuery(&filter); err != nil {
|
||||
// ignore
|
||||
}
|
||||
|
||||
clause, binds := filter.ToSQL()
|
||||
|
||||
// Total stats
|
||||
var totalStats struct {
|
||||
TotalRequests int `db:"total_requests"`
|
||||
TotalTokens int `db:"total_tokens"`
|
||||
CacheReadTokens int `db:"total_cache_read_tokens"`
|
||||
CacheWriteTokens int `db:"total_cache_write_tokens"`
|
||||
TotalCost float64 `db:"total_cost"`
|
||||
ActiveClients int `db:"active_clients"`
|
||||
}
|
||||
err := s.database.Get(&totalStats, fmt.Sprintf(`
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(total_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||||
COALESCE(SUM(cache_write_tokens), 0) as total_cache_write_tokens,
|
||||
COALESCE(SUM(cost), 0.0) as total_cost,
|
||||
COUNT(DISTINCT client_id) as active_clients
|
||||
FROM llm_requests
|
||||
WHERE 1=1 %s
|
||||
`, clause), binds...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Today stats
|
||||
var todayStats struct {
|
||||
TodayRequests int `db:"today_requests"`
|
||||
TodayCost float64 `db:"today_cost"`
|
||||
}
|
||||
today := time.Now().UTC().Format("2006-01-02")
|
||||
err = s.database.Get(&todayStats, `
|
||||
SELECT
|
||||
COUNT(*) as today_requests,
|
||||
COALESCE(SUM(cost), 0.0) as today_cost
|
||||
FROM llm_requests
|
||||
WHERE timestamp LIKE ?
|
||||
`, today+"%")
|
||||
if err != nil {
|
||||
todayStats.TodayRequests = 0
|
||||
todayStats.TodayCost = 0.0
|
||||
}
|
||||
|
||||
// Error rate & Avg response time
|
||||
var miscStats struct {
|
||||
ErrorRate float64 `db:"error_rate"`
|
||||
AvgResponseTime float64 `db:"avg_response_time"`
|
||||
}
|
||||
err = s.database.Get(&miscStats, `
|
||||
SELECT
|
||||
CASE WHEN COUNT(*) = 0 THEN 0.0 ELSE (CAST(SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*)) * 100.0 END as error_rate,
|
||||
COALESCE(AVG(duration_ms), 0.0) as avg_response_time
|
||||
FROM llm_requests
|
||||
`)
|
||||
if err != nil {
|
||||
miscStats.ErrorRate = 0.0
|
||||
miscStats.AvgResponseTime = 0.0
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"total_requests": totalStats.TotalRequests,
|
||||
"total_tokens": totalStats.TotalTokens,
|
||||
"total_cache_read_tokens": totalStats.CacheReadTokens,
|
||||
"total_cache_write_tokens": totalStats.CacheWriteTokens,
|
||||
"total_cost": totalStats.TotalCost,
|
||||
"active_clients": totalStats.ActiveClients,
|
||||
"today_requests": todayStats.TodayRequests,
|
||||
"today_cost": todayStats.TodayCost,
|
||||
"error_rate": miscStats.ErrorRate,
|
||||
"avg_response_time": miscStats.AvgResponseTime,
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleTimeSeries(c *gin.Context) {
|
||||
var filter UsagePeriodFilter
|
||||
if err := c.ShouldBindQuery(&filter); err != nil {
|
||||
// ignore
|
||||
}
|
||||
|
||||
clause, binds := filter.ToSQL()
|
||||
|
||||
if clause == "" {
|
||||
cutoff := time.Now().UTC().Add(-30 * 24 * time.Hour)
|
||||
clause = " AND timestamp >= ?"
|
||||
binds = []interface{}{cutoff.Format(time.RFC3339)}
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
COALESCE(SUBSTR(timestamp, 1, 10), 'unknown') as bucket,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(total_tokens), 0) as tokens,
|
||||
COALESCE(SUM(cost), 0.0) as cost
|
||||
FROM llm_requests
|
||||
WHERE 1=1 %s
|
||||
GROUP BY bucket
|
||||
ORDER BY bucket
|
||||
`, clause)
|
||||
|
||||
rows, err := s.database.Queryx(query, binds...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var series []gin.H
|
||||
for rows.Next() {
|
||||
var bucket string
|
||||
var requests int
|
||||
var tokens int
|
||||
var cost float64
|
||||
if err := rows.Scan(&bucket, &requests, &tokens, &cost); err != nil {
|
||||
continue
|
||||
}
|
||||
series = append(series, gin.H{
|
||||
"time": bucket,
|
||||
"requests": requests,
|
||||
"tokens": tokens,
|
||||
"cost": cost,
|
||||
})
|
||||
}
|
||||
|
||||
granularity := "day"
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"series": series,
|
||||
"granularity": granularity,
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleProvidersUsage(c *gin.Context) {
|
||||
var filter UsagePeriodFilter
|
||||
if err := c.ShouldBindQuery(&filter); err != nil {
|
||||
// ignore
|
||||
}
|
||||
|
||||
clause, binds := filter.ToSQL()
|
||||
|
||||
rows, err := s.database.Queryx(fmt.Sprintf(`
|
||||
SELECT
|
||||
COALESCE(provider, 'unknown') as provider,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(cost), 0.0) as cost
|
||||
FROM llm_requests
|
||||
WHERE 1=1 %s
|
||||
GROUP BY provider
|
||||
`, clause), binds...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, SuccessResponse([]interface{}{}))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []gin.H
|
||||
for rows.Next() {
|
||||
var provider string
|
||||
var requests int
|
||||
var cost float64
|
||||
if err := rows.Scan(&provider, &requests, &cost); err == nil {
|
||||
results = append(results, gin.H{"provider": provider, "requests": requests, "cost": cost})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(results))
|
||||
}
|
||||
|
||||
func (s *Server) handleClientsUsage(c *gin.Context) {
|
||||
var filter UsagePeriodFilter
|
||||
if err := c.ShouldBindQuery(&filter); err != nil {
|
||||
// ignore
|
||||
}
|
||||
|
||||
clause, binds := filter.ToSQL()
|
||||
|
||||
rows, err := s.database.Queryx(fmt.Sprintf(`
|
||||
SELECT COALESCE(client_id, 'unknown') as client_id, COUNT(*) as requests
|
||||
FROM llm_requests
|
||||
WHERE 1=1 %s
|
||||
GROUP BY client_id
|
||||
`, clause), binds...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, SuccessResponse([]interface{}{}))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []gin.H
|
||||
for rows.Next() {
|
||||
var clientID string
|
||||
var requests int
|
||||
if err := rows.Scan(&clientID, &requests); err == nil {
|
||||
results = append(results, gin.H{"client_id": clientID, "requests": requests})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(results))
|
||||
}
|
||||
|
||||
func (s *Server) handleAnalyticsBreakdown(c *gin.Context) {
|
||||
var filter UsagePeriodFilter
|
||||
if err := c.ShouldBindQuery(&filter); err != nil {
|
||||
// ignore
|
||||
}
|
||||
|
||||
clause, binds := filter.ToSQL()
|
||||
|
||||
// Models breakdown
|
||||
var models []struct {
|
||||
Label string `json:"label"`
|
||||
Value int `json:"value"`
|
||||
}
|
||||
mRows, err := s.database.Queryx(fmt.Sprintf("SELECT COALESCE(model, 'unknown') as label, COUNT(*) as value FROM llm_requests WHERE 1=1 %s GROUP BY model ORDER BY value DESC", clause), binds...)
|
||||
if err == nil {
|
||||
for mRows.Next() {
|
||||
var label string
|
||||
var value int
|
||||
if err := mRows.Scan(&label, &value); err == nil {
|
||||
models = append(models, struct {
|
||||
Label string `json:"label"`
|
||||
Value int `json:"value"`
|
||||
}{label, value})
|
||||
}
|
||||
}
|
||||
mRows.Close()
|
||||
}
|
||||
|
||||
// Clients breakdown
|
||||
var clients []struct {
|
||||
Label string `json:"label"`
|
||||
Value int `json:"value"`
|
||||
}
|
||||
cRows, err := s.database.Queryx(fmt.Sprintf("SELECT COALESCE(client_id, 'unknown') as label, COUNT(*) as value FROM llm_requests WHERE 1=1 %s GROUP BY client_id ORDER BY value DESC", clause), binds...)
|
||||
if err == nil {
|
||||
for cRows.Next() {
|
||||
var label string
|
||||
var value int
|
||||
if err := cRows.Scan(&label, &value); err == nil {
|
||||
clients = append(clients, struct {
|
||||
Label string `json:"label"`
|
||||
Value int `json:"value"`
|
||||
}{label, value})
|
||||
}
|
||||
}
|
||||
cRows.Close()
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"models": models,
|
||||
"clients": clients,
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleDetailedUsage(c *gin.Context) {
|
||||
var filter UsagePeriodFilter
|
||||
if err := c.ShouldBindQuery(&filter); err != nil {
|
||||
// ignore
|
||||
}
|
||||
|
||||
clause, binds := filter.ToSQL()
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
COALESCE(SUBSTR(timestamp, 1, 10), 'unknown') as date,
|
||||
COALESCE(client_id, 'unknown') as client,
|
||||
COALESCE(provider, 'unknown') as provider,
|
||||
COALESCE(model, 'unknown') as model,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(total_tokens), 0) as tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens,
|
||||
COALESCE(SUM(cache_write_tokens), 0) as cache_write_tokens,
|
||||
COALESCE(SUM(cost), 0.0) as cost
|
||||
FROM llm_requests
|
||||
WHERE 1=1 %s
|
||||
GROUP BY date, client, provider, model
|
||||
ORDER BY date DESC, cost DESC
|
||||
`, clause)
|
||||
|
||||
rows, err := s.database.Queryx(query, binds...)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, SuccessResponse([]interface{}{}))
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []gin.H
|
||||
for rows.Next() {
|
||||
var date, client, provider, model string
|
||||
var requests, tokens, cacheRead, cacheWrite int
|
||||
var cost float64
|
||||
if err := rows.Scan(&date, &client, &provider, &model, &requests, &tokens, &cacheRead, &cacheWrite, &cost); err == nil {
|
||||
results = append(results, gin.H{
|
||||
"date": date,
|
||||
"client": client,
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"requests": requests,
|
||||
"tokens": tokens,
|
||||
"cache_read_tokens": cacheRead,
|
||||
"cache_write_tokens": cacheWrite,
|
||||
"cost": cost,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(results))
|
||||
}
|
||||
@@ -0,0 +1,273 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"gophergate/internal/db"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func (s *Server) handleGetClients(c *gin.Context) {
|
||||
var clients []db.Client
|
||||
err := s.database.Select(&clients, "SELECT * FROM clients ORDER BY created_at DESC")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
type UIClient struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastUsed *time.Time `json:"last_used"`
|
||||
RequestsCount int `json:"requests_count"`
|
||||
TokensCount int `json:"tokens_count"`
|
||||
Status string `json:"status"`
|
||||
RateLimitPerMinute int `json:"rate_limit_per_minute"`
|
||||
}
|
||||
|
||||
uiClients := make([]UIClient, len(clients))
|
||||
for i, cl := range clients {
|
||||
status := "active"
|
||||
if !cl.IsActive {
|
||||
status = "disabled"
|
||||
}
|
||||
|
||||
name := ""
|
||||
if cl.Name != nil {
|
||||
name = *cl.Name
|
||||
}
|
||||
desc := ""
|
||||
if cl.Description != nil {
|
||||
desc = *cl.Description
|
||||
}
|
||||
|
||||
var lastUsedStr string
|
||||
_ = s.database.Get(&lastUsedStr, "SELECT MAX(last_used_at) FROM client_tokens WHERE client_id = ?", cl.ClientID)
|
||||
|
||||
var lastUsed *time.Time
|
||||
if lastUsedStr != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", lastUsedStr); err == nil {
|
||||
lastUsed = &t
|
||||
}
|
||||
}
|
||||
|
||||
uiClients[i] = UIClient{
|
||||
ID: cl.ClientID,
|
||||
Name: name,
|
||||
Description: desc,
|
||||
CreatedAt: cl.CreatedAt,
|
||||
LastUsed: lastUsed,
|
||||
RequestsCount: cl.TotalRequests,
|
||||
TokensCount: cl.TotalTokens,
|
||||
Status: status,
|
||||
RateLimitPerMinute: cl.RateLimitPerMinute,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(uiClients))
|
||||
}
|
||||
|
||||
func (s *Server) handleGetClient(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var cl db.Client
|
||||
err := s.database.Get(&cl, "SELECT * FROM clients WHERE client_id = ?", id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, ErrorResponse("Client not found"))
|
||||
return
|
||||
}
|
||||
|
||||
name := ""
|
||||
if cl.Name != nil {
|
||||
name = *cl.Name
|
||||
}
|
||||
desc := ""
|
||||
if cl.Description != nil {
|
||||
desc = *cl.Description
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"id": cl.ClientID,
|
||||
"name": name,
|
||||
"description": desc,
|
||||
"is_active": cl.IsActive,
|
||||
"rate_limit_per_minute": cl.RateLimitPerMinute,
|
||||
"created_at": cl.CreatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
type UpdateClientRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
IsActive bool `json:"is_active"`
|
||||
RateLimitPerMinute *int `json:"rate_limit_per_minute"`
|
||||
}
|
||||
|
||||
func (s *Server) handleUpdateClient(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var req UpdateClientRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.database.Exec(`
|
||||
UPDATE clients SET
|
||||
name = ?,
|
||||
description = ?,
|
||||
is_active = ?,
|
||||
rate_limit_per_minute = COALESCE(?, rate_limit_per_minute),
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE client_id = ?
|
||||
`, req.Name, req.Description, req.IsActive, req.RateLimitPerMinute, id)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Client updated"}))
|
||||
}
|
||||
|
||||
type CreateClientRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
ClientID *string `json:"client_id"`
|
||||
}
|
||||
|
||||
func (s *Server) handleCreateClient(c *gin.Context) {
|
||||
var req CreateClientRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
clientID := ""
|
||||
if req.ClientID != nil {
|
||||
clientID = *req.ClientID
|
||||
} else {
|
||||
clientID = "client-" + uuid.New().String()[:8]
|
||||
}
|
||||
|
||||
_, err := s.database.Exec("INSERT INTO clients (client_id, name, is_active) VALUES (?, ?, 1)", clientID, req.Name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
token := "sk-" + uuid.New().String() + uuid.New().String()
|
||||
token = token[:51]
|
||||
|
||||
_, err = s.database.Exec("INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, 'default')", clientID, token)
|
||||
if err != nil {
|
||||
// Log error
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"id": clientID,
|
||||
"name": req.Name,
|
||||
"status": "active",
|
||||
"token": token,
|
||||
"created_at": time.Now(),
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleDeleteClient(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "default" {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Cannot delete default client"))
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.database.Exec("DELETE FROM clients WHERE client_id = ?", id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Client deleted"}))
|
||||
}
|
||||
|
||||
func (s *Server) handleGetClientTokens(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var tokens []db.ClientToken
|
||||
err := s.database.Select(&tokens, "SELECT * FROM client_tokens WHERE client_id = ? ORDER BY created_at DESC", id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
type MaskedToken struct {
|
||||
ID int `json:"id"`
|
||||
TokenMasked string `json:"token_masked"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
}
|
||||
|
||||
masked := make([]MaskedToken, len(tokens))
|
||||
for i, t := range tokens {
|
||||
maskedToken := "••••"
|
||||
if len(t.Token) > 8 {
|
||||
maskedToken = t.Token[:3] + "••••" + t.Token[len(t.Token)-8:]
|
||||
}
|
||||
masked[i] = MaskedToken{
|
||||
ID: t.ID,
|
||||
TokenMasked: maskedToken,
|
||||
Name: t.Name,
|
||||
IsActive: t.IsActive,
|
||||
CreatedAt: t.CreatedAt,
|
||||
LastUsedAt: t.LastUsedAt,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(masked))
|
||||
}
|
||||
|
||||
type CreateTokenRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (s *Server) handleCreateClientToken(c *gin.Context) {
|
||||
clientID := c.Param("id")
|
||||
var req CreateTokenRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// optional name
|
||||
}
|
||||
|
||||
name := "default"
|
||||
if req.Name != "" {
|
||||
name = req.Name
|
||||
}
|
||||
|
||||
token := "sk-" + uuid.New().String() + uuid.New().String()
|
||||
token = token[:51]
|
||||
|
||||
_, err := s.database.Exec("INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, ?)", clientID, token, name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"token": token,
|
||||
"name": name,
|
||||
"created_at": time.Now(),
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleDeleteClientToken(c *gin.Context) {
|
||||
tokenID := c.Param("token_id")
|
||||
|
||||
_, err := s.database.Exec("DELETE FROM client_tokens WHERE id = ?", tokenID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Token revoked"}))
|
||||
}
|
||||
+5
-1234
File diff suppressed because it is too large
Load Diff
@@ -12,6 +12,7 @@ type RequestLog struct {
|
||||
ClientID string `json:"client_id"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
ModelGroup string `json:"model_group,omitempty"`
|
||||
PromptTokens uint32 `json:"prompt_tokens"`
|
||||
CompletionTokens uint32 `json:"completion_tokens"`
|
||||
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||
@@ -72,7 +73,7 @@ func (l *RequestLogger) processLog(entry RequestLog) {
|
||||
defer tx.Rollback()
|
||||
|
||||
// Ensure client exists
|
||||
_, _ = tx.Exec("INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')",
|
||||
_, _ = tx.Exec("INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')",
|
||||
entry.ClientID, entry.ClientID)
|
||||
|
||||
// Insert log
|
||||
@@ -80,9 +81,9 @@ func (l *RequestLogger) processLog(entry RequestLog) {
|
||||
INSERT INTO llm_requests
|
||||
(timestamp, client_id, provider, model, prompt_tokens, completion_tokens, reasoning_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status, error_message, duration_ms)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, entry.Timestamp, entry.ClientID, entry.Provider, entry.Model,
|
||||
entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.TotalTokens,
|
||||
entry.CacheReadTokens, entry.CacheWriteTokens, entry.Cost, entry.HasImages,
|
||||
`, entry.Timestamp, entry.ClientID, entry.Provider, entry.Model,
|
||||
entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.TotalTokens,
|
||||
entry.CacheReadTokens, entry.CacheWriteTokens, entry.Cost, entry.HasImages,
|
||||
entry.Status, entry.ErrorMessage, entry.DurationMS)
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"gophergate/internal/db"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (s *Server) handleGetModelGroups(c *gin.Context) {
|
||||
var groups []db.ModelGroup
|
||||
if err := s.database.Select(&groups, "SELECT * FROM model_groups ORDER BY id"); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if groups == nil {
|
||||
groups = []db.ModelGroup{}
|
||||
}
|
||||
c.JSON(http.StatusOK, SuccessResponse(groups))
|
||||
}
|
||||
|
||||
func (s *Server) handleCreateModelGroup(c *gin.Context) {
|
||||
var group db.ModelGroup
|
||||
if err := c.ShouldBindJSON(&group); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.database.Exec(`
|
||||
INSERT INTO model_groups (id, strategy, selector_model, targets, complexity_threshold, heuristic_rules, logic_level, primary_use)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
group.ID, group.Strategy, group.SelectorModel, group.Targets,
|
||||
group.ComplexityThreshold, group.HeuristicRules, group.LogicLevel, group.PrimaryUse)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
c.JSON(http.StatusCreated, SuccessResponse(group))
|
||||
}
|
||||
|
||||
func (s *Server) handleUpdateModelGroup(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var group db.ModelGroup
|
||||
if err := c.ShouldBindJSON(&group); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
_, err := s.database.Exec(`
|
||||
UPDATE model_groups SET strategy=?, selector_model=?, targets=?, complexity_threshold=?, heuristic_rules=?, logic_level=?, primary_use=?, updated_at=CURRENT_TIMESTAMP
|
||||
WHERE id=?`,
|
||||
group.Strategy, group.SelectorModel, group.Targets,
|
||||
group.ComplexityThreshold, group.HeuristicRules, group.LogicLevel, group.PrimaryUse, id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
c.JSON(http.StatusOK, SuccessResponse(group))
|
||||
}
|
||||
|
||||
func (s *Server) handleDeleteModelGroup(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
_, err := s.database.Exec("DELETE FROM model_groups WHERE id=?", id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
c.JSON(http.StatusOK, gin.H{"status": "deleted"})
|
||||
}
|
||||
@@ -0,0 +1,231 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"gophergate/internal/db"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (s *Server) handleGetModels(c *gin.Context) {
|
||||
usedOnly := c.Query("used_only") == "true"
|
||||
|
||||
// Registry provider normalized name -> Proxy-internal provider ID
|
||||
allowedRegistryProviders := map[string]string{
|
||||
"openai": "openai",
|
||||
"google": "gemini",
|
||||
"deepseek": "deepseek",
|
||||
"xai": "grok",
|
||||
"ollama": "ollama",
|
||||
"xiaomi": "xiaomi",
|
||||
}
|
||||
|
||||
// Merge registry models with DB overrides
|
||||
var dbModels []db.ModelConfig
|
||||
_ = s.database.Select(&dbModels, "SELECT * FROM model_configs")
|
||||
|
||||
dbMap := make(map[string]db.ModelConfig)
|
||||
for _, m := range dbModels {
|
||||
dbMap[m.ID] = m
|
||||
}
|
||||
|
||||
// Fetch specific (model, provider) combinations that have been used
|
||||
type modelProvider struct {
|
||||
Model string `db:"model"`
|
||||
Provider string `db:"provider"`
|
||||
}
|
||||
usedPairs := make(map[string]bool)
|
||||
if usedOnly {
|
||||
var pairs []modelProvider
|
||||
err := s.database.Select(&pairs, "SELECT DISTINCT model, provider FROM llm_requests WHERE status = 'success'")
|
||||
if err == nil {
|
||||
for _, p := range pairs {
|
||||
usedPairs[fmt.Sprintf("%s:%s", p.Model, p.Provider)] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var result []gin.H
|
||||
s.registryMu.RLock()
|
||||
if s.registry != nil {
|
||||
for pID, pInfo := range s.registry.Providers {
|
||||
proxyProvider, allowed := allowedRegistryProviders[pID]
|
||||
if !allowed {
|
||||
continue
|
||||
}
|
||||
|
||||
for mID, mMeta := range pInfo.Models {
|
||||
if usedOnly && !usedPairs[fmt.Sprintf("%s:%s", mID, proxyProvider)] {
|
||||
continue
|
||||
}
|
||||
|
||||
enabled := true
|
||||
promptCost := 0.0
|
||||
completionCost := 0.0
|
||||
var cacheReadCost *float64
|
||||
var cacheWriteCost *float64
|
||||
var mapping *string
|
||||
contextLimit := uint32(0)
|
||||
|
||||
if mMeta.Cost != nil {
|
||||
promptCost = mMeta.Cost.Input
|
||||
completionCost = mMeta.Cost.Output
|
||||
cacheReadCost = mMeta.Cost.CacheRead
|
||||
cacheWriteCost = mMeta.Cost.CacheWrite
|
||||
}
|
||||
if mMeta.Limit != nil {
|
||||
contextLimit = mMeta.Limit.Context
|
||||
}
|
||||
|
||||
// Override from DB
|
||||
if dbCfg, ok := dbMap[mID]; ok {
|
||||
enabled = dbCfg.Enabled
|
||||
if dbCfg.PromptCostPerM != nil {
|
||||
promptCost = *dbCfg.PromptCostPerM
|
||||
}
|
||||
if dbCfg.CompletionCostPerM != nil {
|
||||
completionCost = *dbCfg.CompletionCostPerM
|
||||
}
|
||||
if dbCfg.CacheReadCostPerM != nil {
|
||||
cacheReadCost = dbCfg.CacheReadCostPerM
|
||||
}
|
||||
if dbCfg.CacheWriteCostPerM != nil {
|
||||
cacheWriteCost = dbCfg.CacheWriteCostPerM
|
||||
}
|
||||
mapping = dbCfg.Mapping
|
||||
}
|
||||
|
||||
result = append(result, gin.H{
|
||||
"id": mID,
|
||||
"name": mMeta.Name,
|
||||
"provider": proxyProvider,
|
||||
"enabled": enabled,
|
||||
"prompt_cost": promptCost,
|
||||
"completion_cost": completionCost,
|
||||
"cache_read_cost": cacheReadCost,
|
||||
"cache_write_cost": cacheWriteCost,
|
||||
"context_limit": contextLimit,
|
||||
"mapping": mapping,
|
||||
"tool_call": mMeta.ToolCall != nil && *mMeta.ToolCall,
|
||||
"reasoning": mMeta.Reasoning != nil && *mMeta.Reasoning,
|
||||
"modalities": mMeta.Modalities,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add configured Ollama models if they aren't in registry
|
||||
if s.cfg.Providers.Ollama.Enabled {
|
||||
for _, mID := range s.cfg.Providers.Ollama.Models {
|
||||
// Check if already added from registry
|
||||
exists := false
|
||||
for _, r := range result {
|
||||
if r["id"] == mID {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if exists {
|
||||
continue
|
||||
}
|
||||
|
||||
if usedOnly && !usedPairs[fmt.Sprintf("%s:ollama", mID)] {
|
||||
continue
|
||||
}
|
||||
|
||||
enabled := true
|
||||
promptCost := 0.0
|
||||
completionCost := 0.0
|
||||
var cacheReadCost *float64
|
||||
var cacheWriteCost *float64
|
||||
var mapping *string
|
||||
contextLimit := uint32(0)
|
||||
|
||||
// Override from DB
|
||||
if dbCfg, ok := dbMap[mID]; ok {
|
||||
enabled = dbCfg.Enabled
|
||||
if dbCfg.PromptCostPerM != nil {
|
||||
promptCost = *dbCfg.PromptCostPerM
|
||||
}
|
||||
if dbCfg.CompletionCostPerM != nil {
|
||||
completionCost = *dbCfg.CompletionCostPerM
|
||||
}
|
||||
if dbCfg.CacheReadCostPerM != nil {
|
||||
cacheReadCost = dbCfg.CacheReadCostPerM
|
||||
}
|
||||
if dbCfg.CacheWriteCostPerM != nil {
|
||||
cacheWriteCost = dbCfg.CacheWriteCostPerM
|
||||
}
|
||||
mapping = dbCfg.Mapping
|
||||
}
|
||||
|
||||
result = append(result, gin.H{
|
||||
"id": mID,
|
||||
"name": mID,
|
||||
"provider": "ollama",
|
||||
"enabled": enabled,
|
||||
"prompt_cost": promptCost,
|
||||
"completion_cost": completionCost,
|
||||
"cache_read_cost": cacheReadCost,
|
||||
"cache_write_cost": cacheWriteCost,
|
||||
"context_limit": contextLimit,
|
||||
"modalities": gin.H{"input": []string{"text"}, "output": []string{"text"}},
|
||||
"tool_call": false,
|
||||
"reasoning": false,
|
||||
"mapping": mapping,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(result))
|
||||
}
|
||||
|
||||
func (s *Server) handleUpdateModel(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var req struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
PromptCost float64 `json:"prompt_cost"`
|
||||
CompletionCost float64 `json:"completion_cost"`
|
||||
CacheReadCost *float64 `json:"cache_read_cost"`
|
||||
CacheWriteCost *float64 `json:"cache_write_cost"`
|
||||
Mapping *string `json:"mapping"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
// Find provider for this model
|
||||
providerID := "unknown"
|
||||
s.registryMu.RLock()
|
||||
if s.registry != nil {
|
||||
for pID, pInfo := range s.registry.Providers {
|
||||
if _, ok := pInfo.Models[id]; ok {
|
||||
providerID = pID
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, err := s.database.Exec(`
|
||||
INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, cache_read_cost_per_m, cache_write_cost_per_m, mapping)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
enabled = excluded.enabled,
|
||||
prompt_cost_per_m = excluded.prompt_cost_per_m,
|
||||
completion_cost_per_m = excluded.completion_cost_per_m,
|
||||
cache_read_cost_per_m = excluded.cache_read_cost_per_m,
|
||||
cache_write_cost_per_m = excluded.cache_write_cost_per_m,
|
||||
mapping = excluded.mapping,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
`, id, providerID, req.Enabled, req.PromptCost, req.CompletionCost, req.CacheReadCost, req.CacheWriteCost, req.Mapping)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Model updated"}))
|
||||
}
|
||||
@@ -0,0 +1,252 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gophergate/internal/db"
|
||||
"gophergate/internal/models"
|
||||
"gophergate/internal/utils"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (s *Server) handleGetProviders(c *gin.Context) {
|
||||
var dbConfigs []db.ProviderConfig
|
||||
err := s.database.Select(&dbConfigs, "SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs")
|
||||
if err != nil {
|
||||
// Log error
|
||||
}
|
||||
|
||||
dbMap := make(map[string]db.ProviderConfig)
|
||||
for _, cfg := range dbConfigs {
|
||||
dbMap[cfg.ID] = cfg
|
||||
}
|
||||
|
||||
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama", "xiaomi"}
|
||||
var result []gin.H
|
||||
|
||||
for _, id := range providerIDs {
|
||||
var name string
|
||||
var enabled bool
|
||||
var baseURL string
|
||||
|
||||
switch id {
|
||||
case "openai":
|
||||
name = "OpenAI"
|
||||
enabled = s.cfg.Providers.OpenAI.Enabled
|
||||
baseURL = s.cfg.Providers.OpenAI.BaseURL
|
||||
case "gemini":
|
||||
name = "Google Gemini"
|
||||
enabled = s.cfg.Providers.Gemini.Enabled
|
||||
baseURL = s.cfg.Providers.Gemini.BaseURL
|
||||
case "deepseek":
|
||||
name = "DeepSeek"
|
||||
enabled = s.cfg.Providers.DeepSeek.Enabled
|
||||
baseURL = s.cfg.Providers.DeepSeek.BaseURL
|
||||
case "moonshot":
|
||||
name = "Moonshot"
|
||||
enabled = s.cfg.Providers.Moonshot.Enabled
|
||||
baseURL = s.cfg.Providers.Moonshot.BaseURL
|
||||
case "grok":
|
||||
name = "xAI Grok"
|
||||
enabled = s.cfg.Providers.Grok.Enabled
|
||||
baseURL = s.cfg.Providers.Grok.BaseURL
|
||||
case "xiaomi":
|
||||
name = "Xiaomi MiMo"
|
||||
enabled = s.cfg.Providers.Xiaomi.Enabled
|
||||
baseURL = s.cfg.Providers.Xiaomi.BaseURL
|
||||
case "ollama":
|
||||
name = "Ollama"
|
||||
enabled = s.cfg.Providers.Ollama.Enabled
|
||||
baseURL = s.cfg.Providers.Ollama.BaseURL
|
||||
}
|
||||
|
||||
var balance float64
|
||||
var threshold float64 = 5.0
|
||||
var billingMode string
|
||||
|
||||
if dbCfg, ok := dbMap[id]; ok {
|
||||
enabled = dbCfg.Enabled
|
||||
if dbCfg.BaseURL != nil {
|
||||
baseURL = *dbCfg.BaseURL
|
||||
}
|
||||
balance = dbCfg.CreditBalance
|
||||
threshold = dbCfg.LowCreditThreshold
|
||||
if dbCfg.BillingMode != nil {
|
||||
billingMode = *dbCfg.BillingMode
|
||||
}
|
||||
}
|
||||
|
||||
status := "disabled"
|
||||
if enabled {
|
||||
if _, ok := s.providers[id]; ok {
|
||||
status = "online"
|
||||
} else {
|
||||
status = "error"
|
||||
}
|
||||
}
|
||||
|
||||
// Get last used for this provider
|
||||
var lastUsedStr string
|
||||
_ = s.database.Get(&lastUsedStr, "SELECT MAX(timestamp) FROM llm_requests WHERE provider = ?", id)
|
||||
var lastUsed interface{}
|
||||
if lastUsedStr != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", lastUsedStr); err == nil {
|
||||
lastUsed = t
|
||||
}
|
||||
}
|
||||
|
||||
// Get models for this provider from registry
|
||||
var models []string
|
||||
s.registryMu.RLock()
|
||||
if s.registry != nil {
|
||||
registryID := id
|
||||
if id == "gemini" {
|
||||
registryID = "google"
|
||||
}
|
||||
if id == "moonshot" {
|
||||
registryID = "moonshot"
|
||||
}
|
||||
if id == "grok" {
|
||||
registryID = "xai"
|
||||
}
|
||||
if id == "xiaomi" {
|
||||
registryID = "xiaomi"
|
||||
}
|
||||
|
||||
if pInfo, ok := s.registry.Providers[registryID]; ok {
|
||||
for mID := range pInfo.Models {
|
||||
models = append(models, mID)
|
||||
}
|
||||
}
|
||||
}
|
||||
s.registryMu.RUnlock()
|
||||
|
||||
// If it's ollama, also include models from config
|
||||
if id == "ollama" {
|
||||
models = append(models, s.cfg.Providers.Ollama.Models...)
|
||||
}
|
||||
|
||||
result = append(result, gin.H{
|
||||
"id": id,
|
||||
"name": name,
|
||||
"enabled": enabled,
|
||||
"status": status,
|
||||
"base_url": baseURL,
|
||||
"credit_balance": balance,
|
||||
"low_credit_threshold": threshold,
|
||||
"billing_mode": billingMode,
|
||||
"last_used": lastUsed,
|
||||
"models": models,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(result))
|
||||
}
|
||||
|
||||
type UpdateProviderRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
BaseURL *string `json:"base_url"`
|
||||
APIKey *string `json:"api_key"`
|
||||
CreditBalance *float64 `json:"credit_balance"`
|
||||
LowCreditThreshold *float64 `json:"low_credit_threshold"`
|
||||
BillingMode *string `json:"billing_mode"`
|
||||
}
|
||||
|
||||
func (s *Server) handleUpdateProvider(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
var req UpdateProviderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
apiKeyEncrypted := false
|
||||
var apiKey *string = req.APIKey
|
||||
if req.APIKey != nil && *req.APIKey != "" {
|
||||
encrypted, err := utils.Encrypt(*req.APIKey, s.cfg.KeyBytes)
|
||||
if err == nil {
|
||||
apiKey = &encrypted
|
||||
apiKeyEncrypted = true
|
||||
}
|
||||
}
|
||||
|
||||
_, err := s.database.Exec(`
|
||||
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode, api_key_encrypted)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
enabled = excluded.enabled,
|
||||
base_url = COALESCE(excluded.base_url, provider_configs.base_url),
|
||||
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
|
||||
api_key_encrypted = excluded.api_key_encrypted,
|
||||
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
|
||||
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
|
||||
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
`, name, strings.ToUpper(name), req.Enabled, req.BaseURL, apiKey, req.CreditBalance, req.LowCreditThreshold, req.BillingMode, apiKeyEncrypted)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh in-memory providers
|
||||
if err := s.RefreshProviders(); err != nil {
|
||||
fmt.Printf("Error refreshing providers: %v\n", err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Provider updated"}))
|
||||
}
|
||||
|
||||
func (s *Server) handleTestProvider(c *gin.Context) {
|
||||
name := c.Param("name")
|
||||
provider, ok := s.providers[name]
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, ErrorResponse(fmt.Sprintf("Provider %s not found or not enabled", name)))
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
// Prepare a simple test request
|
||||
testReq := &models.UnifiedRequest{
|
||||
Model: "gpt-4o-mini", // Default cheap test model
|
||||
Messages: []models.UnifiedMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []models.UnifiedContentPart{{Type: "text", Text: "Hi"}},
|
||||
},
|
||||
},
|
||||
MaxTokens: new(uint32),
|
||||
}
|
||||
*testReq.MaxTokens = 5
|
||||
|
||||
// Adjust model for non-openai providers
|
||||
if name == "gemini" {
|
||||
testReq.Model = "gemini-2.0-flash"
|
||||
} else if name == "deepseek" {
|
||||
testReq.Model = "deepseek-chat"
|
||||
} else if name == "moonshot" {
|
||||
testReq.Model = "kimi-k2.5"
|
||||
} else if name == "grok" {
|
||||
testReq.Model = "grok-4-1-fast-non-reasoning"
|
||||
} else if name == "xiaomi" {
|
||||
testReq.Model = "mimo-v2.5"
|
||||
}
|
||||
|
||||
_, err := provider.ChatCompletion(c.Request.Context(), testReq)
|
||||
latency := time.Since(startTime).Milliseconds()
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, ErrorResponse(fmt.Sprintf("Provider test failed: %v", err)))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"message": "Connection test successful",
|
||||
"latency": latency,
|
||||
}))
|
||||
}
|
||||
+668
-62
@@ -2,10 +2,13 @@ package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gophergate/internal/config"
|
||||
@@ -13,20 +16,23 @@ import (
|
||||
"gophergate/internal/middleware"
|
||||
"gophergate/internal/models"
|
||||
"gophergate/internal/providers"
|
||||
"gophergate/internal/router"
|
||||
"gophergate/internal/utils"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
router *gin.Engine
|
||||
cfg *config.Config
|
||||
database *db.DB
|
||||
providers map[string]providers.Provider
|
||||
sessions *SessionManager
|
||||
hub *Hub
|
||||
logger *RequestLogger
|
||||
registry *models.ModelRegistry
|
||||
router *gin.Engine
|
||||
cfg *config.Config
|
||||
database *db.DB
|
||||
providers map[string]providers.Provider
|
||||
sessions *SessionManager
|
||||
hub *Hub
|
||||
logger *RequestLogger
|
||||
registry *models.ModelRegistry
|
||||
registryMu sync.RWMutex
|
||||
modelRouter *router.Router
|
||||
}
|
||||
|
||||
func NewServer(cfg *config.Config, database *db.DB) *Server {
|
||||
@@ -44,6 +50,7 @@ func NewServer(cfg *config.Config, database *db.DB) *Server {
|
||||
registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)},
|
||||
}
|
||||
|
||||
s.sessions.StartCleanup()
|
||||
// Fetch registry in background
|
||||
go func() {
|
||||
registry, err := utils.FetchRegistry()
|
||||
@@ -60,6 +67,9 @@ func NewServer(cfg *config.Config, database *db.DB) *Server {
|
||||
}
|
||||
|
||||
s.setupRoutes()
|
||||
|
||||
// Initialize model group router
|
||||
s.refreshRouter()
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -75,7 +85,7 @@ func (s *Server) RefreshProviders() error {
|
||||
dbMap[cfg.ID] = cfg
|
||||
}
|
||||
|
||||
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama"}
|
||||
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama", "xiaomi"}
|
||||
for _, id := range providerIDs {
|
||||
// Default values from config
|
||||
enabled := false
|
||||
@@ -103,6 +113,10 @@ func (s *Server) RefreshProviders() error {
|
||||
enabled = s.cfg.Providers.Grok.Enabled
|
||||
baseURL = s.cfg.Providers.Grok.BaseURL
|
||||
apiKey, _ = s.cfg.GetAPIKey("grok")
|
||||
case "xiaomi":
|
||||
enabled = s.cfg.Providers.Xiaomi.Enabled
|
||||
baseURL = s.cfg.Providers.Xiaomi.BaseURL
|
||||
apiKey, _ = s.cfg.GetAPIKey("xiaomi")
|
||||
}
|
||||
|
||||
// Overrides from DB
|
||||
@@ -131,40 +145,91 @@ func (s *Server) RefreshProviders() error {
|
||||
}
|
||||
|
||||
// Initialize provider
|
||||
var p providers.Provider
|
||||
switch id {
|
||||
case "openai":
|
||||
cfg := s.cfg.Providers.OpenAI
|
||||
cfg.BaseURL = baseURL
|
||||
s.providers["openai"] = providers.NewOpenAIProvider(cfg, apiKey)
|
||||
p = providers.NewOpenAIProvider(cfg, apiKey)
|
||||
case "gemini":
|
||||
cfg := s.cfg.Providers.Gemini
|
||||
cfg.BaseURL = baseURL
|
||||
s.providers["gemini"] = providers.NewGeminiProvider(cfg, apiKey)
|
||||
p = providers.NewGeminiProvider(cfg, apiKey)
|
||||
case "deepseek":
|
||||
cfg := s.cfg.Providers.DeepSeek
|
||||
cfg.BaseURL = baseURL
|
||||
s.providers["deepseek"] = providers.NewDeepSeekProvider(cfg, apiKey)
|
||||
p = providers.NewDeepSeekProvider(cfg, apiKey)
|
||||
case "moonshot":
|
||||
cfg := s.cfg.Providers.Moonshot
|
||||
cfg.BaseURL = baseURL
|
||||
s.providers["moonshot"] = providers.NewMoonshotProvider(cfg, apiKey)
|
||||
p = providers.NewMoonshotProvider(cfg, apiKey)
|
||||
case "grok":
|
||||
cfg := s.cfg.Providers.Grok
|
||||
cfg.BaseURL = baseURL
|
||||
s.providers["grok"] = providers.NewGrokProvider(cfg, apiKey)
|
||||
p = providers.NewGrokProvider(cfg, apiKey)
|
||||
case "ollama":
|
||||
cfg := s.cfg.Providers.Ollama
|
||||
cfg.BaseURL = baseURL
|
||||
s.providers["ollama"] = providers.NewOllamaProvider(cfg)
|
||||
p = providers.NewOllamaProvider(cfg)
|
||||
case "xiaomi":
|
||||
cfg := s.cfg.Providers.Xiaomi
|
||||
cfg.BaseURL = baseURL
|
||||
p = providers.NewXiaomiProvider(cfg, apiKey)
|
||||
}
|
||||
|
||||
if p != nil {
|
||||
s.providers[id] = providers.NewCircuitBreakerProvider(p)
|
||||
}
|
||||
}
|
||||
|
||||
s.refreshRouter()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) setupRoutes() {
|
||||
s.router.Use(middleware.AuthMiddleware(s.database))
|
||||
func (s *Server) refreshRouter() {
|
||||
var groups []db.ModelGroup
|
||||
if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil {
|
||||
fmt.Printf("Warning: Failed to load model groups: %v\n", err)
|
||||
groups = nil
|
||||
}
|
||||
|
||||
var classifyFn router.ClassifierFunc
|
||||
classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) {
|
||||
provider, _, err := s.selectProvider(selectorModel)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req := &models.UnifiedRequest{
|
||||
Model: selectorModel,
|
||||
Messages: []models.UnifiedMessage{
|
||||
{Role: "system", Content: []models.UnifiedContentPart{{Type: "text", Text: systemPrompt}}},
|
||||
{Role: "user", Content: []models.UnifiedContentPart{{Type: "text", Text: userMessage}}},
|
||||
},
|
||||
MaxTokens: uint32Ptr(5),
|
||||
Stream: false,
|
||||
}
|
||||
resp, err := provider.ChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", fmt.Errorf("no choices in classifier response")
|
||||
}
|
||||
content, ok := resp.Choices[0].Message.Content.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("classifier response content is not a string")
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
if s.modelRouter == nil {
|
||||
s.modelRouter = router.New(groups, classifyFn)
|
||||
} else {
|
||||
s.modelRouter.Reload(groups)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) setupRoutes() {
|
||||
// Static files
|
||||
s.router.StaticFile("/", "./static/index.html")
|
||||
s.router.StaticFile("/favicon.ico", "./static/favicon.ico")
|
||||
@@ -177,10 +242,12 @@ func (s *Server) setupRoutes() {
|
||||
|
||||
// API V1 (External LLM Access) - Secured with AuthMiddleware
|
||||
v1 := s.router.Group("/v1")
|
||||
v1.Use(middleware.AuthMiddleware(s.database))
|
||||
v1.Use(middleware.AuthMiddleware(s.database, true))
|
||||
{
|
||||
v1.POST("/chat/completions", s.handleChatCompletions)
|
||||
v1.POST("/images/generations", s.handleImageGenerations)
|
||||
v1.GET("/models", s.handleListModels)
|
||||
v1.POST("/responses", s.handleResponses)
|
||||
}
|
||||
|
||||
// Dashboard API Group
|
||||
@@ -190,7 +257,7 @@ func (s *Server) setupRoutes() {
|
||||
api.GET("/auth/status", s.handleAuthStatus)
|
||||
api.POST("/auth/logout", s.handleLogout)
|
||||
api.POST("/auth/change-password", s.handleChangePassword)
|
||||
|
||||
|
||||
// Protected dashboard routes (need admin session)
|
||||
admin := api.Group("/")
|
||||
admin.Use(s.adminAuthMiddleware())
|
||||
@@ -201,13 +268,13 @@ func (s *Server) setupRoutes() {
|
||||
admin.GET("/usage/clients", s.handleClientsUsage)
|
||||
admin.GET("/usage/detailed", s.handleDetailedUsage)
|
||||
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
|
||||
|
||||
|
||||
admin.GET("/clients", s.handleGetClients)
|
||||
admin.POST("/clients", s.handleCreateClient)
|
||||
admin.GET("/clients/:id", s.handleGetClient)
|
||||
admin.PUT("/clients/:id", s.handleUpdateClient)
|
||||
admin.DELETE("/clients/:id", s.handleDeleteClient)
|
||||
|
||||
|
||||
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
|
||||
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
|
||||
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
|
||||
@@ -215,10 +282,15 @@ func (s *Server) setupRoutes() {
|
||||
admin.GET("/providers", s.handleGetProviders)
|
||||
admin.PUT("/providers/:name", s.handleUpdateProvider)
|
||||
admin.POST("/providers/:name/test", s.handleTestProvider)
|
||||
|
||||
|
||||
admin.GET("/models", s.handleGetModels)
|
||||
admin.PUT("/models/:id", s.handleUpdateModel)
|
||||
|
||||
admin.GET("/model-groups", s.handleGetModelGroups)
|
||||
admin.POST("/model-groups", s.handleCreateModelGroup)
|
||||
admin.PUT("/model-groups/:id", s.handleUpdateModelGroup)
|
||||
admin.DELETE("/model-groups/:id", s.handleDeleteModelGroup)
|
||||
|
||||
admin.GET("/users", s.handleGetUsers)
|
||||
admin.POST("/users", s.handleCreateUser)
|
||||
admin.PUT("/users/:id", s.handleUpdateUser)
|
||||
@@ -237,6 +309,133 @@ func (s *Server) setupRoutes() {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleResponses(c *gin.Context) {
|
||||
startTime := time.Now()
|
||||
var req models.ResponsesRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Strip common prefixes and resolve model groups to concrete models
|
||||
// (same pattern as handleChatCompletions).
|
||||
modelGroup := ""
|
||||
modelID := req.Model
|
||||
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/", "xiaomi/"}
|
||||
for _, p := range prefixes {
|
||||
if strings.HasPrefix(modelID, p) {
|
||||
modelID = strings.TrimPrefix(modelID, p)
|
||||
break
|
||||
}
|
||||
}
|
||||
if s.modelRouter != nil {
|
||||
routeCtx := s.buildRouteContextFromResponses(req)
|
||||
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
|
||||
return
|
||||
}
|
||||
if decision.SelectedModel != modelID {
|
||||
modelGroup = modelID
|
||||
}
|
||||
modelID = decision.SelectedModel
|
||||
}
|
||||
|
||||
// Select provider based on resolved model name
|
||||
providerName := "openai" // default for Responses API
|
||||
modelLower := strings.ToLower(modelID)
|
||||
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
|
||||
providerName = "gemini"
|
||||
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
|
||||
providerName = "deepseek"
|
||||
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
|
||||
providerName = "moonshot"
|
||||
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
|
||||
providerName = "grok"
|
||||
} else if strings.HasPrefix(modelLower, "ollama/") ||
|
||||
strings.Contains(modelLower, "glm-") ||
|
||||
strings.Contains(modelLower, "qwen") ||
|
||||
strings.Contains(modelLower, "gemma") ||
|
||||
strings.Contains(modelLower, "llama") ||
|
||||
strings.Contains(modelLower, "mistral") ||
|
||||
strings.Contains(modelLower, "phi") ||
|
||||
strings.Contains(modelLower, "yi") ||
|
||||
strings.Contains(modelLower, "codellama") ||
|
||||
strings.Contains(modelLower, "command-r") {
|
||||
providerName = "ollama"
|
||||
}
|
||||
|
||||
provider, ok := s.providers[providerName]
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
||||
return
|
||||
}
|
||||
|
||||
// Use resolved model for the actual API call
|
||||
req.Model = modelID
|
||||
|
||||
clientID := "default"
|
||||
if auth, ok := c.Get("auth"); ok {
|
||||
if authInfo, ok := auth.(models.AuthInfo); ok {
|
||||
clientID = authInfo.ClientID
|
||||
}
|
||||
}
|
||||
|
||||
stream := req.Stream != nil && *req.Stream
|
||||
|
||||
if stream {
|
||||
ch, err := provider.ResponsesStream(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
var lastUsage *models.ResponsesUsage
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
chunk, ok := <-ch
|
||||
if !ok {
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
if lastUsage != nil {
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage.ToUsage(), nil, false)
|
||||
} else {
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false)
|
||||
}
|
||||
return false
|
||||
}
|
||||
// Capture usage from the response payload in streaming chunks
|
||||
if chunk.Response != nil && chunk.Response.Usage != nil {
|
||||
lastUsage = chunk.Response.Usage
|
||||
}
|
||||
data, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := provider.Responses(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Usage != nil {
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage.ToUsage(), nil, false)
|
||||
} else {
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false)
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) handleListModels(c *gin.Context) {
|
||||
type OpenAIModel struct {
|
||||
ID string `json:"id"`
|
||||
@@ -245,37 +444,112 @@ func (s *Server) handleListModels(c *gin.Context) {
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
var data []OpenAIModel
|
||||
modelMap := make(map[string]OpenAIModel)
|
||||
allowedProviders := map[string]bool{
|
||||
"openai": true,
|
||||
"google": true, // Models from models.dev use 'google' ID for Gemini
|
||||
"deepseek": true,
|
||||
"moonshot": true,
|
||||
"xai": true, // Models from models.dev use 'xai' ID for Grok
|
||||
"openai": true,
|
||||
"google": true, // Models from models.dev use 'google' ID for Gemini
|
||||
"deepseek": true,
|
||||
"moonshot": true,
|
||||
"moonshotai": true, // Official moonshotai ID in models.dev
|
||||
"moonshotai-cn": true, // Official moonshotai-cn ID in models.dev
|
||||
"xai": true, // Models from models.dev use 'xai' ID for Grok
|
||||
"llmgateway": true, // Catch-all for newer models
|
||||
"ollama": true,
|
||||
"xiaomi": true, // Xiaomi MiMo models
|
||||
}
|
||||
|
||||
s.registryMu.RLock()
|
||||
if s.registry != nil {
|
||||
for pID, pInfo := range s.registry.Providers {
|
||||
if !allowedProviders[pID] {
|
||||
continue
|
||||
}
|
||||
for mID := range pInfo.Models {
|
||||
data = append(data, OpenAIModel{
|
||||
if _, exists := modelMap[mID]; !exists {
|
||||
modelMap[mID] = OpenAIModel{
|
||||
ID: mID,
|
||||
Object: "model",
|
||||
Created: 1700000000,
|
||||
OwnedBy: pID,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
s.registryMu.RUnlock()
|
||||
|
||||
// Add configured Ollama models
|
||||
if s.cfg.Providers.Ollama.Enabled {
|
||||
for _, mID := range s.cfg.Providers.Ollama.Models {
|
||||
if _, exists := modelMap[mID]; !exists {
|
||||
modelMap[mID] = OpenAIModel{
|
||||
ID: mID,
|
||||
Object: "model",
|
||||
Created: 1700000000,
|
||||
OwnedBy: pID,
|
||||
})
|
||||
OwnedBy: "ollama",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add model groups so clients can discover them
|
||||
if s.modelRouter != nil {
|
||||
for _, gid := range s.modelRouter.Groups() {
|
||||
if _, exists := modelMap[gid]; !exists {
|
||||
modelMap[gid] = OpenAIModel{
|
||||
ID: gid,
|
||||
Object: "model",
|
||||
Created: 1700000000,
|
||||
OwnedBy: "gophergate",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var data []OpenAIModel
|
||||
for _, m := range modelMap {
|
||||
data = append(data, m)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) selectProvider(modelID string) (providers.Provider, string, error) {
|
||||
providerName := "openai" // default
|
||||
modelLower := strings.ToLower(modelID)
|
||||
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
|
||||
providerName = "gemini"
|
||||
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
|
||||
providerName = "deepseek"
|
||||
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
|
||||
providerName = "moonshot"
|
||||
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
|
||||
providerName = "grok"
|
||||
} else if strings.HasPrefix(modelLower, "ollama/") ||
|
||||
strings.Contains(modelLower, "glm-") ||
|
||||
strings.Contains(modelLower, "qwen") ||
|
||||
strings.Contains(modelLower, "gemma") ||
|
||||
strings.Contains(modelLower, "llama") ||
|
||||
strings.Contains(modelLower, "mistral") ||
|
||||
strings.Contains(modelLower, "phi") ||
|
||||
strings.Contains(modelLower, "yi") ||
|
||||
strings.Contains(modelLower, "codellama") ||
|
||||
strings.Contains(modelLower, "command-r") {
|
||||
providerName = "ollama"
|
||||
} else if strings.HasPrefix(modelLower, "xiaomi/") || strings.Contains(modelLower, "mimo") || strings.Contains(modelLower, "xiaomi") {
|
||||
providerName = "xiaomi"
|
||||
}
|
||||
|
||||
p, ok := s.providers[providerName]
|
||||
if !ok {
|
||||
return nil, "", fmt.Errorf("Provider %s not enabled or supported", providerName)
|
||||
}
|
||||
return p, providerName, nil
|
||||
}
|
||||
|
||||
func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||
startTime := time.Now()
|
||||
var req models.ChatCompletionRequest
|
||||
@@ -284,38 +558,79 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Select provider based on model name
|
||||
providerName := "openai" // default
|
||||
if strings.Contains(req.Model, "gemini") {
|
||||
providerName = "gemini"
|
||||
} else if strings.Contains(req.Model, "deepseek") {
|
||||
providerName = "deepseek"
|
||||
} else if strings.Contains(req.Model, "kimi") || strings.Contains(req.Model, "moonshot") {
|
||||
providerName = "moonshot"
|
||||
} else if strings.Contains(req.Model, "grok") {
|
||||
providerName = "grok"
|
||||
// Strip common prefixes and prepare model ID
|
||||
modelID := req.Model
|
||||
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/", "xiaomi/"}
|
||||
for _, p := range prefixes {
|
||||
if strings.HasPrefix(modelID, p) {
|
||||
modelID = strings.TrimPrefix(modelID, p)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
provider, ok := s.providers[providerName]
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
||||
// Resolve model groups to concrete models (hierarchical — groups can target groups)
|
||||
modelGroup := ""
|
||||
for i, m := range req.Messages {
|
||||
log.Printf("[DEBUG] Incoming Msg[%d]: role=%s, hasToolCalls=%v, hasContent=%v", i, m.Role, len(m.ToolCalls) > 0, m.Content != nil)
|
||||
}
|
||||
if s.modelRouter != nil {
|
||||
routeCtx := s.buildRouteContextFromChat(req)
|
||||
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
|
||||
return
|
||||
}
|
||||
if decision.SelectedModel != modelID {
|
||||
modelGroup = modelID
|
||||
}
|
||||
modelID = decision.SelectedModel
|
||||
log.Printf("[ROUTER] %s (%s: %s)", modelID, decision.Strategy, decision.Reason)
|
||||
}
|
||||
|
||||
// Select provider based on the resolved model name
|
||||
provider, providerName, err := s.selectProvider(modelID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Convert ChatCompletionRequest to UnifiedRequest
|
||||
unifiedReq := &models.UnifiedRequest{
|
||||
Model: req.Model,
|
||||
Model: modelID,
|
||||
Messages: []models.UnifiedMessage{},
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
TopK: req.TopK,
|
||||
N: req.N,
|
||||
MaxTokens: req.MaxTokens,
|
||||
PresencePenalty: req.PresencePenalty,
|
||||
FrequencyPenalty: req.FrequencyPenalty,
|
||||
Stream: req.Stream != nil && *req.Stream,
|
||||
Tools: req.Tools,
|
||||
ToolChoice: req.ToolChoice,
|
||||
MaxTokens: req.MaxTokens,
|
||||
PresencePenalty: req.PresencePenalty,
|
||||
FrequencyPenalty: req.FrequencyPenalty,
|
||||
Stream: req.Stream != nil && *req.Stream,
|
||||
Tools: req.Tools,
|
||||
ToolChoice: req.ToolChoice,
|
||||
}
|
||||
|
||||
// Inject or cap max_tokens from model registry.
|
||||
s.registryMu.RLock()
|
||||
meta := s.registry.FindModel(modelID)
|
||||
s.registryMu.RUnlock()
|
||||
|
||||
if meta != nil && meta.Limit != nil && meta.Limit.Output > 0 {
|
||||
if unifiedReq.MaxTokens == nil {
|
||||
unifiedReq.MaxTokens = &meta.Limit.Output
|
||||
log.Printf("[DEBUG] %s: injected registry max_tokens=%d", modelID, meta.Limit.Output)
|
||||
} else if *unifiedReq.MaxTokens > meta.Limit.Output {
|
||||
log.Printf("[DEBUG] %s: capping client max_tokens (%d) to registry limit (%d)", modelID, *unifiedReq.MaxTokens, meta.Limit.Output)
|
||||
unifiedReq.MaxTokens = &meta.Limit.Output
|
||||
} else {
|
||||
log.Printf("[DEBUG] %s: using client max_tokens (%d)", modelID, *unifiedReq.MaxTokens)
|
||||
}
|
||||
} else {
|
||||
if unifiedReq.MaxTokens == nil {
|
||||
log.Printf("[DEBUG] %s: no registry limit found, leaving max_tokens nil", modelID)
|
||||
} else {
|
||||
log.Printf("[DEBUG] %s: using client max_tokens (%d), no registry limit to cap", modelID, *unifiedReq.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Stop sequences
|
||||
@@ -398,7 +713,7 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||
if unifiedReq.Stream {
|
||||
ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, unifiedReq.HasImages)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -412,7 +727,7 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||
chunk, ok := <-ch
|
||||
if !ok {
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, lastUsage, nil, unifiedReq.HasImages)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage, nil, unifiedReq.HasImages)
|
||||
return false
|
||||
}
|
||||
if chunk.Usage != nil {
|
||||
@@ -430,21 +745,143 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||
|
||||
resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, unifiedReq.HasImages)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage, nil, unifiedReq.HasImages)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage, nil, unifiedReq.HasImages)
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) logRequest(start time.Time, clientID, provider, model string, usage *models.Usage, err error, hasImages bool) {
|
||||
func extractUserMessage(messages []models.ChatMessage) string {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == "user" {
|
||||
switch c := messages[i].Content.(type) {
|
||||
case string:
|
||||
return c
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *Server) handleImageGenerations(c *gin.Context) {
|
||||
startTime := time.Now()
|
||||
var req models.ImageGenerationRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Determine provider based on model name
|
||||
providerName := "openai"
|
||||
modelLower := strings.ToLower(req.Model)
|
||||
switch {
|
||||
case strings.Contains(modelLower, "imagen"), strings.Contains(modelLower, "gemini"):
|
||||
providerName = "gemini"
|
||||
case strings.Contains(modelLower, "dall"), strings.HasPrefix(modelLower, "openai/"):
|
||||
providerName = "openai"
|
||||
}
|
||||
|
||||
// Default model for each provider if not specified
|
||||
if req.Model == "" {
|
||||
if providerName == "openai" {
|
||||
req.Model = "dall-e-3"
|
||||
} else {
|
||||
req.Model = "imagen-3.0-generate-001"
|
||||
}
|
||||
}
|
||||
|
||||
// Strip common prefixes
|
||||
prefixes := []string{"openai/", "gemini/", "google/"}
|
||||
for _, p := range prefixes {
|
||||
if strings.HasPrefix(req.Model, p) {
|
||||
req.Model = strings.TrimPrefix(req.Model, p)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
provider, ok := s.providers[providerName]
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
||||
return
|
||||
}
|
||||
|
||||
clientID := "default"
|
||||
if auth, ok := c.Get("auth"); ok {
|
||||
if authInfo, ok := auth.(models.AuthInfo); ok {
|
||||
clientID = authInfo.ClientID
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := provider.ImageGeneration(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Estimate tokens from prompt text (~4 chars per token)
|
||||
promptTokens := uint32(len(req.Prompt) / 4)
|
||||
if promptTokens < 1 {
|
||||
promptTokens = 1
|
||||
}
|
||||
|
||||
// Calculate per-image cost (not per-token like chat)
|
||||
cost := imageGenCost(providerName, req.Model, req.Size, uint32(len(resp.Data)))
|
||||
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, "", &models.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: uint32(len(resp.Data)),
|
||||
TotalTokens: promptTokens + uint32(len(resp.Data)),
|
||||
}, nil, false)
|
||||
|
||||
// Update cost in DB — image gen is per-image, not per-token
|
||||
if cost > 0 {
|
||||
s.database.Exec("UPDATE llm_requests SET cost = ? WHERE id = (SELECT MAX(id) FROM llm_requests)", cost)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// imageGenCost returns per-image pricing for known image generation models.
|
||||
func imageGenCost(provider, model string, size *string, n uint32) float64 {
|
||||
if n == 0 {
|
||||
return 0
|
||||
}
|
||||
modelLower := strings.ToLower(model)
|
||||
var perImage float64
|
||||
|
||||
switch {
|
||||
case strings.Contains(modelLower, "dall-e-3"):
|
||||
perImage = 0.040 // standard 1024x1024
|
||||
if size != nil {
|
||||
s := *size
|
||||
if s == "1024x1792" || s == "1792x1024" {
|
||||
perImage = 0.080
|
||||
}
|
||||
}
|
||||
case strings.Contains(modelLower, "dall-e-2"):
|
||||
perImage = 0.020
|
||||
case strings.Contains(modelLower, "imagen"):
|
||||
perImage = 0.040 // approximate
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
|
||||
return perImage * float64(n)
|
||||
}
|
||||
|
||||
func (s *Server) logRequest(start time.Time, clientID, provider, model, modelGroup string, usage *models.Usage, err error, hasImages bool) {
|
||||
entry := RequestLog{
|
||||
Timestamp: start,
|
||||
ClientID: clientID,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
ModelGroup: modelGroup,
|
||||
Status: "success",
|
||||
DurationMS: time.Since(start).Milliseconds(),
|
||||
HasImages: hasImages,
|
||||
@@ -468,11 +905,16 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model string, u
|
||||
if usage.CacheWriteTokens != nil {
|
||||
entry.CacheWriteTokens = *usage.CacheWriteTokens
|
||||
}
|
||||
|
||||
// Calculate cost using registry
|
||||
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
|
||||
fmt.Printf("[DEBUG] Request logged: model=%s, prompt=%d, completion=%d, reasoning=%d, cache_read=%d, cost=%f\n",
|
||||
model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.Cost)
|
||||
|
||||
// Calculate cost using registry; if the resolved model is unknown,
|
||||
// fall back to the model group so group requests still get priced.
|
||||
s.registryMu.RLock()
|
||||
pricingModel := model
|
||||
if s.registry != nil && s.registry.FindModel(pricingModel) == nil && modelGroup != "" {
|
||||
pricingModel = modelGroup
|
||||
}
|
||||
entry.Cost = utils.CalculateCost(s.registry, pricingModel, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
|
||||
s.registryMu.RUnlock()
|
||||
}
|
||||
|
||||
s.logger.LogRequest(entry)
|
||||
@@ -481,14 +923,16 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model string, u
|
||||
func (s *Server) Run() error {
|
||||
go s.hub.Run()
|
||||
s.logger.Start()
|
||||
|
||||
|
||||
// Start registry refresher
|
||||
go func() {
|
||||
ticker := time.NewTicker(24 * time.Hour)
|
||||
for range ticker.C {
|
||||
newRegistry, err := utils.FetchRegistry()
|
||||
if err == nil {
|
||||
s.registryMu.Lock()
|
||||
s.registry = newRegistry
|
||||
s.registryMu.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -496,3 +940,165 @@ func (s *Server) Run() error {
|
||||
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
|
||||
return s.router.Run(addr)
|
||||
}
|
||||
|
||||
func uint32Ptr(v uint32) *uint32 { return &v }
|
||||
|
||||
func (s *Server) buildRouteContextFromChat(req models.ChatCompletionRequest) *router.RouteContext {
|
||||
userMessage := extractUserMessage(req.Messages)
|
||||
requiresToolCalling := len(req.Tools) > 0
|
||||
hasMultimodal := false
|
||||
inputTokens := 0
|
||||
|
||||
for _, msg := range req.Messages {
|
||||
if strContent, ok := msg.Content.(string); ok {
|
||||
inputTokens += len(strContent) / 4
|
||||
} else if parts, ok := msg.Content.([]interface{}); ok {
|
||||
for _, part := range parts {
|
||||
if partMap, ok := part.(map[string]interface{}); ok {
|
||||
partType, _ := partMap["type"].(string)
|
||||
if partType == "text" {
|
||||
text, _ := partMap["text"].(string)
|
||||
inputTokens += len(text) / 4
|
||||
} else if partType == "image_url" {
|
||||
hasMultimodal = true
|
||||
inputTokens += 1000 // Approximate cost of an image in tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msgLower := strings.ToLower(userMessage)
|
||||
requiresReasoning := strings.Contains(msgLower, "reason") ||
|
||||
strings.Contains(msgLower, "think step by step") ||
|
||||
strings.Contains(msgLower, "mathematics") ||
|
||||
strings.Contains(msgLower, "architecture") ||
|
||||
strings.Contains(msgLower, "explain in detail")
|
||||
|
||||
routeCtx := &router.RouteContext{
|
||||
UserMessage: userMessage,
|
||||
InputTokens: inputTokens,
|
||||
HasMultimodalInput: hasMultimodal,
|
||||
RequiresToolCalling: requiresToolCalling,
|
||||
RequiresReasoning: requiresReasoning,
|
||||
}
|
||||
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
|
||||
return routeCtx
|
||||
}
|
||||
|
||||
func (s *Server) buildRouteContextFromResponses(req models.ResponsesRequest) *router.RouteContext {
|
||||
var userMessage string
|
||||
hasMultimodal := false
|
||||
inputTokens := len(req.Instructions) / 4
|
||||
requiresToolCalling := len(req.Tools) > 0 && string(req.Tools) != "null" && string(req.Tools) != ""
|
||||
|
||||
var strInput string
|
||||
if err := json.Unmarshal(req.Input, &strInput); err == nil {
|
||||
userMessage = strInput
|
||||
inputTokens += len(userMessage) / 4
|
||||
} else {
|
||||
var msgs []models.ResponseInputMessage
|
||||
if err := json.Unmarshal(req.Input, &msgs); err == nil {
|
||||
for _, m := range msgs {
|
||||
var contentStr string
|
||||
if err := json.Unmarshal(m.Content, &contentStr); err == nil {
|
||||
if m.Role == "user" {
|
||||
userMessage = contentStr
|
||||
}
|
||||
inputTokens += len(contentStr) / 4
|
||||
} else {
|
||||
var parts []models.ContentPart
|
||||
if err := json.Unmarshal(m.Content, &parts); err == nil {
|
||||
for _, p := range parts {
|
||||
if p.Type == "text" {
|
||||
if m.Role == "user" {
|
||||
userMessage = p.Text
|
||||
}
|
||||
inputTokens += len(p.Text) / 4
|
||||
} else if p.Type == "image_url" {
|
||||
hasMultimodal = true
|
||||
inputTokens += 1000
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msgLower := strings.ToLower(userMessage)
|
||||
requiresReasoning := strings.Contains(msgLower, "reason") ||
|
||||
strings.Contains(msgLower, "think step by step") ||
|
||||
strings.Contains(msgLower, "mathematics") ||
|
||||
strings.Contains(msgLower, "architecture") ||
|
||||
strings.Contains(msgLower, "explain in detail")
|
||||
|
||||
routeCtx := &router.RouteContext{
|
||||
UserMessage: userMessage,
|
||||
InputTokens: inputTokens,
|
||||
HasMultimodalInput: hasMultimodal,
|
||||
RequiresToolCalling: requiresToolCalling,
|
||||
RequiresReasoning: requiresReasoning,
|
||||
}
|
||||
routeCtx.Tags = s.getRouteCtxTags(routeCtx)
|
||||
return routeCtx
|
||||
}
|
||||
|
||||
func (s *Server) getRouteCtxTags(routeCtx *router.RouteContext) []string {
|
||||
var tags []string
|
||||
msgLower := strings.ToLower(routeCtx.UserMessage)
|
||||
|
||||
// fast-flow keywords
|
||||
fastFlowKeywords := []string{
|
||||
"classify", "classification", "label", "tag", "route", "routing", "intent",
|
||||
"json", "yaml", "csv", "schema", "parse", "extract", "transform", "format", "regex",
|
||||
"short answer", "brief", "concise", "tl;dr", "one line", "simple",
|
||||
"fix this", "small bug", "quick fix", "typo", "syntax error",
|
||||
}
|
||||
for _, kw := range fastFlowKeywords {
|
||||
if strings.Contains(msgLower, kw) {
|
||||
tags = append(tags, "fast-flow", "classification", "json-extraction", "basic-qa")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// standard-pro keywords
|
||||
standardProKeywords := []string{
|
||||
"explain", "summarize", "rewrite", "draft", "edit", "polish", "outline",
|
||||
"long doc", "document", "email", "memo", "proposal", "report", "handout", "notes",
|
||||
"compare", "choose", "recommend", "tradeoff", "pros and cons", "analysis",
|
||||
"code review", "debug", "bug", "feature", "api", "endpoint", "implement",
|
||||
"plan", "planning", "workflow", "integration",
|
||||
}
|
||||
for _, kw := range standardProKeywords {
|
||||
if strings.Contains(msgLower, kw) {
|
||||
tags = append(tags, "standard-pro", "long-doc")
|
||||
break
|
||||
}
|
||||
}
|
||||
if routeCtx.HasMultimodalInput {
|
||||
tags = append(tags, "video-analysis", "multimodal-qa")
|
||||
}
|
||||
|
||||
// heavy-logic keywords
|
||||
heavyLogicKeywords := []string{
|
||||
"agent", "agents", "tool use", "function calling", "multi-agent", "orchestrate",
|
||||
"system design", "scaling", "performance", "architecture review", "distributed",
|
||||
"hard bug", "race condition", "deadlock", "memory leak", "crash", "production outage",
|
||||
"long context", "large codebase", "many files", "complex refactor", "migration",
|
||||
"research", "deep dive", "literature", "paper", "scholarly", "thorough analysis",
|
||||
"deep reasoning", "think step by step", "reason through", "careful analysis",
|
||||
}
|
||||
for _, kw := range heavyLogicKeywords {
|
||||
if strings.Contains(msgLower, kw) {
|
||||
tags = append(tags, "heavy-logic", "deep-reasoning", "architecture", "hard-debugging")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if routeCtx.RequiresToolCalling {
|
||||
tags = append(tags, "tool-heavy", "multi-step-agent", "swe-bench")
|
||||
}
|
||||
|
||||
return tags
|
||||
}
|
||||
|
||||
@@ -79,7 +79,7 @@ func (m *SessionManager) createSignedToken(sessionID, username, displayName, rol
|
||||
}
|
||||
|
||||
payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON)
|
||||
|
||||
|
||||
h := hmac.New(sha256.New, m.secret)
|
||||
h.Write(payloadJSON)
|
||||
signature := h.Sum(nil)
|
||||
@@ -133,23 +133,41 @@ func (m *SessionManager) ValidateSession(token string) (*Session, string, error)
|
||||
return &session, "", nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) RevokeSession(token string) {
|
||||
func (m *SessionManager) RevokeSession(token string) error {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 2 {
|
||||
return
|
||||
return fmt.Errorf("invalid token format")
|
||||
}
|
||||
|
||||
payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return
|
||||
return fmt.Errorf("failed to decode payload: %w", err)
|
||||
}
|
||||
|
||||
var payload sessionPayload
|
||||
if err := json.Unmarshal(payloadJSON, &payload); err != nil {
|
||||
return
|
||||
return fmt.Errorf("failed to parse payload: %w", err)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
delete(m.sessions, payload.SessionID)
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartCleanup runs a background goroutine that removes expired sessions every 15 minutes.
|
||||
func (m *SessionManager) StartCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(15 * time.Minute)
|
||||
for range ticker.C {
|
||||
m.mu.Lock()
|
||||
now := time.Now()
|
||||
for id, s := range m.sessions {
|
||||
if now.After(s.ExpiresAt) {
|
||||
delete(m.sessions, id)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"gophergate/internal/db"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/cpu"
|
||||
"github.com/shirou/gopsutil/v3/disk"
|
||||
"github.com/shirou/gopsutil/v3/load"
|
||||
"github.com/shirou/gopsutil/v3/mem"
|
||||
"github.com/shirou/gopsutil/v3/process"
|
||||
)
|
||||
|
||||
func (s *Server) handleSystemHealth(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"status": "ok",
|
||||
"components": gin.H{
|
||||
"database": "online",
|
||||
"proxy": "online",
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleSystemMetrics(c *gin.Context) {
|
||||
v, _ := mem.VirtualMemory()
|
||||
c_usage, _ := cpu.Percent(time.Second, false)
|
||||
d, _ := disk.Usage("/")
|
||||
l, _ := load.Avg()
|
||||
p, _ := process.NewProcess(int32(os.Getpid()))
|
||||
rss, _ := p.MemoryInfo()
|
||||
|
||||
cpuPercent := 0.0
|
||||
if len(c_usage) > 0 {
|
||||
cpuPercent = c_usage[0]
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"cpu": gin.H{
|
||||
"usage_percent": fmt.Sprintf("%.1f", cpuPercent),
|
||||
"load_average": []float64{l.Load1, l.Load5, l.Load15},
|
||||
},
|
||||
"memory": gin.H{
|
||||
"used_mb": v.Used / 1024 / 1024,
|
||||
"total_mb": v.Total / 1024 / 1024,
|
||||
"usage_percent": fmt.Sprintf("%.1f", v.UsedPercent),
|
||||
"process_rss_mb": rss.RSS / 1024 / 1024,
|
||||
},
|
||||
"disk": gin.H{
|
||||
"used_gb": float64(d.Used) / 1024 / 1024 / 1024,
|
||||
"total_gb": float64(d.Total) / 1024 / 1024 / 1024,
|
||||
"usage_percent": fmt.Sprintf("%.1f", d.UsedPercent),
|
||||
},
|
||||
"connections": gin.H{
|
||||
"db_active": s.database.Stats().OpenConnections,
|
||||
"websocket_listeners": s.hub.GetClientCount(),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleGetSettings(c *gin.Context) {
|
||||
providerCount := 0
|
||||
modelCount := 0
|
||||
s.registryMu.RLock()
|
||||
if s.registry != nil {
|
||||
providerCount = len(s.registry.Providers)
|
||||
for _, p := range s.registry.Providers {
|
||||
modelCount += len(p.Models)
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"server": gin.H{
|
||||
"version": "1.0.0-go",
|
||||
"auth_tokens": s.cfg.Server.AuthTokens,
|
||||
},
|
||||
"database": gin.H{
|
||||
"type": "sqlite",
|
||||
"path": s.cfg.Database.Path,
|
||||
},
|
||||
"registry": gin.H{
|
||||
"provider_count": providerCount,
|
||||
"model_count": modelCount,
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleCreateBackup(c *gin.Context) {
|
||||
// Simplified backup response
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"backup_id": fmt.Sprintf("backup-%d.db", time.Now().Unix()),
|
||||
"status": "created",
|
||||
}))
|
||||
}
|
||||
|
||||
func (s *Server) handleGetLogs(c *gin.Context) {
|
||||
var logs []db.LLMRequest
|
||||
err := s.database.Select(&logs, "SELECT * FROM llm_requests ORDER BY timestamp DESC LIMIT 100")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Format for UI
|
||||
type UILog struct {
|
||||
Timestamp string `json:"timestamp"`
|
||||
ClientID string `json:"client_id"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Tokens int `json:"tokens"`
|
||||
Status string `json:"status"`
|
||||
Duration int `json:"duration"`
|
||||
}
|
||||
|
||||
uiLogs := make([]UILog, len(logs))
|
||||
for i, l := range logs {
|
||||
clientID := "unknown"
|
||||
if l.ClientID != nil {
|
||||
clientID = *l.ClientID
|
||||
}
|
||||
provider := "unknown"
|
||||
if l.Provider != nil {
|
||||
provider = *l.Provider
|
||||
}
|
||||
model := "unknown"
|
||||
if l.Model != nil {
|
||||
model = *l.Model
|
||||
}
|
||||
tokens := 0
|
||||
if l.TotalTokens != nil {
|
||||
tokens = *l.TotalTokens
|
||||
}
|
||||
duration := 0
|
||||
if l.DurationMS != nil {
|
||||
duration = *l.DurationMS
|
||||
}
|
||||
|
||||
uiLogs[i] = UILog{
|
||||
Timestamp: l.Timestamp.Format(time.RFC3339),
|
||||
ClientID: clientID,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Tokens: tokens,
|
||||
Status: l.Status,
|
||||
Duration: duration,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(uiLogs))
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"gophergate/internal/db"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func (s *Server) handleGetUsers(c *gin.Context) {
|
||||
var users []db.User
|
||||
err := s.database.Select(&users, "SELECT id, username, display_name, role, must_change_password, created_at FROM users")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, SuccessResponse(users))
|
||||
}
|
||||
|
||||
type CreateUserRequest struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
DisplayName *string `json:"display_name"`
|
||||
Role *string `json:"role"`
|
||||
}
|
||||
|
||||
func (s *Server) handleCreateUser(c *gin.Context) {
|
||||
var req CreateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), 12)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse("Failed to hash password"))
|
||||
return
|
||||
}
|
||||
|
||||
role := "viewer"
|
||||
if req.Role != nil {
|
||||
role = *req.Role
|
||||
}
|
||||
|
||||
_, err = s.database.Exec("INSERT INTO users (username, password_hash, display_name, role, must_change_password) VALUES (?, ?, ?, ?, 1)",
|
||||
req.Username, string(hash), req.DisplayName, role)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User created"}))
|
||||
}
|
||||
|
||||
type UpdateUserRequest struct {
|
||||
DisplayName *string `json:"display_name"`
|
||||
Role *string `json:"role"`
|
||||
Password *string `json:"password"`
|
||||
MustChangePassword *bool `json:"must_change_password"`
|
||||
}
|
||||
|
||||
func (s *Server) handleUpdateUser(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var req UpdateUserRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
if req.DisplayName != nil {
|
||||
s.database.Exec("UPDATE users SET display_name = ? WHERE id = ?", req.DisplayName, id)
|
||||
}
|
||||
if req.Role != nil {
|
||||
s.database.Exec("UPDATE users SET role = ? WHERE id = ?", req.Role, id)
|
||||
}
|
||||
if req.MustChangePassword != nil {
|
||||
s.database.Exec("UPDATE users SET must_change_password = ? WHERE id = ?", req.MustChangePassword, id)
|
||||
}
|
||||
if req.Password != nil {
|
||||
hash, _ := bcrypt.GenerateFromPassword([]byte(*req.Password), 12)
|
||||
s.database.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hash), id)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User updated"}))
|
||||
}
|
||||
|
||||
func (s *Server) handleDeleteUser(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
session, _ := c.Get("session")
|
||||
if sess, ok := session.(*Session); ok {
|
||||
var username string
|
||||
s.database.Get(&username, "SELECT username FROM users WHERE id = ?", id)
|
||||
if username == sess.Username {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Cannot delete your own account"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
_, err := s.database.Exec("DELETE FROM users WHERE id = ?", id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User deleted"}))
|
||||
}
|
||||
@@ -10,12 +10,18 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // In production, refine this
|
||||
},
|
||||
func newUpgrader(allowedOrigin string) websocket.Upgrader {
|
||||
return websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
if allowedOrigin == "*" {
|
||||
return true
|
||||
}
|
||||
origin := r.Header.Get("Origin")
|
||||
return origin == "" || origin == allowedOrigin
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type Hub struct {
|
||||
@@ -75,6 +81,11 @@ func (h *Hub) GetClientCount() int {
|
||||
}
|
||||
|
||||
func (s *Server) handleWebSocket(c *gin.Context) {
|
||||
allowedOrigin := s.cfg.Server.WSAllowedOrigin
|
||||
if allowedOrigin == "" {
|
||||
allowedOrigin = "*"
|
||||
}
|
||||
upgrader := newUpgrader(allowedOrigin)
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to set websocket upgrade: %v", err)
|
||||
@@ -99,7 +110,7 @@ func (s *Server) handleWebSocket(c *gin.Context) {
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
|
||||
if msg["type"] == "ping" {
|
||||
conn.WriteJSON(gin.H{"type": "pong", "payload": gin.H{}})
|
||||
}
|
||||
|
||||
+50
-18
@@ -6,38 +6,66 @@ import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"gophergate/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
const ModelsDevURL = "https://models.dev/api.json"
|
||||
|
||||
func FetchRegistry() (*models.ModelRegistry, error) {
|
||||
log.Printf("Fetching model registry from %s", ModelsDevURL)
|
||||
|
||||
client := resty.New().SetTimeout(10 * time.Second)
|
||||
resp, err := client.R().Get(ModelsDevURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch registry: %w", err)
|
||||
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
if attempt > 0 {
|
||||
backoff := time.Duration(1<<attempt) * time.Second
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
|
||||
resp, err := client.R().Get(ModelsDevURL)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("attempt %d: %w", attempt+1, err)
|
||||
continue
|
||||
}
|
||||
if !resp.IsSuccess() {
|
||||
lastErr = fmt.Errorf("attempt %d: HTTP %d", attempt+1, resp.StatusCode())
|
||||
continue
|
||||
}
|
||||
|
||||
var providers map[string]models.ProviderInfo
|
||||
if err := json.Unmarshal(resp.Body(), &providers); err != nil {
|
||||
lastErr = fmt.Errorf("attempt %d: unmarshal: %w", attempt+1, err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Println("Successfully loaded model registry")
|
||||
return &models.ModelRegistry{Providers: providers}, nil
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("failed to fetch registry: HTTP %d", resp.StatusCode())
|
||||
}
|
||||
return nil, fmt.Errorf("failed to fetch registry after 3 attempts: %w", lastErr)
|
||||
}
|
||||
|
||||
var providers map[string]models.ProviderInfo
|
||||
if err := json.Unmarshal(resp.Body(), &providers); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal registry: %w", err)
|
||||
}
|
||||
// promoDiscount describes a temporary pricing discount applied on top of
|
||||
// the standard (list) price from the model registry.
|
||||
type promoDiscount struct {
|
||||
Factor float64 // multiplier applied after standard calculation (0.25 = 75% off)
|
||||
ExpiresAt time.Time // discount ends at this time (UTC)
|
||||
}
|
||||
|
||||
log.Println("Successfully loaded model registry")
|
||||
return &models.ModelRegistry{Providers: providers}, nil
|
||||
// promoDiscounts maps model IDs to active promotional discounts.
|
||||
// Sources:
|
||||
// - DeepSeek v4 Pro: 75% off list pricing until 2026-05-31
|
||||
// https://api-docs.deepseek.com/quick_start/pricing
|
||||
var promoDiscounts = map[string]promoDiscount{
|
||||
"deepseek-v4-pro": {
|
||||
Factor: 0.25,
|
||||
ExpiresAt: time.Date(2026, 5, 31, 23, 59, 59, 0, time.UTC),
|
||||
},
|
||||
}
|
||||
|
||||
func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens, completionTokens, reasoningTokens, cacheRead, cacheWrite uint32) float64 {
|
||||
meta := registry.FindModel(modelID)
|
||||
if meta == nil || meta.Cost == nil {
|
||||
log.Printf("[DEBUG] CalculateCost: model %s not found or has no cost metadata", modelID)
|
||||
return 0.0
|
||||
}
|
||||
|
||||
@@ -62,8 +90,12 @@ func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens,
|
||||
cost += float64(cacheWrite) * (*meta.Cost.CacheWrite) / 1000000.0
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG] CalculateCost: model=%s, uncached=%d, completion=%d, reasoning=%d, cache_read=%d, cache_write=%d, cost=%f (input_rate=%f, output_rate=%f)",
|
||||
modelID, uncachedTokens, completionTokens, reasoningTokens, cacheRead, cacheWrite, cost, meta.Cost.Input, meta.Cost.Output)
|
||||
// Apply promotional discounts (e.g. DeepSeek 75% off until 2026-05-31).
|
||||
if discount, ok := promoDiscounts[modelID]; ok {
|
||||
if time.Now().UTC().Before(discount.ExpiresAt) {
|
||||
cost *= discount.Factor
|
||||
}
|
||||
}
|
||||
|
||||
return cost
|
||||
}
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
func TestCalculateCost_NotFound(t *testing.T) {
|
||||
r := &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)}
|
||||
cost := CalculateCost(r, "unknown-model", 100, 50, 0, 0, 0)
|
||||
if cost != 0.0 {
|
||||
t.Fatalf("expected 0 cost for unknown model, got %f", cost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateCost_KnownModel(t *testing.T) {
|
||||
inputCost := 2.5 // $2.50 per 1M tokens
|
||||
outputCost := 10.0 // $10.00 per 1M tokens
|
||||
r := &models.ModelRegistry{
|
||||
Providers: map[string]models.ProviderInfo{
|
||||
"openai": {
|
||||
Models: map[string]models.ModelMetadata{
|
||||
"gpt-4o": {
|
||||
Cost: &models.ModelCost{
|
||||
Input: inputCost,
|
||||
Output: outputCost,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cost := CalculateCost(r, "gpt-4o", 1000, 500, 0, 0, 0)
|
||||
expected := (1000 * inputCost / 1000000.0) + (500 * outputCost / 1000000.0)
|
||||
if cost != expected {
|
||||
t.Fatalf("expected %f, got %f", expected, cost)
|
||||
}
|
||||
}
|
||||
+6
-1
@@ -89,6 +89,10 @@
|
||||
<i class="fas fa-brain"></i>
|
||||
<span>Models</span>
|
||||
</li>
|
||||
<li class="menu-item" data-page="model-groups">
|
||||
<i class="fas fa-code-branch"></i>
|
||||
<span>Model Groups</span>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
@@ -164,7 +168,7 @@
|
||||
<script src="/js/auth.js?v=7"></script>
|
||||
<script src="/js/charts.js?v=7"></script>
|
||||
<script src="/js/websocket.js?v=7"></script>
|
||||
<script src="/js/dashboard.js?v=7"></script>
|
||||
<script src="/js/dashboard.js?v=8"></script>
|
||||
|
||||
<!-- Page Modules -->
|
||||
<script src="/js/pages/overview.js?v=7"></script>
|
||||
@@ -177,5 +181,6 @@
|
||||
<script src="/js/pages/settings.js?v=7"></script>
|
||||
<script src="/js/pages/logs.js?v=7"></script>
|
||||
<script src="/js/pages/users.js?v=7"></script>
|
||||
<script src="/js/pages/model_groups.js?v=9"></script>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@@ -119,6 +119,7 @@ class Dashboard {
|
||||
'settings': 'Settings',
|
||||
'logs': 'Logs',
|
||||
'models': 'Models',
|
||||
'model-groups': 'Model Groups',
|
||||
'users': 'User Management'
|
||||
};
|
||||
if (titleElement) titleElement.textContent = titles[page] || 'Dashboard';
|
||||
@@ -130,6 +131,11 @@ class Dashboard {
|
||||
if (content) {
|
||||
content.innerHTML = await this.getPageTemplate(page);
|
||||
await this.initializePageScript(page);
|
||||
|
||||
// Model Groups page uses its own render method
|
||||
if (page === 'model-groups' && typeof modelGroupsPage !== 'undefined') {
|
||||
await modelGroupsPage.render();
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Error loading page ${page}:`, error);
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
// Model Groups Management Page
|
||||
|
||||
class ModelGroupsPage {
|
||||
constructor() {
|
||||
this.container = document.getElementById('page-content');
|
||||
}
|
||||
|
||||
async render() {
|
||||
this.container.innerHTML = `
|
||||
<div class="page-header">
|
||||
<h3>Model Groups</h3>
|
||||
<p class="text-muted">Define auto-routing groups that pick the best model for each request.</p>
|
||||
<button class="btn btn-primary" onclick="modelGroupsPage.showCreateForm()">
|
||||
<i class="fas fa-plus"></i> Add Group
|
||||
</button>
|
||||
</div>
|
||||
<div id="model-groups-list" class="table-container"></div>
|
||||
<div id="model-group-form" class="form-container" style="display:none;"></div>
|
||||
`;
|
||||
await this.loadGroups();
|
||||
}
|
||||
|
||||
async loadGroups() {
|
||||
try {
|
||||
const groups = await api.get('/model-groups');
|
||||
const list = document.getElementById('model-groups-list');
|
||||
if (!groups || groups.length === 0) {
|
||||
list.innerHTML = '<div class="empty-state">No model groups defined. Create one to enable auto-routing.</div>';
|
||||
return;
|
||||
}
|
||||
|
||||
let html = '<table class="data-table"><thead><tr>';
|
||||
html += '<th>Group ID</th><th>Level</th><th>Primary Use</th><th>Strategy</th><th>Targets</th><th>Actions</th>';
|
||||
html += '</tr></thead><tbody>';
|
||||
|
||||
groups.forEach(g => {
|
||||
html += '<tr>';
|
||||
html += '<td><code>' + this.esc(g.id) + '</code></td>';
|
||||
html += '<td>' + (g.logic_level != null ? g.logic_level : '—') + '</td>';
|
||||
html += '<td>' + this.esc(g.primary_use || '—') + '</td>';
|
||||
html += '<td><span class="badge">' + this.esc(g.strategy) + '</span></td>';
|
||||
html += '<td><code>' + this.esc(g.targets) + '</code></td>';
|
||||
html += '<td>';
|
||||
html += '<button class="btn btn-sm" onclick="modelGroupsPage.showEditForm(\'' + this.esc(g.id) + '\')">Edit</button> ';
|
||||
html += '<button class="btn btn-sm btn-danger" onclick="modelGroupsPage.deleteGroup(\'' + this.esc(g.id) + '\')">Delete</button>';
|
||||
html += '</td></tr>';
|
||||
});
|
||||
|
||||
html += '</tbody></table>';
|
||||
list.innerHTML = html;
|
||||
} catch (err) {
|
||||
document.getElementById('model-groups-list').innerHTML =
|
||||
'<div class="error-message">Failed to load model groups: ' + this.esc(err.message) + '</div>';
|
||||
}
|
||||
}
|
||||
|
||||
showCreateForm() {
|
||||
this.renderForm(null);
|
||||
}
|
||||
|
||||
async showEditForm(id) {
|
||||
try {
|
||||
const groups = await api.get('/model-groups');
|
||||
const group = groups.find(g => g.id === id);
|
||||
if (group) this.renderForm(group);
|
||||
} catch (err) {
|
||||
alert('Failed to load group: ' + err.message);
|
||||
}
|
||||
}
|
||||
|
||||
renderForm(group) {
|
||||
const isEdit = !!group;
|
||||
const form = document.getElementById('model-group-form');
|
||||
form.style.display = 'block';
|
||||
form.innerHTML = `
|
||||
<h4>${isEdit ? 'Edit' : 'Create'} Model Group</h4>
|
||||
<form onsubmit="modelGroupsPage.saveGroup(event, ${isEdit})">
|
||||
<div class="form-control">
|
||||
<label>Group ID</label>
|
||||
<input type="text" id="mg-id" value="${this.esc(group ? group.id : '')}" ${isEdit ? 'readonly' : 'required'}
|
||||
placeholder="e.g. deepseek-auto">
|
||||
<small>Clients use this as the model name.</small>
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label>Strategy</label>
|
||||
<select id="mg-strategy">
|
||||
<option value="heuristic" ${group && group.strategy === 'heuristic' ? 'selected' : ''}>Heuristic (rules-based)</option>
|
||||
<option value="classifier" ${group && group.strategy === 'classifier' ? 'selected' : ''}>Classifier (LLM judge)</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label>Targets (JSON array)</label>
|
||||
<input type="text" id="mg-targets" value='${this.esc(group ? group.targets : '["cheap-model","smart-model"]')}' required>
|
||||
<small>First target = cheapest/fastest. Last target = smartest/most expensive.</small>
|
||||
</div>
|
||||
<div class="form-control" id="mg-selector-row" style="${group && group.strategy === 'classifier' ? '' : 'display:none'}">
|
||||
<label>Selector Model</label>
|
||||
<input type="text" id="mg-selector-model" value="${this.esc(group && group.selector_model ? group.selector_model : 'gpt-4o-mini')}"
|
||||
placeholder="Model used to judge task complexity">
|
||||
</div>
|
||||
<div class="form-control" id="mg-threshold-row" style="${group && group.strategy === 'classifier' ? '' : 'display:none'}">
|
||||
<label>Complexity Threshold</label>
|
||||
<input type="number" id="mg-threshold" value="${group && group.complexity_threshold ? group.complexity_threshold : ''}" min="1"
|
||||
placeholder="Tasks rated >= this go to the smart model">
|
||||
</div>
|
||||
<div class="form-control" id="mg-rules-row" style="${group && group.strategy === 'heuristic' ? '' : 'display:none'}">
|
||||
<label>Heuristic Rules (JSON array)</label>
|
||||
<textarea id="mg-rules" rows="4" placeholder='[{"pattern":"step by step","target":1}]'>${group && group.heuristic_rules ? group.heuristic_rules : ''}</textarea>
|
||||
<small>Pattern to match in user messages. target = index into targets array.</small>
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label>Logic Level (1-10)</label>
|
||||
<input type="number" id="mg-level" value="${group && group.logic_level != null ? group.logic_level : ''}" min="1" max="10"
|
||||
placeholder="e.g. 8 for heavy logic, 2 for fast/basic">
|
||||
<small>Rough complexity scale. 1-3: fast/light, 4-7: standard, 8-10: heavy.</small>
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label>Primary Use</label>
|
||||
<input type="text" id="mg-primary-use" value="${this.esc(group && group.primary_use ? group.primary_use : '')}"
|
||||
placeholder="e.g. Complex Coding, Logic, Agents.">
|
||||
<small>Brief description of what this group is best used for.</small>
|
||||
</div>
|
||||
<div class="form-actions">
|
||||
<button type="submit" class="btn btn-primary">Save</button>
|
||||
<button type="button" class="btn" onclick="document.getElementById('model-group-form').style.display='none'">Cancel</button>
|
||||
</div>
|
||||
</form>
|
||||
`;
|
||||
|
||||
document.getElementById('mg-strategy').onchange = function() {
|
||||
var isClassifier = this.value === 'classifier';
|
||||
document.getElementById('mg-selector-row').style.display = isClassifier ? '' : 'none';
|
||||
document.getElementById('mg-threshold-row').style.display = isClassifier ? '' : 'none';
|
||||
document.getElementById('mg-rules-row').style.display = isClassifier ? 'none' : '';
|
||||
};
|
||||
}
|
||||
|
||||
async saveGroup(event, isEdit) {
|
||||
event.preventDefault();
|
||||
var id = document.getElementById('mg-id').value.trim();
|
||||
var strategy = document.getElementById('mg-strategy').value;
|
||||
var targets = document.getElementById('mg-targets').value;
|
||||
var selectorModel = document.getElementById('mg-selector-model').value.trim() || null;
|
||||
var thresholdVal = document.getElementById('mg-threshold').value;
|
||||
var rules = document.getElementById('mg-rules').value.trim() || null;
|
||||
var logicLevelVal = document.getElementById('mg-level').value;
|
||||
var primaryUse = document.getElementById('mg-primary-use').value.trim() || null;
|
||||
|
||||
try { JSON.parse(targets); } catch (e) { alert('Targets must be valid JSON array'); return; }
|
||||
if (rules) { try { JSON.parse(rules); } catch (e) { alert('Heuristic rules must be valid JSON'); return; } }
|
||||
|
||||
var body = { id: id, strategy: strategy, targets: targets, selector_model: selectorModel, heuristic_rules: rules };
|
||||
if (thresholdVal) body.complexity_threshold = parseInt(thresholdVal);
|
||||
if (logicLevelVal) body.logic_level = parseInt(logicLevelVal);
|
||||
if (primaryUse) body.primary_use = primaryUse;
|
||||
|
||||
try {
|
||||
if (isEdit) {
|
||||
await api.put('/model-groups/' + encodeURIComponent(id), body);
|
||||
} else {
|
||||
await api.post('/model-groups', body);
|
||||
}
|
||||
document.getElementById('model-group-form').style.display = 'none';
|
||||
await this.loadGroups();
|
||||
} catch (err) {
|
||||
alert('Failed to save: ' + err.message);
|
||||
}
|
||||
}
|
||||
|
||||
async deleteGroup(id) {
|
||||
if (!confirm('Delete model group "' + id + '"? This cannot be undone.')) return;
|
||||
try {
|
||||
await api.delete('/model-groups/' + encodeURIComponent(id));
|
||||
await this.loadGroups();
|
||||
} catch (err) {
|
||||
alert('Failed to delete: ' + err.message);
|
||||
}
|
||||
}
|
||||
|
||||
esc(str) {
|
||||
if (!str) return '';
|
||||
return String(str).replace(/&/g,'&').replace(/</g,'<').replace(/>/g,'>').replace(/"/g,'"');
|
||||
}
|
||||
}
|
||||
|
||||
var modelGroupsPage = new ModelGroupsPage();
|
||||
@@ -392,7 +392,7 @@ class MonitoringPage {
|
||||
</div>
|
||||
<div class="stream-entry-content">
|
||||
<strong>${request.client_id || 'Unknown'}</strong> →
|
||||
${request.provider || 'Unknown'} (${request.model || 'Unknown'})
|
||||
${request.provider || 'Unknown'} (${request.model || 'Unknown'}${request.model_group ? ` via ${request.model_group}` : ''})
|
||||
<div class="stream-entry-details">
|
||||
${request.total_tokens || request.tokens || 0} tokens • ${request.duration_ms || request.duration || 0}ms
|
||||
</div>
|
||||
|
||||
@@ -252,7 +252,7 @@ class OverviewPage {
|
||||
<td>${time}</td>
|
||||
<td><span class="badge-client">${request.client_id}</span></td>
|
||||
<td>${request.provider}</td>
|
||||
<td><code class="code-sm">${request.model}</code></td>
|
||||
<td><code class="code-sm">${request.model}${request.model_group ? ` (via ${request.model_group})` : ''}</code></td>
|
||||
<td>${request.tokens.toLocaleString()}</td>
|
||||
<td>
|
||||
<span class="status-badge ${statusClass}">
|
||||
@@ -313,7 +313,7 @@ class OverviewPage {
|
||||
<td>${time}</td>
|
||||
<td><span class="badge-client">${request.client_id}</span></td>
|
||||
<td>${request.provider}</td>
|
||||
<td><code class="code-sm">${request.model}</code></td>
|
||||
<td><code class="code-sm">${request.model}${request.model_group ? ` (via ${request.model_group})` : ''}</code></td>
|
||||
<td>${(request.total_tokens || request.tokens || 0).toLocaleString()}</td>
|
||||
<td>
|
||||
<span class="status-badge ${statusClass}">
|
||||
|
||||
@@ -309,7 +309,7 @@ class WebSocketManager {
|
||||
<td>${time}</td>
|
||||
<td>${request.client_id || 'Unknown'}</td>
|
||||
<td>${request.provider || 'Unknown'}</td>
|
||||
<td>${request.model || 'Unknown'}</td>
|
||||
<td>${request.model || 'Unknown'}${request.model_group ? ` (via ${request.model_group})` : ''}</td>
|
||||
<td>${(request.total_tokens || request.tokens || 0)}</td>
|
||||
<td>
|
||||
<span class="status-badge ${statusClass}">
|
||||
@@ -358,7 +358,7 @@ class WebSocketManager {
|
||||
</div>
|
||||
<div class="stream-entry-content">
|
||||
<strong>${request.client_id || 'Unknown'}</strong> →
|
||||
${request.provider || 'Unknown'} (${request.model || 'Unknown'})
|
||||
${request.provider || 'Unknown'} (${request.model || 'Unknown'}${request.model_group ? ` via ${request.model_group}` : ''})
|
||||
<div class="stream-entry-details">
|
||||
${(request.total_tokens || request.tokens || 0)} tokens • ${(request.duration_ms || request.duration || 0)}ms
|
||||
</div>
|
||||
|
||||
Reference in New Issue
Block a user