244 lines
6.5 KiB
Go
244 lines
6.5 KiB
Go
package server
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"log/slog"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"gophergate/internal/db"
|
|
"gophergate/internal/models"
|
|
"gophergate/internal/utils"
|
|
)
|
|
|
|
func (s *Server) handleGetProviders(c *gin.Context) {
|
|
var dbConfigs []db.ProviderConfig
|
|
err := s.database.Select(&dbConfigs, "SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs")
|
|
if err != nil {
|
|
// Log error
|
|
}
|
|
|
|
dbMap := make(map[string]db.ProviderConfig)
|
|
for _, cfg := range dbConfigs {
|
|
dbMap[cfg.ID] = cfg
|
|
}
|
|
|
|
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama"}
|
|
var result []gin.H
|
|
|
|
for _, id := range providerIDs {
|
|
var name string
|
|
var enabled bool
|
|
var baseURL string
|
|
|
|
switch id {
|
|
case "openai":
|
|
name = "OpenAI"
|
|
enabled = s.cfg.Providers.OpenAI.Enabled
|
|
baseURL = s.cfg.Providers.OpenAI.BaseURL
|
|
case "gemini":
|
|
name = "Google Gemini"
|
|
enabled = s.cfg.Providers.Gemini.Enabled
|
|
baseURL = s.cfg.Providers.Gemini.BaseURL
|
|
case "deepseek":
|
|
name = "DeepSeek"
|
|
enabled = s.cfg.Providers.DeepSeek.Enabled
|
|
baseURL = s.cfg.Providers.DeepSeek.BaseURL
|
|
case "moonshot":
|
|
name = "Moonshot"
|
|
enabled = s.cfg.Providers.Moonshot.Enabled
|
|
baseURL = s.cfg.Providers.Moonshot.BaseURL
|
|
case "grok":
|
|
name = "xAI Grok"
|
|
enabled = s.cfg.Providers.Grok.Enabled
|
|
baseURL = s.cfg.Providers.Grok.BaseURL
|
|
case "ollama":
|
|
name = "Ollama"
|
|
enabled = s.cfg.Providers.Ollama.Enabled
|
|
baseURL = s.cfg.Providers.Ollama.BaseURL
|
|
}
|
|
|
|
var balance float64
|
|
var threshold float64 = 5.0
|
|
var billingMode string
|
|
|
|
if dbCfg, ok := dbMap[id]; ok {
|
|
enabled = dbCfg.Enabled
|
|
if dbCfg.BaseURL != nil {
|
|
baseURL = *dbCfg.BaseURL
|
|
}
|
|
balance = dbCfg.CreditBalance
|
|
threshold = dbCfg.LowCreditThreshold
|
|
if dbCfg.BillingMode != nil {
|
|
billingMode = *dbCfg.BillingMode
|
|
}
|
|
}
|
|
|
|
status := "disabled"
|
|
if enabled {
|
|
if _, ok := s.providers[id]; ok {
|
|
status = "online"
|
|
} else {
|
|
status = "error"
|
|
}
|
|
}
|
|
|
|
// Get last used for this provider
|
|
var lastUsedTime sql.NullTime
|
|
_ = s.database.Get(&lastUsedTime, "SELECT MAX(timestamp) FROM llm_requests WHERE provider = ?", id)
|
|
var lastUsed interface{}
|
|
if lastUsedTime.Valid && !lastUsedTime.Time.IsZero() {
|
|
lastUsed = lastUsedTime.Time
|
|
}
|
|
|
|
// Get models for this provider from registry
|
|
var models []string
|
|
s.registryMu.RLock()
|
|
if s.registry != nil {
|
|
registryID := id
|
|
if id == "gemini" {
|
|
registryID = "google"
|
|
}
|
|
if id == "moonshot" {
|
|
registryID = "moonshot"
|
|
}
|
|
if id == "grok" {
|
|
registryID = "xai"
|
|
}
|
|
|
|
if pInfo, ok := s.registry.Providers[registryID]; ok {
|
|
for mID := range pInfo.Models {
|
|
models = append(models, mID)
|
|
}
|
|
}
|
|
}
|
|
s.registryMu.RUnlock()
|
|
|
|
// If it's ollama, also include models from config
|
|
if id == "ollama" {
|
|
models = append(models, s.cfg.Providers.Ollama.Models...)
|
|
}
|
|
|
|
result = append(result, gin.H{
|
|
"id": id,
|
|
"name": name,
|
|
"enabled": enabled,
|
|
"status": status,
|
|
"base_url": baseURL,
|
|
"credit_balance": balance,
|
|
"low_credit_threshold": threshold,
|
|
"billing_mode": billingMode,
|
|
"last_used": lastUsed,
|
|
"models": models,
|
|
})
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(result))
|
|
}
|
|
|
|
type UpdateProviderRequest struct {
|
|
Enabled bool `json:"enabled"`
|
|
BaseURL *string `json:"base_url"`
|
|
APIKey *string `json:"api_key"`
|
|
CreditBalance *float64 `json:"credit_balance"`
|
|
LowCreditThreshold *float64 `json:"low_credit_threshold"`
|
|
BillingMode *string `json:"billing_mode"`
|
|
}
|
|
|
|
func (s *Server) handleUpdateProvider(c *gin.Context) {
|
|
name := c.Param("name")
|
|
var req UpdateProviderRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
|
return
|
|
}
|
|
|
|
apiKeyEncrypted := false
|
|
var apiKey *string = req.APIKey
|
|
if req.APIKey != nil && *req.APIKey != "" {
|
|
encrypted, err := utils.Encrypt(*req.APIKey, s.cfg.KeyBytes)
|
|
if err == nil {
|
|
apiKey = &encrypted
|
|
apiKeyEncrypted = true
|
|
}
|
|
}
|
|
|
|
_, err := s.database.Exec(`
|
|
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode, api_key_encrypted)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(id) DO UPDATE SET
|
|
enabled = excluded.enabled,
|
|
base_url = COALESCE(excluded.base_url, provider_configs.base_url),
|
|
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
|
|
api_key_encrypted = excluded.api_key_encrypted,
|
|
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
|
|
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
|
|
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
|
|
updated_at = CURRENT_TIMESTAMP
|
|
`, name, strings.ToUpper(name), req.Enabled, req.BaseURL, apiKey, req.CreditBalance, req.LowCreditThreshold, req.BillingMode, apiKeyEncrypted)
|
|
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
|
return
|
|
}
|
|
|
|
// Refresh in-memory providers
|
|
if err := s.RefreshProviders(); err != nil {
|
|
fmt.Printf("Error refreshing providers: %v\n", err)
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Provider updated"}))
|
|
}
|
|
|
|
func (s *Server) handleTestProvider(c *gin.Context) {
|
|
name := c.Param("name")
|
|
provider, ok := s.providers[name]
|
|
if !ok {
|
|
c.JSON(http.StatusNotFound, ErrorResponse(fmt.Sprintf("Provider %s not found or not enabled", name)))
|
|
return
|
|
}
|
|
|
|
startTime := time.Now()
|
|
|
|
// Prepare a simple test request
|
|
testReq := &models.UnifiedRequest{
|
|
Model: "gpt-4o-mini", // Default cheap test model
|
|
Messages: []models.UnifiedMessage{
|
|
{
|
|
Role: "user",
|
|
Content: []models.UnifiedContentPart{{Type: "text", Text: "Hi"}},
|
|
},
|
|
},
|
|
MaxTokens: new(uint32),
|
|
}
|
|
*testReq.MaxTokens = 5
|
|
|
|
// Adjust model for non-openai providers
|
|
if name == "gemini" {
|
|
testReq.Model = "gemini-2.0-flash"
|
|
} else if name == "deepseek" {
|
|
testReq.Model = "deepseek-chat"
|
|
} else if name == "moonshot" {
|
|
testReq.Model = "kimi-k2.5"
|
|
} else if name == "grok" {
|
|
testReq.Model = "grok-4-1-fast-non-reasoning"
|
|
}
|
|
|
|
_, err := provider.ChatCompletion(c.Request.Context(), testReq)
|
|
latency := time.Since(startTime).Milliseconds()
|
|
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, ErrorResponse(fmt.Sprintf("Provider test failed: %v", err)))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
|
"message": "Connection test successful",
|
|
"latency": latency,
|
|
}))
|
|
}
|
|
|