diff --git a/internal/providers/gemini.go b/internal/providers/gemini.go index 90d265d7..d63b7ade 100644 --- a/internal/providers/gemini.go +++ b/internal/providers/gemini.go @@ -31,9 +31,14 @@ func (p *GeminiProvider) Name() string { type GeminiRequest struct { Contents []GeminiContent `json:"contents"` + Tools []GeminiTool `json:"tools,omitempty"` GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"` } +type GeminiTool struct { + FunctionDeclarations []models.FunctionDef `json:"functionDeclarations"` +} + type GeminiGenerationConfig struct { Temperature *float32 `json:"temperature,omitempty"` TopP *float32 `json:"topP,omitempty"` @@ -77,7 +82,7 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified if msg.Role == "assistant" { role = "model" } else if msg.Role == "tool" { - role = "user" // Tool results are user-side in Gemini + role = "function" // Function results use 'function' role in Gemini contents } var parts []GeminiPart @@ -95,10 +100,19 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified name = *msg.Name } + // Try to parse text as JSON if it looks like it, Gemini expects an object + var responseObj interface{} + if err := json.Unmarshal([]byte(text), &responseObj); err != nil { + // If not valid JSON, wrap it in an object + responseObj = map[string]interface{}{"result": text} + } + + respBytes, _ := json.Marshal(responseObj) + parts = append(parts, GeminiPart{ FunctionResponse: &GeminiFunctionResponse{ Name: name, - Response: json.RawMessage(text), + Response: json.RawMessage(respBytes), }, }) } else { @@ -161,6 +175,17 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified GenerationConfig: genConfig, } + // Map Tools + if len(req.Tools) > 0 { + geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}} + for _, t := range req.Tools { + if t.Type == "function" { + geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function) + } + } + body.Tools = []GeminiTool{geminiTool} + } + baseURL := p.config.BaseURL lowerModel := strings.ToLower(req.Model) if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") { @@ -190,8 +215,10 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified var geminiResp struct { Candidates []struct { Content struct { + Role string `json:"role"` Parts []struct { - Text string `json:"text"` + Text string `json:"text"` + FunctionCall *GeminiFunctionCall `json:"functionCall"` } `json:"parts"` } `json:"content"` FinishReason string `json:"finishReason"` @@ -212,23 +239,44 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified } content := "" - for _, p := range geminiResp.Candidates[0].Content.Parts { - content += p.Text + var toolCalls []models.ToolCall + for _, part := range geminiResp.Candidates[0].Content.Parts { + if part.Text != "" { + content += part.Text + } + if part.FunctionCall != nil { + toolCalls = append(toolCalls, models.ToolCall{ + ID: fmt.Sprintf("call_%s", part.FunctionCall.Name), // Gemini doesn't have call IDs + Type: "function", + Function: models.FunctionCall{ + Name: part.FunctionCall.Name, + Arguments: string(part.FunctionCall.Args), + }, + }) + } + } + + finishReason := strings.ToLower(geminiResp.Candidates[0].FinishReason) + if finishReason == "stop" { + finishReason = "stop" + } else if len(toolCalls) > 0 { + finishReason = "tool_calls" } openAIResp := &models.ChatCompletionResponse{ ID: "gemini-" + req.Model, Object: "chat.completion", - Created: 0, // Should be current timestamp + Created: 0, Model: req.Model, Choices: []models.ChatChoice{ { Index: 0, Message: models.ChatMessage{ - Role: "assistant", - Content: content, + Role: "assistant", + Content: content, + ToolCalls: toolCalls, }, - FinishReason: &geminiResp.Candidates[0].FinishReason, + FinishReason: &finishReason, }, }, Usage: &models.Usage{ @@ -248,11 +296,48 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U role := "user" if msg.Role == "assistant" { role = "model" + } else if msg.Role == "tool" { + role = "function" } var parts []GeminiPart - for _, p := range msg.Content { - parts = append(parts, GeminiPart{Text: p.Text}) + + if msg.Role == "tool" { + text := "" + if len(msg.Content) > 0 { + text = msg.Content[0].Text + } + name := "unknown" + if msg.Name != nil { + name = *msg.Name + } + + var responseObj interface{} + if err := json.Unmarshal([]byte(text), &responseObj); err != nil { + responseObj = map[string]interface{}{"result": text} + } + respBytes, _ := json.Marshal(responseObj) + + parts = append(parts, GeminiPart{ + FunctionResponse: &GeminiFunctionResponse{ + Name: name, + Response: json.RawMessage(respBytes), + }, + }) + } else { + for _, p := range msg.Content { + parts = append(parts, GeminiPart{Text: p.Text}) + } + if msg.Role == "assistant" && len(msg.ToolCalls) > 0 { + for _, tc := range msg.ToolCalls { + parts = append(parts, GeminiPart{ + FunctionCall: &GeminiFunctionCall{ + Name: tc.Function.Name, + Args: json.RawMessage(tc.Function.Arguments), + }, + }) + } + } } contents = append(contents, GeminiContent{ @@ -287,6 +372,16 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U GenerationConfig: genConfig, } + if len(req.Tools) > 0 { + geminiTool := GeminiTool{FunctionDeclarations: []models.FunctionDef{}} + for _, t := range req.Tools { + if t.Type == "function" { + geminiTool.FunctionDeclarations = append(geminiTool.FunctionDeclarations, t.Function) + } + } + body.Tools = []GeminiTool{geminiTool} + } + baseURL := p.config.BaseURL lowerModel := strings.ToLower(req.Model) if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {