Updated all naming from LLM Proxy to GopherGate. Implemented new CSS-based branding and updated Go module/binary naming.
255 lines
6.0 KiB
Go
255 lines
6.0 KiB
Go
package providers
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
|
|
"gophergate/internal/config"
|
|
"gophergate/internal/models"
|
|
"github.com/go-resty/resty/v2"
|
|
)
|
|
|
|
type GeminiProvider struct {
|
|
client *resty.Client
|
|
config config.GeminiConfig
|
|
apiKey string
|
|
}
|
|
|
|
func NewGeminiProvider(cfg config.GeminiConfig, apiKey string) *GeminiProvider {
|
|
return &GeminiProvider{
|
|
client: resty.New(),
|
|
config: cfg,
|
|
apiKey: apiKey,
|
|
}
|
|
}
|
|
|
|
func (p *GeminiProvider) Name() string {
|
|
return "gemini"
|
|
}
|
|
|
|
type GeminiRequest struct {
|
|
Contents []GeminiContent `json:"contents"`
|
|
}
|
|
|
|
type GeminiContent struct {
|
|
Role string `json:"role,omitempty"`
|
|
Parts []GeminiPart `json:"parts"`
|
|
}
|
|
|
|
type GeminiPart struct {
|
|
Text string `json:"text,omitempty"`
|
|
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
|
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
|
|
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
|
}
|
|
|
|
type GeminiInlineData struct {
|
|
MimeType string `json:"mimeType"`
|
|
Data string `json:"data"`
|
|
}
|
|
|
|
type GeminiFunctionCall struct {
|
|
Name string `json:"name"`
|
|
Args json.RawMessage `json:"args"`
|
|
}
|
|
|
|
type GeminiFunctionResponse struct {
|
|
Name string `json:"name"`
|
|
Response json.RawMessage `json:"response"`
|
|
}
|
|
|
|
func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
|
// Gemini mapping
|
|
var contents []GeminiContent
|
|
for _, msg := range req.Messages {
|
|
role := "user"
|
|
if msg.Role == "assistant" {
|
|
role = "model"
|
|
} else if msg.Role == "tool" {
|
|
role = "user" // Tool results are user-side in Gemini
|
|
}
|
|
|
|
var parts []GeminiPart
|
|
|
|
// Handle tool responses
|
|
if msg.Role == "tool" {
|
|
text := ""
|
|
if len(msg.Content) > 0 {
|
|
text = msg.Content[0].Text
|
|
}
|
|
|
|
// Gemini expects functionResponse to be an object
|
|
name := "unknown_function"
|
|
if msg.Name != nil {
|
|
name = *msg.Name
|
|
}
|
|
|
|
parts = append(parts, GeminiPart{
|
|
FunctionResponse: &GeminiFunctionResponse{
|
|
Name: name,
|
|
Response: json.RawMessage(text),
|
|
},
|
|
})
|
|
} else {
|
|
for _, cp := range msg.Content {
|
|
if cp.Type == "text" {
|
|
parts = append(parts, GeminiPart{Text: cp.Text})
|
|
} else if cp.Image != nil {
|
|
base64Data, mimeType, _ := cp.Image.ToBase64()
|
|
parts = append(parts, GeminiPart{
|
|
InlineData: &GeminiInlineData{
|
|
MimeType: mimeType,
|
|
Data: base64Data,
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
// Handle assistant tool calls
|
|
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
|
for _, tc := range msg.ToolCalls {
|
|
parts = append(parts, GeminiPart{
|
|
FunctionCall: &GeminiFunctionCall{
|
|
Name: tc.Function.Name,
|
|
Args: json.RawMessage(tc.Function.Arguments),
|
|
},
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
contents = append(contents, GeminiContent{
|
|
Role: role,
|
|
Parts: parts,
|
|
})
|
|
}
|
|
|
|
body := GeminiRequest{
|
|
Contents: contents,
|
|
}
|
|
|
|
url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
|
|
|
|
resp, err := p.client.R().
|
|
SetContext(ctx).
|
|
SetBody(body).
|
|
Post(url)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("request failed: %w", err)
|
|
}
|
|
|
|
if !resp.IsSuccess() {
|
|
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
|
}
|
|
|
|
// Parse Gemini response and convert to OpenAI format
|
|
var geminiResp struct {
|
|
Candidates []struct {
|
|
Content struct {
|
|
Parts []struct {
|
|
Text string `json:"text"`
|
|
} `json:"parts"`
|
|
} `json:"content"`
|
|
FinishReason string `json:"finishReason"`
|
|
} `json:"candidates"`
|
|
UsageMetadata struct {
|
|
PromptTokenCount uint32 `json:"promptTokenCount"`
|
|
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
|
TotalTokenCount uint32 `json:"totalTokenCount"`
|
|
} `json:"usageMetadata"`
|
|
}
|
|
|
|
if err := json.Unmarshal(resp.Body(), &geminiResp); err != nil {
|
|
return nil, fmt.Errorf("failed to parse response: %w", err)
|
|
}
|
|
|
|
if len(geminiResp.Candidates) == 0 {
|
|
return nil, fmt.Errorf("no candidates in Gemini response")
|
|
}
|
|
|
|
content := ""
|
|
for _, p := range geminiResp.Candidates[0].Content.Parts {
|
|
content += p.Text
|
|
}
|
|
|
|
openAIResp := &models.ChatCompletionResponse{
|
|
ID: "gemini-" + req.Model,
|
|
Object: "chat.completion",
|
|
Created: 0, // Should be current timestamp
|
|
Model: req.Model,
|
|
Choices: []models.ChatChoice{
|
|
{
|
|
Index: 0,
|
|
Message: models.ChatMessage{
|
|
Role: "assistant",
|
|
Content: content,
|
|
},
|
|
FinishReason: &geminiResp.Candidates[0].FinishReason,
|
|
},
|
|
},
|
|
Usage: &models.Usage{
|
|
PromptTokens: geminiResp.UsageMetadata.PromptTokenCount,
|
|
CompletionTokens: geminiResp.UsageMetadata.CandidatesTokenCount,
|
|
TotalTokens: geminiResp.UsageMetadata.TotalTokenCount,
|
|
},
|
|
}
|
|
|
|
return openAIResp, nil
|
|
}
|
|
|
|
func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
|
// Simplified Gemini mapping
|
|
var contents []GeminiContent
|
|
for _, msg := range req.Messages {
|
|
role := "user"
|
|
if msg.Role == "assistant" {
|
|
role = "model"
|
|
}
|
|
|
|
var parts []GeminiPart
|
|
for _, p := range msg.Content {
|
|
parts = append(parts, GeminiPart{Text: p.Text})
|
|
}
|
|
|
|
contents = append(contents, GeminiContent{
|
|
Role: role,
|
|
Parts: parts,
|
|
})
|
|
}
|
|
|
|
body := GeminiRequest{
|
|
Contents: contents,
|
|
}
|
|
|
|
// Use streamGenerateContent for streaming
|
|
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
|
|
|
|
resp, err := p.client.R().
|
|
SetContext(ctx).
|
|
SetBody(body).
|
|
SetDoNotParseResponse(true).
|
|
Post(url)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("request failed: %w", err)
|
|
}
|
|
|
|
if !resp.IsSuccess() {
|
|
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
|
}
|
|
|
|
ch := make(chan *models.ChatCompletionStreamResponse)
|
|
|
|
go func() {
|
|
defer close(ch)
|
|
err := StreamGemini(resp.RawBody(), ch, req.Model)
|
|
if err != nil {
|
|
fmt.Printf("Gemini Stream error: %v\n", err)
|
|
}
|
|
}()
|
|
|
|
return ch, nil
|
|
}
|