Improved extraction of reasoning and cached tokens from OpenAI and DeepSeek responses (including streams). Ensured accurate cost calculation using registry metadata.
308 lines
7.6 KiB
Go
308 lines
7.6 KiB
Go
package providers
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
|
|
"llm-proxy/internal/models"
|
|
)
|
|
|
|
// MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images.
|
|
func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) {
|
|
var result []interface{}
|
|
for _, m := range messages {
|
|
if m.Role == "tool" {
|
|
text := ""
|
|
if len(m.Content) > 0 {
|
|
text = m.Content[0].Text
|
|
}
|
|
msg := map[string]interface{}{
|
|
"role": "tool",
|
|
"content": text,
|
|
}
|
|
if m.ToolCallID != nil {
|
|
id := *m.ToolCallID
|
|
if len(id) > 40 {
|
|
id = id[:40]
|
|
}
|
|
msg["tool_call_id"] = id
|
|
}
|
|
if m.Name != nil {
|
|
msg["name"] = *m.Name
|
|
}
|
|
result = append(result, msg)
|
|
continue
|
|
}
|
|
|
|
var parts []interface{}
|
|
for _, p := range m.Content {
|
|
if p.Type == "text" {
|
|
parts = append(parts, map[string]interface{}{
|
|
"type": "text",
|
|
"text": p.Text,
|
|
})
|
|
} else if p.Image != nil {
|
|
base64Data, mimeType, err := p.Image.ToBase64()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to convert image to base64: %w", err)
|
|
}
|
|
parts = append(parts, map[string]interface{}{
|
|
"type": "image_url",
|
|
"image_url": map[string]interface{}{
|
|
"url": fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data),
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
msg := map[string]interface{}{
|
|
"role": m.Role,
|
|
"content": parts,
|
|
}
|
|
|
|
if m.ReasoningContent != nil {
|
|
msg["reasoning_content"] = *m.ReasoningContent
|
|
}
|
|
|
|
if len(m.ToolCalls) > 0 {
|
|
sanitizedCalls := make([]models.ToolCall, len(m.ToolCalls))
|
|
copy(sanitizedCalls, m.ToolCalls)
|
|
for i := range sanitizedCalls {
|
|
if len(sanitizedCalls[i].ID) > 40 {
|
|
sanitizedCalls[i].ID = sanitizedCalls[i].ID[:40]
|
|
}
|
|
}
|
|
msg["tool_calls"] = sanitizedCalls
|
|
if len(parts) == 0 {
|
|
msg["content"] = ""
|
|
}
|
|
}
|
|
|
|
if m.Name != nil {
|
|
msg["name"] = *m.Name
|
|
}
|
|
|
|
result = append(result, msg)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{}, stream bool) map[string]interface{} {
|
|
body := map[string]interface{}{
|
|
"model": request.Model,
|
|
"messages": messagesJSON,
|
|
"stream": stream,
|
|
}
|
|
|
|
if stream {
|
|
body["stream_options"] = map[string]interface{}{
|
|
"include_usage": true,
|
|
}
|
|
}
|
|
|
|
if request.Temperature != nil {
|
|
body["temperature"] = *request.Temperature
|
|
}
|
|
if request.MaxTokens != nil {
|
|
body["max_tokens"] = *request.MaxTokens
|
|
}
|
|
if len(request.Tools) > 0 {
|
|
body["tools"] = request.Tools
|
|
}
|
|
if request.ToolChoice != nil {
|
|
var toolChoice interface{}
|
|
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
|
|
body["tool_choice"] = toolChoice
|
|
}
|
|
}
|
|
|
|
return body
|
|
}
|
|
|
|
type openAIUsage struct {
|
|
PromptTokens uint32 `json:"prompt_tokens"`
|
|
CompletionTokens uint32 `json:"completion_tokens"`
|
|
TotalTokens uint32 `json:"total_tokens"`
|
|
PromptTokensDetails *struct {
|
|
CachedTokens uint32 `json:"cached_tokens"`
|
|
} `json:"prompt_tokens_details"`
|
|
CompletionTokensDetails *struct {
|
|
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
|
} `json:"completion_tokens_details"`
|
|
}
|
|
|
|
func (u *openAIUsage) ToUnified() *models.Usage {
|
|
usage := &models.Usage{
|
|
PromptTokens: u.PromptTokens,
|
|
CompletionTokens: u.CompletionTokens,
|
|
TotalTokens: u.TotalTokens,
|
|
}
|
|
if u.PromptTokensDetails != nil && u.PromptTokensDetails.CachedTokens > 0 {
|
|
usage.CacheReadTokens = &u.PromptTokensDetails.CachedTokens
|
|
}
|
|
if u.CompletionTokensDetails != nil && u.CompletionTokensDetails.ReasoningTokens > 0 {
|
|
usage.ReasoningTokens = &u.CompletionTokensDetails.ReasoningTokens
|
|
}
|
|
return usage
|
|
}
|
|
|
|
func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models.ChatCompletionResponse, error) {
|
|
data, err := json.Marshal(respJSON)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var resp models.ChatCompletionResponse
|
|
if err := json.Unmarshal(data, &resp); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Manually fix usage because ChatCompletionResponse uses the unified Usage struct
|
|
// but the provider might have returned more details.
|
|
if usageData, ok := respJSON["usage"]; ok {
|
|
var oUsage openAIUsage
|
|
usageBytes, _ := json.Marshal(usageData)
|
|
if err := json.Unmarshal(usageBytes, &oUsage); err == nil {
|
|
resp.Usage = oUsage.ToUnified()
|
|
}
|
|
}
|
|
|
|
return &resp, nil
|
|
}
|
|
|
|
// Streaming support
|
|
|
|
func ParseOpenAIStreamChunk(line string) (*models.ChatCompletionStreamResponse, bool, error) {
|
|
if line == "" {
|
|
return nil, false, nil
|
|
}
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
return nil, false, nil
|
|
}
|
|
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
if data == "[DONE]" {
|
|
return nil, true, nil
|
|
}
|
|
|
|
var chunk models.ChatCompletionStreamResponse
|
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
|
return nil, false, fmt.Errorf("failed to unmarshal stream chunk: %w", err)
|
|
}
|
|
|
|
// Handle specialized usage in stream chunks
|
|
var rawChunk struct {
|
|
Usage *openAIUsage `json:"usage"`
|
|
}
|
|
if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil {
|
|
chunk.Usage = rawChunk.Usage.ToUnified()
|
|
}
|
|
|
|
return &chunk, false, nil
|
|
}
|
|
|
|
func StreamOpenAI(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse) error {
|
|
defer ctx.Close()
|
|
scanner := bufio.NewScanner(ctx)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
chunk, done, err := ParseOpenAIStreamChunk(line)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if done {
|
|
break
|
|
}
|
|
if chunk != nil {
|
|
ch <- chunk
|
|
}
|
|
}
|
|
return scanner.Err()
|
|
}
|
|
|
|
func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error {
|
|
defer ctx.Close()
|
|
|
|
dec := json.NewDecoder(ctx)
|
|
|
|
t, err := dec.Token()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if delim, ok := t.(json.Delim); ok && delim == '[' {
|
|
for dec.More() {
|
|
var geminiChunk struct {
|
|
Candidates []struct {
|
|
Content struct {
|
|
Parts []struct {
|
|
Text string `json:"text,omitempty"`
|
|
Thought string `json:"thought,omitempty"`
|
|
} `json:"parts"`
|
|
} `json:"content"`
|
|
FinishReason string `json:"finishReason"`
|
|
} `json:"candidates"`
|
|
UsageMetadata struct {
|
|
PromptTokenCount uint32 `json:"promptTokenCount"`
|
|
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
|
TotalTokenCount uint32 `json:"totalTokenCount"`
|
|
} `json:"usageMetadata"`
|
|
}
|
|
|
|
if err := dec.Decode(&geminiChunk); err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(geminiChunk.Candidates) > 0 || geminiChunk.UsageMetadata.TotalTokenCount > 0 {
|
|
content := ""
|
|
var reasoning *string
|
|
if len(geminiChunk.Candidates) > 0 {
|
|
for _, p := range geminiChunk.Candidates[0].Content.Parts {
|
|
if p.Text != "" {
|
|
content += p.Text
|
|
}
|
|
if p.Thought != "" {
|
|
if reasoning == nil {
|
|
reasoning = new(string)
|
|
}
|
|
*reasoning += p.Thought
|
|
}
|
|
}
|
|
}
|
|
|
|
var finishReason *string
|
|
if len(geminiChunk.Candidates) > 0 {
|
|
fr := strings.ToLower(geminiChunk.Candidates[0].FinishReason)
|
|
finishReason = &fr
|
|
}
|
|
|
|
ch <- &models.ChatCompletionStreamResponse{
|
|
ID: "gemini-stream",
|
|
Object: "chat.completion.chunk",
|
|
Created: 0,
|
|
Model: model,
|
|
Choices: []models.ChatStreamChoice{
|
|
{
|
|
Index: 0,
|
|
Delta: models.ChatStreamDelta{
|
|
Content: &content,
|
|
ReasoningContent: reasoning,
|
|
},
|
|
FinishReason: finishReason,
|
|
},
|
|
},
|
|
Usage: &models.Usage{
|
|
PromptTokens: geminiChunk.UsageMetadata.PromptTokenCount,
|
|
CompletionTokens: geminiChunk.UsageMetadata.CandidatesTokenCount,
|
|
TotalTokens: geminiChunk.UsageMetadata.TotalTokenCount,
|
|
},
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|