Files
GopherGate/internal/server/server.go
hobokenchicken 3f76a544e0
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
fix: improve analytics accuracy and cost calculation
Refined CalculateCost to correctly handle cached token discounts. Added fuzzy matching to model lookup. Robustified SQL date extraction using SUBSTR and LIKE for better SQLite compatibility.
2026-03-19 12:58:08 -04:00

404 lines
11 KiB
Go

package server
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"llm-proxy/internal/config"
"llm-proxy/internal/db"
"llm-proxy/internal/middleware"
"llm-proxy/internal/models"
"llm-proxy/internal/providers"
"llm-proxy/internal/utils"
"github.com/gin-gonic/gin"
)
type Server struct {
router *gin.Engine
cfg *config.Config
database *db.DB
providers map[string]providers.Provider
sessions *SessionManager
hub *Hub
logger *RequestLogger
registry *models.ModelRegistry
}
func NewServer(cfg *config.Config, database *db.DB) *Server {
router := gin.Default()
hub := NewHub()
// Fetch registry (non-blocking for startup if it fails, but we'll try once)
registry, err := utils.FetchRegistry()
if err != nil {
fmt.Printf("Warning: Failed to fetch initial model registry: %v\n", err)
registry = &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)}
}
s := &Server{
router: router,
cfg: cfg,
database: database,
providers: make(map[string]providers.Provider),
sessions: NewSessionManager(cfg.KeyBytes, 24*time.Hour),
hub: hub,
logger: NewRequestLogger(database, hub),
registry: registry,
}
// Initialize providers
if cfg.Providers.OpenAI.Enabled {
apiKey, _ := cfg.GetAPIKey("openai")
s.providers["openai"] = providers.NewOpenAIProvider(cfg.Providers.OpenAI, apiKey)
}
if cfg.Providers.Gemini.Enabled {
apiKey, _ := cfg.GetAPIKey("gemini")
s.providers["gemini"] = providers.NewGeminiProvider(cfg.Providers.Gemini, apiKey)
}
if cfg.Providers.DeepSeek.Enabled {
apiKey, _ := cfg.GetAPIKey("deepseek")
s.providers["deepseek"] = providers.NewDeepSeekProvider(cfg.Providers.DeepSeek, apiKey)
}
if cfg.Providers.Grok.Enabled {
apiKey, _ := cfg.GetAPIKey("grok")
s.providers["grok"] = providers.NewGrokProvider(cfg.Providers.Grok, apiKey)
}
s.setupRoutes()
return s
}
func (s *Server) setupRoutes() {
// Static files
s.router.StaticFile("/", "./static/index.html")
s.router.StaticFile("/favicon.ico", "./static/favicon.ico")
s.router.Static("/css", "./static/css")
s.router.Static("/js", "./static/js")
s.router.Static("/img", "./static/img")
// WebSocket
s.router.GET("/ws", s.handleWebSocket)
// API V1 (External LLM Access) - Secured with AuthMiddleware
v1 := s.router.Group("/v1")
v1.Use(middleware.AuthMiddleware(s.database))
{
v1.POST("/chat/completions", s.handleChatCompletions)
v1.GET("/models", s.handleListModels)
}
// Dashboard API Group
api := s.router.Group("/api")
{
api.POST("/auth/login", s.handleLogin)
api.GET("/auth/status", s.handleAuthStatus)
api.POST("/auth/logout", s.handleLogout)
api.POST("/auth/change-password", s.handleChangePassword)
// Protected dashboard routes (need admin session)
admin := api.Group("/")
admin.Use(s.adminAuthMiddleware())
{
admin.GET("/usage/summary", s.handleUsageSummary)
admin.GET("/usage/time-series", s.handleTimeSeries)
admin.GET("/usage/providers", s.handleProvidersUsage)
admin.GET("/usage/clients", s.handleClientsUsage)
admin.GET("/usage/detailed", s.handleDetailedUsage)
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
admin.GET("/clients", s.handleGetClients)
admin.POST("/clients", s.handleCreateClient)
admin.GET("/clients/:id", s.handleGetClient)
admin.PUT("/clients/:id", s.handleUpdateClient)
admin.DELETE("/clients/:id", s.handleDeleteClient)
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
admin.GET("/providers", s.handleGetProviders)
admin.PUT("/providers/:name", s.handleUpdateProvider)
admin.POST("/providers/:name/test", s.handleTestProvider)
admin.GET("/models", s.handleGetModels)
admin.PUT("/models/:id", s.handleUpdateModel)
admin.GET("/users", s.handleGetUsers)
admin.POST("/users", s.handleCreateUser)
admin.PUT("/users/:id", s.handleUpdateUser)
admin.DELETE("/users/:id", s.handleDeleteUser)
admin.GET("/system/health", s.handleSystemHealth)
admin.GET("/system/settings", s.handleGetSettings)
admin.POST("/system/backup", s.handleCreateBackup)
admin.GET("/system/logs", s.handleGetLogs)
}
}
s.router.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
}
func (s *Server) handleListModels(c *gin.Context) {
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
var data []OpenAIModel
allowedProviders := map[string]bool{
"openai": true,
"google": true, // Models from models.dev use 'google' ID for Gemini
"deepseek": true,
"xai": true, // Models from models.dev use 'xai' ID for Grok
}
if s.registry != nil {
for pID, pInfo := range s.registry.Providers {
if !allowedProviders[pID] {
continue
}
for mID := range pInfo.Models {
data = append(data, OpenAIModel{
ID: mID,
Object: "model",
Created: 1700000000,
OwnedBy: pID,
})
}
}
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": data,
})
}
func (s *Server) handleChatCompletions(c *gin.Context) {
startTime := time.Now()
var req models.ChatCompletionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Select provider based on model name
providerName := "openai" // default
if strings.Contains(req.Model, "gemini") {
providerName = "gemini"
} else if strings.Contains(req.Model, "deepseek") {
providerName = "deepseek"
} else if strings.Contains(req.Model, "grok") {
providerName = "grok"
}
provider, ok := s.providers[providerName]
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
return
}
// Convert ChatCompletionRequest to UnifiedRequest
unifiedReq := &models.UnifiedRequest{
Model: req.Model,
Messages: []models.UnifiedMessage{},
Temperature: req.Temperature,
TopP: req.TopP,
TopK: req.TopK,
N: req.N,
MaxTokens: req.MaxTokens,
PresencePenalty: req.PresencePenalty,
FrequencyPenalty: req.FrequencyPenalty,
Stream: req.Stream != nil && *req.Stream,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
}
// Handle Stop sequences
if req.Stop != nil {
var stop []string
if err := json.Unmarshal(req.Stop, &stop); err == nil {
unifiedReq.Stop = stop
} else {
var singleStop string
if err := json.Unmarshal(req.Stop, &singleStop); err == nil {
unifiedReq.Stop = []string{singleStop}
}
}
}
// Convert messages
for _, msg := range req.Messages {
unifiedMsg := models.UnifiedMessage{
Role: msg.Role,
Content: []models.UnifiedContentPart{},
ReasoningContent: msg.ReasoningContent,
ToolCalls: msg.ToolCalls,
Name: msg.Name,
ToolCallID: msg.ToolCallID,
}
// Handle multimodal content
if strContent, ok := msg.Content.(string); ok {
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
Type: "text",
Text: strContent,
})
} else if parts, ok := msg.Content.([]interface{}); ok {
for _, part := range parts {
if partMap, ok := part.(map[string]interface{}); ok {
partType, _ := partMap["type"].(string)
if partType == "text" {
text, _ := partMap["text"].(string)
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
Type: "text",
Text: text,
})
} else if partType == "image_url" {
if imgURLMap, ok := partMap["image_url"].(map[string]interface{}); ok {
url, _ := imgURLMap["url"].(string)
imageInput := &models.ImageInput{}
if strings.HasPrefix(url, "data:") {
mime, data, err := utils.ParseDataURL(url)
if err == nil {
imageInput.Base64 = data
imageInput.MimeType = mime
}
} else {
imageInput.URL = url
}
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
Type: "image",
Image: imageInput,
})
unifiedReq.HasImages = true
}
}
}
}
}
unifiedReq.Messages = append(unifiedReq.Messages, unifiedMsg)
}
clientID := "default"
if auth, ok := c.Get("auth"); ok {
if authInfo, ok := auth.(models.AuthInfo); ok {
unifiedReq.ClientID = authInfo.ClientID
clientID = authInfo.ClientID
}
} else {
unifiedReq.ClientID = clientID
}
if unifiedReq.Stream {
ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var lastUsage *models.Usage
c.Stream(func(w io.Writer) bool {
chunk, ok := <-ch
if !ok {
fmt.Fprintf(w, "data: [DONE]\n\n")
s.logRequest(startTime, clientID, providerName, req.Model, lastUsage, nil, unifiedReq.HasImages)
return false
}
if chunk.Usage != nil {
lastUsage = chunk.Usage
}
data, err := json.Marshal(chunk)
if err != nil {
return false
}
fmt.Fprintf(w, "data: %s\n\n", data)
return true
})
return
}
resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage, nil, unifiedReq.HasImages)
c.JSON(http.StatusOK, resp)
}
func (s *Server) logRequest(start time.Time, clientID, provider, model string, usage *models.Usage, err error, hasImages bool) {
entry := RequestLog{
Timestamp: start,
ClientID: clientID,
Provider: provider,
Model: model,
Status: "success",
DurationMS: time.Since(start).Milliseconds(),
HasImages: hasImages,
}
if err != nil {
entry.Status = "error"
entry.ErrorMessage = err.Error()
}
if usage != nil {
entry.PromptTokens = usage.PromptTokens
entry.CompletionTokens = usage.CompletionTokens
entry.TotalTokens = usage.TotalTokens
if usage.ReasoningTokens != nil {
entry.ReasoningTokens = *usage.ReasoningTokens
}
if usage.CacheReadTokens != nil {
entry.CacheReadTokens = *usage.CacheReadTokens
}
if usage.CacheWriteTokens != nil {
entry.CacheWriteTokens = *usage.CacheWriteTokens
}
// Calculate cost using registry
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
fmt.Printf("[DEBUG] Request logged: model=%s, prompt=%d, completion=%d, cache_read=%d, cost=%f\n",
model, entry.PromptTokens, entry.CompletionTokens, entry.CacheReadTokens, entry.Cost)
}
s.logger.LogRequest(entry)
}
func (s *Server) Run() error {
go s.hub.Run()
s.logger.Start()
// Start registry refresher
go func() {
ticker := time.NewTicker(24 * time.Hour)
for range ticker.C {
newRegistry, err := utils.FetchRegistry()
if err == nil {
s.registry = newRegistry
}
}
}()
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
return s.router.Run(addr)
}