4aa17b4fd2
Logs what max_tokens the client sends, whether gophergate injects one from the registry, and the final value forwarded to the provider. Helps trace output truncation issues.
694 lines
20 KiB
Go
694 lines
20 KiB
Go
package server
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"gophergate/internal/config"
|
|
"gophergate/internal/db"
|
|
"gophergate/internal/middleware"
|
|
"gophergate/internal/models"
|
|
"gophergate/internal/providers"
|
|
"gophergate/internal/utils"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
type Server struct {
|
|
router *gin.Engine
|
|
cfg *config.Config
|
|
database *db.DB
|
|
providers map[string]providers.Provider
|
|
sessions *SessionManager
|
|
hub *Hub
|
|
logger *RequestLogger
|
|
registry *models.ModelRegistry
|
|
registryMu sync.RWMutex
|
|
}
|
|
|
|
func NewServer(cfg *config.Config, database *db.DB) *Server {
|
|
router := gin.Default()
|
|
hub := NewHub()
|
|
|
|
s := &Server{
|
|
router: router,
|
|
cfg: cfg,
|
|
database: database,
|
|
providers: make(map[string]providers.Provider),
|
|
sessions: NewSessionManager(cfg.KeyBytes, 24*time.Hour),
|
|
hub: hub,
|
|
logger: NewRequestLogger(database, hub),
|
|
registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)},
|
|
}
|
|
|
|
s.sessions.StartCleanup()
|
|
// Fetch registry in background
|
|
go func() {
|
|
registry, err := utils.FetchRegistry()
|
|
if err != nil {
|
|
fmt.Printf("Warning: Failed to fetch initial model registry: %v\n", err)
|
|
} else {
|
|
s.registry = registry
|
|
}
|
|
}()
|
|
|
|
// Initialize providers from DB and Config
|
|
if err := s.RefreshProviders(); err != nil {
|
|
fmt.Printf("Warning: Failed to initial refresh providers: %v\n", err)
|
|
}
|
|
|
|
s.setupRoutes()
|
|
return s
|
|
}
|
|
|
|
func (s *Server) RefreshProviders() error {
|
|
var dbConfigs []db.ProviderConfig
|
|
err := s.database.Select(&dbConfigs, "SELECT * FROM provider_configs")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to fetch provider configs from db: %w", err)
|
|
}
|
|
|
|
dbMap := make(map[string]db.ProviderConfig)
|
|
for _, cfg := range dbConfigs {
|
|
dbMap[cfg.ID] = cfg
|
|
}
|
|
|
|
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama"}
|
|
for _, id := range providerIDs {
|
|
// Default values from config
|
|
enabled := false
|
|
baseURL := ""
|
|
apiKey := ""
|
|
|
|
switch id {
|
|
case "openai":
|
|
enabled = s.cfg.Providers.OpenAI.Enabled
|
|
baseURL = s.cfg.Providers.OpenAI.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("openai")
|
|
case "gemini":
|
|
enabled = s.cfg.Providers.Gemini.Enabled
|
|
baseURL = s.cfg.Providers.Gemini.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("gemini")
|
|
case "deepseek":
|
|
enabled = s.cfg.Providers.DeepSeek.Enabled
|
|
baseURL = s.cfg.Providers.DeepSeek.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("deepseek")
|
|
case "moonshot":
|
|
enabled = s.cfg.Providers.Moonshot.Enabled
|
|
baseURL = s.cfg.Providers.Moonshot.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("moonshot")
|
|
case "grok":
|
|
enabled = s.cfg.Providers.Grok.Enabled
|
|
baseURL = s.cfg.Providers.Grok.BaseURL
|
|
apiKey, _ = s.cfg.GetAPIKey("grok")
|
|
}
|
|
|
|
// Overrides from DB
|
|
if dbCfg, ok := dbMap[id]; ok {
|
|
enabled = dbCfg.Enabled
|
|
if dbCfg.BaseURL != nil && *dbCfg.BaseURL != "" {
|
|
baseURL = *dbCfg.BaseURL
|
|
}
|
|
if dbCfg.APIKey != nil && *dbCfg.APIKey != "" {
|
|
key := *dbCfg.APIKey
|
|
if dbCfg.APIKeyEncrypted {
|
|
decrypted, err := utils.Decrypt(key, s.cfg.KeyBytes)
|
|
if err == nil {
|
|
key = decrypted
|
|
} else {
|
|
fmt.Printf("Warning: Failed to decrypt API key for %s: %v\n", id, err)
|
|
}
|
|
}
|
|
apiKey = key
|
|
}
|
|
}
|
|
|
|
if !enabled {
|
|
delete(s.providers, id)
|
|
continue
|
|
}
|
|
|
|
// Initialize provider
|
|
var p providers.Provider
|
|
switch id {
|
|
case "openai":
|
|
cfg := s.cfg.Providers.OpenAI
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewOpenAIProvider(cfg, apiKey)
|
|
case "gemini":
|
|
cfg := s.cfg.Providers.Gemini
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewGeminiProvider(cfg, apiKey)
|
|
case "deepseek":
|
|
cfg := s.cfg.Providers.DeepSeek
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewDeepSeekProvider(cfg, apiKey)
|
|
case "moonshot":
|
|
cfg := s.cfg.Providers.Moonshot
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewMoonshotProvider(cfg, apiKey)
|
|
case "grok":
|
|
cfg := s.cfg.Providers.Grok
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewGrokProvider(cfg, apiKey)
|
|
case "ollama":
|
|
cfg := s.cfg.Providers.Ollama
|
|
cfg.BaseURL = baseURL
|
|
p = providers.NewOllamaProvider(cfg)
|
|
}
|
|
|
|
if p != nil {
|
|
s.providers[id] = providers.NewCircuitBreakerProvider(p)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) setupRoutes() {
|
|
// Static files
|
|
s.router.StaticFile("/", "./static/index.html")
|
|
s.router.StaticFile("/favicon.ico", "./static/favicon.ico")
|
|
s.router.Static("/css", "./static/css")
|
|
s.router.Static("/js", "./static/js")
|
|
s.router.Static("/img", "./static/img")
|
|
|
|
// WebSocket
|
|
s.router.GET("/ws", s.handleWebSocket)
|
|
|
|
// API V1 (External LLM Access) - Secured with AuthMiddleware
|
|
v1 := s.router.Group("/v1")
|
|
v1.Use(middleware.AuthMiddleware(s.database, true))
|
|
{
|
|
v1.POST("/chat/completions", s.handleChatCompletions)
|
|
v1.POST("/images/generations", s.handleImageGenerations)
|
|
v1.GET("/models", s.handleListModels)
|
|
v1.GET("/responses", s.handleListResponses)
|
|
}
|
|
|
|
// Dashboard API Group
|
|
api := s.router.Group("/api")
|
|
{
|
|
api.POST("/auth/login", s.handleLogin)
|
|
api.GET("/auth/status", s.handleAuthStatus)
|
|
api.POST("/auth/logout", s.handleLogout)
|
|
api.POST("/auth/change-password", s.handleChangePassword)
|
|
|
|
// Protected dashboard routes (need admin session)
|
|
admin := api.Group("/")
|
|
admin.Use(s.adminAuthMiddleware())
|
|
{
|
|
admin.GET("/usage/summary", s.handleUsageSummary)
|
|
admin.GET("/usage/time-series", s.handleTimeSeries)
|
|
admin.GET("/usage/providers", s.handleProvidersUsage)
|
|
admin.GET("/usage/clients", s.handleClientsUsage)
|
|
admin.GET("/usage/detailed", s.handleDetailedUsage)
|
|
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
|
|
|
|
admin.GET("/clients", s.handleGetClients)
|
|
admin.POST("/clients", s.handleCreateClient)
|
|
admin.GET("/clients/:id", s.handleGetClient)
|
|
admin.PUT("/clients/:id", s.handleUpdateClient)
|
|
admin.DELETE("/clients/:id", s.handleDeleteClient)
|
|
|
|
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
|
|
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
|
|
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
|
|
|
|
admin.GET("/providers", s.handleGetProviders)
|
|
admin.PUT("/providers/:name", s.handleUpdateProvider)
|
|
admin.POST("/providers/:name/test", s.handleTestProvider)
|
|
|
|
admin.GET("/models", s.handleGetModels)
|
|
admin.PUT("/models/:id", s.handleUpdateModel)
|
|
|
|
admin.GET("/users", s.handleGetUsers)
|
|
admin.POST("/users", s.handleCreateUser)
|
|
admin.PUT("/users/:id", s.handleUpdateUser)
|
|
admin.DELETE("/users/:id", s.handleDeleteUser)
|
|
|
|
admin.GET("/system/health", s.handleSystemHealth)
|
|
admin.GET("/system/metrics", s.handleSystemMetrics)
|
|
admin.GET("/system/settings", s.handleGetSettings)
|
|
admin.POST("/system/backup", s.handleCreateBackup)
|
|
admin.GET("/system/logs", s.handleGetLogs)
|
|
}
|
|
}
|
|
|
|
s.router.GET("/health", func(c *gin.Context) {
|
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleListResponses(c *gin.Context) {
|
|
// This is a placeholder for the /v1/responses endpoint
|
|
c.JSON(http.StatusOK, gin.H{"data": []interface{}{}})
|
|
}
|
|
|
|
func (s *Server) handleListModels(c *gin.Context) {
|
|
type OpenAIModel struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
OwnedBy string `json:"owned_by"`
|
|
}
|
|
|
|
modelMap := make(map[string]OpenAIModel)
|
|
allowedProviders := map[string]bool{
|
|
"openai": true,
|
|
"google": true, // Models from models.dev use 'google' ID for Gemini
|
|
"deepseek": true,
|
|
"moonshot": true,
|
|
"moonshotai": true, // Official moonshotai ID in models.dev
|
|
"moonshotai-cn": true, // Official moonshotai-cn ID in models.dev
|
|
"xai": true, // Models from models.dev use 'xai' ID for Grok
|
|
"llmgateway": true, // Catch-all for newer models
|
|
"ollama": true,
|
|
}
|
|
|
|
s.registryMu.RLock()
|
|
if s.registry != nil {
|
|
for pID, pInfo := range s.registry.Providers {
|
|
if !allowedProviders[pID] {
|
|
continue
|
|
}
|
|
for mID := range pInfo.Models {
|
|
if _, exists := modelMap[mID]; !exists {
|
|
modelMap[mID] = OpenAIModel{
|
|
ID: mID,
|
|
Object: "model",
|
|
Created: 1700000000,
|
|
OwnedBy: pID,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
s.registryMu.RUnlock()
|
|
|
|
// Add configured Ollama models
|
|
if s.cfg.Providers.Ollama.Enabled {
|
|
for _, mID := range s.cfg.Providers.Ollama.Models {
|
|
if _, exists := modelMap[mID]; !exists {
|
|
modelMap[mID] = OpenAIModel{
|
|
ID: mID,
|
|
Object: "model",
|
|
Created: 1700000000,
|
|
OwnedBy: "ollama",
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
var data []OpenAIModel
|
|
for _, m := range modelMap {
|
|
data = append(data, m)
|
|
}
|
|
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"object": "list",
|
|
"data": data,
|
|
})
|
|
}
|
|
|
|
func (s *Server) handleChatCompletions(c *gin.Context) {
|
|
startTime := time.Now()
|
|
var req models.ChatCompletionRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Select provider based on model name
|
|
providerName := "openai" // default
|
|
modelLower := strings.ToLower(req.Model)
|
|
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
|
|
providerName = "gemini"
|
|
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
|
|
// Only use deepseek provider if it's not explicitly tagged for ollama
|
|
providerName = "deepseek"
|
|
} else if strings.HasPrefix(modelLower, "moonshot/") || strings.Contains(modelLower, "kimi") || strings.Contains(modelLower, "moonshot") {
|
|
providerName = "moonshot"
|
|
} else if strings.HasPrefix(modelLower, "grok/") || strings.Contains(modelLower, "grok") {
|
|
providerName = "grok"
|
|
} else if strings.HasPrefix(modelLower, "ollama/") ||
|
|
strings.Contains(modelLower, "glm-") ||
|
|
strings.Contains(modelLower, "qwen") ||
|
|
strings.Contains(modelLower, "gemma") ||
|
|
strings.Contains(modelLower, "llama") ||
|
|
strings.Contains(modelLower, "mistral") ||
|
|
strings.Contains(modelLower, "phi") ||
|
|
strings.Contains(modelLower, "yi") ||
|
|
strings.Contains(modelLower, "codellama") ||
|
|
strings.Contains(modelLower, "command-r") {
|
|
providerName = "ollama"
|
|
}
|
|
|
|
provider, ok := s.providers[providerName]
|
|
if !ok {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
|
return
|
|
}
|
|
|
|
// Strip common prefixes
|
|
modelID := req.Model
|
|
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"}
|
|
for _, p := range prefixes {
|
|
if strings.HasPrefix(modelID, p) {
|
|
modelID = strings.TrimPrefix(modelID, p)
|
|
break
|
|
}
|
|
}
|
|
|
|
// Convert ChatCompletionRequest to UnifiedRequest
|
|
unifiedReq := &models.UnifiedRequest{
|
|
Model: modelID,
|
|
Messages: []models.UnifiedMessage{},
|
|
Temperature: req.Temperature,
|
|
TopP: req.TopP,
|
|
TopK: req.TopK,
|
|
N: req.N,
|
|
MaxTokens: req.MaxTokens,
|
|
PresencePenalty: req.PresencePenalty,
|
|
FrequencyPenalty: req.FrequencyPenalty,
|
|
Stream: req.Stream != nil && *req.Stream,
|
|
Tools: req.Tools,
|
|
ToolChoice: req.ToolChoice,
|
|
}
|
|
|
|
// Inject max_tokens from model registry when client doesn't specify one.
|
|
// Prevents providers from applying a low default output cap.
|
|
// DEBUG: Trace max_tokens through the proxy
|
|
clientMaxTokens := "nil"
|
|
if unifiedReq.MaxTokens != nil {
|
|
clientMaxTokens = fmt.Sprintf("%d", *unifiedReq.MaxTokens)
|
|
}
|
|
log.Printf("[DEBUG] %s: client max_tokens=%s", modelID, clientMaxTokens)
|
|
if unifiedReq.MaxTokens == nil {
|
|
s.registryMu.RLock()
|
|
meta := s.registry.FindModel(modelID)
|
|
s.registryMu.RUnlock()
|
|
if meta != nil && meta.Limit != nil && meta.Limit.Output > 0 {
|
|
unifiedReq.MaxTokens = &meta.Limit.Output
|
|
log.Printf("[DEBUG] %s: injected registry max_tokens=%d", modelID, meta.Limit.Output)
|
|
} else {
|
|
log.Printf("[DEBUG] %s: no registry limit found, leaving max_tokens nil (provider default)", modelID)
|
|
}
|
|
} else {
|
|
log.Printf("[DEBUG] %s: using client's max_tokens=%d", modelID, *unifiedReq.MaxTokens)
|
|
}
|
|
|
|
// Handle Stop sequences
|
|
if req.Stop != nil {
|
|
var stop []string
|
|
if err := json.Unmarshal(req.Stop, &stop); err == nil {
|
|
unifiedReq.Stop = stop
|
|
} else {
|
|
var singleStop string
|
|
if err := json.Unmarshal(req.Stop, &singleStop); err == nil {
|
|
unifiedReq.Stop = []string{singleStop}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert messages
|
|
for _, msg := range req.Messages {
|
|
unifiedMsg := models.UnifiedMessage{
|
|
Role: msg.Role,
|
|
Content: []models.UnifiedContentPart{},
|
|
ReasoningContent: msg.ReasoningContent,
|
|
ToolCalls: msg.ToolCalls,
|
|
Name: msg.Name,
|
|
ToolCallID: msg.ToolCallID,
|
|
}
|
|
|
|
// Handle multimodal content
|
|
if strContent, ok := msg.Content.(string); ok {
|
|
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
|
Type: "text",
|
|
Text: strContent,
|
|
})
|
|
} else if parts, ok := msg.Content.([]interface{}); ok {
|
|
for _, part := range parts {
|
|
if partMap, ok := part.(map[string]interface{}); ok {
|
|
partType, _ := partMap["type"].(string)
|
|
if partType == "text" {
|
|
text, _ := partMap["text"].(string)
|
|
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
|
Type: "text",
|
|
Text: text,
|
|
})
|
|
} else if partType == "image_url" {
|
|
if imgURLMap, ok := partMap["image_url"].(map[string]interface{}); ok {
|
|
url, _ := imgURLMap["url"].(string)
|
|
imageInput := &models.ImageInput{}
|
|
if strings.HasPrefix(url, "data:") {
|
|
mime, data, err := utils.ParseDataURL(url)
|
|
if err == nil {
|
|
imageInput.Base64 = data
|
|
imageInput.MimeType = mime
|
|
}
|
|
} else {
|
|
imageInput.URL = url
|
|
}
|
|
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
|
Type: "image",
|
|
Image: imageInput,
|
|
})
|
|
unifiedReq.HasImages = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
unifiedReq.Messages = append(unifiedReq.Messages, unifiedMsg)
|
|
}
|
|
|
|
clientID := "default"
|
|
if auth, ok := c.Get("auth"); ok {
|
|
if authInfo, ok := auth.(models.AuthInfo); ok {
|
|
unifiedReq.ClientID = authInfo.ClientID
|
|
clientID = authInfo.ClientID
|
|
}
|
|
} else {
|
|
unifiedReq.ClientID = clientID
|
|
}
|
|
|
|
if unifiedReq.Stream {
|
|
ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
|
|
var lastUsage *models.Usage
|
|
c.Stream(func(w io.Writer) bool {
|
|
chunk, ok := <-ch
|
|
if !ok {
|
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
|
s.logRequest(startTime, clientID, providerName, req.Model, lastUsage, nil, unifiedReq.HasImages)
|
|
return false
|
|
}
|
|
if chunk.Usage != nil {
|
|
lastUsage = chunk.Usage
|
|
}
|
|
data, err := json.Marshal(chunk)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
return true
|
|
})
|
|
return
|
|
}
|
|
|
|
resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage, nil, unifiedReq.HasImages)
|
|
c.JSON(http.StatusOK, resp)
|
|
}
|
|
|
|
func (s *Server) handleImageGenerations(c *gin.Context) {
|
|
startTime := time.Now()
|
|
var req models.ImageGenerationRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Determine provider based on model name
|
|
providerName := "openai"
|
|
modelLower := strings.ToLower(req.Model)
|
|
switch {
|
|
case strings.Contains(modelLower, "imagen"), strings.Contains(modelLower, "gemini"):
|
|
providerName = "gemini"
|
|
case strings.Contains(modelLower, "dall"), strings.HasPrefix(modelLower, "openai/"):
|
|
providerName = "openai"
|
|
}
|
|
|
|
// Default model for each provider if not specified
|
|
if req.Model == "" {
|
|
if providerName == "openai" {
|
|
req.Model = "dall-e-3"
|
|
} else {
|
|
req.Model = "imagen-3.0-generate-001"
|
|
}
|
|
}
|
|
|
|
// Strip common prefixes
|
|
prefixes := []string{"openai/", "gemini/", "google/"}
|
|
for _, p := range prefixes {
|
|
if strings.HasPrefix(req.Model, p) {
|
|
req.Model = strings.TrimPrefix(req.Model, p)
|
|
break
|
|
}
|
|
}
|
|
|
|
provider, ok := s.providers[providerName]
|
|
if !ok {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
|
return
|
|
}
|
|
|
|
clientID := "default"
|
|
if auth, ok := c.Get("auth"); ok {
|
|
if authInfo, ok := auth.(models.AuthInfo); ok {
|
|
clientID = authInfo.ClientID
|
|
}
|
|
}
|
|
|
|
resp, err := provider.ImageGeneration(c.Request.Context(), &req)
|
|
if err != nil {
|
|
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, false)
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// Estimate tokens from prompt text (~4 chars per token)
|
|
promptTokens := uint32(len(req.Prompt) / 4)
|
|
if promptTokens < 1 {
|
|
promptTokens = 1
|
|
}
|
|
|
|
// Calculate per-image cost (not per-token like chat)
|
|
cost := imageGenCost(providerName, req.Model, req.Size, uint32(len(resp.Data)))
|
|
|
|
s.logRequest(startTime, clientID, providerName, req.Model, &models.Usage{
|
|
PromptTokens: promptTokens,
|
|
CompletionTokens: uint32(len(resp.Data)),
|
|
TotalTokens: promptTokens + uint32(len(resp.Data)),
|
|
}, nil, false)
|
|
|
|
// Update cost in DB — image gen is per-image, not per-token
|
|
if cost > 0 {
|
|
s.database.Exec("UPDATE llm_requests SET cost = ? WHERE id = (SELECT MAX(id) FROM llm_requests)", cost)
|
|
}
|
|
|
|
c.JSON(http.StatusOK, resp)
|
|
}
|
|
|
|
// imageGenCost returns per-image pricing for known image generation models.
|
|
func imageGenCost(provider, model string, size *string, n uint32) float64 {
|
|
if n == 0 {
|
|
return 0
|
|
}
|
|
modelLower := strings.ToLower(model)
|
|
var perImage float64
|
|
|
|
switch {
|
|
case strings.Contains(modelLower, "dall-e-3"):
|
|
perImage = 0.040 // standard 1024x1024
|
|
if size != nil {
|
|
s := *size
|
|
if s == "1024x1792" || s == "1792x1024" {
|
|
perImage = 0.080
|
|
}
|
|
}
|
|
case strings.Contains(modelLower, "dall-e-2"):
|
|
perImage = 0.020
|
|
case strings.Contains(modelLower, "imagen"):
|
|
perImage = 0.040 // approximate
|
|
default:
|
|
return 0
|
|
}
|
|
|
|
return perImage * float64(n)
|
|
}
|
|
|
|
func (s *Server) logRequest(start time.Time, clientID, provider, model string, usage *models.Usage, err error, hasImages bool) {
|
|
entry := RequestLog{
|
|
Timestamp: start,
|
|
ClientID: clientID,
|
|
Provider: provider,
|
|
Model: model,
|
|
Status: "success",
|
|
DurationMS: time.Since(start).Milliseconds(),
|
|
HasImages: hasImages,
|
|
}
|
|
|
|
if err != nil {
|
|
entry.Status = "error"
|
|
entry.ErrorMessage = err.Error()
|
|
}
|
|
|
|
if usage != nil {
|
|
entry.PromptTokens = usage.PromptTokens
|
|
entry.CompletionTokens = usage.CompletionTokens
|
|
entry.TotalTokens = usage.TotalTokens
|
|
if usage.ReasoningTokens != nil {
|
|
entry.ReasoningTokens = *usage.ReasoningTokens
|
|
}
|
|
if usage.CacheReadTokens != nil {
|
|
entry.CacheReadTokens = *usage.CacheReadTokens
|
|
}
|
|
if usage.CacheWriteTokens != nil {
|
|
entry.CacheWriteTokens = *usage.CacheWriteTokens
|
|
}
|
|
|
|
// Calculate cost using registry
|
|
s.registryMu.RLock()
|
|
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
|
|
s.registryMu.RUnlock()
|
|
}
|
|
|
|
s.logger.LogRequest(entry)
|
|
}
|
|
|
|
func (s *Server) Run() error {
|
|
go s.hub.Run()
|
|
s.logger.Start()
|
|
|
|
// Start registry refresher
|
|
go func() {
|
|
ticker := time.NewTicker(24 * time.Hour)
|
|
for range ticker.C {
|
|
newRegistry, err := utils.FetchRegistry()
|
|
if err == nil {
|
|
s.registryMu.Lock()
|
|
s.registry = newRegistry
|
|
s.registryMu.Unlock()
|
|
}
|
|
}
|
|
}()
|
|
|
|
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
|
|
return s.router.Run(addr)
|
|
}
|