feat(ollama): improve configuration and dashboard integration
This commit is contained in:
@@ -128,6 +128,9 @@ func Load() (*Config, error) {
|
|||||||
v.BindEnv("encryption_key", "LLM_PROXY__ENCRYPTION_KEY")
|
v.BindEnv("encryption_key", "LLM_PROXY__ENCRYPTION_KEY")
|
||||||
v.BindEnv("server.port", "LLM_PROXY__SERVER__PORT")
|
v.BindEnv("server.port", "LLM_PROXY__SERVER__PORT")
|
||||||
v.BindEnv("server.host", "LLM_PROXY__SERVER__HOST")
|
v.BindEnv("server.host", "LLM_PROXY__SERVER__HOST")
|
||||||
|
v.BindEnv("providers.ollama.enabled", "LLM_PROXY__PROVIDERS__OLLAMA__ENABLED")
|
||||||
|
v.BindEnv("providers.ollama.base_url", "LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL")
|
||||||
|
v.BindEnv("providers.ollama.models", "LLM_PROXY__PROVIDERS__OLLAMA__MODELS")
|
||||||
|
|
||||||
// Config file
|
// Config file
|
||||||
v.SetConfigName("config")
|
v.SetConfigName("config")
|
||||||
@@ -161,6 +164,19 @@ func Load() (*Config, error) {
|
|||||||
fmt.Printf("Overriding host to %s from env\n", cfg.Server.Host)
|
fmt.Printf("Overriding host to %s from env\n", cfg.Server.Host)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ollama overrides
|
||||||
|
if enabled := os.Getenv("LLM_PROXY__PROVIDERS__OLLAMA__ENABLED"); enabled != "" {
|
||||||
|
cfg.Providers.Ollama.Enabled = enabled == "true"
|
||||||
|
}
|
||||||
|
if baseURL := os.Getenv("LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL"); baseURL != "" {
|
||||||
|
cfg.Providers.Ollama.BaseURL = baseURL
|
||||||
|
}
|
||||||
|
if models := os.Getenv("LLM_PROXY__PROVIDERS__OLLAMA__MODELS"); models != "" {
|
||||||
|
cfg.Providers.Ollama.Models = strings.Split(models, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("[DEBUG] Final Ollama Config: Enabled=%v, BaseURL=%s, Models=%v\n", cfg.Providers.Ollama.Enabled, cfg.Providers.Ollama.BaseURL, cfg.Providers.Ollama.Models)
|
||||||
|
|
||||||
// Validate encryption key
|
// Validate encryption key
|
||||||
if cfg.EncryptionKey == "" {
|
if cfg.EncryptionKey == "" {
|
||||||
return nil, fmt.Errorf("encryption key is required (LLM_PROXY__ENCRYPTION_KEY)")
|
return nil, fmt.Errorf("encryption key is required (LLM_PROXY__ENCRYPTION_KEY)")
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"gophergate/internal/config"
|
"gophergate/internal/config"
|
||||||
"gophergate/internal/models"
|
"gophergate/internal/models"
|
||||||
@@ -30,6 +31,15 @@ func (p *GeminiProvider) Name() string {
|
|||||||
|
|
||||||
type GeminiRequest struct {
|
type GeminiRequest struct {
|
||||||
Contents []GeminiContent `json:"contents"`
|
Contents []GeminiContent `json:"contents"`
|
||||||
|
GenerationConfig *GeminiGenerationConfig `json:"generationConfig,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiGenerationConfig struct {
|
||||||
|
Temperature *float32 `json:"temperature,omitempty"`
|
||||||
|
TopP *float32 `json:"topP,omitempty"`
|
||||||
|
TopK *int `json:"topK,omitempty"`
|
||||||
|
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||||
|
StopSequences []string `json:"stopSequences,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GeminiContent struct {
|
type GeminiContent struct {
|
||||||
@@ -125,11 +135,43 @@ func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.Unified
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
body := GeminiRequest{
|
genConfig := &GeminiGenerationConfig{}
|
||||||
Contents: contents,
|
if req.Temperature != nil {
|
||||||
|
t := float32(*req.Temperature)
|
||||||
|
genConfig.Temperature = &t
|
||||||
|
}
|
||||||
|
if req.TopP != nil {
|
||||||
|
tp := float32(*req.TopP)
|
||||||
|
genConfig.TopP = &tp
|
||||||
|
}
|
||||||
|
if req.TopK != nil {
|
||||||
|
tk := int(*req.TopK)
|
||||||
|
genConfig.TopK = &tk
|
||||||
|
}
|
||||||
|
if req.MaxTokens != nil {
|
||||||
|
mt := int(*req.MaxTokens)
|
||||||
|
genConfig.MaxOutputTokens = &mt
|
||||||
|
}
|
||||||
|
if len(req.Stop) > 0 {
|
||||||
|
genConfig.StopSequences = req.Stop
|
||||||
}
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
|
body := GeminiRequest{
|
||||||
|
Contents: contents,
|
||||||
|
GenerationConfig: genConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := p.config.BaseURL
|
||||||
|
lowerModel := strings.ToLower(req.Model)
|
||||||
|
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
|
||||||
|
// Use v1beta for preview and newer models
|
||||||
|
if !strings.Contains(baseURL, "v1beta") {
|
||||||
|
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", baseURL, req.Model, p.apiKey)
|
||||||
|
fmt.Printf("[Gemini] POST %s\n", url)
|
||||||
|
|
||||||
resp, err := p.client.R().
|
resp, err := p.client.R().
|
||||||
SetContext(ctx).
|
SetContext(ctx).
|
||||||
@@ -219,12 +261,44 @@ func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.U
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
genConfig := &GeminiGenerationConfig{}
|
||||||
|
if req.Temperature != nil {
|
||||||
|
t := float32(*req.Temperature)
|
||||||
|
genConfig.Temperature = &t
|
||||||
|
}
|
||||||
|
if req.TopP != nil {
|
||||||
|
tp := float32(*req.TopP)
|
||||||
|
genConfig.TopP = &tp
|
||||||
|
}
|
||||||
|
if req.TopK != nil {
|
||||||
|
tk := int(*req.TopK)
|
||||||
|
genConfig.TopK = &tk
|
||||||
|
}
|
||||||
|
if req.MaxTokens != nil {
|
||||||
|
mt := int(*req.MaxTokens)
|
||||||
|
genConfig.MaxOutputTokens = &mt
|
||||||
|
}
|
||||||
|
if len(req.Stop) > 0 {
|
||||||
|
genConfig.StopSequences = req.Stop
|
||||||
|
}
|
||||||
|
|
||||||
body := GeminiRequest{
|
body := GeminiRequest{
|
||||||
Contents: contents,
|
Contents: contents,
|
||||||
|
GenerationConfig: genConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := p.config.BaseURL
|
||||||
|
lowerModel := strings.ToLower(req.Model)
|
||||||
|
if strings.Contains(lowerModel, "preview") || strings.Contains(lowerModel, "3.1") || strings.Contains(lowerModel, "2.0") || strings.Contains(lowerModel, "thinking") {
|
||||||
|
// Use v1beta for preview and newer models
|
||||||
|
if !strings.Contains(baseURL, "v1beta") {
|
||||||
|
baseURL = strings.Replace(baseURL, "/v1", "/v1beta", 1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use streamGenerateContent for streaming
|
// Use streamGenerateContent for streaming
|
||||||
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
|
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", baseURL, req.Model, p.apiKey)
|
||||||
|
fmt.Printf("[Gemini-Stream] POST %s\n", url)
|
||||||
|
|
||||||
resp, err := p.client.R().
|
resp, err := p.client.R().
|
||||||
SetContext(ctx).
|
SetContext(ctx).
|
||||||
|
|||||||
@@ -884,6 +884,11 @@ func (s *Server) handleGetProviders(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If it's ollama, also include models from config
|
||||||
|
if id == "ollama" {
|
||||||
|
models = append(models, s.cfg.Providers.Ollama.Models...)
|
||||||
|
}
|
||||||
|
|
||||||
result = append(result, gin.H{
|
result = append(result, gin.H{
|
||||||
"id": id,
|
"id": id,
|
||||||
"name": name,
|
"name": name,
|
||||||
@@ -1012,6 +1017,7 @@ func (s *Server) handleGetModels(c *gin.Context) {
|
|||||||
"google": "gemini",
|
"google": "gemini",
|
||||||
"deepseek": "deepseek",
|
"deepseek": "deepseek",
|
||||||
"xai": "grok",
|
"xai": "grok",
|
||||||
|
"ollama": "ollama",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Merge registry models with DB overrides
|
// Merge registry models with DB overrides
|
||||||
@@ -1107,6 +1113,69 @@ func (s *Server) handleGetModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add configured Ollama models if they aren't in registry
|
||||||
|
if s.cfg.Providers.Ollama.Enabled {
|
||||||
|
for _, mID := range s.cfg.Providers.Ollama.Models {
|
||||||
|
// Check if already added from registry
|
||||||
|
exists := false
|
||||||
|
for _, r := range result {
|
||||||
|
if r["id"] == mID {
|
||||||
|
exists = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if usedOnly && !usedPairs[fmt.Sprintf("%s:ollama", mID)] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
enabled := true
|
||||||
|
promptCost := 0.0
|
||||||
|
completionCost := 0.0
|
||||||
|
var cacheReadCost *float64
|
||||||
|
var cacheWriteCost *float64
|
||||||
|
var mapping *string
|
||||||
|
contextLimit := uint32(0)
|
||||||
|
|
||||||
|
// Override from DB
|
||||||
|
if dbCfg, ok := dbMap[mID]; ok {
|
||||||
|
enabled = dbCfg.Enabled
|
||||||
|
if dbCfg.PromptCostPerM != nil {
|
||||||
|
promptCost = *dbCfg.PromptCostPerM
|
||||||
|
}
|
||||||
|
if dbCfg.CompletionCostPerM != nil {
|
||||||
|
completionCost = *dbCfg.CompletionCostPerM
|
||||||
|
}
|
||||||
|
if dbCfg.CacheReadCostPerM != nil {
|
||||||
|
cacheReadCost = dbCfg.CacheReadCostPerM
|
||||||
|
}
|
||||||
|
if dbCfg.CacheWriteCostPerM != nil {
|
||||||
|
cacheWriteCost = dbCfg.CacheWriteCostPerM
|
||||||
|
}
|
||||||
|
mapping = dbCfg.Mapping
|
||||||
|
}
|
||||||
|
|
||||||
|
result = append(result, gin.H{
|
||||||
|
"id": mID,
|
||||||
|
"name": mID,
|
||||||
|
"provider": "ollama",
|
||||||
|
"enabled": enabled,
|
||||||
|
"prompt_cost": promptCost,
|
||||||
|
"completion_cost": completionCost,
|
||||||
|
"cache_read_cost": cacheReadCost,
|
||||||
|
"cache_write_cost": cacheWriteCost,
|
||||||
|
"context_limit": contextLimit,
|
||||||
|
"modalities": gin.H{"input": []string{"text"}, "output": []string{"text"}},
|
||||||
|
"tool_call": false,
|
||||||
|
"reasoning": false,
|
||||||
|
"mapping": mapping,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, SuccessResponse(result))
|
c.JSON(http.StatusOK, SuccessResponse(result))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -154,6 +154,7 @@ func (s *Server) RefreshProviders() error {
|
|||||||
s.providers["grok"] = providers.NewGrokProvider(cfg, apiKey)
|
s.providers["grok"] = providers.NewGrokProvider(cfg, apiKey)
|
||||||
case "ollama":
|
case "ollama":
|
||||||
cfg := s.cfg.Providers.Ollama
|
cfg := s.cfg.Providers.Ollama
|
||||||
|
fmt.Printf("[DEBUG] Ollama config: Enabled=%v, BaseURL=%s, Models=%v\n", cfg.Enabled, baseURL, cfg.Models)
|
||||||
cfg.BaseURL = baseURL
|
cfg.BaseURL = baseURL
|
||||||
s.providers["ollama"] = providers.NewOllamaProvider(cfg)
|
s.providers["ollama"] = providers.NewOllamaProvider(cfg)
|
||||||
}
|
}
|
||||||
@@ -271,6 +272,28 @@ func (s *Server) handleListModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add configured Ollama models
|
||||||
|
if s.cfg.Providers.Ollama.Enabled {
|
||||||
|
for _, mID := range s.cfg.Providers.Ollama.Models {
|
||||||
|
// Check if already added
|
||||||
|
exists := false
|
||||||
|
for _, d := range data {
|
||||||
|
if d.ID == mID {
|
||||||
|
exists = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
data = append(data, OpenAIModel{
|
||||||
|
ID: mID,
|
||||||
|
Object: "model",
|
||||||
|
Created: 1700000000,
|
||||||
|
OwnedBy: "ollama",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": data,
|
"data": data,
|
||||||
@@ -305,9 +328,17 @@ func (s *Server) handleChatCompletions(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Strip common prefixes
|
||||||
|
modelID := req.Model
|
||||||
|
if strings.HasPrefix(modelID, "gemini/") {
|
||||||
|
modelID = strings.TrimPrefix(modelID, "gemini/")
|
||||||
|
} else if strings.HasPrefix(modelID, "google/") {
|
||||||
|
modelID = strings.TrimPrefix(modelID, "google/")
|
||||||
|
}
|
||||||
|
|
||||||
// Convert ChatCompletionRequest to UnifiedRequest
|
// Convert ChatCompletionRequest to UnifiedRequest
|
||||||
unifiedReq := &models.UnifiedRequest{
|
unifiedReq := &models.UnifiedRequest{
|
||||||
Model: req.Model,
|
Model: modelID,
|
||||||
Messages: []models.UnifiedMessage{},
|
Messages: []models.UnifiedMessage{},
|
||||||
Temperature: req.Temperature,
|
Temperature: req.Temperature,
|
||||||
TopP: req.TopP,
|
TopP: req.TopP,
|
||||||
|
|||||||
Reference in New Issue
Block a user