package server import ( "encoding/json" "context" "fmt" "io" "log" "net/http" "strings" "sync" "time" "gophergate/internal/config" "gophergate/internal/db" "gophergate/internal/middleware" "gophergate/internal/models" "gophergate/internal/providers" "gophergate/internal/router" "gophergate/internal/utils" "github.com/gin-gonic/gin" ) type Server struct { router *gin.Engine cfg *config.Config database *db.DB providers map[string]providers.Provider sessions *SessionManager hub *Hub logger *RequestLogger registry *models.ModelRegistry registryMu sync.RWMutex modelRouter *router.Router } func NewServer(cfg *config.Config, database *db.DB) *Server { router := gin.Default() hub := NewHub() s := &Server{ router: router, cfg: cfg, database: database, providers: make(map[string]providers.Provider), sessions: NewSessionManager(cfg.KeyBytes, 24*time.Hour), hub: hub, logger: NewRequestLogger(database, hub), registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)}, } s.sessions.StartCleanup() // Fetch registry in background go func() { registry, err := utils.FetchRegistry() if err != nil { fmt.Printf("Warning: Failed to fetch initial model registry: %v\n", err) } else { s.registry = registry } }() // Initialize providers from DB and Config if err := s.RefreshProviders(); err != nil { fmt.Printf("Warning: Failed to initial refresh providers: %v\n", err) } s.setupRoutes() // Initialize model group router s.refreshRouter() return s } func (s *Server) RefreshProviders() error { var dbConfigs []db.ProviderConfig err := s.database.Select(&dbConfigs, "SELECT * FROM provider_configs") if err != nil { return fmt.Errorf("failed to fetch provider configs from db: %w", err) } dbMap := make(map[string]db.ProviderConfig) for _, cfg := range dbConfigs { dbMap[cfg.ID] = cfg } providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama", "xiaomi"} for _, id := range providerIDs { // Default values from config enabled := false baseURL := "" apiKey := "" switch id { case "openai": enabled = s.cfg.Providers.OpenAI.Enabled baseURL = s.cfg.Providers.OpenAI.BaseURL apiKey, _ = s.cfg.GetAPIKey("openai") case "gemini": enabled = s.cfg.Providers.Gemini.Enabled baseURL = s.cfg.Providers.Gemini.BaseURL apiKey, _ = s.cfg.GetAPIKey("gemini") case "deepseek": enabled = s.cfg.Providers.DeepSeek.Enabled baseURL = s.cfg.Providers.DeepSeek.BaseURL apiKey, _ = s.cfg.GetAPIKey("deepseek") case "moonshot": enabled = s.cfg.Providers.Moonshot.Enabled baseURL = s.cfg.Providers.Moonshot.BaseURL apiKey, _ = s.cfg.GetAPIKey("moonshot") case "grok": enabled = s.cfg.Providers.Grok.Enabled baseURL = s.cfg.Providers.Grok.BaseURL apiKey, _ = s.cfg.GetAPIKey("grok") case "xiaomi": enabled = s.cfg.Providers.Xiaomi.Enabled baseURL = s.cfg.Providers.Xiaomi.BaseURL apiKey, _ = s.cfg.GetAPIKey("xiaomi") } // Overrides from DB if dbCfg, ok := dbMap[id]; ok { enabled = dbCfg.Enabled if dbCfg.BaseURL != nil && *dbCfg.BaseURL != "" { baseURL = *dbCfg.BaseURL } if dbCfg.APIKey != nil && *dbCfg.APIKey != "" { key := *dbCfg.APIKey if dbCfg.APIKeyEncrypted { decrypted, err := utils.Decrypt(key, s.cfg.KeyBytes) if err == nil { key = decrypted } else { fmt.Printf("Warning: Failed to decrypt API key for %s: %v\n", id, err) } } apiKey = key } } if !enabled { delete(s.providers, id) continue } // Initialize provider var p providers.Provider switch id { case "openai": cfg := s.cfg.Providers.OpenAI cfg.BaseURL = baseURL p = providers.NewOpenAIProvider(cfg, apiKey) case "gemini": cfg := s.cfg.Providers.Gemini cfg.BaseURL = baseURL p = providers.NewGeminiProvider(cfg, apiKey) case "deepseek": cfg := s.cfg.Providers.DeepSeek cfg.BaseURL = baseURL p = providers.NewDeepSeekProvider(cfg, apiKey) case "moonshot": cfg := s.cfg.Providers.Moonshot cfg.BaseURL = baseURL p = providers.NewMoonshotProvider(cfg, apiKey) case "grok": cfg := s.cfg.Providers.Grok cfg.BaseURL = baseURL p = providers.NewGrokProvider(cfg, apiKey) case "ollama": cfg := s.cfg.Providers.Ollama cfg.BaseURL = baseURL p = providers.NewOllamaProvider(cfg) case "xiaomi": cfg := s.cfg.Providers.Xiaomi cfg.BaseURL = baseURL p = providers.NewXiaomiProvider(cfg, apiKey) } if p != nil { s.providers[id] = providers.NewCircuitBreakerProvider(p) } } s.refreshRouter() return nil } func (s *Server) refreshRouter() { var groups []db.ModelGroup if err := s.database.Select(&groups, "SELECT * FROM model_groups"); err != nil { fmt.Printf("Warning: Failed to load model groups: %v\n", err) groups = nil } var classifyFn router.ClassifierFunc classifyFn = func(ctx context.Context, selectorModel, systemPrompt, userMessage string) (string, error) { provider, _, err := s.selectProvider(selectorModel) if err != nil { return "", err } req := &models.UnifiedRequest{ Model: selectorModel, Messages: []models.UnifiedMessage{ {Role: "system", Content: []models.UnifiedContentPart{{Type: "text", Text: systemPrompt}}}, {Role: "user", Content: []models.UnifiedContentPart{{Type: "text", Text: userMessage}}}, }, MaxTokens: uint32Ptr(5), Stream: false, } resp, err := provider.ChatCompletion(ctx, req) if err != nil { return "", err } if len(resp.Choices) == 0 { return "", fmt.Errorf("no choices in classifier response") } content, ok := resp.Choices[0].Message.Content.(string) if !ok { return "", fmt.Errorf("classifier response content is not a string") } return content, nil } if s.modelRouter == nil { s.modelRouter = router.New(groups, classifyFn) } else { s.modelRouter.Reload(groups) } } func (s *Server) setupRoutes() { // Static files s.router.StaticFile("/", "./static/index.html") s.router.StaticFile("/favicon.ico", "./static/favicon.ico") s.router.Static("/css", "./static/css") s.router.Static("/js", "./static/js") s.router.Static("/img", "./static/img") // WebSocket s.router.GET("/ws", s.handleWebSocket) // API V1 (External LLM Access) - Secured with AuthMiddleware v1 := s.router.Group("/v1") v1.Use(middleware.AuthMiddleware(s.database, true)) { v1.POST("/chat/completions", s.handleChatCompletions) v1.POST("/images/generations", s.handleImageGenerations) v1.GET("/models", s.handleListModels) v1.POST("/responses", s.handleResponses) } // Dashboard API Group api := s.router.Group("/api") { api.POST("/auth/login", s.handleLogin) api.GET("/auth/status", s.handleAuthStatus) api.POST("/auth/logout", s.handleLogout) api.POST("/auth/change-password", s.handleChangePassword) // Protected dashboard routes (need admin session) admin := api.Group("/") admin.Use(s.adminAuthMiddleware()) { admin.GET("/usage/summary", s.handleUsageSummary) admin.GET("/usage/time-series", s.handleTimeSeries) admin.GET("/usage/providers", s.handleProvidersUsage) admin.GET("/usage/clients", s.handleClientsUsage) admin.GET("/usage/detailed", s.handleDetailedUsage) admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown) admin.GET("/clients", s.handleGetClients) admin.POST("/clients", s.handleCreateClient) admin.GET("/clients/:id", s.handleGetClient) admin.PUT("/clients/:id", s.handleUpdateClient) admin.DELETE("/clients/:id", s.handleDeleteClient) admin.GET("/clients/:id/tokens", s.handleGetClientTokens) admin.POST("/clients/:id/tokens", s.handleCreateClientToken) admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken) admin.GET("/providers", s.handleGetProviders) admin.PUT("/providers/:name", s.handleUpdateProvider) admin.POST("/providers/:name/test", s.handleTestProvider) admin.GET("/models", s.handleGetModels) admin.PUT("/models/:id", s.handleUpdateModel) admin.GET("/model-groups", s.handleGetModelGroups) admin.POST("/model-groups", s.handleCreateModelGroup) admin.PUT("/model-groups/:id", s.handleUpdateModelGroup) admin.DELETE("/model-groups/:id", s.handleDeleteModelGroup) admin.GET("/users", s.handleGetUsers) admin.POST("/users", s.handleCreateUser) admin.PUT("/users/:id", s.handleUpdateUser) admin.DELETE("/users/:id", s.handleDeleteUser) admin.GET("/system/health", s.handleSystemHealth) admin.GET("/system/metrics", s.handleSystemMetrics) admin.GET("/system/settings", s.handleGetSettings) admin.POST("/system/backup", s.handleCreateBackup) admin.GET("/system/logs", s.handleGetLogs) } } s.router.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) }) } func (s *Server) handleResponses(c *gin.Context) { startTime := time.Now() var req models.ResponsesRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } // Strip common prefixes and resolve model groups to concrete models // (same pattern as handleChatCompletions). modelGroup := "" modelID := req.Model prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/", "xiaomi/"} for _, p := range prefixes { if strings.HasPrefix(modelID, p) { modelID = strings.TrimPrefix(modelID, p) break } } if s.modelRouter != nil { routeCtx := s.buildRouteContextFromResponses(req) decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)}) return } if decision.SelectedModel != modelID { modelGroup = modelID } modelID = decision.SelectedModel } // Select provider based on resolved model name providerName := "openai" // default for Responses API modelLower := strings.ToLower(modelID) if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") { providerName = "gemini" } else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) { providerName = "deepseek" } else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") { providerName = "moonshot" } else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") { providerName = "grok" } else if strings.HasPrefix(modelLower, "ollama/") || strings.Contains(modelLower, "glm-") || strings.Contains(modelLower, "qwen") || strings.Contains(modelLower, "gemma") || strings.Contains(modelLower, "llama") || strings.Contains(modelLower, "mistral") || strings.Contains(modelLower, "phi") || strings.Contains(modelLower, "yi") || strings.Contains(modelLower, "codellama") || strings.Contains(modelLower, "command-r") { providerName = "ollama" } provider, ok := s.providers[providerName] if !ok { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)}) return } // Use resolved model for the actual API call req.Model = modelID clientID := "default" if auth, ok := c.Get("auth"); ok { if authInfo, ok := auth.(models.AuthInfo); ok { clientID = authInfo.ClientID } } stream := req.Stream != nil && *req.Stream if stream { ch, err := provider.ResponsesStream(c.Request.Context(), &req) if err != nil { s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") var lastUsage *models.ResponsesUsage c.Stream(func(w io.Writer) bool { chunk, ok := <-ch if !ok { fmt.Fprintf(w, "data: [DONE]\n\n") if lastUsage != nil { s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage.ToUsage(), nil, false) } else { s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false) } return false } // Capture usage from the response payload in streaming chunks if chunk.Response != nil && chunk.Response.Usage != nil { lastUsage = chunk.Response.Usage } data, err := json.Marshal(chunk) if err != nil { return false } fmt.Fprintf(w, "data: %s\n\n", data) return true }) return } resp, err := provider.Responses(c.Request.Context(), &req) if err != nil { s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } if resp.Usage != nil { s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage.ToUsage(), nil, false) } else { s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false) } c.JSON(http.StatusOK, resp) } func (s *Server) handleListModels(c *gin.Context) { type OpenAIModel struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` OwnedBy string `json:"owned_by"` } modelMap := make(map[string]OpenAIModel) allowedProviders := map[string]bool{ "openai": true, "google": true, // Models from models.dev use 'google' ID for Gemini "deepseek": true, "moonshot": true, "moonshotai": true, // Official moonshotai ID in models.dev "moonshotai-cn": true, // Official moonshotai-cn ID in models.dev "xai": true, // Models from models.dev use 'xai' ID for Grok "llmgateway": true, // Catch-all for newer models "ollama": true, "xiaomi": true, // Xiaomi MiMo models } s.registryMu.RLock() if s.registry != nil { for pID, pInfo := range s.registry.Providers { if !allowedProviders[pID] { continue } for mID := range pInfo.Models { if _, exists := modelMap[mID]; !exists { modelMap[mID] = OpenAIModel{ ID: mID, Object: "model", Created: 1700000000, OwnedBy: pID, } } } } } s.registryMu.RUnlock() // Add configured Ollama models if s.cfg.Providers.Ollama.Enabled { for _, mID := range s.cfg.Providers.Ollama.Models { if _, exists := modelMap[mID]; !exists { modelMap[mID] = OpenAIModel{ ID: mID, Object: "model", Created: 1700000000, OwnedBy: "ollama", } } } } // Add model groups so clients can discover them if s.modelRouter != nil { for _, gid := range s.modelRouter.Groups() { if _, exists := modelMap[gid]; !exists { modelMap[gid] = OpenAIModel{ ID: gid, Object: "model", Created: 1700000000, OwnedBy: "gophergate", } } } } var data []OpenAIModel for _, m := range modelMap { data = append(data, m) } c.JSON(http.StatusOK, gin.H{ "object": "list", "data": data, }) } func (s *Server) selectProvider(modelID string) (providers.Provider, string, error) { providerName := "openai" // default modelLower := strings.ToLower(modelID) if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") { providerName = "gemini" } else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) { providerName = "deepseek" } else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") { providerName = "moonshot" } else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") { providerName = "grok" } else if strings.HasPrefix(modelLower, "ollama/") || strings.Contains(modelLower, "glm-") || strings.Contains(modelLower, "qwen") || strings.Contains(modelLower, "gemma") || strings.Contains(modelLower, "llama") || strings.Contains(modelLower, "mistral") || strings.Contains(modelLower, "phi") || strings.Contains(modelLower, "yi") || strings.Contains(modelLower, "codellama") || strings.Contains(modelLower, "command-r") { providerName = "ollama" } else if strings.HasPrefix(modelLower, "xiaomi/") || strings.Contains(modelLower, "mimo") || strings.Contains(modelLower, "xiaomi") { providerName = "xiaomi" } p, ok := s.providers[providerName] if !ok { return nil, "", fmt.Errorf("Provider %s not enabled or supported", providerName) } return p, providerName, nil } func (s *Server) handleChatCompletions(c *gin.Context) { startTime := time.Now() var req models.ChatCompletionRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } // Strip common prefixes and prepare model ID modelID := req.Model prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/", "xiaomi/"} for _, p := range prefixes { if strings.HasPrefix(modelID, p) { modelID = strings.TrimPrefix(modelID, p) break } } // Resolve model groups to concrete models (hierarchical — groups can target groups) modelGroup := "" for i, m := range req.Messages { log.Printf("[DEBUG] Incoming Msg[%d]: role=%s, hasToolCalls=%v, hasContent=%v", i, m.Role, len(m.ToolCalls) > 0, m.Content != nil) } if s.modelRouter != nil { routeCtx := s.buildRouteContextFromChat(req) decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, routeCtx) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)}) return } if decision.SelectedModel != modelID { modelGroup = modelID } modelID = decision.SelectedModel log.Printf("[ROUTER] %s (%s: %s)", modelID, decision.Strategy, decision.Reason) } // Select provider based on the resolved model name provider, providerName, err := s.selectProvider(modelID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } // Convert ChatCompletionRequest to UnifiedRequest unifiedReq := &models.UnifiedRequest{ Model: modelID, Messages: []models.UnifiedMessage{}, Temperature: req.Temperature, TopP: req.TopP, TopK: req.TopK, N: req.N, MaxTokens: req.MaxTokens, PresencePenalty: req.PresencePenalty, FrequencyPenalty: req.FrequencyPenalty, Stream: req.Stream != nil && *req.Stream, Tools: req.Tools, ToolChoice: req.ToolChoice, } // Inject or cap max_tokens from model registry. s.registryMu.RLock() meta := s.registry.FindModel(modelID) s.registryMu.RUnlock() if meta != nil && meta.Limit != nil && meta.Limit.Output > 0 { if unifiedReq.MaxTokens == nil { unifiedReq.MaxTokens = &meta.Limit.Output log.Printf("[DEBUG] %s: injected registry max_tokens=%d", modelID, meta.Limit.Output) } else if *unifiedReq.MaxTokens > meta.Limit.Output { log.Printf("[DEBUG] %s: capping client max_tokens (%d) to registry limit (%d)", modelID, *unifiedReq.MaxTokens, meta.Limit.Output) unifiedReq.MaxTokens = &meta.Limit.Output } else { log.Printf("[DEBUG] %s: using client max_tokens (%d)", modelID, *unifiedReq.MaxTokens) } } else { if unifiedReq.MaxTokens == nil { log.Printf("[DEBUG] %s: no registry limit found, leaving max_tokens nil", modelID) } else { log.Printf("[DEBUG] %s: using client max_tokens (%d), no registry limit to cap", modelID, *unifiedReq.MaxTokens) } } // Handle Stop sequences if req.Stop != nil { var stop []string if err := json.Unmarshal(req.Stop, &stop); err == nil { unifiedReq.Stop = stop } else { var singleStop string if err := json.Unmarshal(req.Stop, &singleStop); err == nil { unifiedReq.Stop = []string{singleStop} } } } // Convert messages for _, msg := range req.Messages { unifiedMsg := models.UnifiedMessage{ Role: msg.Role, Content: []models.UnifiedContentPart{}, ReasoningContent: msg.ReasoningContent, ToolCalls: msg.ToolCalls, Name: msg.Name, ToolCallID: msg.ToolCallID, } // Handle multimodal content if strContent, ok := msg.Content.(string); ok { unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{ Type: "text", Text: strContent, }) } else if parts, ok := msg.Content.([]interface{}); ok { for _, part := range parts { if partMap, ok := part.(map[string]interface{}); ok { partType, _ := partMap["type"].(string) if partType == "text" { text, _ := partMap["text"].(string) unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{ Type: "text", Text: text, }) } else if partType == "image_url" { if imgURLMap, ok := partMap["image_url"].(map[string]interface{}); ok { url, _ := imgURLMap["url"].(string) imageInput := &models.ImageInput{} if strings.HasPrefix(url, "data:") { mime, data, err := utils.ParseDataURL(url) if err == nil { imageInput.Base64 = data imageInput.MimeType = mime } } else { imageInput.URL = url } unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{ Type: "image", Image: imageInput, }) unifiedReq.HasImages = true } } } } } unifiedReq.Messages = append(unifiedReq.Messages, unifiedMsg) } clientID := "default" if auth, ok := c.Get("auth"); ok { if authInfo, ok := auth.(models.AuthInfo); ok { unifiedReq.ClientID = authInfo.ClientID clientID = authInfo.ClientID } } else { unifiedReq.ClientID = clientID } if unifiedReq.Stream { ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq) if err != nil { s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, unifiedReq.HasImages) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") var lastUsage *models.Usage c.Stream(func(w io.Writer) bool { chunk, ok := <-ch if !ok { fmt.Fprintf(w, "data: [DONE]\n\n") s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage, nil, unifiedReq.HasImages) return false } if chunk.Usage != nil { lastUsage = chunk.Usage } data, err := json.Marshal(chunk) if err != nil { return false } fmt.Fprintf(w, "data: %s\n\n", data) return true }) return } resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq) if err != nil { s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, unifiedReq.HasImages) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage, nil, unifiedReq.HasImages) c.JSON(http.StatusOK, resp) } func extractUserMessage(messages []models.ChatMessage) string { for i := len(messages) - 1; i >= 0; i-- { if messages[i].Role == "user" { switch c := messages[i].Content.(type) { case string: return c default: return "" } } } return "" } func (s *Server) handleImageGenerations(c *gin.Context) { startTime := time.Now() var req models.ImageGenerationRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } // Determine provider based on model name providerName := "openai" modelLower := strings.ToLower(req.Model) switch { case strings.Contains(modelLower, "imagen"), strings.Contains(modelLower, "gemini"): providerName = "gemini" case strings.Contains(modelLower, "dall"), strings.HasPrefix(modelLower, "openai/"): providerName = "openai" } // Default model for each provider if not specified if req.Model == "" { if providerName == "openai" { req.Model = "dall-e-3" } else { req.Model = "imagen-3.0-generate-001" } } // Strip common prefixes prefixes := []string{"openai/", "gemini/", "google/"} for _, p := range prefixes { if strings.HasPrefix(req.Model, p) { req.Model = strings.TrimPrefix(req.Model, p) break } } provider, ok := s.providers[providerName] if !ok { c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)}) return } clientID := "default" if auth, ok := c.Get("auth"); ok { if authInfo, ok := auth.(models.AuthInfo); ok { clientID = authInfo.ClientID } } resp, err := provider.ImageGeneration(c.Request.Context(), &req) if err != nil { s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } // Estimate tokens from prompt text (~4 chars per token) promptTokens := uint32(len(req.Prompt) / 4) if promptTokens < 1 { promptTokens = 1 } // Calculate per-image cost (not per-token like chat) cost := imageGenCost(providerName, req.Model, req.Size, uint32(len(resp.Data))) s.logRequest(startTime, clientID, providerName, req.Model, "", &models.Usage{ PromptTokens: promptTokens, CompletionTokens: uint32(len(resp.Data)), TotalTokens: promptTokens + uint32(len(resp.Data)), }, nil, false) // Update cost in DB — image gen is per-image, not per-token if cost > 0 { s.database.Exec("UPDATE llm_requests SET cost = ? WHERE id = (SELECT MAX(id) FROM llm_requests)", cost) } c.JSON(http.StatusOK, resp) } // imageGenCost returns per-image pricing for known image generation models. func imageGenCost(provider, model string, size *string, n uint32) float64 { if n == 0 { return 0 } modelLower := strings.ToLower(model) var perImage float64 switch { case strings.Contains(modelLower, "dall-e-3"): perImage = 0.040 // standard 1024x1024 if size != nil { s := *size if s == "1024x1792" || s == "1792x1024" { perImage = 0.080 } } case strings.Contains(modelLower, "dall-e-2"): perImage = 0.020 case strings.Contains(modelLower, "imagen"): perImage = 0.040 // approximate default: return 0 } return perImage * float64(n) } func (s *Server) logRequest(start time.Time, clientID, provider, model, modelGroup string, usage *models.Usage, err error, hasImages bool) { entry := RequestLog{ Timestamp: start, ClientID: clientID, Provider: provider, Model: model, ModelGroup: modelGroup, Status: "success", DurationMS: time.Since(start).Milliseconds(), HasImages: hasImages, } if err != nil { entry.Status = "error" entry.ErrorMessage = err.Error() } if usage != nil { entry.PromptTokens = usage.PromptTokens entry.CompletionTokens = usage.CompletionTokens entry.TotalTokens = usage.TotalTokens if usage.ReasoningTokens != nil { entry.ReasoningTokens = *usage.ReasoningTokens } if usage.CacheReadTokens != nil { entry.CacheReadTokens = *usage.CacheReadTokens } if usage.CacheWriteTokens != nil { entry.CacheWriteTokens = *usage.CacheWriteTokens } // Calculate cost using registry; if the resolved model is unknown, // fall back to the model group so group requests still get priced. s.registryMu.RLock() pricingModel := model if s.registry != nil && s.registry.FindModel(pricingModel) == nil && modelGroup != "" { pricingModel = modelGroup } entry.Cost = utils.CalculateCost(s.registry, pricingModel, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens) s.registryMu.RUnlock() } s.logger.LogRequest(entry) } func (s *Server) Run() error { go s.hub.Run() s.logger.Start() // Start registry refresher go func() { ticker := time.NewTicker(24 * time.Hour) for range ticker.C { newRegistry, err := utils.FetchRegistry() if err == nil { s.registryMu.Lock() s.registry = newRegistry s.registryMu.Unlock() } } }() addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port) return s.router.Run(addr) } func uint32Ptr(v uint32) *uint32 { return &v } func (s *Server) buildRouteContextFromChat(req models.ChatCompletionRequest) *router.RouteContext { userMessage := extractUserMessage(req.Messages) requiresToolCalling := len(req.Tools) > 0 hasMultimodal := false inputTokens := 0 for _, msg := range req.Messages { if strContent, ok := msg.Content.(string); ok { inputTokens += len(strContent) / 4 } else if parts, ok := msg.Content.([]interface{}); ok { for _, part := range parts { if partMap, ok := part.(map[string]interface{}); ok { partType, _ := partMap["type"].(string) if partType == "text" { text, _ := partMap["text"].(string) inputTokens += len(text) / 4 } else if partType == "image_url" { hasMultimodal = true inputTokens += 1000 // Approximate cost of an image in tokens } } } } } msgLower := strings.ToLower(userMessage) requiresReasoning := strings.Contains(msgLower, "reason") || strings.Contains(msgLower, "think step by step") || strings.Contains(msgLower, "mathematics") || strings.Contains(msgLower, "architecture") || strings.Contains(msgLower, "explain in detail") routeCtx := &router.RouteContext{ UserMessage: userMessage, InputTokens: inputTokens, HasMultimodalInput: hasMultimodal, RequiresToolCalling: requiresToolCalling, RequiresReasoning: requiresReasoning, } routeCtx.Tags = s.getRouteCtxTags(routeCtx) return routeCtx } func (s *Server) buildRouteContextFromResponses(req models.ResponsesRequest) *router.RouteContext { var userMessage string hasMultimodal := false inputTokens := len(req.Instructions) / 4 requiresToolCalling := len(req.Tools) > 0 && string(req.Tools) != "null" && string(req.Tools) != "" var strInput string if err := json.Unmarshal(req.Input, &strInput); err == nil { userMessage = strInput inputTokens += len(userMessage) / 4 } else { var msgs []models.ResponseInputMessage if err := json.Unmarshal(req.Input, &msgs); err == nil { for _, m := range msgs { var contentStr string if err := json.Unmarshal(m.Content, &contentStr); err == nil { if m.Role == "user" { userMessage = contentStr } inputTokens += len(contentStr) / 4 } else { var parts []models.ContentPart if err := json.Unmarshal(m.Content, &parts); err == nil { for _, p := range parts { if p.Type == "text" { if m.Role == "user" { userMessage = p.Text } inputTokens += len(p.Text) / 4 } else if p.Type == "image_url" { hasMultimodal = true inputTokens += 1000 } } } } } } } msgLower := strings.ToLower(userMessage) requiresReasoning := strings.Contains(msgLower, "reason") || strings.Contains(msgLower, "think step by step") || strings.Contains(msgLower, "mathematics") || strings.Contains(msgLower, "architecture") || strings.Contains(msgLower, "explain in detail") routeCtx := &router.RouteContext{ UserMessage: userMessage, InputTokens: inputTokens, HasMultimodalInput: hasMultimodal, RequiresToolCalling: requiresToolCalling, RequiresReasoning: requiresReasoning, } routeCtx.Tags = s.getRouteCtxTags(routeCtx) return routeCtx } func (s *Server) getRouteCtxTags(routeCtx *router.RouteContext) []string { var tags []string msgLower := strings.ToLower(routeCtx.UserMessage) // fast-flow keywords fastFlowKeywords := []string{ "classify", "classification", "label", "tag", "route", "routing", "intent", "json", "yaml", "csv", "schema", "parse", "extract", "transform", "format", "regex", "short answer", "brief", "concise", "tl;dr", "one line", "simple", "fix this", "small bug", "quick fix", "typo", "syntax error", } for _, kw := range fastFlowKeywords { if strings.Contains(msgLower, kw) { tags = append(tags, "fast-flow", "classification", "json-extraction", "basic-qa") break } } // standard-pro keywords standardProKeywords := []string{ "explain", "summarize", "rewrite", "draft", "edit", "polish", "outline", "long doc", "document", "email", "memo", "proposal", "report", "handout", "notes", "compare", "choose", "recommend", "tradeoff", "pros and cons", "analysis", "code review", "debug", "bug", "feature", "api", "endpoint", "implement", "plan", "planning", "workflow", "integration", } for _, kw := range standardProKeywords { if strings.Contains(msgLower, kw) { tags = append(tags, "standard-pro", "long-doc") break } } if routeCtx.HasMultimodalInput { tags = append(tags, "video-analysis", "multimodal-qa") } // heavy-logic keywords heavyLogicKeywords := []string{ "agent", "agents", "tool use", "function calling", "multi-agent", "orchestrate", "system design", "scaling", "performance", "architecture review", "distributed", "hard bug", "race condition", "deadlock", "memory leak", "crash", "production outage", "long context", "large codebase", "many files", "complex refactor", "migration", "research", "deep dive", "literature", "paper", "scholarly", "thorough analysis", "deep reasoning", "think step by step", "reason through", "careful analysis", } for _, kw := range heavyLogicKeywords { if strings.Contains(msgLower, kw) { tags = append(tags, "heavy-logic", "deep-reasoning", "architecture", "hard-debugging") break } } if routeCtx.RequiresToolCalling { tags = append(tags, "tool-heavy", "multi-step-agent", "swe-bench") } return tags }