diff --git a/internal/config/config.go b/internal/config/config.go index 48356cd2..711da6ca 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -128,6 +128,9 @@ func Load() (*Config, error) { v.BindEnv("encryption_key", "LLM_PROXY__ENCRYPTION_KEY") v.BindEnv("server.port", "LLM_PROXY__SERVER__PORT") v.BindEnv("server.host", "LLM_PROXY__SERVER__HOST") + v.BindEnv("providers.ollama.enabled", "LLM_PROXY__PROVIDERS__OLLAMA__ENABLED") + v.BindEnv("providers.ollama.base_url", "LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL") + v.BindEnv("providers.ollama.models", "LLM_PROXY__PROVIDERS__OLLAMA__MODELS") // Config file v.SetConfigName("config") @@ -161,6 +164,19 @@ func Load() (*Config, error) { fmt.Printf("Overriding host to %s from env\n", cfg.Server.Host) } + // Ollama overrides + if enabled := os.Getenv("LLM_PROXY__PROVIDERS__OLLAMA__ENABLED"); enabled != "" { + cfg.Providers.Ollama.Enabled = enabled == "true" + } + if baseURL := os.Getenv("LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL"); baseURL != "" { + cfg.Providers.Ollama.BaseURL = baseURL + } + if models := os.Getenv("LLM_PROXY__PROVIDERS__OLLAMA__MODELS"); models != "" { + cfg.Providers.Ollama.Models = strings.Split(models, ",") + } + + fmt.Printf("[DEBUG] Final Ollama Config: Enabled=%v, BaseURL=%s, Models=%v\n", cfg.Providers.Ollama.Enabled, cfg.Providers.Ollama.BaseURL, cfg.Providers.Ollama.Models) + // Validate encryption key if cfg.EncryptionKey == "" { return nil, fmt.Errorf("encryption key is required (LLM_PROXY__ENCRYPTION_KEY)") diff --git a/internal/providers/gemini.go b/internal/providers/gemini.go index ae8cc566..90d265d7 100644 --- a/internal/providers/gemini.go +++ b/internal/providers/gemini.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "gophergate/internal/config" "gophergate/internal/models" @@ -29,7 +30,16 @@ func (p *GeminiProvider) Name() string { } type GeminiRequest struct { - Contents []GeminiContent `json:"contents"` + Contents []GeminiContent `json:"contents"` + GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"` +} + +type GeminiGenerationConfig struct { + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"topP,omitempty"` + TopK *int `json:"topK,omitempty"` + MaxOutputTokens *int `json:"maxOutputTokens,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` } type GeminiContent struct { @@ -125,11 +135,43 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified }) } - body := GeminiRequest{ - Contents: contents, + genConfig := &GeminiGenerationConfig{} + if req.Temperature != nil { + t := float32(*req.Temperature) + genConfig.Temperature = &t + } + if req.TopP != nil { + tp := float32(*req.TopP) + genConfig.TopP = &tp + } + if req.TopK != nil { + tk := int(*req.TopK) + genConfig.TopK = &tk + } + if req.MaxTokens != nil { + mt := int(*req.MaxTokens) + genConfig.MaxOutputTokens = &mt + } + if len(req.Stop) > 0 { + genConfig.StopSequences = req.Stop } - url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey) + body := GeminiRequest{ + Contents: contents, + GenerationConfig: genConfig, + } + + baseURL := p.config.BaseURL + lowerModel := strings.ToLower(req.Model) + if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") { + // Use v1beta for preview and newer models + if !strings.Contains(baseURL, "v1beta") { + baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1) + } + } + + url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", baseURL, req.Model, p.apiKey) + fmt.Printf("[Gemini] POST %s\n", url) resp, err := p.client.R(). SetContext(ctx). @@ -219,12 +261,44 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U }) } + genConfig := &GeminiGenerationConfig{} + if req.Temperature != nil { + t := float32(*req.Temperature) + genConfig.Temperature = &t + } + if req.TopP != nil { + tp := float32(*req.TopP) + genConfig.TopP = &tp + } + if req.TopK != nil { + tk := int(*req.TopK) + genConfig.TopK = &tk + } + if req.MaxTokens != nil { + mt := int(*req.MaxTokens) + genConfig.MaxOutputTokens = &mt + } + if len(req.Stop) > 0 { + genConfig.StopSequences = req.Stop + } + body := GeminiRequest{ - Contents: contents, + Contents: contents, + GenerationConfig: genConfig, + } + + baseURL := p.config.BaseURL + lowerModel := strings.ToLower(req.Model) + if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") { + // Use v1beta for preview and newer models + if !strings.Contains(baseURL, "v1beta") { + baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1) + } } // Use streamGenerateContent for streaming - url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey) + url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", baseURL, req.Model, p.apiKey) + fmt.Printf("[Gemini-Stream] POST %s\n", url) resp, err := p.client.R(). SetContext(ctx). diff --git a/internal/server/dashboard.go b/internal/server/dashboard.go index 0fcbc0b0..c53c3666 100644 --- a/internal/server/dashboard.go +++ b/internal/server/dashboard.go @@ -884,6 +884,11 @@ func (s *Server) handleGetProviders(c *gin.Context) { } } + // If it's ollama, also include models from config + if id == "ollama" { + models = append(models, s.cfg.Providers.Ollama.Models...) + } + result = append(result, gin.H{ "id": id, "name": name, @@ -1012,6 +1017,7 @@ func (s *Server) handleGetModels(c *gin.Context) { "google": "gemini", "deepseek": "deepseek", "xai": "grok", + "ollama": "ollama", } // Merge registry models with DB overrides @@ -1107,6 +1113,69 @@ func (s *Server) handleGetModels(c *gin.Context) { } } + // Add configured Ollama models if they aren't in registry + if s.cfg.Providers.Ollama.Enabled { + for _, mID := range s.cfg.Providers.Ollama.Models { + // Check if already added from registry + exists := false + for _, r := range result { + if r["id"] == mID { + exists = true + break + } + } + if exists { + continue + } + + if usedOnly && !usedPairs[fmt.Sprintf("%s:ollama", mID)] { + continue + } + + enabled := true + promptCost := 0.0 + completionCost := 0.0 + var cacheReadCost *float64 + var cacheWriteCost *float64 + var mapping *string + contextLimit := uint32(0) + + // Override from DB + if dbCfg, ok := dbMap[mID]; ok { + enabled = dbCfg.Enabled + if dbCfg.PromptCostPerM != nil { + promptCost = *dbCfg.PromptCostPerM + } + if dbCfg.CompletionCostPerM != nil { + completionCost = *dbCfg.CompletionCostPerM + } + if dbCfg.CacheReadCostPerM != nil { + cacheReadCost = dbCfg.CacheReadCostPerM + } + if dbCfg.CacheWriteCostPerM != nil { + cacheWriteCost = dbCfg.CacheWriteCostPerM + } + mapping = dbCfg.Mapping + } + + result = append(result, gin.H{ + "id": mID, + "name": mID, + "provider": "ollama", + "enabled": enabled, + "prompt_cost": promptCost, + "completion_cost": completionCost, + "cache_read_cost": cacheReadCost, + "cache_write_cost": cacheWriteCost, + "context_limit": contextLimit, + "modalities": gin.H{"input": []string{"text"}, "output": []string{"text"}}, + "tool_call": false, + "reasoning": false, + "mapping": mapping, + }) + } + } + c.JSON(http.StatusOK, SuccessResponse(result)) } diff --git a/internal/server/server.go b/internal/server/server.go index d814b250..dccb3af3 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -154,6 +154,7 @@ func (s *Server) RefreshProviders() error { s.providers["grok"] = providers.NewGrokProvider(cfg, apiKey) case "ollama": cfg := s.cfg.Providers.Ollama + fmt.Printf("[DEBUG] Ollama config: Enabled=%v, BaseURL=%s, Models=%v\n", cfg.Enabled, baseURL, cfg.Models) cfg.BaseURL = baseURL s.providers["ollama"] = providers.NewOllamaProvider(cfg) } @@ -271,6 +272,28 @@ func (s *Server) handleListModels(c *gin.Context) { } } + // Add configured Ollama models + if s.cfg.Providers.Ollama.Enabled { + for _, mID := range s.cfg.Providers.Ollama.Models { + // Check if already added + exists := false + for _, d := range data { + if d.ID == mID { + exists = true + break + } + } + if !exists { + data = append(data, OpenAIModel{ + ID: mID, + Object: "model", + Created: 1700000000, + OwnedBy: "ollama", + }) + } + } + } + c.JSON(http.StatusOK, gin.H{ "object": "list", "data": data, @@ -305,9 +328,17 @@ func (s *Server) handleChatCompletions(c *gin.Context) { return } + // Strip common prefixes + modelID := req.Model + if strings.HasPrefix(modelID, "gemini/") { + modelID = strings.TrimPrefix(modelID, "gemini/") + } else if strings.HasPrefix(modelID, "google/") { + modelID = strings.TrimPrefix(modelID, "google/") + } + // Convert ChatCompletionRequest to UnifiedRequest unifiedReq := &models.UnifiedRequest{ - Model: req.Model, + Model: modelID, Messages: []models.UnifiedMessage{}, Temperature: req.Temperature, TopP: req.TopP,