fix: prioritize database provider configs and implement API key encryption
- Added AES-GCM encryption/decryption for provider API keys in the database. - Implemented RefreshProviders to load provider configs from the database with precedence over environment variables. - Updated dashboard handlers to encrypt keys on save and trigger in-memory provider refresh. - Updated Grok test model to grok-3-mini for better compatibility.
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"gophergate/internal/db"
|
"gophergate/internal/db"
|
||||||
"gophergate/internal/models"
|
"gophergate/internal/models"
|
||||||
|
"gophergate/internal/utils"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
@@ -912,24 +913,40 @@ func (s *Server) handleUpdateProvider(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
apiKeyEncrypted := false
|
||||||
|
var apiKey *string = req.APIKey
|
||||||
|
if req.APIKey != nil && *req.APIKey != "" {
|
||||||
|
encrypted, err := utils.Encrypt(*req.APIKey, s.cfg.KeyBytes)
|
||||||
|
if err == nil {
|
||||||
|
apiKey = &encrypted
|
||||||
|
apiKeyEncrypted = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
_, err := s.database.Exec(`
|
_, err := s.database.Exec(`
|
||||||
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode)
|
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode, api_key_encrypted)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
ON CONFLICT(id) DO UPDATE SET
|
ON CONFLICT(id) DO UPDATE SET
|
||||||
enabled = excluded.enabled,
|
enabled = excluded.enabled,
|
||||||
base_url = COALESCE(excluded.base_url, provider_configs.base_url),
|
base_url = COALESCE(excluded.base_url, provider_configs.base_url),
|
||||||
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
|
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
|
||||||
|
api_key_encrypted = excluded.api_key_encrypted,
|
||||||
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
|
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
|
||||||
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
|
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
|
||||||
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
|
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
|
||||||
updated_at = CURRENT_TIMESTAMP
|
updated_at = CURRENT_TIMESTAMP
|
||||||
`, name, strings.ToUpper(name), req.Enabled, req.BaseURL, req.APIKey, req.CreditBalance, req.LowCreditThreshold, req.BillingMode)
|
`, name, strings.ToUpper(name), req.Enabled, req.BaseURL, apiKey, req.CreditBalance, req.LowCreditThreshold, req.BillingMode, apiKeyEncrypted)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Refresh in-memory providers
|
||||||
|
if err := s.RefreshProviders(); err != nil {
|
||||||
|
fmt.Printf("Error refreshing providers: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Provider updated"}))
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Provider updated"}))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -962,7 +979,7 @@ func (s *Server) handleTestProvider(c *gin.Context) {
|
|||||||
} else if name == "deepseek" {
|
} else if name == "deepseek" {
|
||||||
testReq.Model = "deepseek-chat"
|
testReq.Model = "deepseek-chat"
|
||||||
} else if name == "grok" {
|
} else if name == "grok" {
|
||||||
testReq.Model = "grok-2"
|
testReq.Model = "grok-3-mini"
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := provider.ChatCompletion(c.Request.Context(), testReq)
|
_, err := provider.ChatCompletion(c.Request.Context(), testReq)
|
||||||
|
|||||||
@@ -54,28 +54,102 @@ func NewServer(cfg *config.Config, database *db.DB) *Server {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Initialize providers
|
// Initialize providers from DB and Config
|
||||||
if cfg.Providers.OpenAI.Enabled {
|
if err := s.RefreshProviders(); err != nil {
|
||||||
apiKey, _ := cfg.GetAPIKey("openai")
|
fmt.Printf("Warning: Failed to initial refresh providers: %v\n", err)
|
||||||
s.providers["openai"] = providers.NewOpenAIProvider(cfg.Providers.OpenAI, apiKey)
|
|
||||||
}
|
|
||||||
if cfg.Providers.Gemini.Enabled {
|
|
||||||
apiKey, _ := cfg.GetAPIKey("gemini")
|
|
||||||
s.providers["gemini"] = providers.NewGeminiProvider(cfg.Providers.Gemini, apiKey)
|
|
||||||
}
|
|
||||||
if cfg.Providers.DeepSeek.Enabled {
|
|
||||||
apiKey, _ := cfg.GetAPIKey("deepseek")
|
|
||||||
s.providers["deepseek"] = providers.NewDeepSeekProvider(cfg.Providers.DeepSeek, apiKey)
|
|
||||||
}
|
|
||||||
if cfg.Providers.Grok.Enabled {
|
|
||||||
apiKey, _ := cfg.GetAPIKey("grok")
|
|
||||||
s.providers["grok"] = providers.NewGrokProvider(cfg.Providers.Grok, apiKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.setupRoutes()
|
s.setupRoutes()
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) RefreshProviders() error {
|
||||||
|
var dbConfigs []db.ProviderConfig
|
||||||
|
err := s.database.Select(&dbConfigs, "SELECT * FROM provider_configs")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to fetch provider configs from db: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dbMap := make(map[string]db.ProviderConfig)
|
||||||
|
for _, cfg := range dbConfigs {
|
||||||
|
dbMap[cfg.ID] = cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
providerIDs := []string{"openai", "gemini", "deepseek", "grok"}
|
||||||
|
for _, id := range providerIDs {
|
||||||
|
// Default values from config
|
||||||
|
enabled := false
|
||||||
|
baseURL := ""
|
||||||
|
apiKey := ""
|
||||||
|
|
||||||
|
switch id {
|
||||||
|
case "openai":
|
||||||
|
enabled = s.cfg.Providers.OpenAI.Enabled
|
||||||
|
baseURL = s.cfg.Providers.OpenAI.BaseURL
|
||||||
|
apiKey, _ = s.cfg.GetAPIKey("openai")
|
||||||
|
case "gemini":
|
||||||
|
enabled = s.cfg.Providers.Gemini.Enabled
|
||||||
|
baseURL = s.cfg.Providers.Gemini.BaseURL
|
||||||
|
apiKey, _ = s.cfg.GetAPIKey("gemini")
|
||||||
|
case "deepseek":
|
||||||
|
enabled = s.cfg.Providers.DeepSeek.Enabled
|
||||||
|
baseURL = s.cfg.Providers.DeepSeek.BaseURL
|
||||||
|
apiKey, _ = s.cfg.GetAPIKey("deepseek")
|
||||||
|
case "grok":
|
||||||
|
enabled = s.cfg.Providers.Grok.Enabled
|
||||||
|
baseURL = s.cfg.Providers.Grok.BaseURL
|
||||||
|
apiKey, _ = s.cfg.GetAPIKey("grok")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overrides from DB
|
||||||
|
if dbCfg, ok := dbMap[id]; ok {
|
||||||
|
enabled = dbCfg.Enabled
|
||||||
|
if dbCfg.BaseURL != nil && *dbCfg.BaseURL != "" {
|
||||||
|
baseURL = *dbCfg.BaseURL
|
||||||
|
}
|
||||||
|
if dbCfg.APIKey != nil && *dbCfg.APIKey != "" {
|
||||||
|
key := *dbCfg.APIKey
|
||||||
|
if dbCfg.APIKeyEncrypted {
|
||||||
|
decrypted, err := utils.Decrypt(key, s.cfg.KeyBytes)
|
||||||
|
if err == nil {
|
||||||
|
key = decrypted
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Warning: Failed to decrypt API key for %s: %v\n", id, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
apiKey = key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !enabled {
|
||||||
|
delete(s.providers, id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize provider
|
||||||
|
switch id {
|
||||||
|
case "openai":
|
||||||
|
cfg := s.cfg.Providers.OpenAI
|
||||||
|
cfg.BaseURL = baseURL
|
||||||
|
s.providers["openai"] = providers.NewOpenAIProvider(cfg, apiKey)
|
||||||
|
case "gemini":
|
||||||
|
cfg := s.cfg.Providers.Gemini
|
||||||
|
cfg.BaseURL = baseURL
|
||||||
|
s.providers["gemini"] = providers.NewGeminiProvider(cfg, apiKey)
|
||||||
|
case "deepseek":
|
||||||
|
cfg := s.cfg.Providers.DeepSeek
|
||||||
|
cfg.BaseURL = baseURL
|
||||||
|
s.providers["deepseek"] = providers.NewDeepSeekProvider(cfg, apiKey)
|
||||||
|
case "grok":
|
||||||
|
cfg := s.cfg.Providers.Grok
|
||||||
|
cfg.BaseURL = baseURL
|
||||||
|
s.providers["grok"] = providers.NewGrokProvider(cfg, apiKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) setupRoutes() {
|
func (s *Server) setupRoutes() {
|
||||||
s.router.Use(middleware.AuthMiddleware(s.database))
|
s.router.Use(middleware.AuthMiddleware(s.database))
|
||||||
|
|
||||||
|
|||||||
71
internal/utils/crypto.go
Normal file
71
internal/utils/crypto.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Encrypt encrypts plain text using AES-GCM with the given 32-byte key.
|
||||||
|
func Encrypt(plainText string, key []byte) (string, error) {
|
||||||
|
if len(key) != 32 {
|
||||||
|
return "", fmt.Errorf("encryption key must be 32 bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce := make([]byte, gcm.NonceSize())
|
||||||
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// The nonce should be prepended to the ciphertext
|
||||||
|
cipherText := gcm.Seal(nonce, nonce, []byte(plainText), nil)
|
||||||
|
return base64.StdEncoding.EncodeToString(cipherText), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt decrypts base64-encoded cipher text using AES-GCM with the given 32-byte key.
|
||||||
|
func Decrypt(encodedCipherText string, key []byte) (string, error) {
|
||||||
|
if len(key) != 32 {
|
||||||
|
return "", fmt.Errorf("encryption key must be 32 bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
cipherText, err := base64.StdEncoding.DecodeString(encodedCipherText)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
nonceSize := gcm.NonceSize()
|
||||||
|
if len(cipherText) < nonceSize {
|
||||||
|
return "", fmt.Errorf("cipher text too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce, actualCipherText := cipherText[:nonceSize], cipherText[nonceSize:]
|
||||||
|
plainText, err := gcm.Open(nil, nonce, actualCipherText, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(plainText), nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user