Updated all naming from LLM Proxy to GopherGate. Implemented new CSS-based branding and updated Go module/binary naming.
218 lines
5.6 KiB
Go
218 lines
5.6 KiB
Go
package providers
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
|
|
"gophergate/internal/config"
|
|
"gophergate/internal/models"
|
|
"github.com/go-resty/resty/v2"
|
|
)
|
|
|
|
type DeepSeekProvider struct {
|
|
client *resty.Client
|
|
config config.DeepSeekConfig
|
|
apiKey string
|
|
}
|
|
|
|
func NewDeepSeekProvider(cfg config.DeepSeekConfig, apiKey string) *DeepSeekProvider {
|
|
return &DeepSeekProvider{
|
|
client: resty.New(),
|
|
config: cfg,
|
|
apiKey: apiKey,
|
|
}
|
|
}
|
|
|
|
func (p *DeepSeekProvider) Name() string {
|
|
return "deepseek"
|
|
}
|
|
|
|
type deepSeekUsage struct {
|
|
PromptTokens uint32 `json:"prompt_tokens"`
|
|
CompletionTokens uint32 `json:"completion_tokens"`
|
|
TotalTokens uint32 `json:"total_tokens"`
|
|
PromptCacheHitTokens uint32 `json:"prompt_cache_hit_tokens"`
|
|
PromptCacheMissTokens uint32 `json:"prompt_cache_miss_tokens"`
|
|
CompletionTokensDetails *struct {
|
|
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
|
} `json:"completion_tokens_details"`
|
|
}
|
|
|
|
func (u *deepSeekUsage) ToUnified() *models.Usage {
|
|
usage := &models.Usage{
|
|
PromptTokens: u.PromptTokens,
|
|
CompletionTokens: u.CompletionTokens,
|
|
TotalTokens: u.TotalTokens,
|
|
}
|
|
if u.PromptCacheHitTokens > 0 {
|
|
usage.CacheReadTokens = &u.PromptCacheHitTokens
|
|
}
|
|
if u.CompletionTokensDetails != nil && u.CompletionTokensDetails.ReasoningTokens > 0 {
|
|
usage.ReasoningTokens = &u.CompletionTokensDetails.ReasoningTokens
|
|
}
|
|
return usage
|
|
}
|
|
|
|
func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
|
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
|
}
|
|
|
|
body := BuildOpenAIBody(req, messagesJSON, false)
|
|
|
|
// Sanitize for deepseek-reasoner
|
|
if req.Model == "deepseek-reasoner" {
|
|
delete(body, "temperature")
|
|
delete(body, "top_p")
|
|
delete(body, "presence_penalty")
|
|
delete(body, "frequency_penalty")
|
|
|
|
if msgs, ok := body["messages"].([]interface{}); ok {
|
|
for _, m := range msgs {
|
|
if msg, ok := m.(map[string]interface{}); ok {
|
|
if msg["role"] == "assistant" {
|
|
if msg["reasoning_content"] == nil {
|
|
msg["reasoning_content"] = " "
|
|
}
|
|
if msg["content"] == nil || msg["content"] == "" {
|
|
msg["content"] = ""
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
resp, err := p.client.R().
|
|
SetContext(ctx).
|
|
SetHeader("Authorization", "Bearer "+p.apiKey).
|
|
SetBody(body).
|
|
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("request failed: %w", err)
|
|
}
|
|
|
|
if !resp.IsSuccess() {
|
|
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
|
|
}
|
|
|
|
var respJSON map[string]interface{}
|
|
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
|
return nil, fmt.Errorf("failed to parse response: %w", err)
|
|
}
|
|
|
|
result, err := ParseOpenAIResponse(respJSON, req.Model)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Fix usage for DeepSeek specifically if details were missing in ParseOpenAIResponse
|
|
if usageData, ok := respJSON["usage"]; ok {
|
|
var dUsage deepSeekUsage
|
|
usageBytes, _ := json.Marshal(usageData)
|
|
if err := json.Unmarshal(usageBytes, &dUsage); err == nil {
|
|
result.Usage = dUsage.ToUnified()
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
|
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
|
}
|
|
|
|
body := BuildOpenAIBody(req, messagesJSON, true)
|
|
|
|
// Sanitize for deepseek-reasoner
|
|
if req.Model == "deepseek-reasoner" {
|
|
delete(body, "temperature")
|
|
delete(body, "top_p")
|
|
delete(body, "presence_penalty")
|
|
delete(body, "frequency_penalty")
|
|
|
|
if msgs, ok := body["messages"].([]interface{}); ok {
|
|
for _, m := range msgs {
|
|
if msg, ok := m.(map[string]interface{}); ok {
|
|
if msg["role"] == "assistant" {
|
|
if msg["reasoning_content"] == nil {
|
|
msg["reasoning_content"] = " "
|
|
}
|
|
if msg["content"] == nil || msg["content"] == "" {
|
|
msg["content"] = ""
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
resp, err := p.client.R().
|
|
SetContext(ctx).
|
|
SetHeader("Authorization", "Bearer "+p.apiKey).
|
|
SetBody(body).
|
|
SetDoNotParseResponse(true).
|
|
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("request failed: %w", err)
|
|
}
|
|
|
|
if !resp.IsSuccess() {
|
|
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
|
|
}
|
|
|
|
ch := make(chan *models.ChatCompletionStreamResponse)
|
|
|
|
go func() {
|
|
defer close(ch)
|
|
// Custom scanner loop to handle DeepSeek specific usage in chunks
|
|
err := StreamDeepSeek(resp.RawBody(), ch)
|
|
if err != nil {
|
|
fmt.Printf("DeepSeek Stream error: %v\n", err)
|
|
}
|
|
}()
|
|
|
|
return ch, nil
|
|
}
|
|
|
|
func StreamDeepSeek(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse) error {
|
|
defer ctx.Close()
|
|
scanner := bufio.NewScanner(ctx)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if line == "" || !strings.HasPrefix(line, "data: ") {
|
|
continue
|
|
}
|
|
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
if data == "[DONE]" {
|
|
break
|
|
}
|
|
|
|
var chunk models.ChatCompletionStreamResponse
|
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
|
continue
|
|
}
|
|
|
|
// Fix DeepSeek specific usage in stream
|
|
var rawChunk struct {
|
|
Usage *deepSeekUsage `json:"usage"`
|
|
}
|
|
if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil {
|
|
chunk.Usage = rawChunk.Usage.ToUnified()
|
|
}
|
|
|
|
ch <- &chunk
|
|
}
|
|
return scanner.Err()
|
|
}
|