diff --git a/internal/providers/deepseek.go b/internal/providers/deepseek.go index e145cbda..bb34445b 100644 --- a/internal/providers/deepseek.go +++ b/internal/providers/deepseek.go @@ -1,9 +1,12 @@ package providers import ( + "bufio" "context" "encoding/json" "fmt" + "io" + "strings" "llm-proxy/internal/config" "llm-proxy/internal/models" @@ -28,6 +31,32 @@ func (p *DeepSeekProvider) Name() string { return "deepseek" } +type deepSeekUsage struct { + PromptTokens uint32 `json:"prompt_tokens"` + CompletionTokens uint32 `json:"completion_tokens"` + TotalTokens uint32 `json:"total_tokens"` + PromptCacheHitTokens uint32 `json:"prompt_cache_hit_tokens"` + PromptCacheMissTokens uint32 `json:"prompt_cache_miss_tokens"` + CompletionTokensDetails *struct { + ReasoningTokens uint32 `json:"reasoning_tokens"` + } `json:"completion_tokens_details"` +} + +func (u *deepSeekUsage) ToUnified() *models.Usage { + usage := &models.Usage{ + PromptTokens: u.PromptTokens, + CompletionTokens: u.CompletionTokens, + TotalTokens: u.TotalTokens, + } + if u.PromptCacheHitTokens > 0 { + usage.CacheReadTokens = &u.PromptCacheHitTokens + } + if u.CompletionTokensDetails != nil && u.CompletionTokensDetails.ReasoningTokens > 0 { + usage.ReasoningTokens = &u.CompletionTokensDetails.ReasoningTokens + } + return usage +} + func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) { messagesJSON, err := MessagesToOpenAIJSON(req.Messages) if err != nil { @@ -43,7 +72,6 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi delete(body, "presence_penalty") delete(body, "frequency_penalty") - // Ensure assistant messages have content and reasoning_content if msgs, ok := body["messages"].([]interface{}); ok { for _, m := range msgs { if msg, ok := m.(map[string]interface{}); ok { @@ -79,7 +107,21 @@ func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.Unifi return nil, fmt.Errorf("failed to parse response: %w", err) } - return ParseOpenAIResponse(respJSON, req.Model) + result, err := ParseOpenAIResponse(respJSON, req.Model) + if err != nil { + return nil, err + } + + // Fix usage for DeepSeek specifically if details were missing in ParseOpenAIResponse + if usageData, ok := respJSON["usage"]; ok { + var dUsage deepSeekUsage + usageBytes, _ := json.Marshal(usageData) + if err := json.Unmarshal(usageBytes, &dUsage); err == nil { + result.Usage = dUsage.ToUnified() + } + } + + return result, nil } func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) { @@ -97,7 +139,6 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models delete(body, "presence_penalty") delete(body, "frequency_penalty") - // Ensure assistant messages have content and reasoning_content if msgs, ok := body["messages"].([]interface{}); ok { for _, m := range msgs { if msg, ok := m.(map[string]interface{}); ok { @@ -133,7 +174,8 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models go func() { defer close(ch) - err := StreamOpenAI(resp.RawBody(), ch) + // Custom scanner loop to handle DeepSeek specific usage in chunks + err := StreamDeepSeek(resp.RawBody(), ch) if err != nil { fmt.Printf("DeepSeek Stream error: %v\n", err) } @@ -141,3 +183,35 @@ func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models return ch, nil } + +func StreamDeepSeek(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse) error { + defer ctx.Close() + scanner := bufio.NewScanner(ctx) + for scanner.Scan() { + line := scanner.Text() + if line == "" || !strings.HasPrefix(line, "data: ") { + continue + } + + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + break + } + + var chunk models.ChatCompletionStreamResponse + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + continue + } + + // Fix DeepSeek specific usage in stream + var rawChunk struct { + Usage *deepSeekUsage `json:"usage"` + } + if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil { + chunk.Usage = rawChunk.Usage.ToUnified() + } + + ch <- &chunk + } + return scanner.Err() +} diff --git a/internal/providers/helpers.go b/internal/providers/helpers.go index a66c1537..d009fd19 100644 --- a/internal/providers/helpers.go +++ b/internal/providers/helpers.go @@ -122,6 +122,33 @@ func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{}, return body } +type openAIUsage struct { + PromptTokens uint32 `json:"prompt_tokens"` + CompletionTokens uint32 `json:"completion_tokens"` + TotalTokens uint32 `json:"total_tokens"` + PromptTokensDetails *struct { + CachedTokens uint32 `json:"cached_tokens"` + } `json:"prompt_tokens_details"` + CompletionTokensDetails *struct { + ReasoningTokens uint32 `json:"reasoning_tokens"` + } `json:"completion_tokens_details"` +} + +func (u *openAIUsage) ToUnified() *models.Usage { + usage := &models.Usage{ + PromptTokens: u.PromptTokens, + CompletionTokens: u.CompletionTokens, + TotalTokens: u.TotalTokens, + } + if u.PromptTokensDetails != nil && u.PromptTokensDetails.CachedTokens > 0 { + usage.CacheReadTokens = &u.PromptTokensDetails.CachedTokens + } + if u.CompletionTokensDetails != nil && u.CompletionTokensDetails.ReasoningTokens > 0 { + usage.ReasoningTokens = &u.CompletionTokensDetails.ReasoningTokens + } + return usage +} + func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models.ChatCompletionResponse, error) { data, err := json.Marshal(respJSON) if err != nil { @@ -132,6 +159,16 @@ func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models if err := json.Unmarshal(data, &resp); err != nil { return nil, err } + + // Manually fix usage because ChatCompletionResponse uses the unified Usage struct + // but the provider might have returned more details. + if usageData, ok := respJSON["usage"]; ok { + var oUsage openAIUsage + usageBytes, _ := json.Marshal(usageData) + if err := json.Unmarshal(usageBytes, &oUsage); err == nil { + resp.Usage = oUsage.ToUnified() + } + } return &resp, nil } @@ -156,6 +193,14 @@ func ParseOpenAIStreamChunk(line string) (*models.ChatCompletionStreamResponse, return nil, false, fmt.Errorf("failed to unmarshal stream chunk: %w", err) } + // Handle specialized usage in stream chunks + var rawChunk struct { + Usage *openAIUsage `json:"usage"` + } + if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil { + chunk.Usage = rawChunk.Usage.ToUnified() + } + return &chunk, false, nil } @@ -210,24 +255,27 @@ func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo return err } - if len(geminiChunk.Candidates) > 0 { + if len(geminiChunk.Candidates) > 0 || geminiChunk.UsageMetadata.TotalTokenCount > 0 { content := "" var reasoning *string - for _, p := range geminiChunk.Candidates[0].Content.Parts { - if p.Text != "" { - content += p.Text - } - if p.Thought != "" { - if reasoning == nil { - reasoning = new(string) + if len(geminiChunk.Candidates) > 0 { + for _, p := range geminiChunk.Candidates[0].Content.Parts { + if p.Text != "" { + content += p.Text + } + if p.Thought != "" { + if reasoning == nil { + reasoning = new(string) + } + *reasoning += p.Thought } - *reasoning += p.Thought } } - finishReason := strings.ToLower(geminiChunk.Candidates[0].FinishReason) - if finishReason == "stop" { - finishReason = "stop" + var finishReason *string + if len(geminiChunk.Candidates) > 0 { + fr := strings.ToLower(geminiChunk.Candidates[0].FinishReason) + finishReason = &fr } ch <- &models.ChatCompletionStreamResponse{ @@ -242,7 +290,7 @@ func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamRespo Content: &content, ReasoningContent: reasoning, }, - FinishReason: &finishReason, + FinishReason: finishReason, }, }, Usage: &models.Usage{