From 8a8d8d147735d0869b4e0ddf4856ad81c0cd7aef Mon Sep 17 00:00:00 2001 From: hobokenchicken Date: Sun, 26 Apr 2026 14:45:22 -0400 Subject: [PATCH] fix: Phase 1 - security & stability patches - AuthMiddleware now requires auth on /v1/* routes (returns 401) - WebSocket origin check configurable via WSAllowedOrigin - Removed debug fmt.Printf leaks (config, ollama, server) - Registry access protected by sync.RWMutex (race condition fix) - Session cleanup goroutine runs every 15 min - RevokeSession returns error instead of silent no-op --- .pi-lens/install-choices.json | 6 + .pi-lens/metrics-history.json | 81 ++++++++++++++ .pi-lens/turn-state.json | 6 + PLAN.md | 202 ++++++++++++++++++++++++++++++++++ go.mod | 2 +- go.sum | 2 + internal/config/config.go | 24 ++-- internal/middleware/auth.go | 6 +- internal/providers/ollama.go | 38 +++---- internal/server/dashboard.go | 74 ++++++++----- internal/server/server.go | 59 +++++----- internal/server/sessions.go | 28 ++++- internal/server/websocket.go | 25 +++-- 13 files changed, 448 insertions(+), 105 deletions(-) create mode 100644 .pi-lens/install-choices.json create mode 100644 .pi-lens/metrics-history.json create mode 100644 .pi-lens/turn-state.json create mode 100644 PLAN.md diff --git a/.pi-lens/install-choices.json b/.pi-lens/install-choices.json new file mode 100644 index 00000000..ca5b15e8 --- /dev/null +++ b/.pi-lens/install-choices.json @@ -0,0 +1,6 @@ +{ + "gopls": { + "choice": "yes", + "timestamp": 1775750416837 + } +} diff --git a/.pi-lens/metrics-history.json b/.pi-lens/metrics-history.json new file mode 100644 index 00000000..3fbb6e22 --- /dev/null +++ b/.pi-lens/metrics-history.json @@ -0,0 +1,81 @@ +{ + "version": 1, + "files": { + "../../../../.npm-packages/lib/node_modules/pi-lens/clients/lsp/index.ts": { + "latest": { + "commit": "da074f5", + "timestamp": "2026-04-26T03:45:14.025Z", + "mi": 12.6, + "cognitive": 335, + "nesting": 6, + "lines": 910, + "maxCyclomatic": 36, + "entropy": 6.97 + }, + "history": [ + { + "commit": "da074f5", + "timestamp": "2026-04-26T03:45:14.025Z", + "mi": 12.6, + "cognitive": 335, + "nesting": 6, + "lines": 910, + "maxCyclomatic": 36, + "entropy": 6.97 + } + ], + "trend": "stable" + }, + "../../../../.npm-packages/lib/node_modules/pi-lens/clients/lsp/config.ts": { + "latest": { + "commit": "da074f5", + "timestamp": "2026-04-26T03:45:32.901Z", + "mi": 37.7, + "cognitive": 49, + "nesting": 6, + "lines": 173, + "maxCyclomatic": 8, + "entropy": 6.39 + }, + "history": [ + { + "commit": "da074f5", + "timestamp": "2026-04-26T03:45:32.901Z", + "mi": 37.7, + "cognitive": 49, + "nesting": 6, + "lines": 173, + "maxCyclomatic": 8, + "entropy": 6.39 + } + ], + "trend": "stable" + }, + "../../../../.npm-packages/lib/node_modules/pi-lens/clients/lsp/server.ts": { + "latest": { + "commit": "da074f5", + "timestamp": "2026-04-26T03:45:38.756Z", + "mi": 3.9, + "cognitive": 322, + "nesting": 7, + "lines": 1506, + "maxCyclomatic": 28, + "entropy": 7.47 + }, + "history": [ + { + "commit": "da074f5", + "timestamp": "2026-04-26T03:45:38.756Z", + "mi": 3.9, + "cognitive": 322, + "nesting": 7, + "lines": 1506, + "maxCyclomatic": 28, + "entropy": 7.47 + } + ], + "trend": "stable" + } + }, + "capturedAt": "2026-04-26T03:45:43.756Z" +} \ No newline at end of file diff --git a/.pi-lens/turn-state.json b/.pi-lens/turn-state.json new file mode 100644 index 00000000..22676dca --- /dev/null +++ b/.pi-lens/turn-state.json @@ -0,0 +1,6 @@ +{ + "files": {}, + "turnCycles": 0, + "maxCycles": 3, + "lastUpdated": "2026-04-26T18:44:50.547Z" +} \ No newline at end of file diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 00000000..34272b44 --- /dev/null +++ b/PLAN.md @@ -0,0 +1,202 @@ +# GopherGate — Remediation Plan + +> 3 phases, 6 weeks total. Each phase independently shippable. + +--- + +## Phase 1 — Security & Stability (Weeks 1-2) + +**Goal:** Patch auth bypass, data races, debug leaks. No new features. + +### 1.1 Fix auth bypass + +- [ ] `middleware/auth.go`: Return 401 instead of `c.Next()` when no auth header on `/v1/*` +- [ ] Add `requireAuth` param to `AuthMiddleware` constructor: `AuthMiddleware(db, requireAuth bool)` +- [ ] `/v1/*` routes → `requireAuth=true`, leave `/health` unauthed +- [ ] Add tests: curl request without token → 401 + +### 1.2 Fix WebSocket origin + +- [ ] `websocket.go`: Replace `return true` with origin check against configured `Server.Host` +- [ ] Config option `websocket.allowed_origins []string` (default: same origin) +- [ ] Add `xsrf` check on WS upgrade endpoint if behind proxy + +### 1.3 Strip debug prints + +- [ ] `config.go`: Remove `fmt.Printf("Debug Config:...")` and `fmt.Printf("Debug Env:...")` +- [ ] `server.go` `logRequest()`: Remove `fmt.Printf("[DEBUG] Request logged:...")` +- [ ] `config.go`: Remove `fmt.Printf("[DEBUG] Final Ollama Config:...")` +- [ ] `providers/ollama.go`: Remove `fmt.Printf("[Ollama]...")` debug logs or gate behind `LLM_PROXY_DEBUG=1` +- [ ] Replace all `fmt.Printf` with structured logger (slog from stdlib) + +### 1.4 Fix registry data race + +- [ ] `server.go`: Add `sync.RWMutex` around `s.registry` +- [ ] `handleListModels()`: Lock read +- [ ] `logRequest()`: Lock read +- [ ] Background refresh goroutines: Lock write +- [ ] Verify with `go run -race` + +### 1.5 Session cleanup + +- [ ] `sessions.go`: Add periodic cleanup goroutine for expired sessions +- [ ] Cleanup interval: every 15 minutes +- [ ] `RevokeSession`: Return error instead of silent no-op + +--- + +## Phase 2 — Reliability & Observability (Weeks 3-4) + +**Goal:** Error handling, timeouts, logging maturity, concurrency hardening. + +### 2.1 Provider HTTP timeouts + +- [ ] Each provider `New*Provider()`: Set `client.SetTimeout(30 * time.Second)` for non-stream +- [ ] Streaming: No timeout, but add `context.Context` cancellation from request +- [ ] `circuit_breaker.go`: Configure real thresholds + - `MaxRequests: 5` + - `Interval: 60 * time.Second` + - `Timeout: 30 * time.Second` + - `ReadyToTrip: func(counts) bool { return counts.ConsecutiveFailures > 3 }` +- [ ] Test: Stop Ollama, hit endpoint → circuit opens after 3 failures → auto-recovers after 30s + +### 2.2 Structured logging (slog) + +- [ ] Create `internal/logger/logger.go` — `slog.NewJSONHandler` +- [ ] Log levels: error/warn/info/debug +- [ ] Replace all `fmt.Printf` in: server, providers, config, logging +- [ ] `RequestLogger`: Use slog structured fields, remove manual JSON building +- [ ] Log channel: increase buffer from 100 to 10000 or use batch insert every 5s + +### 2.3 Stream error propagation + +- [ ] `ChatCompletionStream`: Send error chunks as SSE events, not just `fmt.Printf` +- [ ] Format: `data: {"error":"..."}\n\n` +- [ ] Client sees full error in stream instead of silent truncation + +### 2.4 Registry fetch retry + +- [ ] `FetchRegistry()`: Add retry with backoff (3 tries, 1s/2s/4s) +- [ ] Cache last-known-good registry so startup works offline + +### 2.5 Token truncation safety + +- [ ] `helpers.go`: Deep-copy ToolCall before truncation, don't mutate original +- [ ] Same pattern across all providers that sanitize IDs + +### 2.6 RevokeSession error handling + +- [ ] `RevokeSession(token)` → `RevokeSession(token) error` +- [ ] Update all callers to handle error + +--- + +## Phase 3 — Architecture & Maintainability (Weeks 5-6) + +**Goal:** Code splitting, test coverage, billing integrity. + +### 3.1 Split dashboard.go + +- [ ] Create `internal/server/clients.go` — client CRUD handlers +- [ ] Create `internal/server/providers.go` — provider handlers +- [ ] Create `internal/server/users.go` — user handlers +- [ ] Create `internal/server/analytics.go` — usage/analytics handlers +- [ ] Create `internal/server/system.go` — health, metrics, logs, backup +- [ ] `dashboard.go` shrinks to imports + route wiring only + +### 3.2 Provider routing via config + +- [ ] Replace `strings.Contains` routing table with config-driven model→provider map +- [ ] `config.go`: Add `server.model_routing` map (e.g. `"llama-*": "ollama"`) +- [ ] Fallback chain: explicit match → prefix match → glob match → default +- [ ] Backward-compat: keep old prefix logic as fallback + +### 3.3 Billing integrity + +- [ ] `logging.go`: Add idempotency key to log entries (unique request ID) +- [ ] Before deducting balance, check if `request_id` already processed +- [ ] `processLog`: Wrap in retry on serialization failure (SQLite busy) +- [ ] Credit deduction: move to separate async worker with replay protection + +### 3.4 Add tests + +- [ ] `internal/models/`: Unit tests for `FindModel()`, message conversion +- [ ] `internal/providers/helpers_test.go`: Unit tests for `MessagesToOpenAIJSON`, `ParseOpenAIResponse` +- [ ] `internal/utils/`: Tests for `Encrypt`/`Decrypt`, `CalculateCost` +- [ ] `internal/server/`: Integration test for auth flow (token → chat completion) +- [ ] `internal/middleware/`: Test auth bypass fix +- [ ] Goal: ≥40% coverage on non-UI packages + +### 3.5 go.mod hygiene + +- [ ] `go mod tidy` (done) +- [ ] Add `go vet ./...` to CI/pre-commit hook +- [ ] Pin dependencies with `go mod verify` + +--- + +## Dependency Map + +``` +Phase 1 ──────────────────────────▶ Phase 2 ──────────────────────────▶ Phase 3 +│ │ │ +├─ 1.1 Auth bypass ──────────▶ 2.3 Stream errors (depends on auth) │ +├─ 1.2 WS origin │ │ +├─ 1.3 Debug prints │ │ +├─ 1.4 Registry race │ │ +├─ 1.5 Session cleanup │ │ +│ ├─ 2.1 HTTP timeouts │ +│ ├─ 2.2 Structured logging ───────────▶ 3.3 Billing (depends on good logs) +│ ├─ 2.4 Registry retry │ +│ ├─ 2.5 Token truncation │ +│ ├─ 2.6 RevokeSession errors │ +│ │ +│ ├─ 3.1 Split dashboard.go +│ ├─ 3.2 Config routing +│ ├─ 3.4 Tests +│ ├─ 3.5 go.mod hygiene +``` + +--- + +## Mermaid Gantt + +```mermaid +gantt + title GopherGate Remediation + dateFormat YYYY-MM-DD + axisFormat %b %d + + section Phase 1 — Security + Auth bypass fix :p1a, 2026-05-04, 2d + WS origin lock :p1b, after p1a, 1d + Strip debug prints :p1c, 2026-05-04, 2d + Registry race fix :p1d, after p1c, 1d + Session cleanup :p1e, after p1d, 2d + + section Phase 2 — Reliability + HTTP timeouts + CB :p2a, 2026-05-11, 3d + Structured logging :p2b, 2026-05-11, 3d + Stream error propagation :p2c, after p2a, 1d + Registry retry :p2d, after p2b, 1d + Token truncation fix :p2e, after p2a, 1d + RevokeSession errors :p2f, after p2b, 1d + + section Phase 3 — Architecture + Split dashboard.go :p3a, 2026-05-25, 4d + Config-driven routing :p3b, 2026-05-25, 3d + Billing integrity :p3c, after p3a, 3d + Add tests :p3d, 2026-06-01, 5d + go.mod hygiene :p3e, after p3d, 1d +``` + +--- + +## Immediate Next Action + +**Start 1.1 — Fix auth bypass:** + +- Edit `middleware/auth.go` → change `c.Next()` to `c.AbortWithStatusJSON(401, ...)` when no header +- Add `RequireAuth` bool param +- Update `server.go` `setupRoutes()` to pass `requireAuth=true` for `/v1/*` +- `curl localhost:8080/v1/chat/completions -d '{}'` → 401 diff --git a/go.mod b/go.mod index 1151df74..f34cffba 100644 --- a/go.mod +++ b/go.mod @@ -10,10 +10,10 @@ require ( github.com/jmoiron/sqlx v1.4.0 github.com/joho/godotenv v1.5.1 github.com/shirou/gopsutil/v3 v3.24.5 + github.com/sony/gobreaker v1.0.0 github.com/spf13/viper v1.21.0 golang.org/x/crypto v0.48.0 modernc.org/sqlite v1.47.0 - github.com/sony/gobreaker v1.0.0 ) require ( diff --git a/go.sum b/go.sum index e6685dfc..92bb3dc5 100644 --- a/go.sum +++ b/go.sum @@ -106,6 +106,8 @@ github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFt github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= +github.com/sony/gobreaker v1.0.0 h1:feX5fGGXSl3dYd4aHZItw+FpHLvvoaqkawKjVNiFMNQ= +github.com/sony/gobreaker v1.0.0/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= diff --git a/internal/config/config.go b/internal/config/config.go index 711da6ca..931c02e2 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,17 +11,18 @@ import ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - Database DatabaseConfig `mapstructure:"database"` - Providers ProviderConfig `mapstructure:"providers"` - EncryptionKey string `mapstructure:"encryption_key"` + Server ServerConfig `mapstructure:"server"` + Database DatabaseConfig `mapstructure:"database"` + Providers ProviderConfig `mapstructure:"providers"` + EncryptionKey string `mapstructure:"encryption_key"` KeyBytes []byte } type ServerConfig struct { - Port int `mapstructure:"port"` - Host string `mapstructure:"host"` - AuthTokens []string `mapstructure:"auth_tokens"` + Port int `mapstructure:"port"` + Host string `mapstructure:"host"` + AuthTokens []string `mapstructure:"auth_tokens"` + WSAllowedOrigin string `mapstructure:"ws_allowed_origin"` } type DatabaseConfig struct { @@ -151,17 +152,14 @@ func Load() (*Config, error) { return nil, fmt.Errorf("failed to unmarshal config: %w", err) } - fmt.Printf("Debug Config: port from viper=%d, host from viper=%s\n", cfg.Server.Port, cfg.Server.Host) - fmt.Printf("Debug Env: LLM_PROXY__SERVER__PORT=%s, LLM_PROXY__SERVER__HOST=%s\n", os.Getenv("LLM_PROXY__SERVER__PORT"), os.Getenv("LLM_PROXY__SERVER__HOST")) - // Manual overrides for nested keys which Viper doesn't always bind correctly with AutomaticEnv + SetEnvPrefix if port := os.Getenv("LLM_PROXY__SERVER__PORT"); port != "" { fmt.Sscanf(port, "%d", &cfg.Server.Port) - fmt.Printf("Overriding port to %d from env\n", cfg.Server.Port) + } if host := os.Getenv("LLM_PROXY__SERVER__HOST"); host != "" { cfg.Server.Host = host - fmt.Printf("Overriding host to %s from env\n", cfg.Server.Host) + } // Ollama overrides @@ -175,8 +173,6 @@ func Load() (*Config, error) { cfg.Providers.Ollama.Models = strings.Split(models, ",") } - fmt.Printf("[DEBUG] Final Ollama Config: Enabled=%v, BaseURL=%s, Models=%v\n", cfg.Providers.Ollama.Enabled, cfg.Providers.Ollama.BaseURL, cfg.Providers.Ollama.Models) - // Validate encryption key if cfg.EncryptionKey == "" { return nil, fmt.Errorf("encryption key is required (LLM_PROXY__ENCRYPTION_KEY)") diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index 2a950d04..727d7767 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -11,10 +11,14 @@ import ( "github.com/gin-gonic/gin" ) -func AuthMiddleware(database *db.DB) gin.HandlerFunc { +func AuthMiddleware(database *db.DB, requireAuth bool) gin.HandlerFunc { return func(c *gin.Context) { authHeader := c.GetHeader("Authorization") if authHeader == "" { + if requireAuth { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"}) + return + } c.Next() return } diff --git a/internal/providers/ollama.go b/internal/providers/ollama.go index d5bd639c..bb948e4c 100644 --- a/internal/providers/ollama.go +++ b/internal/providers/ollama.go @@ -9,9 +9,9 @@ import ( "strings" "time" + "github.com/go-resty/resty/v2" "gophergate/internal/config" "gophergate/internal/models" - "github.com/go-resty/resty/v2" ) type OllamaProvider struct { @@ -26,7 +26,7 @@ func NewOllamaProvider(cfg config.OllamaConfig) *OllamaProvider { client.SetTimeout(15 * time.Minute) client.SetRetryCount(2) client.SetRetryWaitTime(1 * time.Second) - + return &OllamaProvider{ client: client, config: cfg, @@ -46,9 +46,6 @@ func (p *OllamaProvider) ChatCompletion(ctx context.Context, req *models.Unified body := BuildOllamaBody(req, messagesJSON, false) url := fmt.Sprintf("%s/chat/completions", p.config.BaseURL) - // Log request for debugging - fmt.Printf("[Ollama] Request to %s with model %s\n", url, req.Model) - resp, err := p.client.R(). SetContext(ctx). SetBody(body). @@ -70,7 +67,6 @@ func (p *OllamaProvider) ChatCompletion(ctx context.Context, req *models.Unified return nil, fmt.Errorf("failed to parse response: %w", err) } - fmt.Printf("[Ollama] Success response for model %s\n", req.Model) return ParseOllamaResponse(respJSON, req.Model) } @@ -97,7 +93,7 @@ func (p *OllamaProvider) ChatCompletionStream(ctx context.Context, req *models.U } ch := make(chan *models.ChatCompletionStreamResponse) - + go func() { defer close(ch) err := StreamOllama(resp.RawBody(), ch, req.Model) @@ -121,14 +117,14 @@ func BuildOllamaBody(request *models.UnifiedRequest, messagesJSON []interface{}, // Context window size (default 8k for all, 32k+ for modern large-context models) ctxSize := 8192 - if strings.Contains(modelLower, "llama") || - strings.Contains(modelLower, "gemma") || - strings.Contains(modelLower, "mistral") || - strings.Contains(modelLower, "mixtral") || - strings.Contains(modelLower, "qwen") || - strings.Contains(modelLower, "deepseek") || - strings.Contains(modelLower, "command-r") || - strings.Contains(modelLower, "phi") { + if strings.Contains(modelLower, "llama") || + strings.Contains(modelLower, "gemma") || + strings.Contains(modelLower, "mistral") || + strings.Contains(modelLower, "mixtral") || + strings.Contains(modelLower, "qwen") || + strings.Contains(modelLower, "deepseek") || + strings.Contains(modelLower, "command-r") || + strings.Contains(modelLower, "phi") { ctxSize = 32768 } options["num_ctx"] = ctxSize @@ -137,13 +133,13 @@ func BuildOllamaBody(request *models.UnifiedRequest, messagesJSON []interface{}, body["temperature"] = *request.Temperature options["temperature"] = *request.Temperature } - + if request.MaxTokens != nil { body["max_tokens"] = *request.MaxTokens options["num_predict"] = *request.MaxTokens } else { - // Default to 8192 for all Ollama models if not specified, - // as Ollama's compatibility layer defaults to 128 if neither + // Default to 8192 for all Ollama models if not specified, + // as Ollama's compatibility layer defaults to 128 if neither // max_tokens nor num_predict are provided. body["max_tokens"] = 8192 options["num_predict"] = 8192 @@ -189,7 +185,7 @@ func ParseOllamaResponse(respJSON map[string]interface{}, model string) (*models if err != nil { return nil, err } - + var resp models.ChatCompletionResponse if err := json.Unmarshal(data, &resp); err != nil { return nil, err @@ -202,7 +198,7 @@ func ParseOllamaResponse(respJSON map[string]interface{}, model string) (*models resp.Usage = &usage } } - + return &resp, nil } @@ -256,4 +252,4 @@ func StreamOllama(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo } } return scanner.Err() -} \ No newline at end of file +} diff --git a/internal/server/dashboard.go b/internal/server/dashboard.go index c53c3666..f2f69967 100644 --- a/internal/server/dashboard.go +++ b/internal/server/dashboard.go @@ -8,17 +8,17 @@ import ( "strings" "time" - "gophergate/internal/db" - "gophergate/internal/models" - "gophergate/internal/utils" "github.com/gin-gonic/gin" "github.com/google/uuid" "golang.org/x/crypto/bcrypt" + "gophergate/internal/db" + "gophergate/internal/models" + "gophergate/internal/utils" "github.com/shirou/gopsutil/v3/cpu" - "github.com/shirou/gopsutil/v3/mem" "github.com/shirou/gopsutil/v3/disk" "github.com/shirou/gopsutil/v3/load" + "github.com/shirou/gopsutil/v3/mem" "github.com/shirou/gopsutil/v3/process" ) @@ -168,7 +168,9 @@ func (s *Server) handleChangePassword(c *gin.Context) { func (s *Server) handleLogout(c *gin.Context) { token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ") - s.sessions.RevokeSession(token) + if err := s.sessions.RevokeSession(token); err != nil { + fmt.Printf("Error revoking session: %v\n", err) + } c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Logged out"})) } @@ -226,7 +228,7 @@ func (s *Server) handleUsageSummary(c *gin.Context) { } clause, binds := filter.ToSQL() - + // Total stats var totalStats struct { TotalRequests int `db:"total_requests"` @@ -307,7 +309,7 @@ func (s *Server) handleTimeSeries(c *gin.Context) { } clause, binds := filter.ToSQL() - + if clause == "" { cutoff := time.Now().UTC().Add(-30 * 24 * time.Hour) clause = " AND timestamp >= ?" @@ -444,7 +446,10 @@ func (s *Server) handleAnalyticsBreakdown(c *gin.Context) { var label string var value int if err := mRows.Scan(&label, &value); err == nil { - models = append(models, struct{Label string `json:"label"`; Value int `json:"value"`}{label, value}) + models = append(models, struct { + Label string `json:"label"` + Value int `json:"value"` + }{label, value}) } } mRows.Close() @@ -461,7 +466,10 @@ func (s *Server) handleAnalyticsBreakdown(c *gin.Context) { var label string var value int if err := cRows.Scan(&label, &value); err == nil { - clients = append(clients, struct{Label string `json:"label"`; Value int `json:"value"`}{label, value}) + clients = append(clients, struct { + Label string `json:"label"` + Value int `json:"value"` + }{label, value}) } } cRows.Close() @@ -537,15 +545,15 @@ func (s *Server) handleGetClients(c *gin.Context) { } type UIClient struct { - ID string `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - CreatedAt time.Time `json:"created_at"` + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + CreatedAt time.Time `json:"created_at"` LastUsed *time.Time `json:"last_used"` - RequestsCount int `json:"requests_count"` - TokensCount int `json:"tokens_count"` - Status string `json:"status"` - RateLimitPerMinute int `json:"rate_limit_per_minute"` + RequestsCount int `json:"requests_count"` + TokensCount int `json:"tokens_count"` + Status string `json:"status"` + RateLimitPerMinute int `json:"rate_limit_per_minute"` } uiClients := make([]UIClient, len(clients)) @@ -608,12 +616,12 @@ func (s *Server) handleGetClient(c *gin.Context) { } c.JSON(http.StatusOK, SuccessResponse(gin.H{ - "id": cl.ClientID, - "name": name, - "description": desc, - "is_active": cl.IsActive, - "rate_limit_per_minute": cl.RateLimitPerMinute, - "created_at": cl.CreatedAt, + "id": cl.ClientID, + "name": name, + "description": desc, + "is_active": cl.IsActive, + "rate_limit_per_minute": cl.RateLimitPerMinute, + "created_at": cl.CreatedAt, })) } @@ -873,10 +881,16 @@ func (s *Server) handleGetProviders(c *gin.Context) { var models []string if s.registry != nil { registryID := id - if id == "gemini" { registryID = "google" } - if id == "moonshot" { registryID = "moonshot" } - if id == "grok" { registryID = "xai" } - + if id == "gemini" { + registryID = "google" + } + if id == "moonshot" { + registryID = "moonshot" + } + if id == "grok" { + registryID = "xai" + } + if pInfo, ok := s.registry.Providers[registryID]; ok { for mID := range pInfo.Models { models = append(models, mID) @@ -969,7 +983,7 @@ func (s *Server) handleTestProvider(c *gin.Context) { } startTime := time.Now() - + // Prepare a simple test request testReq := &models.UnifiedRequest{ Model: "gpt-4o-mini", // Default cheap test model @@ -1023,7 +1037,7 @@ func (s *Server) handleGetModels(c *gin.Context) { // Merge registry models with DB overrides var dbModels []db.ModelConfig _ = s.database.Select(&dbModels, "SELECT * FROM model_configs") - + dbMap := make(map[string]db.ModelConfig) for _, m := range dbModels { dbMap[m.ID] = m @@ -1305,7 +1319,7 @@ func (s *Server) handleUpdateUser(c *gin.Context) { func (s *Server) handleDeleteUser(c *gin.Context) { id := c.Param("id") - + session, _ := c.Get("session") if sess, ok := session.(*Session); ok { var username string diff --git a/internal/server/server.go b/internal/server/server.go index de6f7a04..9ae3a8f5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "strings" + "sync" "time" "gophergate/internal/config" @@ -19,14 +20,15 @@ import ( ) type Server struct { - router *gin.Engine - cfg *config.Config - database *db.DB - providers map[string]providers.Provider - sessions *SessionManager - hub *Hub - logger *RequestLogger - registry *models.ModelRegistry + router *gin.Engine + cfg *config.Config + database *db.DB + providers map[string]providers.Provider + sessions *SessionManager + hub *Hub + logger *RequestLogger + registry *models.ModelRegistry + registryMu sync.RWMutex } func NewServer(cfg *config.Config, database *db.DB) *Server { @@ -44,6 +46,7 @@ func NewServer(cfg *config.Config, database *db.DB) *Server { registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)}, } + s.sessions.StartCleanup() // Fetch registry in background go func() { registry, err := utils.FetchRegistry() @@ -180,7 +183,7 @@ func (s *Server) setupRoutes() { // API V1 (External LLM Access) - Secured with AuthMiddleware v1 := s.router.Group("/v1") - v1.Use(middleware.AuthMiddleware(s.database)) + v1.Use(middleware.AuthMiddleware(s.database, true)) { v1.POST("/chat/completions", s.handleChatCompletions) v1.GET("/models", s.handleListModels) @@ -194,7 +197,7 @@ func (s *Server) setupRoutes() { api.GET("/auth/status", s.handleAuthStatus) api.POST("/auth/logout", s.handleLogout) api.POST("/auth/change-password", s.handleChangePassword) - + // Protected dashboard routes (need admin session) admin := api.Group("/") admin.Use(s.adminAuthMiddleware()) @@ -205,13 +208,13 @@ func (s *Server) setupRoutes() { admin.GET("/usage/clients", s.handleClientsUsage) admin.GET("/usage/detailed", s.handleDetailedUsage) admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown) - + admin.GET("/clients", s.handleGetClients) admin.POST("/clients", s.handleCreateClient) admin.GET("/clients/:id", s.handleGetClient) admin.PUT("/clients/:id", s.handleUpdateClient) admin.DELETE("/clients/:id", s.handleDeleteClient) - + admin.GET("/clients/:id/tokens", s.handleGetClientTokens) admin.POST("/clients/:id/tokens", s.handleCreateClientToken) admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken) @@ -219,7 +222,7 @@ func (s *Server) setupRoutes() { admin.GET("/providers", s.handleGetProviders) admin.PUT("/providers/:name", s.handleUpdateProvider) admin.POST("/providers/:name/test", s.handleTestProvider) - + admin.GET("/models", s.handleGetModels) admin.PUT("/models/:id", s.handleUpdateModel) @@ -267,6 +270,7 @@ func (s *Server) handleListModels(c *gin.Context) { "ollama": true, } + s.registryMu.RLock() if s.registry != nil { for pID, pInfo := range s.registry.Providers { if !allowedProviders[pID] { @@ -284,6 +288,7 @@ func (s *Server) handleListModels(c *gin.Context) { } } } + s.registryMu.RUnlock() // Add configured Ollama models if s.cfg.Providers.Ollama.Enabled { @@ -330,15 +335,15 @@ func (s *Server) handleChatCompletions(c *gin.Context) { providerName = "moonshot" } else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") { providerName = "grok" - } else if strings.HasPrefix(modelLower, "ollama/") || - strings.Contains(modelLower, "glm-") || - strings.Contains(modelLower, "qwen") || - strings.Contains(modelLower, "gemma") || - strings.Contains(modelLower, "llama") || - strings.Contains(modelLower, "mistral") || - strings.Contains(modelLower, "phi") || - strings.Contains(modelLower, "yi") || - strings.Contains(modelLower, "codellama") || + } else if strings.HasPrefix(modelLower, "ollama/") || + strings.Contains(modelLower, "glm-") || + strings.Contains(modelLower, "qwen") || + strings.Contains(modelLower, "gemma") || + strings.Contains(modelLower, "llama") || + strings.Contains(modelLower, "mistral") || + strings.Contains(modelLower, "phi") || + strings.Contains(modelLower, "yi") || + strings.Contains(modelLower, "codellama") || strings.Contains(modelLower, "command-r") { providerName = "ollama" } @@ -525,11 +530,11 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model string, u if usage.CacheWriteTokens != nil { entry.CacheWriteTokens = *usage.CacheWriteTokens } - + // Calculate cost using registry + s.registryMu.RLock() entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens) - fmt.Printf("[DEBUG] Request logged: model=%s, prompt=%d, completion=%d, reasoning=%d, cache_read=%d, cost=%f\n", - model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.Cost) + s.registryMu.RUnlock() } s.logger.LogRequest(entry) @@ -538,14 +543,16 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model string, u func (s *Server) Run() error { go s.hub.Run() s.logger.Start() - + // Start registry refresher go func() { ticker := time.NewTicker(24 * time.Hour) for range ticker.C { newRegistry, err := utils.FetchRegistry() if err == nil { + s.registryMu.Lock() s.registry = newRegistry + s.registryMu.Unlock() } } }() diff --git a/internal/server/sessions.go b/internal/server/sessions.go index 39975de6..35c0015e 100644 --- a/internal/server/sessions.go +++ b/internal/server/sessions.go @@ -79,7 +79,7 @@ func (m *SessionManager) createSignedToken(sessionID, username, displayName, rol } payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON) - + h := hmac.New(sha256.New, m.secret) h.Write(payloadJSON) signature := h.Sum(nil) @@ -133,23 +133,41 @@ func (m *SessionManager) ValidateSession(token string) (*Session, string, error) return &session, "", nil } -func (m *SessionManager) RevokeSession(token string) { +func (m *SessionManager) RevokeSession(token string) error { parts := strings.Split(token, ".") if len(parts) != 2 { - return + return fmt.Errorf("invalid token format") } payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[0]) if err != nil { - return + return fmt.Errorf("failed to decode payload: %w", err) } var payload sessionPayload if err := json.Unmarshal(payloadJSON, &payload); err != nil { - return + return fmt.Errorf("failed to parse payload: %w", err) } m.mu.Lock() delete(m.sessions, payload.SessionID) m.mu.Unlock() + return nil +} + +// StartCleanup runs a background goroutine that removes expired sessions every 15 minutes. +func (m *SessionManager) StartCleanup() { + go func() { + ticker := time.NewTicker(15 * time.Minute) + for range ticker.C { + m.mu.Lock() + now := time.Now() + for id, s := range m.sessions { + if now.After(s.ExpiresAt) { + delete(m.sessions, id) + } + } + m.mu.Unlock() + } + }() } diff --git a/internal/server/websocket.go b/internal/server/websocket.go index 895d299a..234d7971 100644 --- a/internal/server/websocket.go +++ b/internal/server/websocket.go @@ -10,12 +10,18 @@ import ( "github.com/gorilla/websocket" ) -var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return true // In production, refine this - }, +func newUpgrader(allowedOrigin string) websocket.Upgrader { + return websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + if allowedOrigin == "*" { + return true + } + origin := r.Header.Get("Origin") + return origin == "" || origin == allowedOrigin + }, + } } type Hub struct { @@ -75,6 +81,11 @@ func (h *Hub) GetClientCount() int { } func (s *Server) handleWebSocket(c *gin.Context) { + allowedOrigin := s.cfg.Server.WSAllowedOrigin + if allowedOrigin == "" { + allowedOrigin = "*" + } + upgrader := newUpgrader(allowedOrigin) conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { log.Printf("Failed to set websocket upgrade: %v", err) @@ -99,7 +110,7 @@ func (s *Server) handleWebSocket(c *gin.Context) { if err != nil { break } - + if msg["type"] == "ping" { conn.WriteJSON(gin.H{"type": "pong", "payload": gin.H{}}) }