feat: enhance usage and cost tracking accuracy
Improved extraction of reasoning and cached tokens from OpenAI and DeepSeek responses (including streams). Ensured accurate cost calculation using registry metadata.
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"llm-proxy/internal/config"
|
||||
"llm-proxy/internal/models"
|
||||
@@ -28,6 +31,32 @@ func (p *DeepSeekProvider) Name() string {
|
||||
return "deepseek"
|
||||
}
|
||||
|
||||
type deepSeekUsage struct {
|
||||
PromptTokens uint32 `json:"prompt_tokens"`
|
||||
CompletionTokens uint32 `json:"completion_tokens"`
|
||||
TotalTokens uint32 `json:"total_tokens"`
|
||||
PromptCacheHitTokens uint32 `json:"prompt_cache_hit_tokens"`
|
||||
PromptCacheMissTokens uint32 `json:"prompt_cache_miss_tokens"`
|
||||
CompletionTokensDetails *struct {
|
||||
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||
} `json:"completion_tokens_details"`
|
||||
}
|
||||
|
||||
func (u *deepSeekUsage) ToUnified() *models.Usage {
|
||||
usage := &models.Usage{
|
||||
PromptTokens: u.PromptTokens,
|
||||
CompletionTokens: u.CompletionTokens,
|
||||
TotalTokens: u.TotalTokens,
|
||||
}
|
||||
if u.PromptCacheHitTokens > 0 {
|
||||
usage.CacheReadTokens = &u.PromptCacheHitTokens
|
||||
}
|
||||
if u.CompletionTokensDetails != nil && u.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||
usage.ReasoningTokens = &u.CompletionTokensDetails.ReasoningTokens
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
@@ -43,7 +72,6 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi
|
||||
delete(body, "presence_penalty")
|
||||
delete(body, "frequency_penalty")
|
||||
|
||||
// Ensure assistant messages have content and reasoning_content
|
||||
if msgs, ok := body["messages"].([]interface{}); ok {
|
||||
for _, m := range msgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
@@ -79,7 +107,21 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return ParseOpenAIResponse(respJSON, req.Model)
|
||||
result, err := ParseOpenAIResponse(respJSON, req.Model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fix usage for DeepSeek specifically if details were missing in ParseOpenAIResponse
|
||||
if usageData, ok := respJSON["usage"]; ok {
|
||||
var dUsage deepSeekUsage
|
||||
usageBytes, _ := json.Marshal(usageData)
|
||||
if err := json.Unmarshal(usageBytes, &dUsage); err == nil {
|
||||
result.Usage = dUsage.ToUnified()
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
@@ -97,7 +139,6 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models
|
||||
delete(body, "presence_penalty")
|
||||
delete(body, "frequency_penalty")
|
||||
|
||||
// Ensure assistant messages have content and reasoning_content
|
||||
if msgs, ok := body["messages"].([]interface{}); ok {
|
||||
for _, m := range msgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
@@ -133,7 +174,8 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamOpenAI(resp.RawBody(), ch)
|
||||
// Custom scanner loop to handle DeepSeek specific usage in chunks
|
||||
err := StreamDeepSeek(resp.RawBody(), ch)
|
||||
if err != nil {
|
||||
fmt.Printf("DeepSeek Stream error: %v\n", err)
|
||||
}
|
||||
@@ -141,3 +183,35 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func StreamDeepSeek(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse) error {
|
||||
defer ctx.Close()
|
||||
scanner := bufio.NewScanner(ctx)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
var chunk models.ChatCompletionStreamResponse
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Fix DeepSeek specific usage in stream
|
||||
var rawChunk struct {
|
||||
Usage *deepSeekUsage `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil {
|
||||
chunk.Usage = rawChunk.Usage.ToUnified()
|
||||
}
|
||||
|
||||
ch <- &chunk
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
@@ -122,6 +122,33 @@ func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{},
|
||||
return body
|
||||
}
|
||||
|
||||
type openAIUsage struct {
|
||||
PromptTokens uint32 `json:"prompt_tokens"`
|
||||
CompletionTokens uint32 `json:"completion_tokens"`
|
||||
TotalTokens uint32 `json:"total_tokens"`
|
||||
PromptTokensDetails *struct {
|
||||
CachedTokens uint32 `json:"cached_tokens"`
|
||||
} `json:"prompt_tokens_details"`
|
||||
CompletionTokensDetails *struct {
|
||||
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||
} `json:"completion_tokens_details"`
|
||||
}
|
||||
|
||||
func (u *openAIUsage) ToUnified() *models.Usage {
|
||||
usage := &models.Usage{
|
||||
PromptTokens: u.PromptTokens,
|
||||
CompletionTokens: u.CompletionTokens,
|
||||
TotalTokens: u.TotalTokens,
|
||||
}
|
||||
if u.PromptTokensDetails != nil && u.PromptTokensDetails.CachedTokens > 0 {
|
||||
usage.CacheReadTokens = &u.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
if u.CompletionTokensDetails != nil && u.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||
usage.ReasoningTokens = &u.CompletionTokensDetails.ReasoningTokens
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models.ChatCompletionResponse, error) {
|
||||
data, err := json.Marshal(respJSON)
|
||||
if err != nil {
|
||||
@@ -133,6 +160,16 @@ func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Manually fix usage because ChatCompletionResponse uses the unified Usage struct
|
||||
// but the provider might have returned more details.
|
||||
if usageData, ok := respJSON["usage"]; ok {
|
||||
var oUsage openAIUsage
|
||||
usageBytes, _ := json.Marshal(usageData)
|
||||
if err := json.Unmarshal(usageBytes, &oUsage); err == nil {
|
||||
resp.Usage = oUsage.ToUnified()
|
||||
}
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
@@ -156,6 +193,14 @@ func ParseOpenAIStreamChunk(line string) (*models.ChatCompletionStreamResponse,
|
||||
return nil, false, fmt.Errorf("failed to unmarshal stream chunk: %w", err)
|
||||
}
|
||||
|
||||
// Handle specialized usage in stream chunks
|
||||
var rawChunk struct {
|
||||
Usage *openAIUsage `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil {
|
||||
chunk.Usage = rawChunk.Usage.ToUnified()
|
||||
}
|
||||
|
||||
return &chunk, false, nil
|
||||
}
|
||||
|
||||
@@ -210,24 +255,27 @@ func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
|
||||
return err
|
||||
}
|
||||
|
||||
if len(geminiChunk.Candidates) > 0 {
|
||||
if len(geminiChunk.Candidates) > 0 || geminiChunk.UsageMetadata.TotalTokenCount > 0 {
|
||||
content := ""
|
||||
var reasoning *string
|
||||
for _, p := range geminiChunk.Candidates[0].Content.Parts {
|
||||
if p.Text != "" {
|
||||
content += p.Text
|
||||
}
|
||||
if p.Thought != "" {
|
||||
if reasoning == nil {
|
||||
reasoning = new(string)
|
||||
if len(geminiChunk.Candidates) > 0 {
|
||||
for _, p := range geminiChunk.Candidates[0].Content.Parts {
|
||||
if p.Text != "" {
|
||||
content += p.Text
|
||||
}
|
||||
if p.Thought != "" {
|
||||
if reasoning == nil {
|
||||
reasoning = new(string)
|
||||
}
|
||||
*reasoning += p.Thought
|
||||
}
|
||||
*reasoning += p.Thought
|
||||
}
|
||||
}
|
||||
|
||||
finishReason := strings.ToLower(geminiChunk.Candidates[0].FinishReason)
|
||||
if finishReason == "stop" {
|
||||
finishReason = "stop"
|
||||
var finishReason *string
|
||||
if len(geminiChunk.Candidates) > 0 {
|
||||
fr := strings.ToLower(geminiChunk.Candidates[0].FinishReason)
|
||||
finishReason = &fr
|
||||
}
|
||||
|
||||
ch <- &models.ChatCompletionStreamResponse{
|
||||
@@ -242,7 +290,7 @@ func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo
|
||||
Content: &content,
|
||||
ReasoningContent: reasoning,
|
||||
},
|
||||
FinishReason: &finishReason,
|
||||
FinishReason: finishReason,
|
||||
},
|
||||
},
|
||||
Usage: &models.Usage{
|
||||
|
||||
Reference in New Issue
Block a user