Moved AuthMiddleware to /v1 group only. Added COALESCE and empty result handling to analytics SQL queries to prevent 500 errors on empty databases.
404 lines
11 KiB
Go
404 lines
11 KiB
Go
package server
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"llm-proxy/internal/config"
|
|
"llm-proxy/internal/db"
|
|
"llm-proxy/internal/middleware"
|
|
"llm-proxy/internal/models"
|
|
"llm-proxy/internal/providers"
|
|
"llm-proxy/internal/utils"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
type Server struct {
|
|
router *gin.Engine
|
|
cfg *config.Config
|
|
database *db.DB
|
|
providers map[string]providers.Provider
|
|
sessions *SessionManager
|
|
hub *Hub
|
|
logger *RequestLogger
|
|
registry *models.ModelRegistry
|
|
}
|
|
|
|
func NewServer(cfg *config.Config, database *db.DB) *Server {
|
|
router := gin.Default()
|
|
hub := NewHub()
|
|
|
|
// Fetch registry (non-blocking for startup if it fails, but we'll try once)
|
|
registry, err := utils.FetchRegistry()
|
|
if err != nil {
|
|
fmt.Printf("Warning: Failed to fetch initial model registry: %v\n", err)
|
|
registry = &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)}
|
|
}
|
|
|
|
s := &Server{
|
|
router: router,
|
|
cfg: cfg,
|
|
database: database,
|
|
providers: make(map[string]providers.Provider),
|
|
sessions: NewSessionManager(cfg.KeyBytes, 24*time.Hour),
|
|
hub: hub,
|
|
logger: NewRequestLogger(database, hub),
|
|
registry: registry,
|
|
}
|
|
|
|
// Initialize providers
|
|
if cfg.Providers.OpenAI.Enabled {
|
|
apiKey, _ := cfg.GetAPIKey("openai")
|
|
s.providers["openai"] = providers.NewOpenAIProvider(cfg.Providers.OpenAI, apiKey)
|
|
}
|
|
if cfg.Providers.Gemini.Enabled {
|
|
apiKey, _ := cfg.GetAPIKey("gemini")
|
|
s.providers["gemini"] = providers.NewGeminiProvider(cfg.Providers.Gemini, apiKey)
|
|
}
|
|
if cfg.Providers.DeepSeek.Enabled {
|
|
apiKey, _ := cfg.GetAPIKey("deepseek")
|
|
s.providers["deepseek"] = providers.NewDeepSeekProvider(cfg.Providers.DeepSeek, apiKey)
|
|
}
|
|
if cfg.Providers.Grok.Enabled {
|
|
apiKey, _ := cfg.GetAPIKey("grok")
|
|
s.providers["grok"] = providers.NewGrokProvider(cfg.Providers.Grok, apiKey)
|
|
}
|
|
|
|
s.setupRoutes()
|
|
return s
|
|
}
|
|
|
|
func (s *Server) setupRoutes() {
|
|
// Global middleware should only be for logging/recovery
|
|
// Auth is specific to groups
|
|
|
|
// Static files
|
|
s.router.StaticFile("/", "./static/index.html")
|
|
s.router.StaticFile("/favicon.ico", "./static/favicon.ico")
|
|
s.router.Static("/css", "./static/css")
|
|
s.router.Static("/js", "./static/js")
|
|
s.router.Static("/img", "./static/img")
|
|
|
|
// WebSocket
|
|
s.router.GET("/ws", s.handleWebSocket)
|
|
|
|
// API V1 (External LLM Access)
|
|
v1 := s.router.Group("/v1")
|
|
v1.Use(middleware.AuthMiddleware(s.database))
|
|
{
|
|
v1.POST("/chat/completions", s.handleChatCompletions)
|
|
v1.GET("/models", s.handleListModels)
|
|
}
|
|
|
|
// Dashboard API Group
|
|
api := s.router.Group("/api")
|
|
{
|
|
api.POST("/auth/login", s.handleLogin)
|
|
api.GET("/auth/status", s.handleAuthStatus)
|
|
api.POST("/auth/logout", s.handleLogout)
|
|
api.POST("/auth/change-password", s.handleChangePassword)
|
|
|
|
// Protected dashboard routes (need admin session)
|
|
admin := api.Group("/")
|
|
admin.Use(s.adminAuthMiddleware())
|
|
{
|
|
admin.GET("/usage/summary", s.handleUsageSummary)
|
|
admin.GET("/usage/time-series", s.handleTimeSeries)
|
|
admin.GET("/usage/providers", s.handleProvidersUsage)
|
|
admin.GET("/usage/detailed", s.handleDetailedUsage)
|
|
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
|
|
|
|
admin.GET("/clients", s.handleGetClients)
|
|
admin.POST("/clients", s.handleCreateClient)
|
|
admin.GET("/clients/:id", s.handleGetClient)
|
|
admin.PUT("/clients/:id", s.handleUpdateClient)
|
|
admin.DELETE("/clients/:id", s.handleDeleteClient)
|
|
|
|
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
|
|
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
|
|
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
|
|
|
|
admin.GET("/providers", s.handleGetProviders)
|
|
admin.PUT("/providers/:name", s.handleUpdateProvider)
|
|
admin.POST("/providers/:name/test", s.handleTestProvider)
|
|
|
|
admin.GET("/models", s.handleGetModels)
|
|
admin.PUT("/models/:id", s.handleUpdateModel)
|
|
|
|
admin.GET("/users", s.handleGetUsers)
|
|
admin.POST("/users", s.handleCreateUser)
|
|
admin.PUT("/users/:id", s.handleUpdateUser)
|
|
admin.DELETE("/users/:id", s.handleDeleteUser)
|
|
|
|
admin.GET("/system/health", s.handleSystemHealth)
|
|
admin.GET("/system/settings", s.handleGetSettings)
|
|
admin.POST("/system/backup", s.handleCreateBackup)
|
|
admin.GET("/system/logs", s.handleGetLogs)
|
|
}
|
|
}
|
|
|
|
s.router.GET("/health", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleListModels(c *gin.Context) {
|
|
type OpenAIModel struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
OwnedBy string `json:"owned_by"`
|
|
}
|
|
|
|
var data []OpenAIModel
|
|
allowedProviders := map[string]bool{
|
|
"openai": true,
|
|
"google": true, // Models from models.dev use 'google' ID for Gemini
|
|
"deepseek": true,
|
|
"xai": true, // Models from models.dev use 'xai' ID for Grok
|
|
}
|
|
|
|
if s.registry != nil {
|
|
for pID, pInfo := range s.registry.Providers {
|
|
if !allowedProviders[pID] {
|
|
continue
|
|
}
|
|
for mID := range pInfo.Models {
|
|
data = append(data, OpenAIModel{
|
|
ID: mID,
|
|
Object: "model",
|
|
Created: 1700000000,
|
|
OwnedBy: pID,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"object": "list",
|
|
"data": data,
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleChatCompletions(c *gin.Context) {
|
|
startTime := time.Now()
|
|
var req models.ChatCompletionRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Select provider based on model name
|
|
providerName := "openai" // default
|
|
if strings.Contains(req.Model, "gemini") {
|
|
providerName = "gemini"
|
|
} else if strings.Contains(req.Model, "deepseek") {
|
|
providerName = "deepseek"
|
|
} else if strings.Contains(req.Model, "grok") {
|
|
providerName = "grok"
|
|
}
|
|
|
|
provider, ok := s.providers[providerName]
|
|
if !ok {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
|
return
|
|
}
|
|
|
|
// Convert ChatCompletionRequest to UnifiedRequest
|
|
unifiedReq := &models.UnifiedRequest{
|
|
Model: req.Model,
|
|
Messages: []models.UnifiedMessage{},
|
|
Temperature: req.Temperature,
|
|
TopP: req.TopP,
|
|
TopK: req.TopK,
|
|
N: req.N,
|
|
MaxTokens: req.MaxTokens,
|
|
PresencePenalty: req.PresencePenalty,
|
|
FrequencyPenalty: req.FrequencyPenalty,
|
|
Stream: req.Stream != nil && *req.Stream,
|
|
Tools: req.Tools,
|
|
ToolChoice: req.ToolChoice,
|
|
}
|
|
|
|
// Handle Stop sequences
|
|
if req.Stop != nil {
|
|
var stop []string
|
|
if err := json.Unmarshal(req.Stop, &stop); err == nil {
|
|
unifiedReq.Stop = stop
|
|
} else {
|
|
var singleStop string
|
|
if err := json.Unmarshal(req.Stop, &singleStop); err == nil {
|
|
unifiedReq.Stop = []string{singleStop}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert messages
|
|
for _, msg := range req.Messages {
|
|
unifiedMsg := models.UnifiedMessage{
|
|
Role: msg.Role,
|
|
Content: []models.UnifiedContentPart{},
|
|
ReasoningContent: msg.ReasoningContent,
|
|
ToolCalls: msg.ToolCalls,
|
|
Name: msg.Name,
|
|
ToolCallID: msg.ToolCallID,
|
|
}
|
|
|
|
// Handle multimodal content
|
|
if strContent, ok := msg.Content.(string); ok {
|
|
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
|
Type: "text",
|
|
Text: strContent,
|
|
})
|
|
} else if parts, ok := msg.Content.([]interface{}); ok {
|
|
for _, part := range parts {
|
|
if partMap, ok := part.(map[string]interface{}); ok {
|
|
partType, _ := partMap["type"].(string)
|
|
if partType == "text" {
|
|
text, _ := partMap["text"].(string)
|
|
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
|
Type: "text",
|
|
Text: text,
|
|
})
|
|
} else if partType == "image_url" {
|
|
if imgURLMap, ok := partMap["image_url"].(map[string]interface{}); ok {
|
|
url, _ := imgURLMap["url"].(string)
|
|
imageInput := &models.ImageInput{}
|
|
if strings.HasPrefix(url, "data:") {
|
|
mime, data, err := utils.ParseDataURL(url)
|
|
if err == nil {
|
|
imageInput.Base64 = data
|
|
imageInput.MimeType = mime
|
|
}
|
|
} else {
|
|
imageInput.URL = url
|
|
}
|
|
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
|
Type: "image",
|
|
Image: imageInput,
|
|
})
|
|
unifiedReq.HasImages = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
unifiedReq.Messages = append(unifiedReq.Messages, unifiedMsg)
|
|
}
|
|
|
|
clientID := "default"
|
|
if auth, ok := c.Get("auth"); ok {
|
|
if authInfo, ok := auth.(models.AuthInfo); ok {
|
|
unifiedReq.ClientID = authInfo.ClientID
|
|
clientID = authInfo.ClientID
|
|
}
|
|
} else {
|
|
unifiedReq.ClientID = clientID
|
|
}
|
|
|
|
if unifiedReq.Stream {
|
|
ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
|
|
var lastUsage *models.Usage
|
|
c.Stream(func(w io.Writer) bool {
|
|
chunk, ok := <-ch
|
|
if !ok {
|
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
s.logRequest(startTime, clientID, providerName, req.Model, lastUsage, nil, unifiedReq.HasImages)
|
|
return false
|
|
}
|
|
if chunk.Usage != nil {
|
|
lastUsage = chunk.Usage
|
|
}
|
|
data, err := json.Marshal(chunk)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
return true
|
|
})
|
|
return
|
|
}
|
|
|
|
resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage, nil, unifiedReq.HasImages)
|
|
c.JSON(http.StatusOK, resp)
|
|
}
|
|
|
|
func (s *Server) logRequest(start time.Time, clientID, provider, model string, usage *models.Usage, err error, hasImages bool) {
|
|
entry := RequestLog{
|
|
Timestamp: start,
|
|
ClientID: clientID,
|
|
Provider: provider,
|
|
Model: model,
|
|
Status: "success",
|
|
DurationMS: time.Since(start).Milliseconds(),
|
|
HasImages: hasImages,
|
|
}
|
|
|
|
if err != nil {
|
|
entry.Status = "error"
|
|
entry.ErrorMessage = err.Error()
|
|
}
|
|
|
|
if usage != nil {
|
|
entry.PromptTokens = usage.PromptTokens
|
|
entry.CompletionTokens = usage.CompletionTokens
|
|
entry.TotalTokens = usage.TotalTokens
|
|
if usage.ReasoningTokens != nil {
|
|
entry.ReasoningTokens = *usage.ReasoningTokens
|
|
}
|
|
if usage.CacheReadTokens != nil {
|
|
entry.CacheReadTokens = *usage.CacheReadTokens
|
|
}
|
|
if usage.CacheWriteTokens != nil {
|
|
entry.CacheWriteTokens = *usage.CacheWriteTokens
|
|
}
|
|
|
|
// Calculate cost using registry
|
|
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
|
|
}
|
|
|
|
s.logger.LogRequest(entry)
|
|
}
|
|
|
|
func (s *Server) Run() error {
|
|
go s.hub.Run()
|
|
s.logger.Start()
|
|
|
|
// Start registry refresher
|
|
go func() {
|
|
ticker := time.NewTicker(24 * time.Hour)
|
|
for range ticker.C {
|
|
newRegistry, err := utils.FetchRegistry()
|
|
if err == nil {
|
|
s.registry = newRegistry
|
|
}
|
|
}
|
|
}()
|
|
|
|
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
|
|
return s.router.Run(addr)
|
|
}
|