feat: enhance usage and cost tracking accuracy
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled

Improved extraction of reasoning and cached tokens from OpenAI and DeepSeek responses (including streams). Ensured accurate cost calculation using registry metadata.
This commit is contained in:
2026-03-19 11:56:26 -04:00
parent 66a1643bca
commit 0f3c5b6eb4
2 changed files with 139 additions and 17 deletions

View File

@@ -1,9 +1,12 @@
package providers package providers
import ( import (
"bufio"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"strings"
"llm-proxy/internal/config" "llm-proxy/internal/config"
"llm-proxy/internal/models" "llm-proxy/internal/models"
@@ -28,6 +31,32 @@ func (p *DeepSeekProvider) Name() string {
return "deepseek" return "deepseek"
} }
type deepSeekUsage struct {
PromptTokens uint32 `json:"prompt_tokens"`
CompletionTokens uint32 `json:"completion_tokens"`
TotalTokens uint32 `json:"total_tokens"`
PromptCacheHitTokens uint32 `json:"prompt_cache_hit_tokens"`
PromptCacheMissTokens uint32 `json:"prompt_cache_miss_tokens"`
CompletionTokensDetails *struct {
ReasoningTokens uint32 `json:"reasoning_tokens"`
} `json:"completion_tokens_details"`
}
func (u *deepSeekUsage) ToUnified() *models.Usage {
usage := &models.Usage{
PromptTokens: u.PromptTokens,
CompletionTokens: u.CompletionTokens,
TotalTokens: u.TotalTokens,
}
if u.PromptCacheHitTokens > 0 {
usage.CacheReadTokens = &u.PromptCacheHitTokens
}
if u.CompletionTokensDetails != nil && u.CompletionTokensDetails.ReasoningTokens > 0 {
usage.ReasoningTokens = &u.CompletionTokensDetails.ReasoningTokens
}
return usage
}
func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) { func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages) messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil { if err != nil {
@@ -43,7 +72,6 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi
delete(body, "presence_penalty") delete(body, "presence_penalty")
delete(body, "frequency_penalty") delete(body, "frequency_penalty")
// Ensure assistant messages have content and reasoning_content
if msgs, ok := body["messages"].([]interface{}); ok { if msgs, ok := body["messages"].([]interface{}); ok {
for _, m := range msgs { for _, m := range msgs {
if msg, ok := m.(map[string]interface{}); ok { if msg, ok := m.(map[string]interface{}); ok {
@@ -79,7 +107,21 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi
return nil, fmt.Errorf("failed to parse response: %w", err) return nil, fmt.Errorf("failed to parse response: %w", err)
} }
return ParseOpenAIResponse(respJSON, req.Model) result, err := ParseOpenAIResponse(respJSON, req.Model)
if err != nil {
return nil, err
}
// Fix usage for DeepSeek specifically if details were missing in ParseOpenAIResponse
if usageData, ok := respJSON["usage"]; ok {
var dUsage deepSeekUsage
usageBytes, _ := json.Marshal(usageData)
if err := json.Unmarshal(usageBytes, &dUsage); err == nil {
result.Usage = dUsage.ToUnified()
}
}
return result, nil
} }
func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) { func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
@@ -97,7 +139,6 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models
delete(body, "presence_penalty") delete(body, "presence_penalty")
delete(body, "frequency_penalty") delete(body, "frequency_penalty")
// Ensure assistant messages have content and reasoning_content
if msgs, ok := body["messages"].([]interface{}); ok { if msgs, ok := body["messages"].([]interface{}); ok {
for _, m := range msgs { for _, m := range msgs {
if msg, ok := m.(map[string]interface{}); ok { if msg, ok := m.(map[string]interface{}); ok {
@@ -133,7 +174,8 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models
go func() { go func() {
defer close(ch) defer close(ch)
err := StreamOpenAI(resp.RawBody(), ch) // Custom scanner loop to handle DeepSeek specific usage in chunks
err := StreamDeepSeek(resp.RawBody(), ch)
if err != nil { if err != nil {
fmt.Printf("DeepSeek Stream error: %v\n", err) fmt.Printf("DeepSeek Stream error: %v\n", err)
} }
@@ -141,3 +183,35 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models
return ch, nil return ch, nil
} }
func StreamDeepSeek(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse) error {
defer ctx.Close()
scanner := bufio.NewScanner(ctx)
for scanner.Scan() {
line := scanner.Text()
if line == "" || !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
break
}
var chunk models.ChatCompletionStreamResponse
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
// Fix DeepSeek specific usage in stream
var rawChunk struct {
Usage *deepSeekUsage `json:"usage"`
}
if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil {
chunk.Usage = rawChunk.Usage.ToUnified()
}
ch <- &chunk
}
return scanner.Err()
}

View File

@@ -122,6 +122,33 @@ func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{},
return body return body
} }
type openAIUsage struct {
PromptTokens uint32 `json:"prompt_tokens"`
CompletionTokens uint32 `json:"completion_tokens"`
TotalTokens uint32 `json:"total_tokens"`
PromptTokensDetails *struct {
CachedTokens uint32 `json:"cached_tokens"`
} `json:"prompt_tokens_details"`
CompletionTokensDetails *struct {
ReasoningTokens uint32 `json:"reasoning_tokens"`
} `json:"completion_tokens_details"`
}
func (u *openAIUsage) ToUnified() *models.Usage {
usage := &models.Usage{
PromptTokens: u.PromptTokens,
CompletionTokens: u.CompletionTokens,
TotalTokens: u.TotalTokens,
}
if u.PromptTokensDetails != nil && u.PromptTokensDetails.CachedTokens > 0 {
usage.CacheReadTokens = &u.PromptTokensDetails.CachedTokens
}
if u.CompletionTokensDetails != nil && u.CompletionTokensDetails.ReasoningTokens > 0 {
usage.ReasoningTokens = &u.CompletionTokensDetails.ReasoningTokens
}
return usage
}
func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models.ChatCompletionResponse, error) { func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models.ChatCompletionResponse, error) {
data, err := json.Marshal(respJSON) data, err := json.Marshal(respJSON)
if err != nil { if err != nil {
@@ -133,6 +160,16 @@ func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models
return nil, err return nil, err
} }
// Manually fix usage because ChatCompletionResponse uses the unified Usage struct
// but the provider might have returned more details.
if usageData, ok := respJSON["usage"]; ok {
var oUsage openAIUsage
usageBytes, _ := json.Marshal(usageData)
if err := json.Unmarshal(usageBytes, &oUsage); err == nil {
resp.Usage = oUsage.ToUnified()
}
}
return &resp, nil return &resp, nil
} }
@@ -156,6 +193,14 @@ func ParseOpenAIStreamChunk(line string) (*models.ChatCompletionStreamResponse,
return nil, false, fmt.Errorf("failed to unmarshal stream chunk: %w", err) return nil, false, fmt.Errorf("failed to unmarshal stream chunk: %w", err)
} }
// Handle specialized usage in stream chunks
var rawChunk struct {
Usage *openAIUsage `json:"usage"`
}
if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil {
chunk.Usage = rawChunk.Usage.ToUnified()
}
return &chunk, false, nil return &chunk, false, nil
} }
@@ -210,9 +255,10 @@ func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
return err return err
} }
if len(geminiChunk.Candidates) > 0 { if len(geminiChunk.Candidates) > 0 || geminiChunk.UsageMetadata.TotalTokenCount > 0 {
content := "" content := ""
var reasoning *string var reasoning *string
if len(geminiChunk.Candidates) > 0 {
for _, p := range geminiChunk.Candidates[0].Content.Parts { for _, p := range geminiChunk.Candidates[0].Content.Parts {
if p.Text != "" { if p.Text != "" {
content += p.Text content += p.Text
@@ -224,10 +270,12 @@ func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
*reasoning += p.Thought *reasoning += p.Thought
} }
} }
}
finishReason := strings.ToLower(geminiChunk.Candidates[0].FinishReason) var finishReason *string
if finishReason == "stop" { if len(geminiChunk.Candidates) > 0 {
finishReason = "stop" fr := strings.ToLower(geminiChunk.Candidates[0].FinishReason)
finishReason = &fr
} }
ch <- &models.ChatCompletionStreamResponse{ ch <- &models.ChatCompletionStreamResponse{
@@ -242,7 +290,7 @@ func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
Content: &content, Content: &content,
ReasoningContent: reasoning, ReasoningContent: reasoning,
}, },
FinishReason: &finishReason, FinishReason: finishReason,
}, },
}, },
Usage: &models.Usage{ Usage: &models.Usage{