diff --git a/internal/server/dashboard.go b/internal/server/dashboard.go index f97bdf05..5b0074aa 100644 --- a/internal/server/dashboard.go +++ b/internal/server/dashboard.go @@ -10,6 +10,7 @@ import ( "gophergate/internal/db" "gophergate/internal/models" + "gophergate/internal/utils" "github.com/gin-gonic/gin" "github.com/google/uuid" "golang.org/x/crypto/bcrypt" @@ -912,24 +913,40 @@ func (s *Server) handleUpdateProvider(c *gin.Context) { return } + apiKeyEncrypted := false + var apiKey *string = req.APIKey + if req.APIKey != nil && *req.APIKey != "" { + encrypted, err := utils.Encrypt(*req.APIKey, s.cfg.KeyBytes) + if err == nil { + apiKey = &encrypted + apiKeyEncrypted = true + } + } + _, err := s.database.Exec(` - INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode, api_key_encrypted) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET enabled = excluded.enabled, base_url = COALESCE(excluded.base_url, provider_configs.base_url), api_key = COALESCE(excluded.api_key, provider_configs.api_key), + api_key_encrypted = excluded.api_key_encrypted, credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance), low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold), billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode), updated_at = CURRENT_TIMESTAMP - `, name, strings.ToUpper(name), req.Enabled, req.BaseURL, req.APIKey, req.CreditBalance, req.LowCreditThreshold, req.BillingMode) + `, name, strings.ToUpper(name), req.Enabled, req.BaseURL, apiKey, req.CreditBalance, req.LowCreditThreshold, req.BillingMode, apiKeyEncrypted) if err != nil { c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) return } + // Refresh in-memory providers + if err := s.RefreshProviders(); err != nil { + fmt.Printf("Error refreshing providers: %v\n", err) + } + c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Provider updated"})) } @@ -962,7 +979,7 @@ func (s *Server) handleTestProvider(c *gin.Context) { } else if name == "deepseek" { testReq.Model = "deepseek-chat" } else if name == "grok" { - testReq.Model = "grok-2" + testReq.Model = "grok-3-mini" } _, err := provider.ChatCompletion(c.Request.Context(), testReq) diff --git a/internal/server/server.go b/internal/server/server.go index 0ff0ac1b..9ab98f05 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -54,28 +54,102 @@ func NewServer(cfg *config.Config, database *db.DB) *Server { } }() - // Initialize providers - if cfg.Providers.OpenAI.Enabled { - apiKey, _ := cfg.GetAPIKey("openai") - s.providers["openai"] = providers.NewOpenAIProvider(cfg.Providers.OpenAI, apiKey) - } - if cfg.Providers.Gemini.Enabled { - apiKey, _ := cfg.GetAPIKey("gemini") - s.providers["gemini"] = providers.NewGeminiProvider(cfg.Providers.Gemini, apiKey) - } - if cfg.Providers.DeepSeek.Enabled { - apiKey, _ := cfg.GetAPIKey("deepseek") - s.providers["deepseek"] = providers.NewDeepSeekProvider(cfg.Providers.DeepSeek, apiKey) - } - if cfg.Providers.Grok.Enabled { - apiKey, _ := cfg.GetAPIKey("grok") - s.providers["grok"] = providers.NewGrokProvider(cfg.Providers.Grok, apiKey) + // Initialize providers from DB and Config + if err := s.RefreshProviders(); err != nil { + fmt.Printf("Warning: Failed to initial refresh providers: %v\n", err) } s.setupRoutes() return s } +func (s *Server) RefreshProviders() error { + var dbConfigs []db.ProviderConfig + err := s.database.Select(&dbConfigs, "SELECT * FROM provider_configs") + if err != nil { + return fmt.Errorf("failed to fetch provider configs from db: %w", err) + } + + dbMap := make(map[string]db.ProviderConfig) + for _, cfg := range dbConfigs { + dbMap[cfg.ID] = cfg + } + + providerIDs := []string{"openai", "gemini", "deepseek", "grok"} + for _, id := range providerIDs { + // Default values from config + enabled := false + baseURL := "" + apiKey := "" + + switch id { + case "openai": + enabled = s.cfg.Providers.OpenAI.Enabled + baseURL = s.cfg.Providers.OpenAI.BaseURL + apiKey, _ = s.cfg.GetAPIKey("openai") + case "gemini": + enabled = s.cfg.Providers.Gemini.Enabled + baseURL = s.cfg.Providers.Gemini.BaseURL + apiKey, _ = s.cfg.GetAPIKey("gemini") + case "deepseek": + enabled = s.cfg.Providers.DeepSeek.Enabled + baseURL = s.cfg.Providers.DeepSeek.BaseURL + apiKey, _ = s.cfg.GetAPIKey("deepseek") + case "grok": + enabled = s.cfg.Providers.Grok.Enabled + baseURL = s.cfg.Providers.Grok.BaseURL + apiKey, _ = s.cfg.GetAPIKey("grok") + } + + // Overrides from DB + if dbCfg, ok := dbMap[id]; ok { + enabled = dbCfg.Enabled + if dbCfg.BaseURL != nil && *dbCfg.BaseURL != "" { + baseURL = *dbCfg.BaseURL + } + if dbCfg.APIKey != nil && *dbCfg.APIKey != "" { + key := *dbCfg.APIKey + if dbCfg.APIKeyEncrypted { + decrypted, err := utils.Decrypt(key, s.cfg.KeyBytes) + if err == nil { + key = decrypted + } else { + fmt.Printf("Warning: Failed to decrypt API key for %s: %v\n", id, err) + } + } + apiKey = key + } + } + + if !enabled { + delete(s.providers, id) + continue + } + + // Initialize provider + switch id { + case "openai": + cfg := s.cfg.Providers.OpenAI + cfg.BaseURL = baseURL + s.providers["openai"] = providers.NewOpenAIProvider(cfg, apiKey) + case "gemini": + cfg := s.cfg.Providers.Gemini + cfg.BaseURL = baseURL + s.providers["gemini"] = providers.NewGeminiProvider(cfg, apiKey) + case "deepseek": + cfg := s.cfg.Providers.DeepSeek + cfg.BaseURL = baseURL + s.providers["deepseek"] = providers.NewDeepSeekProvider(cfg, apiKey) + case "grok": + cfg := s.cfg.Providers.Grok + cfg.BaseURL = baseURL + s.providers["grok"] = providers.NewGrokProvider(cfg, apiKey) + } + } + + return nil +} + func (s *Server) setupRoutes() { s.router.Use(middleware.AuthMiddleware(s.database)) diff --git a/internal/utils/crypto.go b/internal/utils/crypto.go new file mode 100644 index 00000000..ca77fd55 --- /dev/null +++ b/internal/utils/crypto.go @@ -0,0 +1,71 @@ +package utils + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "fmt" + "io" +) + +// Encrypt encrypts plain text using AES-GCM with the given 32-byte key. +func Encrypt(plainText string, key []byte) (string, error) { + if len(key) != 32 { + return "", fmt.Errorf("encryption key must be 32 bytes") + } + + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", err + } + + // The nonce should be prepended to the ciphertext + cipherText := gcm.Seal(nonce, nonce, []byte(plainText), nil) + return base64.StdEncoding.EncodeToString(cipherText), nil +} + +// Decrypt decrypts base64-encoded cipher text using AES-GCM with the given 32-byte key. +func Decrypt(encodedCipherText string, key []byte) (string, error) { + if len(key) != 32 { + return "", fmt.Errorf("encryption key must be 32 bytes") + } + + cipherText, err := base64.StdEncoding.DecodeString(encodedCipherText) + if err != nil { + return "", err + } + + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + nonceSize := gcm.NonceSize() + if len(cipherText) < nonceSize { + return "", fmt.Errorf("cipher text too short") + } + + nonce, actualCipherText := cipherText[:nonceSize], cipherText[nonceSize:] + plainText, err := gcm.Open(nil, nonce, actualCipherText, nil) + if err != nil { + return "", err + } + + return string(plainText), nil +}