diff --git a/internal/config/config.go b/internal/config/config.go index 42101266..48356cd2 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -195,6 +195,9 @@ func (c *Config) GetAPIKey(provider string) (string, error) { envVar = c.Providers.Moonshot.APIKeyEnv case "grok": envVar = c.Providers.Grok.APIKeyEnv + case "ollama": + // Ollama doesn't require an API key + return "", nil default: return "", fmt.Errorf("unknown provider: %s", provider) } diff --git a/internal/providers/ollama.go b/internal/providers/ollama.go new file mode 100644 index 00000000..171924ec --- /dev/null +++ b/internal/providers/ollama.go @@ -0,0 +1,198 @@ +package providers + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "strings" + + "gophergate/internal/config" + "gophergate/internal/models" + "github.com/go-resty/resty/v2" +) + +type OllamaProvider struct { + client *resty.Client + config config.OllamaConfig +} + +func NewOllamaProvider(cfg config.OllamaConfig) *OllamaProvider { + return &OllamaProvider{ + client: resty.New(), + config: cfg, + } +} + +func (p *OllamaProvider) Name() string { + return "ollama" +} + +func (p *OllamaProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) { + messagesJSON, err := MessagesToOpenAIJSON(req.Messages) + if err != nil { + return nil, fmt.Errorf("failed to convert messages: %w", err) + } + + body := BuildOllamaBody(req, messagesJSON, false) + + resp, err := p.client.R(). + SetContext(ctx). + SetBody(body). + Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL)) + + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if !resp.IsSuccess() { + return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), resp.String()) + } + + var respJSON map[string]interface{} + if err := json.Unmarshal(resp.Body(), &respJSON); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return ParseOllamaResponse(respJSON, req.Model) +} + +func (p *OllamaProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) { + messagesJSON, err := MessagesToOpenAIJSON(req.Messages) + if err != nil { + return nil, fmt.Errorf("failed to convert messages: %w", err) + } + + body := BuildOllamaBody(req, messagesJSON, true) + + resp, err := p.client.R(). + SetContext(ctx). + SetBody(body). + SetDoNotParseResponse(true). + Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL)) + + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if !resp.IsSuccess() { + return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), resp.String()) + } + + ch := make(chan *models.ChatCompletionStreamResponse) + + go func() { + defer close(ch) + err := StreamOllama(resp.RawBody(), ch, req.Model) + if err != nil { + fmt.Printf("Stream error: %v\n", err) + } + }() + + return ch, nil +} + +func BuildOllamaBody(request *models.UnifiedRequest, messagesJSON []interface{}, stream bool) map[string]interface{} { + body := map[string]interface{}{ + "model": request.Model, + "messages": messagesJSON, + "stream": stream, + } + + if request.Temperature != nil { + body["temperature"] = *request.Temperature + } + if request.MaxTokens != nil { + body["max_tokens"] = *request.MaxTokens + } + if request.TopP != nil { + body["top_p"] = *request.TopP + } + if request.TopK != nil { + body["top_k"] = *request.TopK + } + if len(request.Stop) > 0 { + body["stop"] = request.Stop + } + if len(request.Tools) > 0 { + body["tools"] = request.Tools + } + if request.ToolChoice != nil { + var toolChoice interface{} + if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil { + body["tool_choice"] = toolChoice + } + } + + return body +} + +func ParseOllamaResponse(respJSON map[string]interface{}, model string) (*models.ChatCompletionResponse, error) { + data, err := json.Marshal(respJSON) + if err != nil { + return nil, err + } + + var resp models.ChatCompletionResponse + if err := json.Unmarshal(data, &resp); err != nil { + return nil, err + } + + if usageData, ok := respJSON["usage"]; ok { + var usage models.Usage + usageBytes, _ := json.Marshal(usageData) + if err := json.Unmarshal(usageBytes, &usage); err == nil { + resp.Usage = &usage + } + } + + return &resp, nil +} + +func ParseOllamaStreamChunk(line string) (*models.ChatCompletionStreamResponse, bool, error) { + if line == "" { + return nil, false, nil + } + if !strings.HasPrefix(line, "data: ") { + return nil, false, nil + } + + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + return nil, true, nil + } + + var chunk models.ChatCompletionStreamResponse + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + return nil, false, fmt.Errorf("failed to unmarshal stream chunk: %w", err) + } + + var rawChunk struct { + Usage *models.Usage `json:"usage"` + } + if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil { + chunk.Usage = rawChunk.Usage + } + + return &chunk, false, nil +} + +func StreamOllama(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error { + defer ctx.Close() + scanner := bufio.NewScanner(ctx) + for scanner.Scan() { + line := scanner.Text() + chunk, done, err := ParseOllamaStreamChunk(line) + if err != nil { + return err + } + if done { + break + } + if chunk != nil { + ch <- chunk + } + } + return scanner.Err() +} \ No newline at end of file diff --git a/internal/server/server.go b/internal/server/server.go index 3534d147..41469b5d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -75,7 +75,7 @@ func (s *Server) RefreshProviders() error { dbMap[cfg.ID] = cfg } - providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok"} + providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama"} for _, id := range providerIDs { // Default values from config enabled := false @@ -152,6 +152,10 @@ func (s *Server) RefreshProviders() error { cfg := s.cfg.Providers.Grok cfg.BaseURL = baseURL s.providers["grok"] = providers.NewGrokProvider(cfg, apiKey) + case "ollama": + cfg := s.cfg.Providers.Ollama + cfg.BaseURL = baseURL + s.providers["ollama"] = providers.NewOllamaProvider(cfg) } }