package providers import ( "context" "encoding/json" "fmt" "llm-proxy/internal/config" "llm-proxy/internal/models" "github.com/go-resty/resty/v2" ) type GeminiProvider struct { client *resty.Client config config.GeminiConfig apiKey string } func NewGeminiProvider(cfg config.GeminiConfig, apiKey string) *GeminiProvider { return &GeminiProvider{ client: resty.New(), config: cfg, apiKey: apiKey, } } func (p *GeminiProvider) Name() string { return "gemini" } type GeminiRequest struct { Contents []GeminiContent `json:"contents"` } type GeminiContent struct { Role string `json:"role,omitempty"` Parts []GeminiPart `json:"parts"` } type GeminiPart struct { Text string `json:"text,omitempty"` InlineData *GeminiInlineData `json:"inlineData,omitempty"` FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"` FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"` } type GeminiInlineData struct { MimeType string `json:"mimeType"` Data string `json:"data"` } type GeminiFunctionCall struct { Name string `json:"name"` Args json.RawMessage `json:"args"` } type GeminiFunctionResponse struct { Name string `json:"name"` Response json.RawMessage `json:"response"` } func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) { // Gemini mapping var contents []GeminiContent for _, msg := range req.Messages { role := "user" if msg.Role == "assistant" { role = "model" } else if msg.Role == "tool" { role = "user" // Tool results are user-side in Gemini } var parts []GeminiPart // Handle tool responses if msg.Role == "tool" { text := "" if len(msg.Content) > 0 { text = msg.Content[0].Text } // Gemini expects functionResponse to be an object name := "unknown_function" if msg.Name != nil { name = *msg.Name } parts = append(parts, GeminiPart{ FunctionResponse: &GeminiFunctionResponse{ Name: name, Response: json.RawMessage(text), }, }) } else { for _, cp := range msg.Content { if cp.Type == "text" { parts = append(parts, GeminiPart{Text: cp.Text}) } else if cp.Image != nil { base64Data, mimeType, _ := cp.Image.ToBase64() parts = append(parts, GeminiPart{ InlineData: &GeminiInlineData{ MimeType: mimeType, Data: base64Data, }, }) } } // Handle assistant tool calls if msg.Role == "assistant" && len(msg.ToolCalls) > 0 { for _, tc := range msg.ToolCalls { parts = append(parts, GeminiPart{ FunctionCall: &GeminiFunctionCall{ Name: tc.Function.Name, Args: json.RawMessage(tc.Function.Arguments), }, }) } } } contents = append(contents, GeminiContent{ Role: role, Parts: parts, }) } body := GeminiRequest{ Contents: contents, } url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey) resp, err := p.client.R(). SetContext(ctx). SetBody(body). Post(url) if err != nil { return nil, fmt.Errorf("request failed: %w", err) } if !resp.IsSuccess() { return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String()) } // Parse Gemini response and convert to OpenAI format var geminiResp struct { Candidates []struct { Content struct { Parts []struct { Text string `json:"text"` } `json:"parts"` } `json:"content"` FinishReason string `json:"finishReason"` } `json:"candidates"` UsageMetadata struct { PromptTokenCount uint32 `json:"promptTokenCount"` CandidatesTokenCount uint32 `json:"candidatesTokenCount"` TotalTokenCount uint32 `json:"totalTokenCount"` } `json:"usageMetadata"` } if err := json.Unmarshal(resp.Body(), &geminiResp); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } if len(geminiResp.Candidates) == 0 { return nil, fmt.Errorf("no candidates in Gemini response") } content := "" for _, p := range geminiResp.Candidates[0].Content.Parts { content += p.Text } openAIResp := &models.ChatCompletionResponse{ ID: "gemini-" + req.Model, Object: "chat.completion", Created: 0, // Should be current timestamp Model: req.Model, Choices: []models.ChatChoice{ { Index: 0, Message: models.ChatMessage{ Role: "assistant", Content: content, }, FinishReason: &geminiResp.Candidates[0].FinishReason, }, }, Usage: &models.Usage{ PromptTokens: geminiResp.UsageMetadata.PromptTokenCount, CompletionTokens: geminiResp.UsageMetadata.CandidatesTokenCount, TotalTokens: geminiResp.UsageMetadata.TotalTokenCount, }, } return openAIResp, nil } func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) { // Simplified Gemini mapping var contents []GeminiContent for _, msg := range req.Messages { role := "user" if msg.Role == "assistant" { role = "model" } var parts []GeminiPart for _, p := range msg.Content { parts = append(parts, GeminiPart{Text: p.Text}) } contents = append(contents, GeminiContent{ Role: role, Parts: parts, }) } body := GeminiRequest{ Contents: contents, } // Use streamGenerateContent for streaming url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey) resp, err := p.client.R(). SetContext(ctx). SetBody(body). SetDoNotParseResponse(true). Post(url) if err != nil { return nil, fmt.Errorf("request failed: %w", err) } if !resp.IsSuccess() { return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String()) } ch := make(chan *models.ChatCompletionStreamResponse) go func() { defer close(ch) err := StreamGemini(resp.RawBody(), ch, req.Model) if err != nil { fmt.Printf("Gemini Stream error: %v\n", err) } }() return ch, nil }