8a8d8d1477
- AuthMiddleware now requires auth on /v1/* routes (returns 401) - WebSocket origin check configurable via WSAllowedOrigin - Removed debug fmt.Printf leaks (config, ollama, server) - Registry access protected by sync.RWMutex (race condition fix) - Session cleanup goroutine runs every 15 min - RevokeSession returns error instead of silent no-op
256 lines
6.5 KiB
Go
256 lines
6.5 KiB
Go
package providers
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-resty/resty/v2"
|
|
"gophergate/internal/config"
|
|
"gophergate/internal/models"
|
|
)
|
|
|
|
type OllamaProvider struct {
|
|
client *resty.Client
|
|
config config.OllamaConfig
|
|
}
|
|
|
|
func NewOllamaProvider(cfg config.OllamaConfig) *OllamaProvider {
|
|
client := resty.New()
|
|
// Set reasonable timeouts for local Ollama server (longer for larger models)
|
|
// For streaming, we want a very long timeout or none at all to handle generation time
|
|
client.SetTimeout(15 * time.Minute)
|
|
client.SetRetryCount(2)
|
|
client.SetRetryWaitTime(1 * time.Second)
|
|
|
|
return &OllamaProvider{
|
|
client: client,
|
|
config: cfg,
|
|
}
|
|
}
|
|
|
|
func (p *OllamaProvider) Name() string {
|
|
return "ollama"
|
|
}
|
|
|
|
func (p *OllamaProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
|
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
|
}
|
|
|
|
body := BuildOllamaBody(req, messagesJSON, false)
|
|
url := fmt.Sprintf("%s/chat/completions", p.config.BaseURL)
|
|
|
|
resp, err := p.client.R().
|
|
SetContext(ctx).
|
|
SetBody(body).
|
|
Post(url)
|
|
|
|
if err != nil {
|
|
fmt.Printf("[Ollama] Request error: %v\n", err)
|
|
return nil, fmt.Errorf("request failed: %w", err)
|
|
}
|
|
|
|
if !resp.IsSuccess() {
|
|
fmt.Printf("[Ollama] API error %d: %s\n", resp.StatusCode(), resp.String())
|
|
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), resp.String())
|
|
}
|
|
|
|
var respJSON map[string]interface{}
|
|
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
|
fmt.Printf("[Ollama] Parse error: %v\n", err)
|
|
return nil, fmt.Errorf("failed to parse response: %w", err)
|
|
}
|
|
|
|
return ParseOllamaResponse(respJSON, req.Model)
|
|
}
|
|
|
|
func (p *OllamaProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
|
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
|
}
|
|
|
|
body := BuildOllamaBody(req, messagesJSON, true)
|
|
|
|
resp, err := p.client.R().
|
|
SetContext(ctx).
|
|
SetBody(body).
|
|
SetDoNotParseResponse(true).
|
|
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("request failed: %w", err)
|
|
}
|
|
|
|
if !resp.IsSuccess() {
|
|
return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), resp.String())
|
|
}
|
|
|
|
ch := make(chan *models.ChatCompletionStreamResponse)
|
|
|
|
go func() {
|
|
defer close(ch)
|
|
err := StreamOllama(resp.RawBody(), ch, req.Model)
|
|
if err != nil {
|
|
fmt.Printf("Stream error: %v\n", err)
|
|
}
|
|
}()
|
|
|
|
return ch, nil
|
|
}
|
|
|
|
func BuildOllamaBody(request *models.UnifiedRequest, messagesJSON []interface{}, stream bool) map[string]interface{} {
|
|
body := map[string]interface{}{
|
|
"model": request.Model,
|
|
"messages": messagesJSON,
|
|
"stream": stream,
|
|
}
|
|
|
|
options := make(map[string]interface{})
|
|
modelLower := strings.ToLower(request.Model)
|
|
|
|
// Context window size (default 8k for all, 32k+ for modern large-context models)
|
|
ctxSize := 8192
|
|
if strings.Contains(modelLower, "llama") ||
|
|
strings.Contains(modelLower, "gemma") ||
|
|
strings.Contains(modelLower, "mistral") ||
|
|
strings.Contains(modelLower, "mixtral") ||
|
|
strings.Contains(modelLower, "qwen") ||
|
|
strings.Contains(modelLower, "deepseek") ||
|
|
strings.Contains(modelLower, "command-r") ||
|
|
strings.Contains(modelLower, "phi") {
|
|
ctxSize = 32768
|
|
}
|
|
options["num_ctx"] = ctxSize
|
|
|
|
if request.Temperature != nil {
|
|
body["temperature"] = *request.Temperature
|
|
options["temperature"] = *request.Temperature
|
|
}
|
|
|
|
if request.MaxTokens != nil {
|
|
body["max_tokens"] = *request.MaxTokens
|
|
options["num_predict"] = *request.MaxTokens
|
|
} else {
|
|
// Default to 8192 for all Ollama models if not specified,
|
|
// as Ollama's compatibility layer defaults to 128 if neither
|
|
// max_tokens nor num_predict are provided.
|
|
body["max_tokens"] = 8192
|
|
options["num_predict"] = 8192
|
|
}
|
|
|
|
if request.TopP != nil {
|
|
body["top_p"] = *request.TopP
|
|
options["top_p"] = *request.TopP
|
|
}
|
|
if request.TopK != nil {
|
|
body["top_k"] = *request.TopK
|
|
options["top_k"] = *request.TopK
|
|
}
|
|
|
|
if len(request.Stop) > 0 {
|
|
body["stop"] = request.Stop
|
|
options["stop"] = request.Stop
|
|
}
|
|
|
|
if len(options) > 0 {
|
|
body["options"] = options
|
|
}
|
|
|
|
if len(request.Tools) > 0 {
|
|
body["tools"] = request.Tools
|
|
// Explicitly set tool_choice to auto if tools are present but choice is not specified
|
|
if request.ToolChoice == nil {
|
|
body["tool_choice"] = "auto"
|
|
}
|
|
}
|
|
if request.ToolChoice != nil {
|
|
var toolChoice interface{}
|
|
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
|
|
body["tool_choice"] = toolChoice
|
|
}
|
|
}
|
|
|
|
return body
|
|
}
|
|
|
|
func ParseOllamaResponse(respJSON map[string]interface{}, model string) (*models.ChatCompletionResponse, error) {
|
|
data, err := json.Marshal(respJSON)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var resp models.ChatCompletionResponse
|
|
if err := json.Unmarshal(data, &resp); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if usageData, ok := respJSON["usage"]; ok {
|
|
var usage models.Usage
|
|
usageBytes, _ := json.Marshal(usageData)
|
|
if err := json.Unmarshal(usageBytes, &usage); err == nil {
|
|
resp.Usage = &usage
|
|
}
|
|
}
|
|
|
|
return &resp, nil
|
|
}
|
|
|
|
func ParseOllamaStreamChunk(line string) (*models.ChatCompletionStreamResponse, bool, error) {
|
|
if line == "" {
|
|
return nil, false, nil
|
|
}
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
return nil, false, nil
|
|
}
|
|
|
|
data := strings.TrimPrefix(line, "data: ")
|
|
if data == "[DONE]" {
|
|
return nil, true, nil
|
|
}
|
|
|
|
var chunk models.ChatCompletionStreamResponse
|
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
|
return nil, false, fmt.Errorf("failed to unmarshal stream chunk: %w", err)
|
|
}
|
|
|
|
var rawChunk struct {
|
|
Usage *models.Usage `json:"usage"`
|
|
}
|
|
if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil {
|
|
chunk.Usage = rawChunk.Usage
|
|
}
|
|
|
|
return &chunk, false, nil
|
|
}
|
|
|
|
func StreamOllama(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error {
|
|
defer ctx.Close()
|
|
scanner := bufio.NewScanner(ctx)
|
|
// Set a larger buffer for scanning to handle large chunks if they occur
|
|
const maxCapacity = 10 * 1024 * 1024 // 10MB
|
|
buf := make([]byte, 64*1024)
|
|
scanner.Buffer(buf, maxCapacity)
|
|
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
chunk, done, err := ParseOllamaStreamChunk(line)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if done {
|
|
break
|
|
}
|
|
if chunk != nil {
|
|
ch <- chunk
|
|
}
|
|
}
|
|
return scanner.Err()
|
|
}
|