477a811999
The 40-character truncation of tool call IDs in helper.go caused collisions when models (like deepseek-v4-flash) generated longer IDs, leading to "Duplicate value for 'tool_call_id'" errors. Removed the limit to allow full unique IDs. DeepSeek: updated reasoning_content injection to use an empty string instead of a space, better matching provider expectations for history. Improved API error reporting across all providers by capturing raw body content when response parsing fails or returns empty strings.
276 lines
7.1 KiB
Go
276 lines
7.1 KiB
Go
package providers
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-resty/resty/v2"
|
|
"gophergate/internal/config"
|
|
"gophergate/internal/models"
|
|
)
|
|
|
|
type OllamaProvider struct {
|
|
client *resty.Client
|
|
config config.OllamaConfig
|
|
}
|
|
|
|
func NewOllamaProvider(cfg config.OllamaConfig) *OllamaProvider {
|
|
client := resty.New()
|
|
// Set reasonable timeouts for local Ollama server (longer for larger models)
|
|
// For streaming, we want a very long timeout or none at all to handle generation time
|
|
client.SetTimeout(15 * time.Minute)
|
|
client.SetRetryCount(2)
|
|
client.SetRetryWaitTime(1 * time.Second)
|
|
|
|
return &OllamaProvider{
|
|
client: client,
|
|
config: cfg,
|
|
}
|
|
}
|
|
|
|
func (p *OllamaProvider) Name() string {
|
|
return "ollama"
|
|
}
|
|
|
|
func (p *OllamaProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
|
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
|
}
|
|
|
|
body := BuildOllamaBody(req, messagesJSON, false)
|
|
url := fmt.Sprintf("%s/chat/completions", p.config.BaseURL)
|
|
|
|
resp, err := p.client.R().
|
|
SetContext(ctx).
|
|
SetBody(body).
|
|
Post(url)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("request failed: %w", err)
|
|
}
|
|
|
|
if !resp.IsSuccess() {
|
|
msg := resp.String()
|
|
if msg == "" {
|
|
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
|
msg = string(body)
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), msg)
|
|
}
|
|
|
|
var respJSON map[string]interface{}
|
|
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
|
return nil, fmt.Errorf("failed to parse response: %w", err)
|
|
}
|
|
|
|
return ParseOllamaResponse(respJSON, req.Model)
|
|
}
|
|
|
|
func (p *OllamaProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
|
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
|
}
|
|
|
|
body := BuildOllamaBody(req, messagesJSON, true)
|
|
|
|
resp, err := p.client.R().
|
|
SetContext(ctx).
|
|
SetBody(body).
|
|
SetDoNotParseResponse(true).
|
|
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("request failed: %w", err)
|
|
}
|
|
|
|
if !resp.IsSuccess() {
|
|
msg := resp.String()
|
|
if msg == "" {
|
|
if body, err := io.ReadAll(resp.RawBody()); err == nil {
|
|
msg = string(body)
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), msg)
|
|
}
|
|
|
|
ch := make(chan *models.ChatCompletionStreamResponse)
|
|
|
|
go func() {
|
|
defer close(ch)
|
|
err := StreamOllama(resp.RawBody(), ch, req.Model)
|
|
if err != nil {
|
|
}
|
|
}()
|
|
|
|
return ch, nil
|
|
}
|
|
|
|
func BuildOllamaBody(request *models.UnifiedRequest, messagesJSON []interface{}, stream bool) map[string]interface{} {
|
|
body := map[string]interface{}{
|
|
"model": request.Model,
|
|
"messages": messagesJSON,
|
|
"stream": stream,
|
|
}
|
|
|
|
options := make(map[string]interface{})
|
|
modelLower := strings.ToLower(request.Model)
|
|
|
|
// Context window size (default 8k for all, 32k+ for modern large-context models)
|
|
ctxSize := 8192
|
|
if strings.Contains(modelLower, "llama") ||
|
|
strings.Contains(modelLower, "gemma") ||
|
|
strings.Contains(modelLower, "mistral") ||
|
|
strings.Contains(modelLower, "mixtral") ||
|
|
strings.Contains(modelLower, "qwen") ||
|
|
strings.Contains(modelLower, "deepseek") ||
|
|
strings.Contains(modelLower, "command-r") ||
|
|
strings.Contains(modelLower, "phi") {
|
|
ctxSize = 32768
|
|
}
|
|
options["num_ctx"] = ctxSize
|
|
|
|
if request.Temperature != nil {
|
|
body["temperature"] = *request.Temperature
|
|
options["temperature"] = *request.Temperature
|
|
}
|
|
|
|
if request.MaxTokens != nil {
|
|
body["max_tokens"] = *request.MaxTokens
|
|
options["num_predict"] = *request.MaxTokens
|
|
} else {
|
|
// Default to 8192 for all Ollama models if not specified,
|
|
// as Ollama's compatibility layer defaults to 128 if neither
|
|
// max_tokens nor num_predict are provided.
|
|
body["max_tokens"] = 8192
|
|
options["num_predict"] = 8192
|
|
}
|
|
|
|
if request.TopP != nil {
|
|
body["top_p"] = *request.TopP
|
|
options["top_p"] = *request.TopP
|
|
}
|
|
if request.TopK != nil {
|
|
body["top_k"] = *request.TopK
|
|
options["top_k"] = *request.TopK
|
|
}
|
|
|
|
if len(request.Stop) > 0 {
|
|
body["stop"] = request.Stop
|
|
options["stop"] = request.Stop
|
|
}
|
|
|
|
if len(options) > 0 {
|
|
body["options"] = options
|
|
}
|
|
|
|
if len(request.Tools) > 0 {
|
|
body["tools"] = request.Tools
|
|
// Explicitly set tool_choice to auto if tools are present but choice is not specified
|
|
if request.ToolChoice == nil {
|
|
body["tool_choice"] = "auto"
|
|
}
|
|
}
|
|
if request.ToolChoice != nil {
|
|
var toolChoice interface{}
|
|
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
|
|
body["tool_choice"] = toolChoice
|
|
}
|
|
}
|
|
|
|
return body
|
|
}
|
|
|
|
func ParseOllamaResponse(respJSON map[string]interface{}, model string) (*models.ChatCompletionResponse, error) {
|
|
data, err := json.Marshal(respJSON)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var resp models.ChatCompletionResponse
|
|
if err := json.Unmarshal(data, &resp); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if usageData, ok := respJSON["usage"]; ok {
|
|
var usage models.Usage
|
|
usageBytes, _ := json.Marshal(usageData)
|
|
if err := json.Unmarshal(usageBytes, &usage); err == nil {
|
|
resp.Usage = &usage
|
|
}
|
|
}
|
|
|
|
return &resp, nil
|
|
}
|
|
|
|
func ParseOllamaStreamChunk(line string) (*models.ChatCompletionStreamResponse, bool, error) {
|
|
if line == "" {
|
|
return nil, false, nil
|
|
}
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
return nil, false, nil
|
|
}
|
|
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
if data == "[DONE]" {
|
|
return nil, true, nil
|
|
}
|
|
|
|
var chunk models.ChatCompletionStreamResponse
|
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
|
return nil, false, fmt.Errorf("failed to unmarshal stream chunk: %w", err)
|
|
}
|
|
|
|
var rawChunk struct {
|
|
Usage *models.Usage `json:"usage"`
|
|
}
|
|
if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil {
|
|
chunk.Usage = rawChunk.Usage
|
|
}
|
|
|
|
return &chunk, false, nil
|
|
}
|
|
|
|
func StreamOllama(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error {
|
|
defer ctx.Close()
|
|
scanner := bufio.NewScanner(ctx)
|
|
// Set a larger buffer for scanning to handle large chunks if they occur
|
|
const maxCapacity = 10 * 1024 * 1024 // 10MB
|
|
buf := make([]byte, 64*1024)
|
|
scanner.Buffer(buf, maxCapacity)
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
chunk, done, err := ParseOllamaStreamChunk(line)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if done {
|
|
break
|
|
}
|
|
if chunk != nil {
|
|
ch <- chunk
|
|
}
|
|
}
|
|
return scanner.Err()
|
|
}
|
|
|
|
func (p *OllamaProvider) ImageGeneration(ctx context.Context, req *models.ImageGenerationRequest) (*models.ImageGenerationResponse, error) {
|
|
return nil, fmt.Errorf("ollama does not support image generation")
|
|
}
|
|
|
|
func (p *OllamaProvider) Responses(ctx context.Context, req *models.ResponsesRequest) (*models.ResponsesResponse, error) {
|
|
return nil, fmt.Errorf("responses API not supported by ollama")
|
|
}
|
|
|
|
func (p *OllamaProvider) ResponsesStream(ctx context.Context, req *models.ResponsesRequest) (<-chan *models.ResponsesStreamChunk, error) {
|
|
return nil, fmt.Errorf("responses API not supported by ollama")
|
|
}
|