feat: migrate backend from rust to go
This commit replaces the Axum/Rust backend with a Gin/Go implementation. The original Rust code has been archived in the 'rust' branch.
This commit is contained in:
143
internal/providers/deepseek.go
Normal file
143
internal/providers/deepseek.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"llm-proxy/internal/config"
|
||||
"llm-proxy/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type DeepSeekProvider struct {
|
||||
client *resty.Client
|
||||
config config.DeepSeekConfig
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewDeepSeekProvider(cfg config.DeepSeekConfig, apiKey string) *DeepSeekProvider {
|
||||
return &DeepSeekProvider{
|
||||
client: resty.New(),
|
||||
config: cfg,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) Name() string {
|
||||
return "deepseek"
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||
|
||||
// Sanitize for deepseek-reasoner
|
||||
if req.Model == "deepseek-reasoner" {
|
||||
delete(body, "temperature")
|
||||
delete(body, "top_p")
|
||||
delete(body, "presence_penalty")
|
||||
delete(body, "frequency_penalty")
|
||||
|
||||
// Ensure assistant messages have content and reasoning_content
|
||||
if msgs, ok := body["messages"].([]interface{}); ok {
|
||||
for _, m := range msgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
if msg["role"] == "assistant" {
|
||||
if msg["reasoning_content"] == nil {
|
||||
msg["reasoning_content"] = " "
|
||||
}
|
||||
if msg["content"] == nil || msg["content"] == "" {
|
||||
msg["content"] = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return ParseOpenAIResponse(respJSON, req.Model)
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||
|
||||
// Sanitize for deepseek-reasoner
|
||||
if req.Model == "deepseek-reasoner" {
|
||||
delete(body, "temperature")
|
||||
delete(body, "top_p")
|
||||
delete(body, "presence_penalty")
|
||||
delete(body, "frequency_penalty")
|
||||
|
||||
// Ensure assistant messages have content and reasoning_content
|
||||
if msgs, ok := body["messages"].([]interface{}); ok {
|
||||
for _, m := range msgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
if msg["role"] == "assistant" {
|
||||
if msg["reasoning_content"] == nil {
|
||||
msg["reasoning_content"] = " "
|
||||
}
|
||||
if msg["content"] == nil || msg["content"] == "" {
|
||||
msg["content"] = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamOpenAI(resp.RawBody(), ch)
|
||||
if err != nil {
|
||||
fmt.Printf("DeepSeek Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
254
internal/providers/gemini.go
Normal file
254
internal/providers/gemini.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"llm-proxy/internal/config"
|
||||
"llm-proxy/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type GeminiProvider struct {
|
||||
client *resty.Client
|
||||
config config.GeminiConfig
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewGeminiProvider(cfg config.GeminiConfig, apiKey string) *GeminiProvider {
|
||||
return &GeminiProvider{
|
||||
client: resty.New(),
|
||||
config: cfg,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) Name() string {
|
||||
return "gemini"
|
||||
}
|
||||
|
||||
type GeminiRequest struct {
|
||||
Contents []GeminiContent `json:"contents"`
|
||||
}
|
||||
|
||||
type GeminiContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type GeminiFunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Args json.RawMessage `json:"args"`
|
||||
}
|
||||
|
||||
type GeminiFunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response json.RawMessage `json:"response"`
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
// Gemini mapping
|
||||
var contents []GeminiContent
|
||||
for _, msg := range req.Messages {
|
||||
role := "user"
|
||||
if msg.Role == "assistant" {
|
||||
role = "model"
|
||||
} else if msg.Role == "tool" {
|
||||
role = "user" // Tool results are user-side in Gemini
|
||||
}
|
||||
|
||||
var parts []GeminiPart
|
||||
|
||||
// Handle tool responses
|
||||
if msg.Role == "tool" {
|
||||
text := ""
|
||||
if len(msg.Content) > 0 {
|
||||
text = msg.Content[0].Text
|
||||
}
|
||||
|
||||
// Gemini expects functionResponse to be an object
|
||||
name := "unknown_function"
|
||||
if msg.Name != nil {
|
||||
name = *msg.Name
|
||||
}
|
||||
|
||||
parts = append(parts, GeminiPart{
|
||||
FunctionResponse: &GeminiFunctionResponse{
|
||||
Name: name,
|
||||
Response: json.RawMessage(text),
|
||||
},
|
||||
})
|
||||
} else {
|
||||
for _, cp := range msg.Content {
|
||||
if cp.Type == "text" {
|
||||
parts = append(parts, GeminiPart{Text: cp.Text})
|
||||
} else if cp.Image != nil {
|
||||
base64Data, mimeType, _ := cp.Image.ToBase64()
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: base64Data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Handle assistant tool calls
|
||||
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
parts = append(parts, GeminiPart{
|
||||
FunctionCall: &GeminiFunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Args: json.RawMessage(tc.Function.Arguments),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
contents = append(contents, GeminiContent{
|
||||
Role: role,
|
||||
Parts: parts,
|
||||
})
|
||||
}
|
||||
|
||||
body := GeminiRequest{
|
||||
Contents: contents,
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetBody(body).
|
||||
Post(url)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
// Parse Gemini response and convert to OpenAI format
|
||||
var geminiResp struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata struct {
|
||||
PromptTokenCount uint32 `json:"promptTokenCount"`
|
||||
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
||||
TotalTokenCount uint32 `json:"totalTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(resp.Body(), &geminiResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(geminiResp.Candidates) == 0 {
|
||||
return nil, fmt.Errorf("no candidates in Gemini response")
|
||||
}
|
||||
|
||||
content := ""
|
||||
for _, p := range geminiResp.Candidates[0].Content.Parts {
|
||||
content += p.Text
|
||||
}
|
||||
|
||||
openAIResp := &models.ChatCompletionResponse{
|
||||
ID: "gemini-" + req.Model,
|
||||
Object: "chat.completion",
|
||||
Created: 0, // Should be current timestamp
|
||||
Model: req.Model,
|
||||
Choices: []models.ChatChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: models.ChatMessage{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: &geminiResp.Candidates[0].FinishReason,
|
||||
},
|
||||
},
|
||||
Usage: &models.Usage{
|
||||
PromptTokens: geminiResp.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResp.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiResp.UsageMetadata.TotalTokenCount,
|
||||
},
|
||||
}
|
||||
|
||||
return openAIResp, nil
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
// Simplified Gemini mapping
|
||||
var contents []GeminiContent
|
||||
for _, msg := range req.Messages {
|
||||
role := "user"
|
||||
if msg.Role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
|
||||
var parts []GeminiPart
|
||||
for _, p := range msg.Content {
|
||||
parts = append(parts, GeminiPart{Text: p.Text})
|
||||
}
|
||||
|
||||
contents = append(contents, GeminiContent{
|
||||
Role: role,
|
||||
Parts: parts,
|
||||
})
|
||||
}
|
||||
|
||||
body := GeminiRequest{
|
||||
Contents: contents,
|
||||
}
|
||||
|
||||
// Use streamGenerateContent for streaming
|
||||
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(url)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamGemini(resp.RawBody(), ch, req.Model)
|
||||
if err != nil {
|
||||
fmt.Printf("Gemini Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
95
internal/providers/grok.go
Normal file
95
internal/providers/grok.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"llm-proxy/internal/config"
|
||||
"llm-proxy/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type GrokProvider struct {
|
||||
client *resty.Client
|
||||
config config.GrokConfig
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewGrokProvider(cfg config.GrokConfig, apiKey string) *GrokProvider {
|
||||
return &GrokProvider{
|
||||
client: resty.New(),
|
||||
config: cfg,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GrokProvider) Name() string {
|
||||
return "grok"
|
||||
}
|
||||
|
||||
func (p *GrokProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return ParseOpenAIResponse(respJSON, req.Model)
|
||||
}
|
||||
|
||||
func (p *GrokProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamOpenAI(resp.RawBody(), ch)
|
||||
if err != nil {
|
||||
fmt.Printf("Grok Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
259
internal/providers/helpers.go
Normal file
259
internal/providers/helpers.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"llm-proxy/internal/models"
|
||||
)
|
||||
|
||||
// MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images.
|
||||
func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) {
|
||||
var result []interface{}
|
||||
for _, m := range messages {
|
||||
if m.Role == "tool" {
|
||||
text := ""
|
||||
if len(m.Content) > 0 {
|
||||
text = m.Content[0].Text
|
||||
}
|
||||
msg := map[string]interface{}{
|
||||
"role": "tool",
|
||||
"content": text,
|
||||
}
|
||||
if m.ToolCallID != nil {
|
||||
id := *m.ToolCallID
|
||||
if len(id) > 40 {
|
||||
id = id[:40]
|
||||
}
|
||||
msg["tool_call_id"] = id
|
||||
}
|
||||
if m.Name != nil {
|
||||
msg["name"] = *m.Name
|
||||
}
|
||||
result = append(result, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
var parts []interface{}
|
||||
for _, p := range m.Content {
|
||||
if p.Type == "text" {
|
||||
parts = append(parts, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": p.Text,
|
||||
})
|
||||
} else if p.Image != nil {
|
||||
base64Data, mimeType, err := p.Image.ToBase64()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert image to base64: %w", err)
|
||||
}
|
||||
parts = append(parts, map[string]interface{}{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]interface{}{
|
||||
"url": fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
msg := map[string]interface{}{
|
||||
"role": m.Role,
|
||||
"content": parts,
|
||||
}
|
||||
|
||||
if m.ReasoningContent != nil {
|
||||
msg["reasoning_content"] = *m.ReasoningContent
|
||||
}
|
||||
|
||||
if len(m.ToolCalls) > 0 {
|
||||
sanitizedCalls := make([]models.ToolCall, len(m.ToolCalls))
|
||||
copy(sanitizedCalls, m.ToolCalls)
|
||||
for i := range sanitizedCalls {
|
||||
if len(sanitizedCalls[i].ID) > 40 {
|
||||
sanitizedCalls[i].ID = sanitizedCalls[i].ID[:40]
|
||||
}
|
||||
}
|
||||
msg["tool_calls"] = sanitizedCalls
|
||||
if len(parts) == 0 {
|
||||
msg["content"] = ""
|
||||
}
|
||||
}
|
||||
|
||||
if m.Name != nil {
|
||||
msg["name"] = *m.Name
|
||||
}
|
||||
|
||||
result = append(result, msg)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{}, stream bool) map[string]interface{} {
|
||||
body := map[string]interface{}{
|
||||
"model": request.Model,
|
||||
"messages": messagesJSON,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
if stream {
|
||||
body["stream_options"] = map[string]interface{}{
|
||||
"include_usage": true,
|
||||
}
|
||||
}
|
||||
|
||||
if request.Temperature != nil {
|
||||
body["temperature"] = *request.Temperature
|
||||
}
|
||||
if request.MaxTokens != nil {
|
||||
body["max_tokens"] = *request.MaxTokens
|
||||
}
|
||||
if len(request.Tools) > 0 {
|
||||
body["tools"] = request.Tools
|
||||
}
|
||||
if request.ToolChoice != nil {
|
||||
var toolChoice interface{}
|
||||
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
|
||||
body["tool_choice"] = toolChoice
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models.ChatCompletionResponse, error) {
|
||||
data, err := json.Marshal(respJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp models.ChatCompletionResponse
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// Streaming support
|
||||
|
||||
func ParseOpenAIStreamChunk(line string) (*models.ChatCompletionStreamResponse, bool, error) {
|
||||
if line == "" {
|
||||
return nil, false, nil
|
||||
}
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
var chunk models.ChatCompletionStreamResponse
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
return nil, false, fmt.Errorf("failed to unmarshal stream chunk: %w", err)
|
||||
}
|
||||
|
||||
return &chunk, false, nil
|
||||
}
|
||||
|
||||
func StreamOpenAI(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse) error {
|
||||
defer ctx.Close()
|
||||
scanner := bufio.NewScanner(ctx)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
chunk, done, err := ParseOpenAIStreamChunk(line)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
if chunk != nil {
|
||||
ch <- chunk
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error {
|
||||
defer ctx.Close()
|
||||
|
||||
dec := json.NewDecoder(ctx)
|
||||
|
||||
t, err := dec.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if delim, ok := t.(json.Delim); ok && delim == '[' {
|
||||
for dec.More() {
|
||||
var geminiChunk struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought string `json:"thought,omitempty"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata struct {
|
||||
PromptTokenCount uint32 `json:"promptTokenCount"`
|
||||
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
||||
TotalTokenCount uint32 `json:"totalTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
if err := dec.Decode(&geminiChunk); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(geminiChunk.Candidates) > 0 {
|
||||
content := ""
|
||||
var reasoning *string
|
||||
for _, p := range geminiChunk.Candidates[0].Content.Parts {
|
||||
if p.Text != "" {
|
||||
content += p.Text
|
||||
}
|
||||
if p.Thought != "" {
|
||||
if reasoning == nil {
|
||||
reasoning = new(string)
|
||||
}
|
||||
*reasoning += p.Thought
|
||||
}
|
||||
}
|
||||
|
||||
finishReason := strings.ToLower(geminiChunk.Candidates[0].FinishReason)
|
||||
if finishReason == "stop" {
|
||||
finishReason = "stop"
|
||||
}
|
||||
|
||||
ch <- &models.ChatCompletionStreamResponse{
|
||||
ID: "gemini-stream",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 0,
|
||||
Model: model,
|
||||
Choices: []models.ChatStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: models.ChatStreamDelta{
|
||||
Content: &content,
|
||||
ReasoningContent: reasoning,
|
||||
},
|
||||
FinishReason: &finishReason,
|
||||
},
|
||||
},
|
||||
Usage: &models.Usage{
|
||||
PromptTokens: geminiChunk.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiChunk.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiChunk.UsageMetadata.TotalTokenCount,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
113
internal/providers/openai.go
Normal file
113
internal/providers/openai.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"llm-proxy/internal/config"
|
||||
"llm-proxy/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type OpenAIProvider struct {
|
||||
client *resty.Client
|
||||
config config.OpenAIConfig
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewOpenAIProvider(cfg config.OpenAIConfig, apiKey string) *OpenAIProvider {
|
||||
return &OpenAIProvider{
|
||||
client: resty.New(),
|
||||
config: cfg,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) Name() string {
|
||||
return "openai"
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||
|
||||
// Transition: Newer models require max_completion_tokens
|
||||
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
|
||||
if maxTokens, ok := body["max_tokens"]; ok {
|
||||
delete(body, "max_tokens")
|
||||
body["max_completion_tokens"] = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return ParseOpenAIResponse(respJSON, req.Model)
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||
|
||||
// Transition: Newer models require max_completion_tokens
|
||||
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
|
||||
if maxTokens, ok := body["max_tokens"]; ok {
|
||||
delete(body, "max_tokens")
|
||||
body["max_completion_tokens"] = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamOpenAI(resp.RawBody(), ch)
|
||||
if err != nil {
|
||||
// In a real app, you might want to send an error chunk or log it
|
||||
fmt.Printf("Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
13
internal/providers/provider.go
Normal file
13
internal/providers/provider.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"llm-proxy/internal/models"
|
||||
)
|
||||
|
||||
type Provider interface {
|
||||
Name() string
|
||||
ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error)
|
||||
ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error)
|
||||
}
|
||||
Reference in New Issue
Block a user