193 lines
6.1 KiB
Go
193 lines
6.1 KiB
Go
package config
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/spf13/viper"
|
|
)
|
|
|
|
type Config struct {
|
|
Server ServerConfig `mapstructure:"server"`
|
|
Database DatabaseConfig `mapstructure:"database"`
|
|
Providers ProviderConfig `mapstructure:"providers"`
|
|
EncryptionKey string `mapstructure:"encryption_key"`
|
|
KeyBytes []byte
|
|
}
|
|
|
|
type ServerConfig struct {
|
|
Port int `mapstructure:"port"`
|
|
Host string `mapstructure:"host"`
|
|
AuthTokens []string `mapstructure:"auth_tokens"`
|
|
}
|
|
|
|
type DatabaseConfig struct {
|
|
Path string `mapstructure:"path"`
|
|
MaxConnections int `mapstructure:"max_connections"`
|
|
}
|
|
|
|
type ProviderConfig struct {
|
|
OpenAI OpenAIConfig `mapstructure:"openai"`
|
|
Gemini GeminiConfig `mapstructure:"gemini"`
|
|
DeepSeek DeepSeekConfig `mapstructure:"deepseek"`
|
|
Grok GrokConfig `mapstructure:"grok"`
|
|
Ollama OllamaConfig `mapstructure:"ollama"`
|
|
}
|
|
|
|
type OpenAIConfig struct {
|
|
APIKeyEnv string `mapstructure:"api_key_env"`
|
|
BaseURL string `mapstructure:"base_url"`
|
|
DefaultModel string `mapstructure:"default_model"`
|
|
Enabled bool `mapstructure:"enabled"`
|
|
}
|
|
|
|
type GeminiConfig struct {
|
|
APIKeyEnv string `mapstructure:"api_key_env"`
|
|
BaseURL string `mapstructure:"base_url"`
|
|
DefaultModel string `mapstructure:"default_model"`
|
|
Enabled bool `mapstructure:"enabled"`
|
|
}
|
|
|
|
type DeepSeekConfig struct {
|
|
APIKeyEnv string `mapstructure:"api_key_env"`
|
|
BaseURL string `mapstructure:"base_url"`
|
|
DefaultModel string `mapstructure:"default_model"`
|
|
Enabled bool `mapstructure:"enabled"`
|
|
}
|
|
|
|
type GrokConfig struct {
|
|
APIKeyEnv string `mapstructure:"api_key_env"`
|
|
BaseURL string `mapstructure:"base_url"`
|
|
DefaultModel string `mapstructure:"default_model"`
|
|
Enabled bool `mapstructure:"enabled"`
|
|
}
|
|
|
|
type OllamaConfig struct {
|
|
BaseURL string `mapstructure:"base_url"`
|
|
Enabled bool `mapstructure:"enabled"`
|
|
DefaultModel string `mapstructure:"default_model"`
|
|
Models []string `mapstructure:"models"`
|
|
}
|
|
|
|
func Load() (*Config, error) {
|
|
v := viper.New()
|
|
|
|
// Defaults
|
|
v.SetDefault("server.port", 8080)
|
|
v.SetDefault("server.host", "0.0.0.0")
|
|
v.SetDefault("server.auth_tokens", []string{})
|
|
v.SetDefault("database.path", "./data/llm_proxy.db")
|
|
v.SetDefault("database.max_connections", 10)
|
|
|
|
v.SetDefault("providers.openai.api_key_env", "OPENAI_API_KEY")
|
|
v.SetDefault("providers.openai.base_url", "https://api.openai.com/v1")
|
|
v.SetDefault("providers.openai.default_model", "gpt-4o")
|
|
v.SetDefault("providers.openai.enabled", true)
|
|
|
|
v.SetDefault("providers.gemini.api_key_env", "GEMINI_API_KEY")
|
|
v.SetDefault("providers.gemini.base_url", "https://generativelanguage.googleapis.com/v1")
|
|
v.SetDefault("providers.gemini.default_model", "gemini-2.0-flash")
|
|
v.SetDefault("providers.gemini.enabled", true)
|
|
|
|
v.SetDefault("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")
|
|
v.SetDefault("providers.deepseek.base_url", "https://api.deepseek.com")
|
|
v.SetDefault("providers.deepseek.default_model", "deepseek-reasoner")
|
|
v.SetDefault("providers.deepseek.enabled", true)
|
|
|
|
v.SetDefault("providers.grok.api_key_env", "GROK_API_KEY")
|
|
v.SetDefault("providers.grok.base_url", "https://api.x.ai/v1")
|
|
v.SetDefault("providers.grok.default_model", "grok-4-1-fast-non-reasoning")
|
|
v.SetDefault("providers.grok.enabled", true)
|
|
|
|
v.SetDefault("providers.ollama.base_url", "http://localhost:11434/v1")
|
|
v.SetDefault("providers.ollama.enabled", false)
|
|
v.SetDefault("providers.ollama.models", []string{})
|
|
|
|
// Environment variables
|
|
v.SetEnvPrefix("LLM_PROXY")
|
|
v.SetEnvKeyReplacer(strings.NewReplacer(".", "__"))
|
|
v.AutomaticEnv()
|
|
|
|
// Explicitly bind keys that might use double underscores in .env
|
|
v.BindEnv("encryption_key", "LLM_PROXY__ENCRYPTION_KEY")
|
|
v.BindEnv("server.port", "LLM_PROXY__SERVER__PORT")
|
|
v.BindEnv("server.host", "LLM_PROXY__SERVER__HOST")
|
|
|
|
// Config file
|
|
v.SetConfigName("config")
|
|
v.SetConfigType("toml")
|
|
v.AddConfigPath(".")
|
|
if envPath := os.Getenv("LLM_PROXY__CONFIG_PATH"); envPath != "" {
|
|
v.SetConfigFile(envPath)
|
|
}
|
|
|
|
if err := v.ReadInConfig(); err != nil {
|
|
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
|
return nil, fmt.Errorf("failed to read config file: %w", err)
|
|
}
|
|
}
|
|
|
|
var cfg Config
|
|
if err := v.Unmarshal(&cfg); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
|
}
|
|
|
|
fmt.Printf("Debug Config: port from viper=%d, host from viper=%s\n", cfg.Server.Port, cfg.Server.Host)
|
|
fmt.Printf("Debug Env: LLM_PROXY__SERVER__PORT=%s, LLM_PROXY__SERVER__HOST=%s\n", os.Getenv("LLM_PROXY__SERVER__PORT"), os.Getenv("LLM_PROXY__SERVER__HOST"))
|
|
|
|
// Manual overrides for nested keys which Viper doesn't always bind correctly with AutomaticEnv + SetEnvPrefix
|
|
if port := os.Getenv("LLM_PROXY__SERVER__PORT"); port != "" {
|
|
fmt.Sscanf(port, "%d", &cfg.Server.Port)
|
|
fmt.Printf("Overriding port to %d from env\n", cfg.Server.Port)
|
|
}
|
|
if host := os.Getenv("LLM_PROXY__SERVER__HOST"); host != "" {
|
|
cfg.Server.Host = host
|
|
fmt.Printf("Overriding host to %s from env\n", cfg.Server.Host)
|
|
}
|
|
|
|
// Validate encryption key
|
|
if cfg.EncryptionKey == "" {
|
|
return nil, fmt.Errorf("encryption key is required (LLM_PROXY__ENCRYPTION_KEY)")
|
|
}
|
|
|
|
keyBytes, err := hex.DecodeString(cfg.EncryptionKey)
|
|
if err != nil {
|
|
keyBytes, err = base64.StdEncoding.DecodeString(cfg.EncryptionKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("encryption key must be hex or base64 encoded")
|
|
}
|
|
}
|
|
|
|
if len(keyBytes) != 32 {
|
|
return nil, fmt.Errorf("encryption key must be 32 bytes, got %d", len(keyBytes))
|
|
}
|
|
cfg.KeyBytes = keyBytes
|
|
|
|
return &cfg, nil
|
|
}
|
|
|
|
func (c *Config) GetAPIKey(provider string) (string, error) {
|
|
var envVar string
|
|
switch provider {
|
|
case "openai":
|
|
envVar = c.Providers.OpenAI.APIKeyEnv
|
|
case "gemini":
|
|
envVar = c.Providers.Gemini.APIKeyEnv
|
|
case "deepseek":
|
|
envVar = c.Providers.DeepSeek.APIKeyEnv
|
|
case "grok":
|
|
envVar = c.Providers.Grok.APIKeyEnv
|
|
default:
|
|
return "", fmt.Errorf("unknown provider: %s", provider)
|
|
}
|
|
|
|
val := os.Getenv(envVar)
|
|
if val == "" {
|
|
return "", fmt.Errorf("environment variable %s not set for %s", envVar, provider)
|
|
}
|
|
return val, nil
|
|
}
|