fix: correct deepseek pricing, gemini streaming tokens, and group-name logging
- Add promo discount system for deepseek-v4-pro (75% off until 2026-05-31) - Rewrite StreamGemini to handle both SSE and JSON array response formats, fixing 0-token logging for gemini-3-flash and gemini-3-flash-preview - Fall back to model group name for cost lookup when concrete model isnt in the registry (fixes $0 cost on deepseek-auto entries) - Move registry lock before FindModel call to fix data race
This commit is contained in:
@@ -621,16 +621,10 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
||||
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), msg)
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamGemini(resp.RawBody(), ch, req.Model)
|
||||
if err != nil {
|
||||
fmt.Printf("Gemini Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ch, err := StreamGemini(resp.RawBody(), req.Model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gemini stream init error: %w", err)
|
||||
}
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
|
||||
+185
-72
@@ -364,89 +364,202 @@ func StreamOpenAI(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error {
|
||||
defer ctx.Close()
|
||||
// geminiStreamChunk is the shared data structure for parsing Gemini streaming responses.
|
||||
type geminiStreamChunk struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought string `json:"thought,omitempty"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata struct {
|
||||
PromptTokenCount uint32 `json:"promptTokenCount"`
|
||||
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
||||
TotalTokenCount uint32 `json:"totalTokenCount"`
|
||||
CachedContentTokenCount uint32 `json:"cachedContentTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
dec := json.NewDecoder(ctx)
|
||||
|
||||
t, err := dec.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
// emitGeminiChunk builds a ChatCompletionStreamResponse from a parsed geminiStreamChunk
|
||||
// and sends it to the channel. Returns true if anything was emitted.
|
||||
func emitGeminiChunk(ch chan<- *models.ChatCompletionStreamResponse, chunk *geminiStreamChunk, model string) bool {
|
||||
if len(chunk.Candidates) == 0 && chunk.UsageMetadata.TotalTokenCount == 0 {
|
||||
return false
|
||||
}
|
||||
if delim, ok := t.(json.Delim); ok && delim == '[' {
|
||||
for dec.More() {
|
||||
var geminiChunk struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought string `json:"thought,omitempty"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata struct {
|
||||
PromptTokenCount uint32 `json:"promptTokenCount"`
|
||||
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
||||
TotalTokenCount uint32 `json:"totalTokenCount"`
|
||||
CachedContentTokenCount uint32 `json:"cachedContentTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
|
||||
content := ""
|
||||
var reasoning *string
|
||||
var finishReason *string
|
||||
if len(chunk.Candidates) > 0 {
|
||||
for _, p := range chunk.Candidates[0].Content.Parts {
|
||||
if p.Text != "" {
|
||||
content += p.Text
|
||||
}
|
||||
|
||||
if err := dec.Decode(&geminiChunk); err != nil {
|
||||
return err
|
||||
if p.Thought != "" {
|
||||
if reasoning == nil {
|
||||
reasoning = new(string)
|
||||
}
|
||||
*reasoning += p.Thought
|
||||
}
|
||||
}
|
||||
fr := strings.ToLower(chunk.Candidates[0].FinishReason)
|
||||
finishReason = &fr
|
||||
}
|
||||
|
||||
if len(geminiChunk.Candidates) > 0 || geminiChunk.UsageMetadata.TotalTokenCount > 0 {
|
||||
content := ""
|
||||
var reasoning *string
|
||||
if len(geminiChunk.Candidates) > 0 {
|
||||
for _, p := range geminiChunk.Candidates[0].Content.Parts {
|
||||
if p.Text != "" {
|
||||
content += p.Text
|
||||
}
|
||||
if p.Thought != "" {
|
||||
if reasoning == nil {
|
||||
reasoning = new(string)
|
||||
}
|
||||
*reasoning += p.Thought
|
||||
}
|
||||
}
|
||||
}
|
||||
ch <- &models.ChatCompletionStreamResponse{
|
||||
ID: "gemini-stream",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 0,
|
||||
Model: model,
|
||||
Choices: []models.ChatStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: models.ChatStreamDelta{
|
||||
Content: &content,
|
||||
ReasoningContent: reasoning,
|
||||
},
|
||||
FinishReason: finishReason,
|
||||
},
|
||||
},
|
||||
Usage: &models.Usage{
|
||||
PromptTokens: chunk.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: chunk.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: chunk.UsageMetadata.TotalTokenCount,
|
||||
CacheReadTokens: uint32Ptr(chunk.UsageMetadata.CachedContentTokenCount),
|
||||
},
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var finishReason *string
|
||||
if len(geminiChunk.Candidates) > 0 {
|
||||
fr := strings.ToLower(geminiChunk.Candidates[0].FinishReason)
|
||||
finishReason = &fr
|
||||
}
|
||||
// StreamGemini handles Gemini streaming responses in two formats:
|
||||
// 1. SSE format (newer models): each line is "data: {...}"
|
||||
// 2. JSON array format (older models): response body is [ {...}, {...} ]
|
||||
//
|
||||
// Usage metadata is only present in the final chunk, which we accumulate
|
||||
// and emit so the server can log it on stream end.
|
||||
func StreamGemini(ctx io.ReadCloser, model string) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
ch <- &models.ChatCompletionStreamResponse{
|
||||
ID: "gemini-stream",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 0,
|
||||
Model: model,
|
||||
Choices: []models.ChatStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: models.ChatStreamDelta{
|
||||
Content: &content,
|
||||
ReasoningContent: reasoning,
|
||||
},
|
||||
FinishReason: finishReason,
|
||||
},
|
||||
},
|
||||
Usage: &models.Usage{
|
||||
PromptTokens: geminiChunk.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiChunk.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiChunk.UsageMetadata.TotalTokenCount,
|
||||
CacheReadTokens: uint32Ptr(geminiChunk.UsageMetadata.CachedContentTokenCount),
|
||||
},
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
_ = ctx.Close()
|
||||
}()
|
||||
defer close(ch)
|
||||
|
||||
// Peek at the first byte to detect format
|
||||
peek := make([]byte, 6)
|
||||
n, _ := io.ReadAtLeast(ctx, peek, 1)
|
||||
if n == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
first := string(peek[:n])
|
||||
|
||||
if first[0] == '[' {
|
||||
// JSON array format
|
||||
rest, _ := io.ReadAll(ctx)
|
||||
streamGeminiJSONArray(append([]byte(first), rest...), ch, model)
|
||||
return
|
||||
} else if strings.HasPrefix(first, "data:") || strings.HasPrefix(first, "data: ") {
|
||||
// SSE format — pre-pend the peeked bytes then run SSE scanner
|
||||
combined := io.MultiReader(
|
||||
strings.NewReader(string(peek[:n])),
|
||||
ctx,
|
||||
)
|
||||
streamGeminiSSE(combined, ch, model)
|
||||
} else {
|
||||
// Unknown format — might still be SSE starting after a peek char
|
||||
// Pre-pend peeked bytes and try SSE
|
||||
combined := io.MultiReader(
|
||||
strings.NewReader(string(peek[:n])),
|
||||
ctx,
|
||||
)
|
||||
streamGeminiSSE(combined, ch, model)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// readAll reads remaining bytes from a reader (keeps the function signature simple
|
||||
// for the JSON array fallback path).
|
||||
func readAll(r io.Reader) []byte {
|
||||
b, _ := io.ReadAll(r)
|
||||
return b
|
||||
}
|
||||
|
||||
func streamGeminiJSONArray(data []byte, ch chan<- *models.ChatCompletionStreamResponse, model string) {
|
||||
var chunks []geminiStreamChunk
|
||||
if err := json.Unmarshal(data, &chunks); err != nil {
|
||||
fmt.Printf("[Gemini-Stream] JSON array parse error: %v\n", err)
|
||||
return
|
||||
}
|
||||
// Track the last chunk with usage for the final emission
|
||||
var lastUsage *geminiStreamChunk
|
||||
for i := range chunks {
|
||||
if chunks[i].UsageMetadata.TotalTokenCount > 0 {
|
||||
lastUsage = &chunks[i]
|
||||
}
|
||||
}
|
||||
if lastUsage != nil {
|
||||
// Emit a synthetic final chunk with usage data
|
||||
if len(lastUsage.Candidates) == 0 && lastUsage.UsageMetadata.TotalTokenCount > 0 {
|
||||
emitGeminiChunk(ch, lastUsage, model)
|
||||
}
|
||||
}
|
||||
// Also emit each content-bearing chunk
|
||||
for i := range chunks {
|
||||
emitGeminiChunk(ch, &chunks[i], model)
|
||||
}
|
||||
}
|
||||
|
||||
func streamGeminiSSE(r io.Reader, ch chan<- *models.ChatCompletionStreamResponse, model string) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
// Track the last seen usage for emission at end of stream
|
||||
var lastUsage geminiStreamChunk
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
// Emit final usage if we have one
|
||||
if lastUsage.UsageMetadata.TotalTokenCount > 0 {
|
||||
emitGeminiChunk(ch, &lastUsage, model)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
var chunk geminiStreamChunk
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Capture usage from any chunk (Gemini puts it in the final response)
|
||||
if chunk.UsageMetadata.TotalTokenCount > 0 {
|
||||
lastUsage = chunk
|
||||
}
|
||||
|
||||
// Emit content chunks as they arrive
|
||||
if len(chunk.Candidates) > 0 {
|
||||
emitGeminiChunk(ch, &chunk, model)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
// If stream ended without [DONE] marker but we collected usage, emit it
|
||||
if lastUsage.UsageMetadata.TotalTokenCount > 0 {
|
||||
emitGeminiChunk(ch, &lastUsage, model)
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
fmt.Printf("[Gemini-Stream] SSE scan error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ensureEnglish injects a system message instructing the model to respond in
|
||||
|
||||
+39
-21
@@ -309,9 +309,32 @@ func (s *Server) handleResponses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Select provider based on model name
|
||||
// Strip common prefixes and resolve model groups to concrete models
|
||||
// (same pattern as handleChatCompletions).
|
||||
modelGroup := ""
|
||||
modelID := req.Model
|
||||
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"}
|
||||
for _, p := range prefixes {
|
||||
if strings.HasPrefix(modelID, p) {
|
||||
modelID = strings.TrimPrefix(modelID, p)
|
||||
break
|
||||
}
|
||||
}
|
||||
if s.modelRouter != nil {
|
||||
decision, err := s.modelRouter.RouteToConcrete(c.Request.Context(), modelID, "")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("model routing failed: %v", err)})
|
||||
return
|
||||
}
|
||||
if decision.SelectedModel != modelID {
|
||||
modelGroup = modelID
|
||||
}
|
||||
modelID = decision.SelectedModel
|
||||
}
|
||||
|
||||
// Select provider based on resolved model name
|
||||
providerName := "openai" // default for Responses API
|
||||
modelLower := strings.ToLower(req.Model)
|
||||
modelLower := strings.ToLower(modelID)
|
||||
if strings.HasPrefix(modelLower, "gemini/") || strings.Contains(modelLower, "gemini") || strings.HasPrefix(modelLower, "google/") {
|
||||
providerName = "gemini"
|
||||
} else if strings.HasPrefix(modelLower, "deepseek/") || (strings.Contains(modelLower, "deepseek") && !strings.Contains(modelLower, "ollama")) {
|
||||
@@ -339,17 +362,7 @@ func (s *Server) handleResponses(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Strip common prefixes from model name
|
||||
modelID := req.Model
|
||||
prefixes := []string{"gemini/", "google/", "openai/", "deepseek/", "moonshot/", "grok/", "ollama/"}
|
||||
for _, p := range prefixes {
|
||||
if strings.HasPrefix(modelID, p) {
|
||||
modelID = strings.TrimPrefix(modelID, p)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Use the stripped model name for the actual API call
|
||||
// Use resolved model for the actual API call
|
||||
req.Model = modelID
|
||||
|
||||
clientID := "default"
|
||||
@@ -364,7 +377,7 @@ func (s *Server) handleResponses(c *gin.Context) {
|
||||
if stream {
|
||||
ch, err := provider.ResponsesStream(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -379,9 +392,9 @@ func (s *Server) handleResponses(c *gin.Context) {
|
||||
if !ok {
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
if lastUsage != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, "", lastUsage.ToUsage(), nil, false)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, lastUsage.ToUsage(), nil, false)
|
||||
} else {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, nil, false)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false)
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -401,15 +414,15 @@ func (s *Server) handleResponses(c *gin.Context) {
|
||||
|
||||
resp, err := provider.Responses(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, err, false)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, err, false)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Usage != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, "", resp.Usage.ToUsage(), nil, false)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, resp.Usage.ToUsage(), nil, false)
|
||||
} else {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, "", nil, nil, false)
|
||||
s.logRequest(startTime, clientID, providerName, modelID, modelGroup, nil, nil, false)
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
@@ -881,9 +894,14 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model, modelGro
|
||||
entry.CacheWriteTokens = *usage.CacheWriteTokens
|
||||
}
|
||||
|
||||
// Calculate cost using registry
|
||||
// Calculate cost using registry; if the resolved model is unknown,
|
||||
// fall back to the model group so group requests still get priced.
|
||||
s.registryMu.RLock()
|
||||
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
|
||||
pricingModel := model
|
||||
if s.registry != nil && s.registry.FindModel(pricingModel) == nil && modelGroup != "" {
|
||||
pricingModel = modelGroup
|
||||
}
|
||||
entry.Cost = utils.CalculateCost(s.registry, pricingModel, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
|
||||
s.registryMu.RUnlock()
|
||||
}
|
||||
|
||||
|
||||
@@ -45,6 +45,24 @@ func FetchRegistry() (*models.ModelRegistry, error) {
|
||||
return nil, fmt.Errorf("failed to fetch registry after 3 attempts: %w", lastErr)
|
||||
}
|
||||
|
||||
// promoDiscount describes a temporary pricing discount applied on top of
|
||||
// the standard (list) price from the model registry.
|
||||
type promoDiscount struct {
|
||||
Factor float64 // multiplier applied after standard calculation (0.25 = 75% off)
|
||||
ExpiresAt time.Time // discount ends at this time (UTC)
|
||||
}
|
||||
|
||||
// promoDiscounts maps model IDs to active promotional discounts.
|
||||
// Sources:
|
||||
// - DeepSeek v4 Pro: 75% off list pricing until 2026-05-31
|
||||
// https://api-docs.deepseek.com/quick_start/pricing
|
||||
var promoDiscounts = map[string]promoDiscount{
|
||||
"deepseek-v4-pro": {
|
||||
Factor: 0.25,
|
||||
ExpiresAt: time.Date(2026, 5, 31, 23, 59, 59, 0, time.UTC),
|
||||
},
|
||||
}
|
||||
|
||||
func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens, completionTokens, reasoningTokens, cacheRead, cacheWrite uint32) float64 {
|
||||
meta := registry.FindModel(modelID)
|
||||
if meta == nil || meta.Cost == nil {
|
||||
@@ -72,5 +90,12 @@ func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens,
|
||||
cost += float64(cacheWrite) * (*meta.Cost.CacheWrite) / 1000000.0
|
||||
}
|
||||
|
||||
// Apply promotional discounts (e.g. DeepSeek 75% off until 2026-05-31).
|
||||
if discount, ok := promoDiscounts[modelID]; ok {
|
||||
if time.Now().UTC().Before(discount.ExpiresAt) {
|
||||
cost *= discount.Factor
|
||||
}
|
||||
}
|
||||
|
||||
return cost
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user