fix: Phase 1 - security & stability patches
- AuthMiddleware now requires auth on /v1/* routes (returns 401) - WebSocket origin check configurable via WSAllowedOrigin - Removed debug fmt.Printf leaks (config, ollama, server) - Registry access protected by sync.RWMutex (race condition fix) - Session cleanup goroutine runs every 15 min - RevokeSession returns error instead of silent no-op
This commit is contained in:
+33
-26
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gophergate/internal/config"
|
||||
@@ -19,14 +20,15 @@ import (
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
router *gin.Engine
|
||||
cfg *config.Config
|
||||
database *db.DB
|
||||
providers map[string]providers.Provider
|
||||
sessions *SessionManager
|
||||
hub *Hub
|
||||
logger *RequestLogger
|
||||
registry *models.ModelRegistry
|
||||
router *gin.Engine
|
||||
cfg *config.Config
|
||||
database *db.DB
|
||||
providers map[string]providers.Provider
|
||||
sessions *SessionManager
|
||||
hub *Hub
|
||||
logger *RequestLogger
|
||||
registry *models.ModelRegistry
|
||||
registryMu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewServer(cfg *config.Config, database *db.DB) *Server {
|
||||
@@ -44,6 +46,7 @@ func NewServer(cfg *config.Config, database *db.DB) *Server {
|
||||
registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)},
|
||||
}
|
||||
|
||||
s.sessions.StartCleanup()
|
||||
// Fetch registry in background
|
||||
go func() {
|
||||
registry, err := utils.FetchRegistry()
|
||||
@@ -180,7 +183,7 @@ func (s *Server) setupRoutes() {
|
||||
|
||||
// API V1 (External LLM Access) - Secured with AuthMiddleware
|
||||
v1 := s.router.Group("/v1")
|
||||
v1.Use(middleware.AuthMiddleware(s.database))
|
||||
v1.Use(middleware.AuthMiddleware(s.database, true))
|
||||
{
|
||||
v1.POST("/chat/completions", s.handleChatCompletions)
|
||||
v1.GET("/models", s.handleListModels)
|
||||
@@ -194,7 +197,7 @@ func (s *Server) setupRoutes() {
|
||||
api.GET("/auth/status", s.handleAuthStatus)
|
||||
api.POST("/auth/logout", s.handleLogout)
|
||||
api.POST("/auth/change-password", s.handleChangePassword)
|
||||
|
||||
|
||||
// Protected dashboard routes (need admin session)
|
||||
admin := api.Group("/")
|
||||
admin.Use(s.adminAuthMiddleware())
|
||||
@@ -205,13 +208,13 @@ func (s *Server) setupRoutes() {
|
||||
admin.GET("/usage/clients", s.handleClientsUsage)
|
||||
admin.GET("/usage/detailed", s.handleDetailedUsage)
|
||||
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
|
||||
|
||||
|
||||
admin.GET("/clients", s.handleGetClients)
|
||||
admin.POST("/clients", s.handleCreateClient)
|
||||
admin.GET("/clients/:id", s.handleGetClient)
|
||||
admin.PUT("/clients/:id", s.handleUpdateClient)
|
||||
admin.DELETE("/clients/:id", s.handleDeleteClient)
|
||||
|
||||
|
||||
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
|
||||
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
|
||||
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
|
||||
@@ -219,7 +222,7 @@ func (s *Server) setupRoutes() {
|
||||
admin.GET("/providers", s.handleGetProviders)
|
||||
admin.PUT("/providers/:name", s.handleUpdateProvider)
|
||||
admin.POST("/providers/:name/test", s.handleTestProvider)
|
||||
|
||||
|
||||
admin.GET("/models", s.handleGetModels)
|
||||
admin.PUT("/models/:id", s.handleUpdateModel)
|
||||
|
||||
@@ -267,6 +270,7 @@ func (s *Server) handleListModels(c *gin.Context) {
|
||||
"ollama": true,
|
||||
}
|
||||
|
||||
s.registryMu.RLock()
|
||||
if s.registry != nil {
|
||||
for pID, pInfo := range s.registry.Providers {
|
||||
if !allowedProviders[pID] {
|
||||
@@ -284,6 +288,7 @@ func (s *Server) handleListModels(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
s.registryMu.RUnlock()
|
||||
|
||||
// Add configured Ollama models
|
||||
if s.cfg.Providers.Ollama.Enabled {
|
||||
@@ -330,15 +335,15 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||
providerName = "moonshot"
|
||||
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
|
||||
providerName = "grok"
|
||||
} else if strings.HasPrefix(modelLower, "ollama/") ||
|
||||
strings.Contains(modelLower, "glm-") ||
|
||||
strings.Contains(modelLower, "qwen") ||
|
||||
strings.Contains(modelLower, "gemma") ||
|
||||
strings.Contains(modelLower, "llama") ||
|
||||
strings.Contains(modelLower, "mistral") ||
|
||||
strings.Contains(modelLower, "phi") ||
|
||||
strings.Contains(modelLower, "yi") ||
|
||||
strings.Contains(modelLower, "codellama") ||
|
||||
} else if strings.HasPrefix(modelLower, "ollama/") ||
|
||||
strings.Contains(modelLower, "glm-") ||
|
||||
strings.Contains(modelLower, "qwen") ||
|
||||
strings.Contains(modelLower, "gemma") ||
|
||||
strings.Contains(modelLower, "llama") ||
|
||||
strings.Contains(modelLower, "mistral") ||
|
||||
strings.Contains(modelLower, "phi") ||
|
||||
strings.Contains(modelLower, "yi") ||
|
||||
strings.Contains(modelLower, "codellama") ||
|
||||
strings.Contains(modelLower, "command-r") {
|
||||
providerName = "ollama"
|
||||
}
|
||||
@@ -525,11 +530,11 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model string, u
|
||||
if usage.CacheWriteTokens != nil {
|
||||
entry.CacheWriteTokens = *usage.CacheWriteTokens
|
||||
}
|
||||
|
||||
|
||||
// Calculate cost using registry
|
||||
s.registryMu.RLock()
|
||||
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
|
||||
fmt.Printf("[DEBUG] Request logged: model=%s, prompt=%d, completion=%d, reasoning=%d, cache_read=%d, cost=%f\n",
|
||||
model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.Cost)
|
||||
s.registryMu.RUnlock()
|
||||
}
|
||||
|
||||
s.logger.LogRequest(entry)
|
||||
@@ -538,14 +543,16 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model string, u
|
||||
func (s *Server) Run() error {
|
||||
go s.hub.Run()
|
||||
s.logger.Start()
|
||||
|
||||
|
||||
// Start registry refresher
|
||||
go func() {
|
||||
ticker := time.NewTicker(24 * time.Hour)
|
||||
for range ticker.C {
|
||||
newRegistry, err := utils.FetchRegistry()
|
||||
if err == nil {
|
||||
s.registryMu.Lock()
|
||||
s.registry = newRegistry
|
||||
s.registryMu.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
Reference in New Issue
Block a user