package providers import ( "bufio" "context" "encoding/json" "fmt" "io" "strings" "time" "gophergate/internal/config" "gophergate/internal/models" "github.com/go-resty/resty/v2" ) type OllamaProvider struct { client *resty.Client config config.OllamaConfig } func NewOllamaProvider(cfg config.OllamaConfig) *OllamaProvider { client := resty.New() // Set reasonable timeouts for local Ollama server (longer for larger models) // For streaming, we want a very long timeout or none at all to handle generation time client.SetTimeout(15 * time.Minute) client.SetRetryCount(2) client.SetRetryWaitTime(1 * time.Second) return &OllamaProvider{ client: client, config: cfg, } } func (p *OllamaProvider) Name() string { return "ollama" } func (p *OllamaProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) { messagesJSON, err := MessagesToOpenAIJSON(req.Messages) if err != nil { return nil, fmt.Errorf("failed to convert messages: %w", err) } body := BuildOllamaBody(req, messagesJSON, false) url := fmt.Sprintf("%s/chat/completions", p.config.BaseURL) // Log request for debugging fmt.Printf("[Ollama] Request to %s with model %s\n", url, req.Model) resp, err := p.client.R(). SetContext(ctx). SetBody(body). Post(url) if err != nil { fmt.Printf("[Ollama] Request error: %v\n", err) return nil, fmt.Errorf("request failed: %w", err) } if !resp.IsSuccess() { fmt.Printf("[Ollama] API error %d: %s\n", resp.StatusCode(), resp.String()) return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), resp.String()) } var respJSON map[string]interface{} if err := json.Unmarshal(resp.Body(), &respJSON); err != nil { fmt.Printf("[Ollama] Parse error: %v\n", err) return nil, fmt.Errorf("failed to parse response: %w", err) } fmt.Printf("[Ollama] Success response for model %s\n", req.Model) return ParseOllamaResponse(respJSON, req.Model) } func (p *OllamaProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) { messagesJSON, err := MessagesToOpenAIJSON(req.Messages) if err != nil { return nil, fmt.Errorf("failed to convert messages: %w", err) } body := BuildOllamaBody(req, messagesJSON, true) resp, err := p.client.R(). SetContext(ctx). SetBody(body). SetDoNotParseResponse(true). Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL)) if err != nil { return nil, fmt.Errorf("request failed: %w", err) } if !resp.IsSuccess() { return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), resp.String()) } ch := make(chan *models.ChatCompletionStreamResponse) go func() { defer close(ch) err := StreamOllama(resp.RawBody(), ch, req.Model) if err != nil { fmt.Printf("Stream error: %v\n", err) } }() return ch, nil } func BuildOllamaBody(request *models.UnifiedRequest, messagesJSON []interface{}, stream bool) map[string]interface{} { body := map[string]interface{}{ "model": request.Model, "messages": messagesJSON, "stream": stream, } options := make(map[string]interface{}) modelLower := strings.ToLower(request.Model) // Context window size (default 8k for all, 32k+ for modern large-context models) ctxSize := 8192 if strings.Contains(modelLower, "llama") || strings.Contains(modelLower, "gemma") || strings.Contains(modelLower, "mistral") || strings.Contains(modelLower, "mixtral") || strings.Contains(modelLower, "qwen") || strings.Contains(modelLower, "deepseek") || strings.Contains(modelLower, "command-r") || strings.Contains(modelLower, "phi") { ctxSize = 32768 } options["num_ctx"] = ctxSize if request.Temperature != nil { body["temperature"] = *request.Temperature options["temperature"] = *request.Temperature } if request.MaxTokens != nil { body["max_tokens"] = *request.MaxTokens options["num_predict"] = *request.MaxTokens } else { // Default to 8192 for all Ollama models if not specified, // as Ollama's compatibility layer defaults to 128 if neither // max_tokens nor num_predict are provided. body["max_tokens"] = 8192 options["num_predict"] = 8192 } if request.TopP != nil { body["top_p"] = *request.TopP options["top_p"] = *request.TopP } if request.TopK != nil { body["top_k"] = *request.TopK options["top_k"] = *request.TopK } if len(request.Stop) > 0 { body["stop"] = request.Stop options["stop"] = request.Stop } if len(options) > 0 { body["options"] = options } if len(request.Tools) > 0 { body["tools"] = request.Tools // Explicitly set tool_choice to auto if tools are present but choice is not specified if request.ToolChoice == nil { body["tool_choice"] = "auto" } } if request.ToolChoice != nil { var toolChoice interface{} if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil { body["tool_choice"] = toolChoice } } return body } func ParseOllamaResponse(respJSON map[string]interface{}, model string) (*models.ChatCompletionResponse, error) { data, err := json.Marshal(respJSON) if err != nil { return nil, err } var resp models.ChatCompletionResponse if err := json.Unmarshal(data, &resp); err != nil { return nil, err } if usageData, ok := respJSON["usage"]; ok { var usage models.Usage usageBytes, _ := json.Marshal(usageData) if err := json.Unmarshal(usageBytes, &usage); err == nil { resp.Usage = &usage } } return &resp, nil } func ParseOllamaStreamChunk(line string) (*models.ChatCompletionStreamResponse, bool, error) { if line == "" { return nil, false, nil } if !strings.HasPrefix(line, "data: ") { return nil, false, nil } data := strings.TrimPrefix(line, "data: ") if data == "[DONE]" { return nil, true, nil } var chunk models.ChatCompletionStreamResponse if err := json.Unmarshal([]byte(data), &chunk); err != nil { return nil, false, fmt.Errorf("failed to unmarshal stream chunk: %w", err) } var rawChunk struct { Usage *models.Usage `json:"usage"` } if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil { chunk.Usage = rawChunk.Usage } return &chunk, false, nil } func StreamOllama(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error { defer ctx.Close() scanner := bufio.NewScanner(ctx) // Set a larger buffer for scanning to handle large chunks if they occur const maxCapacity = 10 * 1024 * 1024 // 10MB buf := make([]byte, 64*1024) scanner.Buffer(buf, maxCapacity) for scanner.Scan() { line := scanner.Text() chunk, done, err := ParseOllamaStreamChunk(line) if err != nil { return err } if done { break } if chunk != nil { ch <- chunk } } return scanner.Err() }