fix: Phase 1 - security & stability patches
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled

- AuthMiddleware now requires auth on /v1/* routes (returns 401)
- WebSocket origin check configurable via WSAllowedOrigin
- Removed debug fmt.Printf leaks (config, ollama, server)
- Registry access protected by sync.RWMutex (race condition fix)
- Session cleanup goroutine runs every 15 min
- RevokeSession returns error instead of silent no-op
This commit is contained in:
2026-04-26 14:45:22 -04:00
parent da074f52b4
commit 8a8d8d1477
13 changed files with 448 additions and 105 deletions
+44 -30
View File
@@ -8,17 +8,17 @@ import (
"strings"
"time"
"gophergate/internal/db"
"gophergate/internal/models"
"gophergate/internal/utils"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"gophergate/internal/db"
"gophergate/internal/models"
"gophergate/internal/utils"
"github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/mem"
"github.com/shirou/gopsutil/v3/disk"
"github.com/shirou/gopsutil/v3/load"
"github.com/shirou/gopsutil/v3/mem"
"github.com/shirou/gopsutil/v3/process"
)
@@ -168,7 +168,9 @@ func (s *Server) handleChangePassword(c *gin.Context) {
func (s *Server) handleLogout(c *gin.Context) {
token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
s.sessions.RevokeSession(token)
if err := s.sessions.RevokeSession(token); err != nil {
fmt.Printf("Error revoking session: %v\n", err)
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Logged out"}))
}
@@ -226,7 +228,7 @@ func (s *Server) handleUsageSummary(c *gin.Context) {
}
clause, binds := filter.ToSQL()
// Total stats
var totalStats struct {
TotalRequests int `db:"total_requests"`
@@ -307,7 +309,7 @@ func (s *Server) handleTimeSeries(c *gin.Context) {
}
clause, binds := filter.ToSQL()
if clause == "" {
cutoff := time.Now().UTC().Add(-30 * 24 * time.Hour)
clause = " AND timestamp >= ?"
@@ -444,7 +446,10 @@ func (s *Server) handleAnalyticsBreakdown(c *gin.Context) {
var label string
var value int
if err := mRows.Scan(&label, &value); err == nil {
models = append(models, struct{Label string `json:"label"`; Value int `json:"value"`}{label, value})
models = append(models, struct {
Label string `json:"label"`
Value int `json:"value"`
}{label, value})
}
}
mRows.Close()
@@ -461,7 +466,10 @@ func (s *Server) handleAnalyticsBreakdown(c *gin.Context) {
var label string
var value int
if err := cRows.Scan(&label, &value); err == nil {
clients = append(clients, struct{Label string `json:"label"`; Value int `json:"value"`}{label, value})
clients = append(clients, struct {
Label string `json:"label"`
Value int `json:"value"`
}{label, value})
}
}
cRows.Close()
@@ -537,15 +545,15 @@ func (s *Server) handleGetClients(c *gin.Context) {
}
type UIClient struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
CreatedAt time.Time `json:"created_at"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
CreatedAt time.Time `json:"created_at"`
LastUsed *time.Time `json:"last_used"`
RequestsCount int `json:"requests_count"`
TokensCount int `json:"tokens_count"`
Status string `json:"status"`
RateLimitPerMinute int `json:"rate_limit_per_minute"`
RequestsCount int `json:"requests_count"`
TokensCount int `json:"tokens_count"`
Status string `json:"status"`
RateLimitPerMinute int `json:"rate_limit_per_minute"`
}
uiClients := make([]UIClient, len(clients))
@@ -608,12 +616,12 @@ func (s *Server) handleGetClient(c *gin.Context) {
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"id": cl.ClientID,
"name": name,
"description": desc,
"is_active": cl.IsActive,
"rate_limit_per_minute": cl.RateLimitPerMinute,
"created_at": cl.CreatedAt,
"id": cl.ClientID,
"name": name,
"description": desc,
"is_active": cl.IsActive,
"rate_limit_per_minute": cl.RateLimitPerMinute,
"created_at": cl.CreatedAt,
}))
}
@@ -873,10 +881,16 @@ func (s *Server) handleGetProviders(c *gin.Context) {
var models []string
if s.registry != nil {
registryID := id
if id == "gemini" { registryID = "google" }
if id == "moonshot" { registryID = "moonshot" }
if id == "grok" { registryID = "xai" }
if id == "gemini" {
registryID = "google"
}
if id == "moonshot" {
registryID = "moonshot"
}
if id == "grok" {
registryID = "xai"
}
if pInfo, ok := s.registry.Providers[registryID]; ok {
for mID := range pInfo.Models {
models = append(models, mID)
@@ -969,7 +983,7 @@ func (s *Server) handleTestProvider(c *gin.Context) {
}
startTime := time.Now()
// Prepare a simple test request
testReq := &models.UnifiedRequest{
Model: "gpt-4o-mini", // Default cheap test model
@@ -1023,7 +1037,7 @@ func (s *Server) handleGetModels(c *gin.Context) {
// Merge registry models with DB overrides
var dbModels []db.ModelConfig
_ = s.database.Select(&dbModels, "SELECT * FROM model_configs")
dbMap := make(map[string]db.ModelConfig)
for _, m := range dbModels {
dbMap[m.ID] = m
@@ -1305,7 +1319,7 @@ func (s *Server) handleUpdateUser(c *gin.Context) {
func (s *Server) handleDeleteUser(c *gin.Context) {
id := c.Param("id")
session, _ := c.Get("session")
if sess, ok := session.(*Session); ok {
var username string
+33 -26
View File
@@ -6,6 +6,7 @@ import (
"io"
"net/http"
"strings"
"sync"
"time"
"gophergate/internal/config"
@@ -19,14 +20,15 @@ import (
)
type Server struct {
router *gin.Engine
cfg *config.Config
database *db.DB
providers map[string]providers.Provider
sessions *SessionManager
hub *Hub
logger *RequestLogger
registry *models.ModelRegistry
router *gin.Engine
cfg *config.Config
database *db.DB
providers map[string]providers.Provider
sessions *SessionManager
hub *Hub
logger *RequestLogger
registry *models.ModelRegistry
registryMu sync.RWMutex
}
func NewServer(cfg *config.Config, database *db.DB) *Server {
@@ -44,6 +46,7 @@ func NewServer(cfg *config.Config, database *db.DB) *Server {
registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)},
}
s.sessions.StartCleanup()
// Fetch registry in background
go func() {
registry, err := utils.FetchRegistry()
@@ -180,7 +183,7 @@ func (s *Server) setupRoutes() {
// API V1 (External LLM Access) - Secured with AuthMiddleware
v1 := s.router.Group("/v1")
v1.Use(middleware.AuthMiddleware(s.database))
v1.Use(middleware.AuthMiddleware(s.database, true))
{
v1.POST("/chat/completions", s.handleChatCompletions)
v1.GET("/models", s.handleListModels)
@@ -194,7 +197,7 @@ func (s *Server) setupRoutes() {
api.GET("/auth/status", s.handleAuthStatus)
api.POST("/auth/logout", s.handleLogout)
api.POST("/auth/change-password", s.handleChangePassword)
// Protected dashboard routes (need admin session)
admin := api.Group("/")
admin.Use(s.adminAuthMiddleware())
@@ -205,13 +208,13 @@ func (s *Server) setupRoutes() {
admin.GET("/usage/clients", s.handleClientsUsage)
admin.GET("/usage/detailed", s.handleDetailedUsage)
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
admin.GET("/clients", s.handleGetClients)
admin.POST("/clients", s.handleCreateClient)
admin.GET("/clients/:id", s.handleGetClient)
admin.PUT("/clients/:id", s.handleUpdateClient)
admin.DELETE("/clients/:id", s.handleDeleteClient)
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
@@ -219,7 +222,7 @@ func (s *Server) setupRoutes() {
admin.GET("/providers", s.handleGetProviders)
admin.PUT("/providers/:name", s.handleUpdateProvider)
admin.POST("/providers/:name/test", s.handleTestProvider)
admin.GET("/models", s.handleGetModels)
admin.PUT("/models/:id", s.handleUpdateModel)
@@ -267,6 +270,7 @@ func (s *Server) handleListModels(c *gin.Context) {
"ollama": true,
}
s.registryMu.RLock()
if s.registry != nil {
for pID, pInfo := range s.registry.Providers {
if !allowedProviders[pID] {
@@ -284,6 +288,7 @@ func (s *Server) handleListModels(c *gin.Context) {
}
}
}
s.registryMu.RUnlock()
// Add configured Ollama models
if s.cfg.Providers.Ollama.Enabled {
@@ -330,15 +335,15 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
providerName = "moonshot"
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
providerName = "grok"
} else if strings.HasPrefix(modelLower, "ollama/") ||
strings.Contains(modelLower, "glm-") ||
strings.Contains(modelLower, "qwen") ||
strings.Contains(modelLower, "gemma") ||
strings.Contains(modelLower, "llama") ||
strings.Contains(modelLower, "mistral") ||
strings.Contains(modelLower, "phi") ||
strings.Contains(modelLower, "yi") ||
strings.Contains(modelLower, "codellama") ||
} else if strings.HasPrefix(modelLower, "ollama/") ||
strings.Contains(modelLower, "glm-") ||
strings.Contains(modelLower, "qwen") ||
strings.Contains(modelLower, "gemma") ||
strings.Contains(modelLower, "llama") ||
strings.Contains(modelLower, "mistral") ||
strings.Contains(modelLower, "phi") ||
strings.Contains(modelLower, "yi") ||
strings.Contains(modelLower, "codellama") ||
strings.Contains(modelLower, "command-r") {
providerName = "ollama"
}
@@ -525,11 +530,11 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model string, u
if usage.CacheWriteTokens != nil {
entry.CacheWriteTokens = *usage.CacheWriteTokens
}
// Calculate cost using registry
s.registryMu.RLock()
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
fmt.Printf("[DEBUG] Request logged: model=%s, prompt=%d, completion=%d, reasoning=%d, cache_read=%d, cost=%f\n",
model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.Cost)
s.registryMu.RUnlock()
}
s.logger.LogRequest(entry)
@@ -538,14 +543,16 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model string, u
func (s *Server) Run() error {
go s.hub.Run()
s.logger.Start()
// Start registry refresher
go func() {
ticker := time.NewTicker(24 * time.Hour)
for range ticker.C {
newRegistry, err := utils.FetchRegistry()
if err == nil {
s.registryMu.Lock()
s.registry = newRegistry
s.registryMu.Unlock()
}
}
}()
+23 -5
View File
@@ -79,7 +79,7 @@ func (m *SessionManager) createSignedToken(sessionID, username, displayName, rol
}
payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON)
h := hmac.New(sha256.New, m.secret)
h.Write(payloadJSON)
signature := h.Sum(nil)
@@ -133,23 +133,41 @@ func (m *SessionManager) ValidateSession(token string) (*Session, string, error)
return &session, "", nil
}
func (m *SessionManager) RevokeSession(token string) {
func (m *SessionManager) RevokeSession(token string) error {
parts := strings.Split(token, ".")
if len(parts) != 2 {
return
return fmt.Errorf("invalid token format")
}
payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return
return fmt.Errorf("failed to decode payload: %w", err)
}
var payload sessionPayload
if err := json.Unmarshal(payloadJSON, &payload); err != nil {
return
return fmt.Errorf("failed to parse payload: %w", err)
}
m.mu.Lock()
delete(m.sessions, payload.SessionID)
m.mu.Unlock()
return nil
}
// StartCleanup runs a background goroutine that removes expired sessions every 15 minutes.
func (m *SessionManager) StartCleanup() {
go func() {
ticker := time.NewTicker(15 * time.Minute)
for range ticker.C {
m.mu.Lock()
now := time.Now()
for id, s := range m.sessions {
if now.After(s.ExpiresAt) {
delete(m.sessions, id)
}
}
m.mu.Unlock()
}
}()
}
+18 -7
View File
@@ -10,12 +10,18 @@ import (
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // In production, refine this
},
func newUpgrader(allowedOrigin string) websocket.Upgrader {
return websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
if allowedOrigin == "*" {
return true
}
origin := r.Header.Get("Origin")
return origin == "" || origin == allowedOrigin
},
}
}
type Hub struct {
@@ -75,6 +81,11 @@ func (h *Hub) GetClientCount() int {
}
func (s *Server) handleWebSocket(c *gin.Context) {
allowedOrigin := s.cfg.Server.WSAllowedOrigin
if allowedOrigin == "" {
allowedOrigin = "*"
}
upgrader := newUpgrader(allowedOrigin)
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("Failed to set websocket upgrade: %v", err)
@@ -99,7 +110,7 @@ func (s *Server) handleWebSocket(c *gin.Context) {
if err != nil {
break
}
if msg["type"] == "ping" {
conn.WriteJSON(gin.H{"type": "pong", "payload": gin.H{}})
}