fix: correct deepseek pricing, gemini streaming tokens, and group-name logging
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled

- Add promo discount system for deepseek-v4-pro (75% off until 2026-05-31)
- Rewrite StreamGemini to handle both SSE and JSON array response formats,
  fixing 0-token logging for gemini-3-flash and gemini-3-flash-preview
- Fall back to model group name for cost lookup when concrete model
  isnt in the registry (fixes $0 cost on deepseek-auto entries)
- Move registry lock before FindModel call to fix data race
This commit is contained in:
newkirk
2026-05-17 19:48:47 -04:00
parent 970e778703
commit 40f055cb57
4 changed files with 253 additions and 103 deletions
+2 -8
View File
@@ -621,16 +621,10 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), msg)
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
err := StreamGemini(resp.RawBody(), ch, req.Model)
ch, err := StreamGemini(resp.RawBody(), req.Model)
if err != nil {
fmt.Printf("Gemini Stream error: %v\n", err)
return nil, fmt.Errorf("gemini stream init error: %w", err)
}
}()
return ch, nil
}
+140 -27
View File
@@ -364,18 +364,8 @@ func StreamOpenAI(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
return scanner.Err()
}
func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error {
defer ctx.Close()
dec := json.NewDecoder(ctx)
t, err := dec.Token()
if err != nil {
return err
}
if delim, ok := t.(json.Delim); ok && delim == '[' {
for dec.More() {
var geminiChunk struct {
// geminiStreamChunk is the shared data structure for parsing Gemini streaming responses.
type geminiStreamChunk struct {
Candidates []struct {
Content struct {
Parts []struct {
@@ -393,15 +383,18 @@ func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
} `json:"usageMetadata"`
}
if err := dec.Decode(&geminiChunk); err != nil {
return err
// emitGeminiChunk builds a ChatCompletionStreamResponse from a parsed geminiStreamChunk
// and sends it to the channel. Returns true if anything was emitted.
func emitGeminiChunk(ch chan<- *models.ChatCompletionStreamResponse, chunk *geminiStreamChunk, model string) bool {
if len(chunk.Candidates) == 0 && chunk.UsageMetadata.TotalTokenCount == 0 {
return false
}
if len(geminiChunk.Candidates) > 0 || geminiChunk.UsageMetadata.TotalTokenCount > 0 {
content := ""
var reasoning *string
if len(geminiChunk.Candidates) > 0 {
for _, p := range geminiChunk.Candidates[0].Content.Parts {
var finishReason *string
if len(chunk.Candidates) > 0 {
for _, p := range chunk.Candidates[0].Content.Parts {
if p.Text != "" {
content += p.Text
}
@@ -412,11 +405,7 @@ func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
*reasoning += p.Thought
}
}
}
var finishReason *string
if len(geminiChunk.Candidates) > 0 {
fr := strings.ToLower(geminiChunk.Candidates[0].FinishReason)
fr := strings.ToLower(chunk.Candidates[0].FinishReason)
finishReason = &fr
}
@@ -436,17 +425,141 @@ func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
},
},
Usage: &models.Usage{
PromptTokens: geminiChunk.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiChunk.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiChunk.UsageMetadata.TotalTokenCount,
CacheReadTokens: uint32Ptr(geminiChunk.UsageMetadata.CachedContentTokenCount),
PromptTokens: chunk.UsageMetadata.PromptTokenCount,
CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount,
TotalTokens: chunk.UsageMetadata.TotalTokenCount,
CacheReadTokens: uint32Ptr(chunk.UsageMetadata.CachedContentTokenCount),
},
}
return true
}
// StreamGemini handles Gemini streaming responses in two formats:
// 1. SSE format (newer models): each line is "data: {...}"
// 2. JSON array format (older models): response body is [ {...}, {...} ]
//
// Usage metadata is only present in the final chunk, which we accumulate
// and emit so the server can log it on stream end.
func StreamGemini(ctx io.ReadCloser, model string) (<-chan *models.ChatCompletionStreamResponse, error) {
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer func() {
_ = ctx.Close()
}()
defer close(ch)
// Peek at the first byte to detect format
peek := make([]byte, 6)
n, _ := io.ReadAtLeast(ctx, peek, 1)
if n == 0 {
return
}
first := string(peek[:n])
if first[0] == '[' {
// JSON array format
rest, _ := io.ReadAll(ctx)
streamGeminiJSONArray(append([]byte(first), rest...), ch, model)
return
} else if strings.HasPrefix(first, "data:") || strings.HasPrefix(first, "data: ") {
// SSE format — pre-pend the peeked bytes then run SSE scanner
combined := io.MultiReader(
strings.NewReader(string(peek[:n])),
ctx,
)
streamGeminiSSE(combined, ch, model)
} else {
// Unknown format — might still be SSE starting after a peek char
// Pre-pend peeked bytes and try SSE
combined := io.MultiReader(
strings.NewReader(string(peek[:n])),
ctx,
)
streamGeminiSSE(combined, ch, model)
}
}()
return ch, nil
}
// readAll reads remaining bytes from a reader (keeps the function signature simple
// for the JSON array fallback path).
func readAll(r io.Reader) []byte {
b, _ := io.ReadAll(r)
return b
}
func streamGeminiJSONArray(data []byte, ch chan<- *models.ChatCompletionStreamResponse, model string) {
var chunks []geminiStreamChunk
if err := json.Unmarshal(data, &chunks); err != nil {
fmt.Printf("[Gemini-Stream] JSON array parse error: %v\n", err)
return
}
// Track the last chunk with usage for the final emission
var lastUsage *geminiStreamChunk
for i := range chunks {
if chunks[i].UsageMetadata.TotalTokenCount > 0 {
lastUsage = &chunks[i]
}
}
if lastUsage != nil {
// Emit a synthetic final chunk with usage data
if len(lastUsage.Candidates) == 0 && lastUsage.UsageMetadata.TotalTokenCount > 0 {
emitGeminiChunk(ch, lastUsage, model)
}
}
// Also emit each content-bearing chunk
for i := range chunks {
emitGeminiChunk(ch, &chunks[i], model)
}
}
return nil
func streamGeminiSSE(r io.Reader, ch chan<- *models.ChatCompletionStreamResponse, model string) {
scanner := bufio.NewScanner(r)
// Track the last seen usage for emission at end of stream
var lastUsage geminiStreamChunk
for scanner.Scan() {
line := scanner.Text()
if line == "" || !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
// Emit final usage if we have one
if lastUsage.UsageMetadata.TotalTokenCount > 0 {
emitGeminiChunk(ch, &lastUsage, model)
}
break
}
var chunk geminiStreamChunk
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
// Capture usage from any chunk (Gemini puts it in the final response)
if chunk.UsageMetadata.TotalTokenCount > 0 {
lastUsage = chunk
}
// Emit content chunks as they arrive
if len(chunk.Candidates) > 0 {
emitGeminiChunk(ch, &chunk, model)
}
}
// If stream ended without [DONE] marker but we collected usage, emit it
if lastUsage.UsageMetadata.TotalTokenCount > 0 {
emitGeminiChunk(ch, &lastUsage, model)
}
if err := scanner.Err(); err != nil {
fmt.Printf("[Gemini-Stream] SSE scan error: %v\n", err)
}
}
// ensureEnglish injects a system message instructing the model to respond in
+39 -21
View File
@@ -309,9 +309,32 @@ func (s *Server) handleResponses(c *gin.Context) {
return
}
// Select provider based on model name
// Strip common prefixes and resolve model groups to concrete models
// (same pattern as handleChatCompletions).
modelGroup := ""
modelID := req.Model
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"}
for _, p := range prefixes {
if strings.HasPrefix(modelID, p) {
modelID = strings.TrimPrefix(modelID, p)
break
}
}
if s.modelRouter != nil {
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, "")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
return
}
if decision.SelectedModel != modelID {
modelGroup = modelID
}
modelID = decision.SelectedModel
}
// Select provider based on resolved model name
providerName := "openai" // default for Responses API
modelLower := strings.ToLower(req.Model)
modelLower := strings.ToLower(modelID)
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
providerName = "gemini"
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
@@ -339,17 +362,7 @@ func (s *Server) handleResponses(c *gin.Context) {
return
}
// Strip common prefixes from model name
modelID := req.Model
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"}
for _, p := range prefixes {
if strings.HasPrefix(modelID, p) {
modelID = strings.TrimPrefix(modelID, p)
break
}
}
// Use the stripped model name for the actual API call
// Use resolved model for the actual API call
req.Model = modelID
clientID := "default"
@@ -364,7 +377,7 @@ func (s *Server) handleResponses(c *gin.Context) {
if stream {
ch, err := provider.ResponsesStream(c.Request.Context(), &req)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -379,9 +392,9 @@ func (s *Server) handleResponses(c *gin.Context) {
if !ok {
fmt.Fprintf(w, "data: [DONE]\n\n")
if lastUsage != nil {
s.logRequest(startTime, clientID, providerName, req.Model, "", lastUsage.ToUsage(), nil, false)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage.ToUsage(), nil, false)
} else {
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, nil, false)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false)
}
return false
}
@@ -401,15 +414,15 @@ func (s *Server) handleResponses(c *gin.Context) {
resp, err := provider.Responses(c.Request.Context(), &req)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if resp.Usage != nil {
s.logRequest(startTime, clientID, providerName, req.Model, "", resp.Usage.ToUsage(), nil, false)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage.ToUsage(), nil, false)
} else {
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, nil, false)
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false)
}
c.JSON(http.StatusOK, resp)
}
@@ -881,9 +894,14 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model, modelGro
entry.CacheWriteTokens = *usage.CacheWriteTokens
}
// Calculate cost using registry
// Calculate cost using registry; if the resolved model is unknown,
// fall back to the model group so group requests still get priced.
s.registryMu.RLock()
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
pricingModel := model
if s.registry != nil && s.registry.FindModel(pricingModel) == nil && modelGroup != "" {
pricingModel = modelGroup
}
entry.Cost = utils.CalculateCost(s.registry, pricingModel, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
s.registryMu.RUnlock()
}
+25
View File
@@ -45,6 +45,24 @@ func FetchRegistry() (*models.ModelRegistry, error) {
return nil, fmt.Errorf("failed to fetch registry after 3 attempts: %w", lastErr)
}
// promoDiscount describes a temporary pricing discount applied on top of
// the standard (list) price from the model registry.
type promoDiscount struct {
Factor float64 // multiplier applied after standard calculation (0.25 = 75% off)
ExpiresAt time.Time // discount ends at this time (UTC)
}
// promoDiscounts maps model IDs to active promotional discounts.
// Sources:
// - DeepSeek v4 Pro: 75% off list pricing until 2026-05-31
// https://api-docs.deepseek.com/quick_start/pricing
var promoDiscounts = map[string]promoDiscount{
"deepseek-v4-pro": {
Factor: 0.25,
ExpiresAt: time.Date(2026, 5, 31, 23, 59, 59, 0, time.UTC),
},
}
func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens, completionTokens, reasoningTokens, cacheRead, cacheWrite uint32) float64 {
meta := registry.FindModel(modelID)
if meta == nil || meta.Cost == nil {
@@ -72,5 +90,12 @@ func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens,
cost += float64(cacheWrite) * (*meta.Cost.CacheWrite) / 1000000.0
}
// Apply promotional discounts (e.g. DeepSeek 75% off until 2026-05-31).
if discount, ok := promoDiscounts[modelID]; ok {
if time.Now().UTC().Before(discount.ExpiresAt) {
cost *= discount.Factor
}
}
return cost
}