fix: Phase 1 - security & stability patches
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled

- AuthMiddleware now requires auth on /v1/* routes (returns 401)
- WebSocket origin check configurable via WSAllowedOrigin
- Removed debug fmt.Printf leaks (config, ollama, server)
- Registry access protected by sync.RWMutex (race condition fix)
- Session cleanup goroutine runs every 15 min
- RevokeSession returns error instead of silent no-op
This commit is contained in:
2026-04-26 14:45:22 -04:00
parent da074f52b4
commit 8a8d8d1477
13 changed files with 448 additions and 105 deletions
+6
View File
@@ -0,0 +1,6 @@
{
"gopls": {
"choice": "yes",
"timestamp": 1775750416837
}
}
+81
View File
@@ -0,0 +1,81 @@
{
"version": 1,
"files": {
"../../../../.npm-packages/lib/node_modules/pi-lens/clients/lsp/index.ts": {
"latest": {
"commit": "da074f5",
"timestamp": "2026-04-26T03:45:14.025Z",
"mi": 12.6,
"cognitive": 335,
"nesting": 6,
"lines": 910,
"maxCyclomatic": 36,
"entropy": 6.97
},
"history": [
{
"commit": "da074f5",
"timestamp": "2026-04-26T03:45:14.025Z",
"mi": 12.6,
"cognitive": 335,
"nesting": 6,
"lines": 910,
"maxCyclomatic": 36,
"entropy": 6.97
}
],
"trend": "stable"
},
"../../../../.npm-packages/lib/node_modules/pi-lens/clients/lsp/config.ts": {
"latest": {
"commit": "da074f5",
"timestamp": "2026-04-26T03:45:32.901Z",
"mi": 37.7,
"cognitive": 49,
"nesting": 6,
"lines": 173,
"maxCyclomatic": 8,
"entropy": 6.39
},
"history": [
{
"commit": "da074f5",
"timestamp": "2026-04-26T03:45:32.901Z",
"mi": 37.7,
"cognitive": 49,
"nesting": 6,
"lines": 173,
"maxCyclomatic": 8,
"entropy": 6.39
}
],
"trend": "stable"
},
"../../../../.npm-packages/lib/node_modules/pi-lens/clients/lsp/server.ts": {
"latest": {
"commit": "da074f5",
"timestamp": "2026-04-26T03:45:38.756Z",
"mi": 3.9,
"cognitive": 322,
"nesting": 7,
"lines": 1506,
"maxCyclomatic": 28,
"entropy": 7.47
},
"history": [
{
"commit": "da074f5",
"timestamp": "2026-04-26T03:45:38.756Z",
"mi": 3.9,
"cognitive": 322,
"nesting": 7,
"lines": 1506,
"maxCyclomatic": 28,
"entropy": 7.47
}
],
"trend": "stable"
}
},
"capturedAt": "2026-04-26T03:45:43.756Z"
}
+6
View File
@@ -0,0 +1,6 @@
{
"files": {},
"turnCycles": 0,
"maxCycles": 3,
"lastUpdated": "2026-04-26T18:44:50.547Z"
}
+202
View File
@@ -0,0 +1,202 @@
# GopherGate — Remediation Plan
> 3 phases, 6 weeks total. Each phase independently shippable.
---
## Phase 1 — Security & Stability (Weeks 1-2)
**Goal:** Patch auth bypass, data races, debug leaks. No new features.
### 1.1 Fix auth bypass
- [ ] `middleware/auth.go`: Return 401 instead of `c.Next()` when no auth header on `/v1/*`
- [ ] Add `requireAuth` param to `AuthMiddleware` constructor: `AuthMiddleware(db, requireAuth bool)`
- [ ] `/v1/*` routes → `requireAuth=true`, leave `/health` unauthed
- [ ] Add tests: curl request without token → 401
### 1.2 Fix WebSocket origin
- [ ] `websocket.go`: Replace `return true` with origin check against configured `Server.Host`
- [ ] Config option `websocket.allowed_origins []string` (default: same origin)
- [ ] Add `xsrf` check on WS upgrade endpoint if behind proxy
### 1.3 Strip debug prints
- [ ] `config.go`: Remove `fmt.Printf("Debug Config:...")` and `fmt.Printf("Debug Env:...")`
- [ ] `server.go` `logRequest()`: Remove `fmt.Printf("[DEBUG] Request logged:...")`
- [ ] `config.go`: Remove `fmt.Printf("[DEBUG] Final Ollama Config:...")`
- [ ] `providers/ollama.go`: Remove `fmt.Printf("[Ollama]...")` debug logs or gate behind `LLM_PROXY_DEBUG=1`
- [ ] Replace all `fmt.Printf` with structured logger (slog from stdlib)
### 1.4 Fix registry data race
- [ ] `server.go`: Add `sync.RWMutex` around `s.registry`
- [ ] `handleListModels()`: Lock read
- [ ] `logRequest()`: Lock read
- [ ] Background refresh goroutines: Lock write
- [ ] Verify with `go run -race`
### 1.5 Session cleanup
- [ ] `sessions.go`: Add periodic cleanup goroutine for expired sessions
- [ ] Cleanup interval: every 15 minutes
- [ ] `RevokeSession`: Return error instead of silent no-op
---
## Phase 2 — Reliability & Observability (Weeks 3-4)
**Goal:** Error handling, timeouts, logging maturity, concurrency hardening.
### 2.1 Provider HTTP timeouts
- [ ] Each provider `New*Provider()`: Set `client.SetTimeout(30 * time.Second)` for non-stream
- [ ] Streaming: No timeout, but add `context.Context` cancellation from request
- [ ] `circuit_breaker.go`: Configure real thresholds
- `MaxRequests: 5`
- `Interval: 60 * time.Second`
- `Timeout: 30 * time.Second`
- `ReadyToTrip: func(counts) bool { return counts.ConsecutiveFailures > 3 }`
- [ ] Test: Stop Ollama, hit endpoint → circuit opens after 3 failures → auto-recovers after 30s
### 2.2 Structured logging (slog)
- [ ] Create `internal/logger/logger.go``slog.NewJSONHandler`
- [ ] Log levels: error/warn/info/debug
- [ ] Replace all `fmt.Printf` in: server, providers, config, logging
- [ ] `RequestLogger`: Use slog structured fields, remove manual JSON building
- [ ] Log channel: increase buffer from 100 to 10000 or use batch insert every 5s
### 2.3 Stream error propagation
- [ ] `ChatCompletionStream`: Send error chunks as SSE events, not just `fmt.Printf`
- [ ] Format: `data: {"error":"..."}\n\n`
- [ ] Client sees full error in stream instead of silent truncation
### 2.4 Registry fetch retry
- [ ] `FetchRegistry()`: Add retry with backoff (3 tries, 1s/2s/4s)
- [ ] Cache last-known-good registry so startup works offline
### 2.5 Token truncation safety
- [ ] `helpers.go`: Deep-copy ToolCall before truncation, don't mutate original
- [ ] Same pattern across all providers that sanitize IDs
### 2.6 RevokeSession error handling
- [ ] `RevokeSession(token)``RevokeSession(token) error`
- [ ] Update all callers to handle error
---
## Phase 3 — Architecture & Maintainability (Weeks 5-6)
**Goal:** Code splitting, test coverage, billing integrity.
### 3.1 Split dashboard.go
- [ ] Create `internal/server/clients.go` — client CRUD handlers
- [ ] Create `internal/server/providers.go` — provider handlers
- [ ] Create `internal/server/users.go` — user handlers
- [ ] Create `internal/server/analytics.go` — usage/analytics handlers
- [ ] Create `internal/server/system.go` — health, metrics, logs, backup
- [ ] `dashboard.go` shrinks to imports + route wiring only
### 3.2 Provider routing via config
- [ ] Replace `strings.Contains` routing table with config-driven model→provider map
- [ ] `config.go`: Add `server.model_routing` map (e.g. `"llama-*": "ollama"`)
- [ ] Fallback chain: explicit match → prefix match → glob match → default
- [ ] Backward-compat: keep old prefix logic as fallback
### 3.3 Billing integrity
- [ ] `logging.go`: Add idempotency key to log entries (unique request ID)
- [ ] Before deducting balance, check if `request_id` already processed
- [ ] `processLog`: Wrap in retry on serialization failure (SQLite busy)
- [ ] Credit deduction: move to separate async worker with replay protection
### 3.4 Add tests
- [ ] `internal/models/`: Unit tests for `FindModel()`, message conversion
- [ ] `internal/providers/helpers_test.go`: Unit tests for `MessagesToOpenAIJSON`, `ParseOpenAIResponse`
- [ ] `internal/utils/`: Tests for `Encrypt`/`Decrypt`, `CalculateCost`
- [ ] `internal/server/`: Integration test for auth flow (token → chat completion)
- [ ] `internal/middleware/`: Test auth bypass fix
- [ ] Goal: ≥40% coverage on non-UI packages
### 3.5 go.mod hygiene
- [ ] `go mod tidy` (done)
- [ ] Add `go vet ./...` to CI/pre-commit hook
- [ ] Pin dependencies with `go mod verify`
---
## Dependency Map
```
Phase 1 ──────────────────────────▶ Phase 2 ──────────────────────────▶ Phase 3
│ │ │
├─ 1.1 Auth bypass ──────────▶ 2.3 Stream errors (depends on auth) │
├─ 1.2 WS origin │ │
├─ 1.3 Debug prints │ │
├─ 1.4 Registry race │ │
├─ 1.5 Session cleanup │ │
│ ├─ 2.1 HTTP timeouts │
│ ├─ 2.2 Structured logging ───────────▶ 3.3 Billing (depends on good logs)
│ ├─ 2.4 Registry retry │
│ ├─ 2.5 Token truncation │
│ ├─ 2.6 RevokeSession errors │
│ │
│ ├─ 3.1 Split dashboard.go
│ ├─ 3.2 Config routing
│ ├─ 3.4 Tests
│ ├─ 3.5 go.mod hygiene
```
---
## Mermaid Gantt
```mermaid
gantt
title GopherGate Remediation
dateFormat YYYY-MM-DD
axisFormat %b %d
section Phase 1 — Security
Auth bypass fix :p1a, 2026-05-04, 2d
WS origin lock :p1b, after p1a, 1d
Strip debug prints :p1c, 2026-05-04, 2d
Registry race fix :p1d, after p1c, 1d
Session cleanup :p1e, after p1d, 2d
section Phase 2 — Reliability
HTTP timeouts + CB :p2a, 2026-05-11, 3d
Structured logging :p2b, 2026-05-11, 3d
Stream error propagation :p2c, after p2a, 1d
Registry retry :p2d, after p2b, 1d
Token truncation fix :p2e, after p2a, 1d
RevokeSession errors :p2f, after p2b, 1d
section Phase 3 — Architecture
Split dashboard.go :p3a, 2026-05-25, 4d
Config-driven routing :p3b, 2026-05-25, 3d
Billing integrity :p3c, after p3a, 3d
Add tests :p3d, 2026-06-01, 5d
go.mod hygiene :p3e, after p3d, 1d
```
---
## Immediate Next Action
**Start 1.1 — Fix auth bypass:**
- Edit `middleware/auth.go` → change `c.Next()` to `c.AbortWithStatusJSON(401, ...)` when no header
- Add `RequireAuth` bool param
- Update `server.go` `setupRoutes()` to pass `requireAuth=true` for `/v1/*`
- `curl localhost:8080/v1/chat/completions -d '{}'` → 401
+1 -1
View File
@@ -10,10 +10,10 @@ require (
github.com/jmoiron/sqlx v1.4.0 github.com/jmoiron/sqlx v1.4.0
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/shirou/gopsutil/v3 v3.24.5 github.com/shirou/gopsutil/v3 v3.24.5
github.com/sony/gobreaker v1.0.0
github.com/spf13/viper v1.21.0 github.com/spf13/viper v1.21.0
golang.org/x/crypto v0.48.0 golang.org/x/crypto v0.48.0
modernc.org/sqlite v1.47.0 modernc.org/sqlite v1.47.0
github.com/sony/gobreaker v1.0.0
) )
require ( require (
+2
View File
@@ -106,6 +106,8 @@ github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFt
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
github.com/sony/gobreaker v1.0.0 h1:feX5fGGXSl3dYd4aHZItw+FpHLvvoaqkawKjVNiFMNQ=
github.com/sony/gobreaker v1.0.0/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
+3 -7
View File
@@ -22,6 +22,7 @@ type ServerConfig struct {
Port int `mapstructure:"port"` Port int `mapstructure:"port"`
Host string `mapstructure:"host"` Host string `mapstructure:"host"`
AuthTokens []string `mapstructure:"auth_tokens"` AuthTokens []string `mapstructure:"auth_tokens"`
WSAllowedOrigin string `mapstructure:"ws_allowed_origin"`
} }
type DatabaseConfig struct { type DatabaseConfig struct {
@@ -151,17 +152,14 @@ func Load() (*Config, error) {
return nil, fmt.Errorf("failed to unmarshal config: %w", err) return nil, fmt.Errorf("failed to unmarshal config: %w", err)
} }
fmt.Printf("Debug Config: port from viper=%d, host from viper=%s\n", cfg.Server.Port, cfg.Server.Host)
fmt.Printf("Debug Env: LLM_PROXY__SERVER__PORT=%s, LLM_PROXY__SERVER__HOST=%s\n", os.Getenv("LLM_PROXY__SERVER__PORT"), os.Getenv("LLM_PROXY__SERVER__HOST"))
// Manual overrides for nested keys which Viper doesn't always bind correctly with AutomaticEnv + SetEnvPrefix // Manual overrides for nested keys which Viper doesn't always bind correctly with AutomaticEnv + SetEnvPrefix
if port := os.Getenv("LLM_PROXY__SERVER__PORT"); port != "" { if port := os.Getenv("LLM_PROXY__SERVER__PORT"); port != "" {
fmt.Sscanf(port, "%d", &cfg.Server.Port) fmt.Sscanf(port, "%d", &cfg.Server.Port)
fmt.Printf("Overriding port to %d from env\n", cfg.Server.Port)
} }
if host := os.Getenv("LLM_PROXY__SERVER__HOST"); host != "" { if host := os.Getenv("LLM_PROXY__SERVER__HOST"); host != "" {
cfg.Server.Host = host cfg.Server.Host = host
fmt.Printf("Overriding host to %s from env\n", cfg.Server.Host)
} }
// Ollama overrides // Ollama overrides
@@ -175,8 +173,6 @@ func Load() (*Config, error) {
cfg.Providers.Ollama.Models = strings.Split(models, ",") cfg.Providers.Ollama.Models = strings.Split(models, ",")
} }
fmt.Printf("[DEBUG] Final Ollama Config: Enabled=%v, BaseURL=%s, Models=%v\n", cfg.Providers.Ollama.Enabled, cfg.Providers.Ollama.BaseURL, cfg.Providers.Ollama.Models)
// Validate encryption key // Validate encryption key
if cfg.EncryptionKey == "" { if cfg.EncryptionKey == "" {
return nil, fmt.Errorf("encryption key is required (LLM_PROXY__ENCRYPTION_KEY)") return nil, fmt.Errorf("encryption key is required (LLM_PROXY__ENCRYPTION_KEY)")
+5 -1
View File
@@ -11,10 +11,14 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func AuthMiddleware(database *db.DB) gin.HandlerFunc { func AuthMiddleware(database *db.DB, requireAuth bool) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
if authHeader == "" { if authHeader == "" {
if requireAuth {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"})
return
}
c.Next() c.Next()
return return
} }
+1 -5
View File
@@ -9,9 +9,9 @@ import (
"strings" "strings"
"time" "time"
"github.com/go-resty/resty/v2"
"gophergate/internal/config" "gophergate/internal/config"
"gophergate/internal/models" "gophergate/internal/models"
"github.com/go-resty/resty/v2"
) )
type OllamaProvider struct { type OllamaProvider struct {
@@ -46,9 +46,6 @@ func (p *OllamaProvider) ChatCompletion(ctx context.Context, req *models.Unified
body := BuildOllamaBody(req, messagesJSON, false) body := BuildOllamaBody(req, messagesJSON, false)
url := fmt.Sprintf("%s/chat/completions", p.config.BaseURL) url := fmt.Sprintf("%s/chat/completions", p.config.BaseURL)
// Log request for debugging
fmt.Printf("[Ollama] Request to %s with model %s\n", url, req.Model)
resp, err := p.client.R(). resp, err := p.client.R().
SetContext(ctx). SetContext(ctx).
SetBody(body). SetBody(body).
@@ -70,7 +67,6 @@ func (p *OllamaProvider) ChatCompletion(ctx context.Context, req *models.Unified
return nil, fmt.Errorf("failed to parse response: %w", err) return nil, fmt.Errorf("failed to parse response: %w", err)
} }
fmt.Printf("[Ollama] Success response for model %s\n", req.Model)
return ParseOllamaResponse(respJSON, req.Model) return ParseOllamaResponse(respJSON, req.Model)
} }
+24 -10
View File
@@ -8,17 +8,17 @@ import (
"strings" "strings"
"time" "time"
"gophergate/internal/db"
"gophergate/internal/models"
"gophergate/internal/utils"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gophergate/internal/db"
"gophergate/internal/models"
"gophergate/internal/utils"
"github.com/shirou/gopsutil/v3/cpu" "github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/mem"
"github.com/shirou/gopsutil/v3/disk" "github.com/shirou/gopsutil/v3/disk"
"github.com/shirou/gopsutil/v3/load" "github.com/shirou/gopsutil/v3/load"
"github.com/shirou/gopsutil/v3/mem"
"github.com/shirou/gopsutil/v3/process" "github.com/shirou/gopsutil/v3/process"
) )
@@ -168,7 +168,9 @@ func (s *Server) handleChangePassword(c *gin.Context) {
func (s *Server) handleLogout(c *gin.Context) { func (s *Server) handleLogout(c *gin.Context) {
token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ") token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
s.sessions.RevokeSession(token) if err := s.sessions.RevokeSession(token); err != nil {
fmt.Printf("Error revoking session: %v\n", err)
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Logged out"})) c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Logged out"}))
} }
@@ -444,7 +446,10 @@ func (s *Server) handleAnalyticsBreakdown(c *gin.Context) {
var label string var label string
var value int var value int
if err := mRows.Scan(&label, &value); err == nil { if err := mRows.Scan(&label, &value); err == nil {
models = append(models, struct{Label string `json:"label"`; Value int `json:"value"`}{label, value}) models = append(models, struct {
Label string `json:"label"`
Value int `json:"value"`
}{label, value})
} }
} }
mRows.Close() mRows.Close()
@@ -461,7 +466,10 @@ func (s *Server) handleAnalyticsBreakdown(c *gin.Context) {
var label string var label string
var value int var value int
if err := cRows.Scan(&label, &value); err == nil { if err := cRows.Scan(&label, &value); err == nil {
clients = append(clients, struct{Label string `json:"label"`; Value int `json:"value"`}{label, value}) clients = append(clients, struct {
Label string `json:"label"`
Value int `json:"value"`
}{label, value})
} }
} }
cRows.Close() cRows.Close()
@@ -873,9 +881,15 @@ func (s *Server) handleGetProviders(c *gin.Context) {
var models []string var models []string
if s.registry != nil { if s.registry != nil {
registryID := id registryID := id
if id == "gemini" { registryID = "google" } if id == "gemini" {
if id == "moonshot" { registryID = "moonshot" } registryID = "google"
if id == "grok" { registryID = "xai" } }
if id == "moonshot" {
registryID = "moonshot"
}
if id == "grok" {
registryID = "xai"
}
if pInfo, ok := s.registry.Providers[registryID]; ok { if pInfo, ok := s.registry.Providers[registryID]; ok {
for mID := range pInfo.Models { for mID := range pInfo.Models {
+10 -3
View File
@@ -6,6 +6,7 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"sync"
"time" "time"
"gophergate/internal/config" "gophergate/internal/config"
@@ -27,6 +28,7 @@ type Server struct {
hub *Hub hub *Hub
logger *RequestLogger logger *RequestLogger
registry *models.ModelRegistry registry *models.ModelRegistry
registryMu sync.RWMutex
} }
func NewServer(cfg *config.Config, database *db.DB) *Server { func NewServer(cfg *config.Config, database *db.DB) *Server {
@@ -44,6 +46,7 @@ func NewServer(cfg *config.Config, database *db.DB) *Server {
registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)}, registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)},
} }
s.sessions.StartCleanup()
// Fetch registry in background // Fetch registry in background
go func() { go func() {
registry, err := utils.FetchRegistry() registry, err := utils.FetchRegistry()
@@ -180,7 +183,7 @@ func (s *Server) setupRoutes() {
// API V1 (External LLM Access) - Secured with AuthMiddleware // API V1 (External LLM Access) - Secured with AuthMiddleware
v1 := s.router.Group("/v1") v1 := s.router.Group("/v1")
v1.Use(middleware.AuthMiddleware(s.database)) v1.Use(middleware.AuthMiddleware(s.database, true))
{ {
v1.POST("/chat/completions", s.handleChatCompletions) v1.POST("/chat/completions", s.handleChatCompletions)
v1.GET("/models", s.handleListModels) v1.GET("/models", s.handleListModels)
@@ -267,6 +270,7 @@ func (s *Server) handleListModels(c *gin.Context) {
"ollama": true, "ollama": true,
} }
s.registryMu.RLock()
if s.registry != nil { if s.registry != nil {
for pID, pInfo := range s.registry.Providers { for pID, pInfo := range s.registry.Providers {
if !allowedProviders[pID] { if !allowedProviders[pID] {
@@ -284,6 +288,7 @@ func (s *Server) handleListModels(c *gin.Context) {
} }
} }
} }
s.registryMu.RUnlock()
// Add configured Ollama models // Add configured Ollama models
if s.cfg.Providers.Ollama.Enabled { if s.cfg.Providers.Ollama.Enabled {
@@ -527,9 +532,9 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model string, u
} }
// Calculate cost using registry // Calculate cost using registry
s.registryMu.RLock()
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens) entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
fmt.Printf("[DEBUG] Request logged: model=%s, prompt=%d, completion=%d, reasoning=%d, cache_read=%d, cost=%f\n", s.registryMu.RUnlock()
model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.Cost)
} }
s.logger.LogRequest(entry) s.logger.LogRequest(entry)
@@ -545,7 +550,9 @@ func (s *Server) Run() error {
for range ticker.C { for range ticker.C {
newRegistry, err := utils.FetchRegistry() newRegistry, err := utils.FetchRegistry()
if err == nil { if err == nil {
s.registryMu.Lock()
s.registry = newRegistry s.registry = newRegistry
s.registryMu.Unlock()
} }
} }
}() }()
+22 -4
View File
@@ -133,23 +133,41 @@ func (m *SessionManager) ValidateSession(token string) (*Session, string, error)
return &session, "", nil return &session, "", nil
} }
func (m *SessionManager) RevokeSession(token string) { func (m *SessionManager) RevokeSession(token string) error {
parts := strings.Split(token, ".") parts := strings.Split(token, ".")
if len(parts) != 2 { if len(parts) != 2 {
return return fmt.Errorf("invalid token format")
} }
payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[0]) payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil { if err != nil {
return return fmt.Errorf("failed to decode payload: %w", err)
} }
var payload sessionPayload var payload sessionPayload
if err := json.Unmarshal(payloadJSON, &payload); err != nil { if err := json.Unmarshal(payloadJSON, &payload); err != nil {
return return fmt.Errorf("failed to parse payload: %w", err)
} }
m.mu.Lock() m.mu.Lock()
delete(m.sessions, payload.SessionID) delete(m.sessions, payload.SessionID)
m.mu.Unlock() m.mu.Unlock()
return nil
}
// StartCleanup runs a background goroutine that removes expired sessions every 15 minutes.
func (m *SessionManager) StartCleanup() {
go func() {
ticker := time.NewTicker(15 * time.Minute)
for range ticker.C {
m.mu.Lock()
now := time.Now()
for id, s := range m.sessions {
if now.After(s.ExpiresAt) {
delete(m.sessions, id)
}
}
m.mu.Unlock()
}
}()
} }
+13 -2
View File
@@ -10,13 +10,19 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
var upgrader = websocket.Upgrader{ func newUpgrader(allowedOrigin string) websocket.Upgrader {
return websocket.Upgrader{
ReadBufferSize: 1024, ReadBufferSize: 1024,
WriteBufferSize: 1024, WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool { CheckOrigin: func(r *http.Request) bool {
return true // In production, refine this if allowedOrigin == "*" {
return true
}
origin := r.Header.Get("Origin")
return origin == "" || origin == allowedOrigin
}, },
} }
}
type Hub struct { type Hub struct {
clients map[*websocket.Conn]bool clients map[*websocket.Conn]bool
@@ -75,6 +81,11 @@ func (h *Hub) GetClientCount() int {
} }
func (s *Server) handleWebSocket(c *gin.Context) { func (s *Server) handleWebSocket(c *gin.Context) {
allowedOrigin := s.cfg.Server.WSAllowedOrigin
if allowedOrigin == "" {
allowedOrigin = "*"
}
upgrader := newUpgrader(allowedOrigin)
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil { if err != nil {
log.Printf("Failed to set websocket upgrade: %v", err) log.Printf("Failed to set websocket upgrade: %v", err)