feat: migrate backend from rust to go
This commit replaces the Axum/Rust backend with a Gin/Go implementation. The original Rust code has been archived in the 'rust' branch.
This commit is contained in:
254
internal/providers/gemini.go
Normal file
254
internal/providers/gemini.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"llm-proxy/internal/config"
|
||||
"llm-proxy/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type GeminiProvider struct {
|
||||
client *resty.Client
|
||||
config config.GeminiConfig
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewGeminiProvider(cfg config.GeminiConfig, apiKey string) *GeminiProvider {
|
||||
return &GeminiProvider{
|
||||
client: resty.New(),
|
||||
config: cfg,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) Name() string {
|
||||
return "gemini"
|
||||
}
|
||||
|
||||
type GeminiRequest struct {
|
||||
Contents []GeminiContent `json:"contents"`
|
||||
}
|
||||
|
||||
type GeminiContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type GeminiFunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Args json.RawMessage `json:"args"`
|
||||
}
|
||||
|
||||
type GeminiFunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response json.RawMessage `json:"response"`
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
// Gemini mapping
|
||||
var contents []GeminiContent
|
||||
for _, msg := range req.Messages {
|
||||
role := "user"
|
||||
if msg.Role == "assistant" {
|
||||
role = "model"
|
||||
} else if msg.Role == "tool" {
|
||||
role = "user" // Tool results are user-side in Gemini
|
||||
}
|
||||
|
||||
var parts []GeminiPart
|
||||
|
||||
// Handle tool responses
|
||||
if msg.Role == "tool" {
|
||||
text := ""
|
||||
if len(msg.Content) > 0 {
|
||||
text = msg.Content[0].Text
|
||||
}
|
||||
|
||||
// Gemini expects functionResponse to be an object
|
||||
name := "unknown_function"
|
||||
if msg.Name != nil {
|
||||
name = *msg.Name
|
||||
}
|
||||
|
||||
parts = append(parts, GeminiPart{
|
||||
FunctionResponse: &GeminiFunctionResponse{
|
||||
Name: name,
|
||||
Response: json.RawMessage(text),
|
||||
},
|
||||
})
|
||||
} else {
|
||||
for _, cp := range msg.Content {
|
||||
if cp.Type == "text" {
|
||||
parts = append(parts, GeminiPart{Text: cp.Text})
|
||||
} else if cp.Image != nil {
|
||||
base64Data, mimeType, _ := cp.Image.ToBase64()
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: base64Data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Handle assistant tool calls
|
||||
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
parts = append(parts, GeminiPart{
|
||||
FunctionCall: &GeminiFunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Args: json.RawMessage(tc.Function.Arguments),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
contents = append(contents, GeminiContent{
|
||||
Role: role,
|
||||
Parts: parts,
|
||||
})
|
||||
}
|
||||
|
||||
body := GeminiRequest{
|
||||
Contents: contents,
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetBody(body).
|
||||
Post(url)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
// Parse Gemini response and convert to OpenAI format
|
||||
var geminiResp struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata struct {
|
||||
PromptTokenCount uint32 `json:"promptTokenCount"`
|
||||
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
||||
TotalTokenCount uint32 `json:"totalTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(resp.Body(), &geminiResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(geminiResp.Candidates) == 0 {
|
||||
return nil, fmt.Errorf("no candidates in Gemini response")
|
||||
}
|
||||
|
||||
content := ""
|
||||
for _, p := range geminiResp.Candidates[0].Content.Parts {
|
||||
content += p.Text
|
||||
}
|
||||
|
||||
openAIResp := &models.ChatCompletionResponse{
|
||||
ID: "gemini-" + req.Model,
|
||||
Object: "chat.completion",
|
||||
Created: 0, // Should be current timestamp
|
||||
Model: req.Model,
|
||||
Choices: []models.ChatChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: models.ChatMessage{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: &geminiResp.Candidates[0].FinishReason,
|
||||
},
|
||||
},
|
||||
Usage: &models.Usage{
|
||||
PromptTokens: geminiResp.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResp.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiResp.UsageMetadata.TotalTokenCount,
|
||||
},
|
||||
}
|
||||
|
||||
return openAIResp, nil
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
// Simplified Gemini mapping
|
||||
var contents []GeminiContent
|
||||
for _, msg := range req.Messages {
|
||||
role := "user"
|
||||
if msg.Role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
|
||||
var parts []GeminiPart
|
||||
for _, p := range msg.Content {
|
||||
parts = append(parts, GeminiPart{Text: p.Text})
|
||||
}
|
||||
|
||||
contents = append(contents, GeminiContent{
|
||||
Role: role,
|
||||
Parts: parts,
|
||||
})
|
||||
}
|
||||
|
||||
body := GeminiRequest{
|
||||
Contents: contents,
|
||||
}
|
||||
|
||||
// Use streamGenerateContent for streaming
|
||||
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(url)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamGemini(resp.RawBody(), ch, req.Model)
|
||||
if err != nil {
|
||||
fmt.Printf("Gemini Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
Reference in New Issue
Block a user