fix: Phase 1 - security & stability patches
- AuthMiddleware now requires auth on /v1/* routes (returns 401) - WebSocket origin check configurable via WSAllowedOrigin - Removed debug fmt.Printf leaks (config, ollama, server) - Registry access protected by sync.RWMutex (race condition fix) - Session cleanup goroutine runs every 15 min - RevokeSession returns error instead of silent no-op
This commit is contained in:
@@ -8,17 +8,17 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gophergate/internal/db"
|
||||
"gophergate/internal/models"
|
||||
"gophergate/internal/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gophergate/internal/db"
|
||||
"gophergate/internal/models"
|
||||
"gophergate/internal/utils"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/cpu"
|
||||
"github.com/shirou/gopsutil/v3/mem"
|
||||
"github.com/shirou/gopsutil/v3/disk"
|
||||
"github.com/shirou/gopsutil/v3/load"
|
||||
"github.com/shirou/gopsutil/v3/mem"
|
||||
"github.com/shirou/gopsutil/v3/process"
|
||||
)
|
||||
|
||||
@@ -168,7 +168,9 @@ func (s *Server) handleChangePassword(c *gin.Context) {
|
||||
|
||||
func (s *Server) handleLogout(c *gin.Context) {
|
||||
token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
|
||||
s.sessions.RevokeSession(token)
|
||||
if err := s.sessions.RevokeSession(token); err != nil {
|
||||
fmt.Printf("Error revoking session: %v\n", err)
|
||||
}
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Logged out"}))
|
||||
}
|
||||
|
||||
@@ -226,7 +228,7 @@ func (s *Server) handleUsageSummary(c *gin.Context) {
|
||||
}
|
||||
|
||||
clause, binds := filter.ToSQL()
|
||||
|
||||
|
||||
// Total stats
|
||||
var totalStats struct {
|
||||
TotalRequests int `db:"total_requests"`
|
||||
@@ -307,7 +309,7 @@ func (s *Server) handleTimeSeries(c *gin.Context) {
|
||||
}
|
||||
|
||||
clause, binds := filter.ToSQL()
|
||||
|
||||
|
||||
if clause == "" {
|
||||
cutoff := time.Now().UTC().Add(-30 * 24 * time.Hour)
|
||||
clause = " AND timestamp >= ?"
|
||||
@@ -444,7 +446,10 @@ func (s *Server) handleAnalyticsBreakdown(c *gin.Context) {
|
||||
var label string
|
||||
var value int
|
||||
if err := mRows.Scan(&label, &value); err == nil {
|
||||
models = append(models, struct{Label string `json:"label"`; Value int `json:"value"`}{label, value})
|
||||
models = append(models, struct {
|
||||
Label string `json:"label"`
|
||||
Value int `json:"value"`
|
||||
}{label, value})
|
||||
}
|
||||
}
|
||||
mRows.Close()
|
||||
@@ -461,7 +466,10 @@ func (s *Server) handleAnalyticsBreakdown(c *gin.Context) {
|
||||
var label string
|
||||
var value int
|
||||
if err := cRows.Scan(&label, &value); err == nil {
|
||||
clients = append(clients, struct{Label string `json:"label"`; Value int `json:"value"`}{label, value})
|
||||
clients = append(clients, struct {
|
||||
Label string `json:"label"`
|
||||
Value int `json:"value"`
|
||||
}{label, value})
|
||||
}
|
||||
}
|
||||
cRows.Close()
|
||||
@@ -537,15 +545,15 @@ func (s *Server) handleGetClients(c *gin.Context) {
|
||||
}
|
||||
|
||||
type UIClient struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastUsed *time.Time `json:"last_used"`
|
||||
RequestsCount int `json:"requests_count"`
|
||||
TokensCount int `json:"tokens_count"`
|
||||
Status string `json:"status"`
|
||||
RateLimitPerMinute int `json:"rate_limit_per_minute"`
|
||||
RequestsCount int `json:"requests_count"`
|
||||
TokensCount int `json:"tokens_count"`
|
||||
Status string `json:"status"`
|
||||
RateLimitPerMinute int `json:"rate_limit_per_minute"`
|
||||
}
|
||||
|
||||
uiClients := make([]UIClient, len(clients))
|
||||
@@ -608,12 +616,12 @@ func (s *Server) handleGetClient(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
||||
"id": cl.ClientID,
|
||||
"name": name,
|
||||
"description": desc,
|
||||
"is_active": cl.IsActive,
|
||||
"rate_limit_per_minute": cl.RateLimitPerMinute,
|
||||
"created_at": cl.CreatedAt,
|
||||
"id": cl.ClientID,
|
||||
"name": name,
|
||||
"description": desc,
|
||||
"is_active": cl.IsActive,
|
||||
"rate_limit_per_minute": cl.RateLimitPerMinute,
|
||||
"created_at": cl.CreatedAt,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -873,10 +881,16 @@ func (s *Server) handleGetProviders(c *gin.Context) {
|
||||
var models []string
|
||||
if s.registry != nil {
|
||||
registryID := id
|
||||
if id == "gemini" { registryID = "google" }
|
||||
if id == "moonshot" { registryID = "moonshot" }
|
||||
if id == "grok" { registryID = "xai" }
|
||||
|
||||
if id == "gemini" {
|
||||
registryID = "google"
|
||||
}
|
||||
if id == "moonshot" {
|
||||
registryID = "moonshot"
|
||||
}
|
||||
if id == "grok" {
|
||||
registryID = "xai"
|
||||
}
|
||||
|
||||
if pInfo, ok := s.registry.Providers[registryID]; ok {
|
||||
for mID := range pInfo.Models {
|
||||
models = append(models, mID)
|
||||
@@ -969,7 +983,7 @@ func (s *Server) handleTestProvider(c *gin.Context) {
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
|
||||
// Prepare a simple test request
|
||||
testReq := &models.UnifiedRequest{
|
||||
Model: "gpt-4o-mini", // Default cheap test model
|
||||
@@ -1023,7 +1037,7 @@ func (s *Server) handleGetModels(c *gin.Context) {
|
||||
// Merge registry models with DB overrides
|
||||
var dbModels []db.ModelConfig
|
||||
_ = s.database.Select(&dbModels, "SELECT * FROM model_configs")
|
||||
|
||||
|
||||
dbMap := make(map[string]db.ModelConfig)
|
||||
for _, m := range dbModels {
|
||||
dbMap[m.ID] = m
|
||||
@@ -1305,7 +1319,7 @@ func (s *Server) handleUpdateUser(c *gin.Context) {
|
||||
|
||||
func (s *Server) handleDeleteUser(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
|
||||
session, _ := c.Get("session")
|
||||
if sess, ok := session.(*Session); ok {
|
||||
var username string
|
||||
|
||||
+33
-26
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gophergate/internal/config"
|
||||
@@ -19,14 +20,15 @@ import (
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
router *gin.Engine
|
||||
cfg *config.Config
|
||||
database *db.DB
|
||||
providers map[string]providers.Provider
|
||||
sessions *SessionManager
|
||||
hub *Hub
|
||||
logger *RequestLogger
|
||||
registry *models.ModelRegistry
|
||||
router *gin.Engine
|
||||
cfg *config.Config
|
||||
database *db.DB
|
||||
providers map[string]providers.Provider
|
||||
sessions *SessionManager
|
||||
hub *Hub
|
||||
logger *RequestLogger
|
||||
registry *models.ModelRegistry
|
||||
registryMu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewServer(cfg *config.Config, database *db.DB) *Server {
|
||||
@@ -44,6 +46,7 @@ func NewServer(cfg *config.Config, database *db.DB) *Server {
|
||||
registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)},
|
||||
}
|
||||
|
||||
s.sessions.StartCleanup()
|
||||
// Fetch registry in background
|
||||
go func() {
|
||||
registry, err := utils.FetchRegistry()
|
||||
@@ -180,7 +183,7 @@ func (s *Server) setupRoutes() {
|
||||
|
||||
// API V1 (External LLM Access) - Secured with AuthMiddleware
|
||||
v1 := s.router.Group("/v1")
|
||||
v1.Use(middleware.AuthMiddleware(s.database))
|
||||
v1.Use(middleware.AuthMiddleware(s.database, true))
|
||||
{
|
||||
v1.POST("/chat/completions", s.handleChatCompletions)
|
||||
v1.GET("/models", s.handleListModels)
|
||||
@@ -194,7 +197,7 @@ func (s *Server) setupRoutes() {
|
||||
api.GET("/auth/status", s.handleAuthStatus)
|
||||
api.POST("/auth/logout", s.handleLogout)
|
||||
api.POST("/auth/change-password", s.handleChangePassword)
|
||||
|
||||
|
||||
// Protected dashboard routes (need admin session)
|
||||
admin := api.Group("/")
|
||||
admin.Use(s.adminAuthMiddleware())
|
||||
@@ -205,13 +208,13 @@ func (s *Server) setupRoutes() {
|
||||
admin.GET("/usage/clients", s.handleClientsUsage)
|
||||
admin.GET("/usage/detailed", s.handleDetailedUsage)
|
||||
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
|
||||
|
||||
|
||||
admin.GET("/clients", s.handleGetClients)
|
||||
admin.POST("/clients", s.handleCreateClient)
|
||||
admin.GET("/clients/:id", s.handleGetClient)
|
||||
admin.PUT("/clients/:id", s.handleUpdateClient)
|
||||
admin.DELETE("/clients/:id", s.handleDeleteClient)
|
||||
|
||||
|
||||
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
|
||||
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
|
||||
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
|
||||
@@ -219,7 +222,7 @@ func (s *Server) setupRoutes() {
|
||||
admin.GET("/providers", s.handleGetProviders)
|
||||
admin.PUT("/providers/:name", s.handleUpdateProvider)
|
||||
admin.POST("/providers/:name/test", s.handleTestProvider)
|
||||
|
||||
|
||||
admin.GET("/models", s.handleGetModels)
|
||||
admin.PUT("/models/:id", s.handleUpdateModel)
|
||||
|
||||
@@ -267,6 +270,7 @@ func (s *Server) handleListModels(c *gin.Context) {
|
||||
"ollama": true,
|
||||
}
|
||||
|
||||
s.registryMu.RLock()
|
||||
if s.registry != nil {
|
||||
for pID, pInfo := range s.registry.Providers {
|
||||
if !allowedProviders[pID] {
|
||||
@@ -284,6 +288,7 @@ func (s *Server) handleListModels(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
s.registryMu.RUnlock()
|
||||
|
||||
// Add configured Ollama models
|
||||
if s.cfg.Providers.Ollama.Enabled {
|
||||
@@ -330,15 +335,15 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||
providerName = "moonshot"
|
||||
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
|
||||
providerName = "grok"
|
||||
} else if strings.HasPrefix(modelLower, "ollama/") ||
|
||||
strings.Contains(modelLower, "glm-") ||
|
||||
strings.Contains(modelLower, "qwen") ||
|
||||
strings.Contains(modelLower, "gemma") ||
|
||||
strings.Contains(modelLower, "llama") ||
|
||||
strings.Contains(modelLower, "mistral") ||
|
||||
strings.Contains(modelLower, "phi") ||
|
||||
strings.Contains(modelLower, "yi") ||
|
||||
strings.Contains(modelLower, "codellama") ||
|
||||
} else if strings.HasPrefix(modelLower, "ollama/") ||
|
||||
strings.Contains(modelLower, "glm-") ||
|
||||
strings.Contains(modelLower, "qwen") ||
|
||||
strings.Contains(modelLower, "gemma") ||
|
||||
strings.Contains(modelLower, "llama") ||
|
||||
strings.Contains(modelLower, "mistral") ||
|
||||
strings.Contains(modelLower, "phi") ||
|
||||
strings.Contains(modelLower, "yi") ||
|
||||
strings.Contains(modelLower, "codellama") ||
|
||||
strings.Contains(modelLower, "command-r") {
|
||||
providerName = "ollama"
|
||||
}
|
||||
@@ -525,11 +530,11 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model string, u
|
||||
if usage.CacheWriteTokens != nil {
|
||||
entry.CacheWriteTokens = *usage.CacheWriteTokens
|
||||
}
|
||||
|
||||
|
||||
// Calculate cost using registry
|
||||
s.registryMu.RLock()
|
||||
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
|
||||
fmt.Printf("[DEBUG] Request logged: model=%s, prompt=%d, completion=%d, reasoning=%d, cache_read=%d, cost=%f\n",
|
||||
model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.Cost)
|
||||
s.registryMu.RUnlock()
|
||||
}
|
||||
|
||||
s.logger.LogRequest(entry)
|
||||
@@ -538,14 +543,16 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model string, u
|
||||
func (s *Server) Run() error {
|
||||
go s.hub.Run()
|
||||
s.logger.Start()
|
||||
|
||||
|
||||
// Start registry refresher
|
||||
go func() {
|
||||
ticker := time.NewTicker(24 * time.Hour)
|
||||
for range ticker.C {
|
||||
newRegistry, err := utils.FetchRegistry()
|
||||
if err == nil {
|
||||
s.registryMu.Lock()
|
||||
s.registry = newRegistry
|
||||
s.registryMu.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -79,7 +79,7 @@ func (m *SessionManager) createSignedToken(sessionID, username, displayName, rol
|
||||
}
|
||||
|
||||
payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON)
|
||||
|
||||
|
||||
h := hmac.New(sha256.New, m.secret)
|
||||
h.Write(payloadJSON)
|
||||
signature := h.Sum(nil)
|
||||
@@ -133,23 +133,41 @@ func (m *SessionManager) ValidateSession(token string) (*Session, string, error)
|
||||
return &session, "", nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) RevokeSession(token string) {
|
||||
func (m *SessionManager) RevokeSession(token string) error {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 2 {
|
||||
return
|
||||
return fmt.Errorf("invalid token format")
|
||||
}
|
||||
|
||||
payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return
|
||||
return fmt.Errorf("failed to decode payload: %w", err)
|
||||
}
|
||||
|
||||
var payload sessionPayload
|
||||
if err := json.Unmarshal(payloadJSON, &payload); err != nil {
|
||||
return
|
||||
return fmt.Errorf("failed to parse payload: %w", err)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
delete(m.sessions, payload.SessionID)
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartCleanup runs a background goroutine that removes expired sessions every 15 minutes.
|
||||
func (m *SessionManager) StartCleanup() {
|
||||
go func() {
|
||||
ticker := time.NewTicker(15 * time.Minute)
|
||||
for range ticker.C {
|
||||
m.mu.Lock()
|
||||
now := time.Now()
|
||||
for id, s := range m.sessions {
|
||||
if now.After(s.ExpiresAt) {
|
||||
delete(m.sessions, id)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -10,12 +10,18 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // In production, refine this
|
||||
},
|
||||
func newUpgrader(allowedOrigin string) websocket.Upgrader {
|
||||
return websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
if allowedOrigin == "*" {
|
||||
return true
|
||||
}
|
||||
origin := r.Header.Get("Origin")
|
||||
return origin == "" || origin == allowedOrigin
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type Hub struct {
|
||||
@@ -75,6 +81,11 @@ func (h *Hub) GetClientCount() int {
|
||||
}
|
||||
|
||||
func (s *Server) handleWebSocket(c *gin.Context) {
|
||||
allowedOrigin := s.cfg.Server.WSAllowedOrigin
|
||||
if allowedOrigin == "" {
|
||||
allowedOrigin = "*"
|
||||
}
|
||||
upgrader := newUpgrader(allowedOrigin)
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to set websocket upgrade: %v", err)
|
||||
@@ -99,7 +110,7 @@ func (s *Server) handleWebSocket(c *gin.Context) {
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
|
||||
if msg["type"] == "ping" {
|
||||
conn.WriteJSON(gin.H{"type": "pong", "payload": gin.H{}})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user