package providers import ( "context" "encoding/json" "fmt" "strings" "time" "github.com/go-resty/resty/v2" "gophergate/internal/config" "gophergate/internal/models" ) type GeminiProvider struct { client *resty.Client config config.GeminiConfig apiKey string } func NewGeminiProvider(cfg config.GeminiConfig, apiKey string) *GeminiProvider { return &GeminiProvider{ client: resty.New().SetTimeout(30 * time.Second), config: cfg, apiKey: apiKey, } } func (p *GeminiProvider) Name() string { return "gemini" } type GeminiRequest struct { Contents []GeminiContent `json:"contents"` Tools []GeminiTool `json:"tools,omitempty"` GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"` } type GeminiTool struct { FunctionDeclarations []models.FunctionDef `json:"functionDeclarations"` } type GeminiGenerationConfig struct { Temperature *float32 `json:"temperature,omitempty"` TopP *float32 `json:"topP,omitempty"` TopK *int `json:"topK,omitempty"` MaxOutputTokens *int `json:"maxOutputTokens,omitempty"` StopSequences []string `json:"stopSequences,omitempty"` } type GeminiContent struct { Role string `json:"role,omitempty"` Parts []GeminiPart `json:"parts"` } type GeminiPart struct { Text string `json:"text,omitempty"` InlineData *GeminiInlineData `json:"inlineData,omitempty"` FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"` FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"` } type GeminiInlineData struct { MimeType string `json:"mimeType"` Data string `json:"data"` } type GeminiFunctionCall struct { Name string `json:"name"` Args json.RawMessage `json:"args"` } type GeminiFunctionResponse struct { Name string `json:"name"` Response json.RawMessage `json:"response"` } func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) { // Gemini mapping var contents []GeminiContent for i := 0; i < len(req.Messages); i++ { msg := req.Messages[i] if msg.Role == "assistant" && len(msg.ToolCalls) > 0 { // 1. Add the assistant (model) message with tool calls parts := []GeminiPart{} for _, cp := range msg.Content { if cp.Type == "text" && cp.Text != "" { parts = append(parts, GeminiPart{Text: cp.Text}) } } for _, tc := range msg.ToolCalls { parts = append(parts, GeminiPart{ FunctionCall: &GeminiFunctionCall{ Name: tc.Function.Name, Args: json.RawMessage(tc.Function.Arguments), }, }) } contents = append(contents, GeminiContent{Role: "model", Parts: parts}) // 2. The VERY NEXT message MUST be the "function" results for THESE EXACT calls. // Look ahead for tool messages. var functionParts []GeminiPart toolCallIDs := make(map[string]bool) for _, tc := range msg.ToolCalls { toolCallIDs[tc.ID] = true } // We need to find tool messages that correspond to these calls. // In many patterns, they follow immediately. j := i + 1 foundAny := false for j < len(req.Messages) && req.Messages[j].Role == "tool" { m := req.Messages[j] // Try to match by ID or just take them in order if IDs are missing/mismatched // Gemini is strict: you must respond to EVERY call in the previous message. text := "" if len(m.Content) > 0 { text = m.Content[0].Text } name := "unknown_function" if m.Name != nil { name = *m.Name } var responseObj interface{} if err := json.Unmarshal([]byte(text), &responseObj); err != nil { responseObj = map[string]interface{}{"result": text} } respBytes, _ := json.Marshal(responseObj) functionParts = append(functionParts, GeminiPart{ FunctionResponse: &GeminiFunctionResponse{ Name: name, Response: json.RawMessage(respBytes), }, }) foundAny = true j++ } if foundAny { contents = append(contents, GeminiContent{Role: "function", Parts: functionParts}) i = j - 1 // Advance outer loop past the tool messages we consumed } else { // If no tool results found but assistant made calls, Gemini WILL error. // We should probably skip the calls or provide dummy results, // but usually this means the conversation is incomplete. // For now, don't add a "function" message if none found. } continue } // Standard message handling (System/User/Assistant without tools) role := "user" if msg.Role == "assistant" { role = "model" } else if msg.Role == "system" { role = "user" // Gemini uses 'user' for system prompts in some versions, or handles it via systemInstruction } else if msg.Role == "tool" { // Orphaned tool message (not following an assistant call) - Gemini doesn't like this. // Skip or map to user? Skipping is safer for API stability. continue } var parts []GeminiPart for _, cp := range msg.Content { if cp.Type == "text" && cp.Text != "" { parts = append(parts, GeminiPart{Text: cp.Text}) } else if cp.Image != nil { base64Data, mimeType, _ := cp.Image.ToBase64() parts = append(parts, GeminiPart{ InlineData: &GeminiInlineData{ MimeType: mimeType, Data: base64Data, }, }) } } if len(parts) > 0 { contents = append(contents, GeminiContent{Role: role, Parts: parts}) } } genConfig := &GeminiGenerationConfig{} if req.Temperature != nil { t := float32(*req.Temperature) genConfig.Temperature = &t } if req.TopP != nil { tp := float32(*req.TopP) genConfig.TopP = &tp } if req.TopK != nil { tk := int(*req.TopK) genConfig.TopK = &tk } if req.MaxTokens != nil { mt := int(*req.MaxTokens) genConfig.MaxOutputTokens = &mt } if len(req.Stop) > 0 { genConfig.StopSequences = req.Stop } body := GeminiRequest{ Contents: contents, GenerationConfig: genConfig, } // Map Tools if len(req.Tools) > 0 { geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}} for _, t := range req.Tools { if t.Type == "function" { geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function) } } body.Tools = []GeminiTool{geminiTool} } baseURL := p.config.BaseURL lowerModel := strings.ToLower(req.Model) if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") { // Use v1beta for preview and newer models if !strings.Contains(baseURL, "v1beta") { baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1) } } url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", baseURL, req.Model, p.apiKey) fmt.Printf("[Gemini] POST %s\n", url) resp, err := p.client.R(). SetContext(ctx). SetBody(body). Post(url) if err != nil { return nil, fmt.Errorf("request failed: %w", err) } if !resp.IsSuccess() { fmt.Printf("[Gemini] API Error %d: %s\n", resp.StatusCode(), resp.String()) // Also log the request body for debugging (careful with API keys if logged elsewhere) reqJSON, _ := json.Marshal(body) fmt.Printf("[Gemini] Request Body: %s\n", string(reqJSON)) return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String()) } // Parse Gemini response and convert to OpenAI format var geminiResp struct { Candidates []struct { Content struct { Role string `json:"role"` Parts []struct { Text string `json:"text"` FunctionCall *GeminiFunctionCall `json:"functionCall"` } `json:"parts"` } `json:"content"` FinishReason string `json:"finishReason"` } `json:"candidates"` UsageMetadata struct { PromptTokenCount uint32 `json:"promptTokenCount"` CandidatesTokenCount uint32 `json:"candidatesTokenCount"` TotalTokenCount uint32 `json:"totalTokenCount"` CachedContentTokenCount uint32 `json:"cachedContentTokenCount"` } `json:"usageMetadata"` } if err := json.Unmarshal(resp.Body(), &geminiResp); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } if len(geminiResp.Candidates) == 0 { return nil, fmt.Errorf("no candidates in Gemini response") } content := "" var toolCalls []models.ToolCall for _, part := range geminiResp.Candidates[0].Content.Parts { if part.Text != "" { content += part.Text } if part.FunctionCall != nil { toolCalls = append(toolCalls, models.ToolCall{ ID: fmt.Sprintf("call_%s", part.FunctionCall.Name), // Gemini doesn't have call IDs Type: "function", Function: models.FunctionCall{ Name: part.FunctionCall.Name, Arguments: string(part.FunctionCall.Args), }, }) } } finishReason := strings.ToLower(geminiResp.Candidates[0].FinishReason) if finishReason == "stop" { finishReason = "stop" } else if len(toolCalls) > 0 { finishReason = "tool_calls" } openAIResp := &models.ChatCompletionResponse{ ID: "gemini-" + req.Model, Object: "chat.completion", Created: 0, Model: req.Model, Choices: []models.ChatChoice{ { Index: 0, Message: models.ChatMessage{ Role: "assistant", Content: content, ToolCalls: toolCalls, }, FinishReason: &finishReason, }, }, Usage: &models.Usage{ PromptTokens: geminiResp.UsageMetadata.PromptTokenCount, CompletionTokens: geminiResp.UsageMetadata.CandidatesTokenCount, TotalTokens: geminiResp.UsageMetadata.TotalTokenCount, CacheReadTokens: uint32Ptr(geminiResp.UsageMetadata.CachedContentTokenCount), }, } return openAIResp, nil } func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) { // Simplified Gemini mapping var contents []GeminiContent for i := 0; i < len(req.Messages); i++ { msg := req.Messages[i] if msg.Role == "assistant" && len(msg.ToolCalls) > 0 { parts := []GeminiPart{} for _, cp := range msg.Content { if cp.Type == "text" && cp.Text != "" { parts = append(parts, GeminiPart{Text: cp.Text}) } } for _, tc := range msg.ToolCalls { parts = append(parts, GeminiPart{ FunctionCall: &GeminiFunctionCall{ Name: tc.Function.Name, Args: json.RawMessage(tc.Function.Arguments), }, }) } contents = append(contents, GeminiContent{Role: "model", Parts: parts}) var functionParts []GeminiPart j := i + 1 foundAny := false for j < len(req.Messages) && req.Messages[j].Role == "tool" { m := req.Messages[j] text := "" if len(m.Content) > 0 { text = m.Content[0].Text } name := "unknown_function" if m.Name != nil { name = *m.Name } var responseObj interface{} if err := json.Unmarshal([]byte(text), &responseObj); err != nil { responseObj = map[string]interface{}{"result": text} } respBytes, _ := json.Marshal(responseObj) functionParts = append(functionParts, GeminiPart{ FunctionResponse: &GeminiFunctionResponse{ Name: name, Response: json.RawMessage(respBytes), }, }) foundAny = true j++ } if foundAny { contents = append(contents, GeminiContent{Role: "function", Parts: functionParts}) i = j - 1 } continue } role := "user" if msg.Role == "assistant" { role = "model" } else if msg.Role == "system" { role = "user" } else if msg.Role == "tool" { continue } var parts []GeminiPart for _, cp := range msg.Content { if cp.Type == "text" && cp.Text != "" { parts = append(parts, GeminiPart{Text: cp.Text}) } else if cp.Image != nil { base64Data, mimeType, _ := cp.Image.ToBase64() parts = append(parts, GeminiPart{ InlineData: &GeminiInlineData{ MimeType: mimeType, Data: base64Data, }, }) } } if len(parts) > 0 { contents = append(contents, GeminiContent{Role: role, Parts: parts}) } } genConfig := &GeminiGenerationConfig{} if req.Temperature != nil { t := float32(*req.Temperature) genConfig.Temperature = &t } if req.TopP != nil { tp := float32(*req.TopP) genConfig.TopP = &tp } if req.TopK != nil { tk := int(*req.TopK) genConfig.TopK = &tk } if req.MaxTokens != nil { mt := int(*req.MaxTokens) genConfig.MaxOutputTokens = &mt } if len(req.Stop) > 0 { genConfig.StopSequences = req.Stop } body := GeminiRequest{ Contents: contents, GenerationConfig: genConfig, } if len(req.Tools) > 0 { geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}} for _, t := range req.Tools { if t.Type == "function" { geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function) } } body.Tools = []GeminiTool{geminiTool} } baseURL := p.config.BaseURL lowerModel := strings.ToLower(req.Model) if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") { // Use v1beta for preview and newer models if !strings.Contains(baseURL, "v1beta") { baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1) } } // Use streamGenerateContent for streaming url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", baseURL, req.Model, p.apiKey) fmt.Printf("[Gemini-Stream] POST %s\n", url) resp, err := p.client.R(). SetContext(ctx). SetBody(body). SetDoNotParseResponse(true). Post(url) if err != nil { return nil, fmt.Errorf("request failed: %w", err) } if !resp.IsSuccess() { return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String()) } ch := make(chan *models.ChatCompletionStreamResponse) go func() { defer close(ch) err := StreamGemini(resp.RawBody(), ch, req.Model) if err != nil { fmt.Printf("Gemini Stream error: %v\n", err) } }() return ch, nil } func uint32Ptr(v uint32) *uint32 { if v > 0 { return &v } return nil }