package server import ( "encoding/json" "fmt" "io" "net/http" "strings" "time" "llm-proxy/internal/config" "llm-proxy/internal/db" "llm-proxy/internal/middleware" "llm-proxy/internal/models" "llm-proxy/internal/providers" "llm-proxy/internal/utils" "github.com/gin-gonic/gin" ) type Server struct { router *gin.Engine cfg *config.Config database *db.DB providers map[string]providers.Provider sessions *SessionManager hub *Hub logger *RequestLogger registry *models.ModelRegistry } func NewServer(cfg *config.Config, database *db.DB) *Server { router := gin.Default() hub := NewHub() // Fetch registry (non-blocking for startup if it fails, but we'll try once) registry, err := utils.FetchRegistry() if err != nil { fmt.Printf("Warning: Failed to fetch initial model registry: %v\n", err) registry = &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)} } s := &Server{ router: router, cfg: cfg, database: database, providers: make(map[string]providers.Provider), sessions: NewSessionManager(cfg.KeyBytes, 24*time.Hour), hub: hub, logger: NewRequestLogger(database, hub), registry: registry, } // Initialize providers if cfg.Providers.OpenAI.Enabled { apiKey, _ := cfg.GetAPIKey("openai") s.providers["openai"] = providers.NewOpenAIProvider(cfg.Providers.OpenAI, apiKey) } if cfg.Providers.Gemini.Enabled { apiKey, _ := cfg.GetAPIKey("gemini") s.providers["gemini"] = providers.NewGeminiProvider(cfg.Providers.Gemini, apiKey) } if cfg.Providers.DeepSeek.Enabled { apiKey, _ := cfg.GetAPIKey("deepseek") s.providers["deepseek"] = providers.NewDeepSeekProvider(cfg.Providers.DeepSeek, apiKey) } if cfg.Providers.Grok.Enabled { apiKey, _ := cfg.GetAPIKey("grok") s.providers["grok"] = providers.NewGrokProvider(cfg.Providers.Grok, apiKey) } s.setupRoutes() return s } func (s *Server) setupRoutes() { // Global middleware should only be for logging/recovery // Auth is specific to groups // Static files s.router.StaticFile("/", "./static/index.html") s.router.StaticFile("/favicon.ico", "./static/favicon.ico") s.router.Static("/css", "./static/css") s.router.Static("/js", "./static/js") s.router.Static("/img", "./static/img") // WebSocket s.router.GET("/ws", s.handleWebSocket) // API V1 (External LLM Access) v1 := s.router.Group("/v1") v1.Use(middleware.AuthMiddleware(s.database)) { v1.POST("/chat/completions", s.handleChatCompletions) v1.GET("/models", s.handleListModels) } // Dashboard API Group api := s.router.Group("/api") { api.POST("/auth/login", s.handleLogin) api.GET("/auth/status", s.handleAuthStatus) api.POST("/auth/logout", s.handleLogout) api.POST("/auth/change-password", s.handleChangePassword) // Protected dashboard routes (need admin session) admin := api.Group("/") admin.Use(s.adminAuthMiddleware()) { admin.GET("/usage/summary", s.handleUsageSummary) admin.GET("/usage/time-series", s.handleTimeSeries) admin.GET("/usage/providers", s.handleProvidersUsage) admin.GET("/usage/detailed", s.handleDetailedUsage) admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown) admin.GET("/clients", s.handleGetClients) admin.POST("/clients", s.handleCreateClient) admin.GET("/clients/:id", s.handleGetClient) admin.PUT("/clients/:id", s.handleUpdateClient) admin.DELETE("/clients/:id", s.handleDeleteClient) admin.GET("/clients/:id/tokens", s.handleGetClientTokens) admin.POST("/clients/:id/tokens", s.handleCreateClientToken) admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken) admin.GET("/providers", s.handleGetProviders) admin.PUT("/providers/:name", s.handleUpdateProvider) admin.POST("/providers/:name/test", s.handleTestProvider) admin.GET("/models", s.handleGetModels) admin.PUT("/models/:id", s.handleUpdateModel) admin.GET("/users", s.handleGetUsers) admin.POST("/users", s.handleCreateUser) admin.PUT("/users/:id", s.handleUpdateUser) admin.DELETE("/users/:id", s.handleDeleteUser) admin.GET("/system/health", s.handleSystemHealth) admin.GET("/system/settings", s.handleGetSettings) admin.POST("/system/backup", s.handleCreateBackup) admin.GET("/system/logs", s.handleGetLogs) } } s.router.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) }) } func (s *Server) handleListModels(c *gin.Context) { type OpenAIModel struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` OwnedBy string `json:"owned_by"` } var data []OpenAIModel allowedProviders := map[string]bool{ "openai": true, "google": true, // Models from models.dev use 'google' ID for Gemini "deepseek": true, "xai": true, // Models from models.dev use 'xai' ID for Grok } if s.registry != nil { for pID, pInfo := range s.registry.Providers { if !allowedProviders[pID] { continue } for mID := range pInfo.Models { data = append(data, OpenAIModel{ ID: mID, Object: "model", Created: 1700000000, OwnedBy: pID, }) } } } c.JSON(http.StatusOK, gin.H{ "object": "list", "data": data, }) } func (s *Server) handleChatCompletions(c *gin.Context) { startTime := time.Now() var req models.ChatCompletionRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } // Select provider based on model name providerName := "openai" // default if strings.Contains(req.Model, "gemini") { providerName = "gemini" } else if strings.Contains(req.Model, "deepseek") { providerName = "deepseek" } else if strings.Contains(req.Model, "grok") { providerName = "grok" } provider, ok := s.providers[providerName] if !ok { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)}) return } // Convert ChatCompletionRequest to UnifiedRequest unifiedReq := &models.UnifiedRequest{ Model: req.Model, Messages: []models.UnifiedMessage{}, Temperature: req.Temperature, TopP: req.TopP, TopK: req.TopK, N: req.N, MaxTokens: req.MaxTokens, PresencePenalty: req.PresencePenalty, FrequencyPenalty: req.FrequencyPenalty, Stream: req.Stream != nil && *req.Stream, Tools: req.Tools, ToolChoice: req.ToolChoice, } // Handle Stop sequences if req.Stop != nil { var stop []string if err := json.Unmarshal(req.Stop, &stop); err == nil { unifiedReq.Stop = stop } else { var singleStop string if err := json.Unmarshal(req.Stop, &singleStop); err == nil { unifiedReq.Stop = []string{singleStop} } } } // Convert messages for _, msg := range req.Messages { unifiedMsg := models.UnifiedMessage{ Role: msg.Role, Content: []models.UnifiedContentPart{}, ReasoningContent: msg.ReasoningContent, ToolCalls: msg.ToolCalls, Name: msg.Name, ToolCallID: msg.ToolCallID, } // Handle multimodal content if strContent, ok := msg.Content.(string); ok { unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{ Type: "text", Text: strContent, }) } else if parts, ok := msg.Content.([]interface{}); ok { for _, part := range parts { if partMap, ok := part.(map[string]interface{}); ok { partType, _ := partMap["type"].(string) if partType == "text" { text, _ := partMap["text"].(string) unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{ Type: "text", Text: text, }) } else if partType == "image_url" { if imgURLMap, ok := partMap["image_url"].(map[string]interface{}); ok { url, _ := imgURLMap["url"].(string) imageInput := &models.ImageInput{} if strings.HasPrefix(url, "data:") { mime, data, err := utils.ParseDataURL(url) if err == nil { imageInput.Base64 = data imageInput.MimeType = mime } } else { imageInput.URL = url } unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{ Type: "image", Image: imageInput, }) unifiedReq.HasImages = true } } } } } unifiedReq.Messages = append(unifiedReq.Messages, unifiedMsg) } clientID := "default" if auth, ok := c.Get("auth"); ok { if authInfo, ok := auth.(models.AuthInfo); ok { unifiedReq.ClientID = authInfo.ClientID clientID = authInfo.ClientID } } else { unifiedReq.ClientID = clientID } if unifiedReq.Stream { ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq) if err != nil { s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") var lastUsage *models.Usage c.Stream(func(w io.Writer) bool { chunk, ok := <-ch if !ok { fmt.Fprintf(w, "data: [DONE]\n\n") s.logRequest(startTime, clientID, providerName, req.Model, lastUsage, nil, unifiedReq.HasImages) return false } if chunk.Usage != nil { lastUsage = chunk.Usage } data, err := json.Marshal(chunk) if err != nil { return false } fmt.Fprintf(w, "data: %s\n\n", data) return true }) return } resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq) if err != nil { s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage, nil, unifiedReq.HasImages) c.JSON(http.StatusOK, resp) } func (s *Server) logRequest(start time.Time, clientID, provider, model string, usage *models.Usage, err error, hasImages bool) { entry := RequestLog{ Timestamp: start, ClientID: clientID, Provider: provider, Model: model, Status: "success", DurationMS: time.Since(start).Milliseconds(), HasImages: hasImages, } if err != nil { entry.Status = "error" entry.ErrorMessage = err.Error() } if usage != nil { entry.PromptTokens = usage.PromptTokens entry.CompletionTokens = usage.CompletionTokens entry.TotalTokens = usage.TotalTokens if usage.ReasoningTokens != nil { entry.ReasoningTokens = *usage.ReasoningTokens } if usage.CacheReadTokens != nil { entry.CacheReadTokens = *usage.CacheReadTokens } if usage.CacheWriteTokens != nil { entry.CacheWriteTokens = *usage.CacheWriteTokens } // Calculate cost using registry entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.CacheReadTokens, entry.CacheWriteTokens) } s.logger.LogRequest(entry) } func (s *Server) Run() error { go s.hub.Run() s.logger.Start() // Start registry refresher go func() { ticker := time.NewTicker(24 * time.Hour) for range ticker.C { newRegistry, err := utils.FetchRegistry() if err == nil { s.registry = newRegistry } } }() addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port) return s.router.Run(addr) }