From 40f055cb5722ca076374a3090800e387c6874601 Mon Sep 17 00:00:00 2001 From: newkirk Date: Sun, 17 May 2026 19:48:47 -0400 Subject: [PATCH] fix: correct deepseek pricing, gemini streaming tokens, and group-name logging - Add promo discount system for deepseek-v4-pro (75% off until 2026-05-31) - Rewrite StreamGemini to handle both SSE and JSON array response formats, fixing 0-token logging for gemini-3-flash and gemini-3-flash-preview - Fall back to model group name for cost lookup when concrete model isnt in the registry (fixes $0 cost on deepseek-auto entries) - Move registry lock before FindModel call to fix data race --- internal/providers/gemini.go | 14 +- internal/providers/helpers.go | 257 ++++++++++++++++++++++++---------- internal/server/server.go | 60 +++++--- internal/utils/registry.go | 25 ++++ 4 files changed, 253 insertions(+), 103 deletions(-) diff --git a/internal/providers/gemini.go b/internal/providers/gemini.go index a2f9ab5b..1f77c3f0 100644 --- a/internal/providers/gemini.go +++ b/internal/providers/gemini.go @@ -621,16 +621,10 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), msg) } - ch := make(chan *models.ChatCompletionStreamResponse) - - go func() { - defer close(ch) - err := StreamGemini(resp.RawBody(), ch, req.Model) - if err != nil { - fmt.Printf("Gemini Stream error: %v\n", err) - } - }() - + ch, err := StreamGemini(resp.RawBody(), req.Model) + if err != nil { + return nil, fmt.Errorf("gemini stream init error: %w", err) + } return ch, nil } diff --git a/internal/providers/helpers.go b/internal/providers/helpers.go index d209347e..0bd988ea 100644 --- a/internal/providers/helpers.go +++ b/internal/providers/helpers.go @@ -364,89 +364,202 @@ func StreamOpenAI(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo return scanner.Err() } -func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error { - defer ctx.Close() +// geminiStreamChunk is the shared data structure for parsing Gemini streaming responses. +type geminiStreamChunk struct { + Candidates []struct { + Content struct { + Parts []struct { + Text string `json:"text,omitempty"` + Thought string `json:"thought,omitempty"` + } `json:"parts"` + } `json:"content"` + FinishReason string `json:"finishReason"` + } `json:"candidates"` + UsageMetadata struct { + PromptTokenCount uint32 `json:"promptTokenCount"` + CandidatesTokenCount uint32 `json:"candidatesTokenCount"` + TotalTokenCount uint32 `json:"totalTokenCount"` + CachedContentTokenCount uint32 `json:"cachedContentTokenCount"` + } `json:"usageMetadata"` +} - dec := json.NewDecoder(ctx) - - t, err := dec.Token() - if err != nil { - return err +// emitGeminiChunk builds a ChatCompletionStreamResponse from a parsed geminiStreamChunk +// and sends it to the channel. Returns true if anything was emitted. +func emitGeminiChunk(ch chan<- *models.ChatCompletionStreamResponse, chunk *geminiStreamChunk, model string) bool { + if len(chunk.Candidates) == 0 && chunk.UsageMetadata.TotalTokenCount == 0 { + return false } - if delim, ok := t.(json.Delim); ok && delim == '[' { - for dec.More() { - var geminiChunk struct { - Candidates []struct { - Content struct { - Parts []struct { - Text string `json:"text,omitempty"` - Thought string `json:"thought,omitempty"` - } `json:"parts"` - } `json:"content"` - FinishReason string `json:"finishReason"` - } `json:"candidates"` - UsageMetadata struct { - PromptTokenCount uint32 `json:"promptTokenCount"` - CandidatesTokenCount uint32 `json:"candidatesTokenCount"` - TotalTokenCount uint32 `json:"totalTokenCount"` - CachedContentTokenCount uint32 `json:"cachedContentTokenCount"` - } `json:"usageMetadata"` + + content := "" + var reasoning *string + var finishReason *string + if len(chunk.Candidates) > 0 { + for _, p := range chunk.Candidates[0].Content.Parts { + if p.Text != "" { + content += p.Text } - - if err := dec.Decode(&geminiChunk); err != nil { - return err + if p.Thought != "" { + if reasoning == nil { + reasoning = new(string) + } + *reasoning += p.Thought } + } + fr := strings.ToLower(chunk.Candidates[0].FinishReason) + finishReason = &fr + } - if len(geminiChunk.Candidates) > 0 || geminiChunk.UsageMetadata.TotalTokenCount > 0 { - content := "" - var reasoning *string - if len(geminiChunk.Candidates) > 0 { - for _, p := range geminiChunk.Candidates[0].Content.Parts { - if p.Text != "" { - content += p.Text - } - if p.Thought != "" { - if reasoning == nil { - reasoning = new(string) - } - *reasoning += p.Thought - } - } - } + ch <- &models.ChatCompletionStreamResponse{ + ID: "gemini-stream", + Object: "chat.completion.chunk", + Created: 0, + Model: model, + Choices: []models.ChatStreamChoice{ + { + Index: 0, + Delta: models.ChatStreamDelta{ + Content: &content, + ReasoningContent: reasoning, + }, + FinishReason: finishReason, + }, + }, + Usage: &models.Usage{ + PromptTokens: chunk.UsageMetadata.PromptTokenCount, + CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount, + TotalTokens: chunk.UsageMetadata.TotalTokenCount, + CacheReadTokens: uint32Ptr(chunk.UsageMetadata.CachedContentTokenCount), + }, + } + return true +} - var finishReason *string - if len(geminiChunk.Candidates) > 0 { - fr := strings.ToLower(geminiChunk.Candidates[0].FinishReason) - finishReason = &fr - } +// StreamGemini handles Gemini streaming responses in two formats: +// 1. SSE format (newer models): each line is "data: {...}" +// 2. JSON array format (older models): response body is [ {...}, {...} ] +// +// Usage metadata is only present in the final chunk, which we accumulate +// and emit so the server can log it on stream end. +func StreamGemini(ctx io.ReadCloser, model string) (<-chan *models.ChatCompletionStreamResponse, error) { + ch := make(chan *models.ChatCompletionStreamResponse) - ch <- &models.ChatCompletionStreamResponse{ - ID: "gemini-stream", - Object: "chat.completion.chunk", - Created: 0, - Model: model, - Choices: []models.ChatStreamChoice{ - { - Index: 0, - Delta: models.ChatStreamDelta{ - Content: &content, - ReasoningContent: reasoning, - }, - FinishReason: finishReason, - }, - }, - Usage: &models.Usage{ - PromptTokens: geminiChunk.UsageMetadata.PromptTokenCount, - CompletionTokens: geminiChunk.UsageMetadata.CandidatesTokenCount, - TotalTokens: geminiChunk.UsageMetadata.TotalTokenCount, - CacheReadTokens: uint32Ptr(geminiChunk.UsageMetadata.CachedContentTokenCount), - }, - } + go func() { + defer func() { + _ = ctx.Close() + }() + defer close(ch) + + // Peek at the first byte to detect format + peek := make([]byte, 6) + n, _ := io.ReadAtLeast(ctx, peek, 1) + if n == 0 { + return + } + + first := string(peek[:n]) + + if first[0] == '[' { + // JSON array format + rest, _ := io.ReadAll(ctx) + streamGeminiJSONArray(append([]byte(first), rest...), ch, model) + return + } else if strings.HasPrefix(first, "data:") || strings.HasPrefix(first, "data: ") { + // SSE format — pre-pend the peeked bytes then run SSE scanner + combined := io.MultiReader( + strings.NewReader(string(peek[:n])), + ctx, + ) + streamGeminiSSE(combined, ch, model) + } else { + // Unknown format — might still be SSE starting after a peek char + // Pre-pend peeked bytes and try SSE + combined := io.MultiReader( + strings.NewReader(string(peek[:n])), + ctx, + ) + streamGeminiSSE(combined, ch, model) + } + }() + + return ch, nil +} + +// readAll reads remaining bytes from a reader (keeps the function signature simple +// for the JSON array fallback path). +func readAll(r io.Reader) []byte { + b, _ := io.ReadAll(r) + return b +} + +func streamGeminiJSONArray(data []byte, ch chan<- *models.ChatCompletionStreamResponse, model string) { + var chunks []geminiStreamChunk + if err := json.Unmarshal(data, &chunks); err != nil { + fmt.Printf("[Gemini-Stream] JSON array parse error: %v\n", err) + return + } + // Track the last chunk with usage for the final emission + var lastUsage *geminiStreamChunk + for i := range chunks { + if chunks[i].UsageMetadata.TotalTokenCount > 0 { + lastUsage = &chunks[i] + } + } + if lastUsage != nil { + // Emit a synthetic final chunk with usage data + if len(lastUsage.Candidates) == 0 && lastUsage.UsageMetadata.TotalTokenCount > 0 { + emitGeminiChunk(ch, lastUsage, model) + } + } + // Also emit each content-bearing chunk + for i := range chunks { + emitGeminiChunk(ch, &chunks[i], model) + } +} + +func streamGeminiSSE(r io.Reader, ch chan<- *models.ChatCompletionStreamResponse, model string) { + scanner := bufio.NewScanner(r) + // Track the last seen usage for emission at end of stream + var lastUsage geminiStreamChunk + + for scanner.Scan() { + line := scanner.Text() + if line == "" || !strings.HasPrefix(line, "data: ") { + continue + } + + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + // Emit final usage if we have one + if lastUsage.UsageMetadata.TotalTokenCount > 0 { + emitGeminiChunk(ch, &lastUsage, model) } + break + } + + var chunk geminiStreamChunk + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + continue + } + + // Capture usage from any chunk (Gemini puts it in the final response) + if chunk.UsageMetadata.TotalTokenCount > 0 { + lastUsage = chunk + } + + // Emit content chunks as they arrive + if len(chunk.Candidates) > 0 { + emitGeminiChunk(ch, &chunk, model) } } - return nil + // If stream ended without [DONE] marker but we collected usage, emit it + if lastUsage.UsageMetadata.TotalTokenCount > 0 { + emitGeminiChunk(ch, &lastUsage, model) + } + + if err := scanner.Err(); err != nil { + fmt.Printf("[Gemini-Stream] SSE scan error: %v\n", err) + } } // ensureEnglish injects a system message instructing the model to respond in diff --git a/internal/server/server.go b/internal/server/server.go index 652cc0cd..5e135835 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -309,9 +309,32 @@ func (s *Server) handleResponses(c *gin.Context) { return } - // Select provider based on model name + // Strip common prefixes and resolve model groups to concrete models + // (same pattern as handleChatCompletions). + modelGroup := "" + modelID := req.Model + prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"} + for _, p := range prefixes { + if strings.HasPrefix(modelID, p) { + modelID = strings.TrimPrefix(modelID, p) + break + } + } + if s.modelRouter != nil { + decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, "") + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)}) + return + } + if decision.SelectedModel != modelID { + modelGroup = modelID + } + modelID = decision.SelectedModel + } + + // Select provider based on resolved model name providerName := "openai" // default for Responses API - modelLower := strings.ToLower(req.Model) + modelLower := strings.ToLower(modelID) if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") { providerName = "gemini" } else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) { @@ -339,17 +362,7 @@ func (s *Server) handleResponses(c *gin.Context) { return } - // Strip common prefixes from model name - modelID := req.Model - prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"} - for _, p := range prefixes { - if strings.HasPrefix(modelID, p) { - modelID = strings.TrimPrefix(modelID, p) - break - } - } - - // Use the stripped model name for the actual API call + // Use resolved model for the actual API call req.Model = modelID clientID := "default" @@ -364,7 +377,7 @@ func (s *Server) handleResponses(c *gin.Context) { if stream { ch, err := provider.ResponsesStream(c.Request.Context(), &req) if err != nil { - s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false) + s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -379,9 +392,9 @@ func (s *Server) handleResponses(c *gin.Context) { if !ok { fmt.Fprintf(w, "data: [DONE]\n\n") if lastUsage != nil { - s.logRequest(startTime, clientID, providerName, req.Model, "", lastUsage.ToUsage(), nil, false) + s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage.ToUsage(), nil, false) } else { - s.logRequest(startTime, clientID, providerName, req.Model, "", nil, nil, false) + s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false) } return false } @@ -401,15 +414,15 @@ func (s *Server) handleResponses(c *gin.Context) { resp, err := provider.Responses(c.Request.Context(), &req) if err != nil { - s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false) + s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } if resp.Usage != nil { - s.logRequest(startTime, clientID, providerName, req.Model, "", resp.Usage.ToUsage(), nil, false) + s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage.ToUsage(), nil, false) } else { - s.logRequest(startTime, clientID, providerName, req.Model, "", nil, nil, false) + s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false) } c.JSON(http.StatusOK, resp) } @@ -881,9 +894,14 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model, modelGro entry.CacheWriteTokens = *usage.CacheWriteTokens } - // Calculate cost using registry + // Calculate cost using registry; if the resolved model is unknown, + // fall back to the model group so group requests still get priced. s.registryMu.RLock() - entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens) + pricingModel := model + if s.registry != nil && s.registry.FindModel(pricingModel) == nil && modelGroup != "" { + pricingModel = modelGroup + } + entry.Cost = utils.CalculateCost(s.registry, pricingModel, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens) s.registryMu.RUnlock() } diff --git a/internal/utils/registry.go b/internal/utils/registry.go index b82201bd..e20d67e0 100644 --- a/internal/utils/registry.go +++ b/internal/utils/registry.go @@ -45,6 +45,24 @@ func FetchRegistry() (*models.ModelRegistry, error) { return nil, fmt.Errorf("failed to fetch registry after 3 attempts: %w", lastErr) } +// promoDiscount describes a temporary pricing discount applied on top of +// the standard (list) price from the model registry. +type promoDiscount struct { + Factor float64 // multiplier applied after standard calculation (0.25 = 75% off) + ExpiresAt time.Time // discount ends at this time (UTC) +} + +// promoDiscounts maps model IDs to active promotional discounts. +// Sources: +// - DeepSeek v4 Pro: 75% off list pricing until 2026-05-31 +// https://api-docs.deepseek.com/quick_start/pricing +var promoDiscounts = map[string]promoDiscount{ + "deepseek-v4-pro": { + Factor: 0.25, + ExpiresAt: time.Date(2026, 5, 31, 23, 59, 59, 0, time.UTC), + }, +} + func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens, completionTokens, reasoningTokens, cacheRead, cacheWrite uint32) float64 { meta := registry.FindModel(modelID) if meta == nil || meta.Cost == nil { @@ -72,5 +90,12 @@ func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens, cost += float64(cacheWrite) * (*meta.Cost.CacheWrite) / 1000000.0 } + // Apply promotional discounts (e.g. DeepSeek 75% off until 2026-05-31). + if discount, ok := promoDiscounts[modelID]; ok { + if time.Now().UTC().Before(discount.ExpiresAt) { + cost *= discount.Factor + } + } + return cost }