feat: migrate backend from rust to go
This commit replaces the Axum/Rust backend with a Gin/Go implementation. The original Rust code has been archived in the 'rust' branch.
This commit is contained in:
326
internal/server/server.go
Normal file
326
internal/server/server.go
Normal file
@@ -0,0 +1,326 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"llm-proxy/internal/config"
|
||||
"llm-proxy/internal/db"
|
||||
"llm-proxy/internal/middleware"
|
||||
"llm-proxy/internal/models"
|
||||
"llm-proxy/internal/providers"
|
||||
"llm-proxy/internal/utils"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
router *gin.Engine
|
||||
cfg *config.Config
|
||||
database *db.DB
|
||||
providers map[string]providers.Provider
|
||||
sessions *SessionManager
|
||||
hub *Hub
|
||||
logger *RequestLogger
|
||||
}
|
||||
|
||||
func NewServer(cfg *config.Config, database *db.DB) *Server {
|
||||
router := gin.Default()
|
||||
hub := NewHub()
|
||||
|
||||
s := &Server{
|
||||
router: router,
|
||||
cfg: cfg,
|
||||
database: database,
|
||||
providers: make(map[string]providers.Provider),
|
||||
sessions: NewSessionManager(cfg.KeyBytes, 24*time.Hour),
|
||||
hub: hub,
|
||||
logger: NewRequestLogger(database, hub),
|
||||
}
|
||||
|
||||
// Initialize providers
|
||||
if cfg.Providers.OpenAI.Enabled {
|
||||
apiKey, _ := cfg.GetAPIKey("openai")
|
||||
s.providers["openai"] = providers.NewOpenAIProvider(cfg.Providers.OpenAI, apiKey)
|
||||
}
|
||||
if cfg.Providers.Gemini.Enabled {
|
||||
apiKey, _ := cfg.GetAPIKey("gemini")
|
||||
s.providers["gemini"] = providers.NewGeminiProvider(cfg.Providers.Gemini, apiKey)
|
||||
}
|
||||
if cfg.Providers.DeepSeek.Enabled {
|
||||
apiKey, _ := cfg.GetAPIKey("deepseek")
|
||||
s.providers["deepseek"] = providers.NewDeepSeekProvider(cfg.Providers.DeepSeek, apiKey)
|
||||
}
|
||||
if cfg.Providers.Grok.Enabled {
|
||||
apiKey, _ := cfg.GetAPIKey("grok")
|
||||
s.providers["grok"] = providers.NewGrokProvider(cfg.Providers.Grok, apiKey)
|
||||
}
|
||||
|
||||
s.setupRoutes()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Server) setupRoutes() {
|
||||
s.router.Use(middleware.AuthMiddleware(s.database))
|
||||
|
||||
// Static files
|
||||
s.router.Static("/static", "./static")
|
||||
s.router.StaticFile("/", "./static/index.html")
|
||||
s.router.StaticFile("/favicon.ico", "./static/favicon.ico")
|
||||
|
||||
// WebSocket
|
||||
s.router.GET("/ws", s.handleWebSocket)
|
||||
|
||||
v1 := s.router.Group("/v1")
|
||||
{
|
||||
v1.POST("/chat/completions", s.handleChatCompletions)
|
||||
}
|
||||
|
||||
// Dashboard API Group
|
||||
api := s.router.Group("/api")
|
||||
{
|
||||
api.POST("/auth/login", s.handleLogin)
|
||||
api.GET("/auth/status", s.handleAuthStatus)
|
||||
api.POST("/auth/logout", s.handleLogout)
|
||||
|
||||
// Protected dashboard routes (need admin session)
|
||||
admin := api.Group("/")
|
||||
admin.Use(s.adminAuthMiddleware())
|
||||
{
|
||||
admin.GET("/usage/summary", s.handleUsageSummary)
|
||||
admin.GET("/usage/time-series", s.handleTimeSeries)
|
||||
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
|
||||
|
||||
admin.GET("/clients", s.handleGetClients)
|
||||
admin.POST("/clients", s.handleCreateClient)
|
||||
admin.DELETE("/clients/:id", s.handleDeleteClient)
|
||||
|
||||
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
|
||||
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
|
||||
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
|
||||
|
||||
admin.GET("/providers", s.handleGetProviders)
|
||||
admin.PUT("/providers/:name", s.handleUpdateProvider)
|
||||
admin.GET("/models", s.handleGetModels)
|
||||
|
||||
admin.GET("/users", s.handleGetUsers)
|
||||
admin.POST("/users", s.handleCreateUser)
|
||||
admin.PUT("/users/:id", s.handleUpdateUser)
|
||||
admin.DELETE("/users/:id", s.handleDeleteUser)
|
||||
|
||||
admin.GET("/system/health", s.handleSystemHealth)
|
||||
}
|
||||
}
|
||||
|
||||
s.router.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||
startTime := time.Now()
|
||||
var req models.ChatCompletionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Select provider based on model name
|
||||
providerName := "openai" // default
|
||||
if strings.Contains(req.Model, "gemini") {
|
||||
providerName = "gemini"
|
||||
} else if strings.Contains(req.Model, "deepseek") {
|
||||
providerName = "deepseek"
|
||||
} else if strings.Contains(req.Model, "grok") {
|
||||
providerName = "grok"
|
||||
}
|
||||
|
||||
provider, ok := s.providers[providerName]
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
||||
return
|
||||
}
|
||||
|
||||
// Convert ChatCompletionRequest to UnifiedRequest
|
||||
unifiedReq := &models.UnifiedRequest{
|
||||
Model: req.Model,
|
||||
Messages: []models.UnifiedMessage{},
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
TopK: req.TopK,
|
||||
N: req.N,
|
||||
MaxTokens: req.MaxTokens,
|
||||
PresencePenalty: req.PresencePenalty,
|
||||
FrequencyPenalty: req.FrequencyPenalty,
|
||||
Stream: req.Stream != nil && *req.Stream,
|
||||
Tools: req.Tools,
|
||||
ToolChoice: req.ToolChoice,
|
||||
}
|
||||
|
||||
// Handle Stop sequences
|
||||
if req.Stop != nil {
|
||||
var stop []string
|
||||
if err := json.Unmarshal(req.Stop, &stop); err == nil {
|
||||
unifiedReq.Stop = stop
|
||||
} else {
|
||||
var singleStop string
|
||||
if err := json.Unmarshal(req.Stop, &singleStop); err == nil {
|
||||
unifiedReq.Stop = []string{singleStop}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
for _, msg := range req.Messages {
|
||||
unifiedMsg := models.UnifiedMessage{
|
||||
Role: msg.Role,
|
||||
Content: []models.UnifiedContentPart{},
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
ToolCalls: msg.ToolCalls,
|
||||
Name: msg.Name,
|
||||
ToolCallID: msg.ToolCallID,
|
||||
}
|
||||
|
||||
// Handle multimodal content
|
||||
if strContent, ok := msg.Content.(string); ok {
|
||||
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
||||
Type: "text",
|
||||
Text: strContent,
|
||||
})
|
||||
} else if parts, ok := msg.Content.([]interface{}); ok {
|
||||
for _, part := range parts {
|
||||
if partMap, ok := part.(map[string]interface{}); ok {
|
||||
partType, _ := partMap["type"].(string)
|
||||
if partType == "text" {
|
||||
text, _ := partMap["text"].(string)
|
||||
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
||||
Type: "text",
|
||||
Text: text,
|
||||
})
|
||||
} else if partType == "image_url" {
|
||||
if imgURLMap, ok := partMap["image_url"].(map[string]interface{}); ok {
|
||||
url, _ := imgURLMap["url"].(string)
|
||||
imageInput := &models.ImageInput{}
|
||||
if strings.HasPrefix(url, "data:") {
|
||||
mime, data, err := utils.ParseDataURL(url)
|
||||
if err == nil {
|
||||
imageInput.Base64 = data
|
||||
imageInput.MimeType = mime
|
||||
}
|
||||
} else {
|
||||
imageInput.URL = url
|
||||
}
|
||||
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
||||
Type: "image",
|
||||
Image: imageInput,
|
||||
})
|
||||
unifiedReq.HasImages = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unifiedReq.Messages = append(unifiedReq.Messages, unifiedMsg)
|
||||
}
|
||||
|
||||
clientID := "default"
|
||||
if auth, ok := c.Get("auth"); ok {
|
||||
if authInfo, ok := auth.(models.AuthInfo); ok {
|
||||
unifiedReq.ClientID = authInfo.ClientID
|
||||
clientID = authInfo.ClientID
|
||||
}
|
||||
} else {
|
||||
unifiedReq.ClientID = clientID
|
||||
}
|
||||
|
||||
if unifiedReq.Stream {
|
||||
ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
var lastUsage *models.Usage
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
chunk, ok := <-ch
|
||||
if !ok {
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, lastUsage, nil, unifiedReq.HasImages)
|
||||
return false
|
||||
}
|
||||
if chunk.Usage != nil {
|
||||
lastUsage = chunk.Usage
|
||||
}
|
||||
data, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage, nil, unifiedReq.HasImages)
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) logRequest(start time.Time, clientID, provider, model string, usage *models.Usage, err error, hasImages bool) {
|
||||
entry := RequestLog{
|
||||
Timestamp: start,
|
||||
ClientID: clientID,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Status: "success",
|
||||
DurationMS: time.Since(start).Milliseconds(),
|
||||
HasImages: hasImages,
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
entry.Status = "error"
|
||||
entry.ErrorMessage = err.Error()
|
||||
}
|
||||
|
||||
if usage != nil {
|
||||
entry.PromptTokens = usage.PromptTokens
|
||||
entry.CompletionTokens = usage.CompletionTokens
|
||||
entry.TotalTokens = usage.TotalTokens
|
||||
if usage.ReasoningTokens != nil {
|
||||
entry.ReasoningTokens = *usage.ReasoningTokens
|
||||
}
|
||||
if usage.CacheReadTokens != nil {
|
||||
entry.CacheReadTokens = *usage.CacheReadTokens
|
||||
}
|
||||
if usage.CacheWriteTokens != nil {
|
||||
entry.CacheWriteTokens = *usage.CacheWriteTokens
|
||||
}
|
||||
// TODO: Calculate cost properly based on pricing
|
||||
entry.Cost = 0.0
|
||||
}
|
||||
|
||||
s.logger.LogRequest(entry)
|
||||
}
|
||||
|
||||
func (s *Server) Run() error {
|
||||
go s.hub.Run()
|
||||
s.logger.Start()
|
||||
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
|
||||
return s.router.Run(addr)
|
||||
}
|
||||
Reference in New Issue
Block a user