diff --git a/.pi-lens/turn-state.json b/.pi-lens/turn-state.json index 2b1f3650..8d71670b 100644 --- a/.pi-lens/turn-state.json +++ b/.pi-lens/turn-state.json @@ -2,5 +2,5 @@ "files": {}, "turnCycles": 0, "maxCycles": 3, - "lastUpdated": "2026-04-26T18:47:32.097Z" + "lastUpdated": "2026-04-26T18:49:43.830Z" } \ No newline at end of file diff --git a/internal/models/registry_test.go b/internal/models/registry_test.go new file mode 100644 index 00000000..c0e47565 --- /dev/null +++ b/internal/models/registry_test.go @@ -0,0 +1,60 @@ +package models + +import ( + "testing" +) + +func TestModelRegistry_FindModel_Exact(t *testing.T) { + r := &ModelRegistry{ + Providers: map[string]ProviderInfo{ + "openai": { + Models: map[string]ModelMetadata{ + "gpt-4o": {ID: "gpt-4o", Name: "GPT-4o"}, + }, + }, + }, + } + m := r.FindModel("gpt-4o") + if m == nil { + t.Fatal("expected to find gpt-4o") + } + if m.Name != "GPT-4o" { + t.Fatalf("expected GPT-4o, got %s", m.Name) + } +} + +func TestModelRegistry_FindModel_Fuzzy(t *testing.T) { + r := &ModelRegistry{ + Providers: map[string]ProviderInfo{ + "openai": { + Models: map[string]ModelMetadata{ + "gpt-4o": {ID: "gpt-4o", Name: "GPT-4o"}, + }, + }, + }, + } + // Fuzzy: "gpt-4o-2024-05-13" should match "gpt-4o" + m := r.FindModel("gpt-4o-2024-05-13") + if m == nil { + t.Fatal("expected fuzzy match") + } + if m.Name != "GPT-4o" { + t.Fatalf("expected GPT-4o, got %s", m.Name) + } +} + +func TestModelRegistry_FindModel_NotFound(t *testing.T) { + r := &ModelRegistry{ + Providers: map[string]ProviderInfo{ + "openai": { + Models: map[string]ModelMetadata{ + "gpt-4o": {ID: "gpt-4o", Name: "GPT-4o"}, + }, + }, + }, + } + m := r.FindModel("nonexistent-model") + if m != nil { + t.Fatal("expected nil for nonexistent model") + } +} diff --git a/internal/providers/ollama.go b/internal/providers/ollama.go index 5ceae4d1..08175a17 100644 --- a/internal/providers/ollama.go +++ b/internal/providers/ollama.go @@ -12,7 +12,6 @@ import ( "github.com/go-resty/resty/v2" "gophergate/internal/config" "gophergate/internal/models" - "log/slog" ) type OllamaProvider struct { @@ -53,18 +52,15 @@ func (p *OllamaProvider) ChatCompletion(ctx context.Context, req *models.Unified Post(url) if err != nil { - fmt.Printf("[Ollama] Request error: %v\n", err) return nil, fmt.Errorf("request failed: %w", err) } if !resp.IsSuccess() { - fmt.Printf("[Ollama] API error %d: %s\n", resp.StatusCode(), resp.String()) return nil, fmt.Errorf("Ollama API error (%d): %s", resp.StatusCode(), resp.String()) } var respJSON map[string]interface{} if err := json.Unmarshal(resp.Body(), &respJSON); err != nil { - fmt.Printf("[Ollama] Parse error: %v\n", err) return nil, fmt.Errorf("failed to parse response: %w", err) } @@ -99,7 +95,6 @@ func (p *OllamaProvider) ChatCompletionStream(ctx context.Context, req *models.U defer close(ch) err := StreamOllama(resp.RawBody(), ch, req.Model) if err != nil { - fmt.Printf("Stream error: %v\n", err) } }() diff --git a/internal/server/clients.go b/internal/server/clients.go new file mode 100644 index 00000000..cc21f644 --- /dev/null +++ b/internal/server/clients.go @@ -0,0 +1,281 @@ +package server + +import ( + "database/sql" + "fmt" + "net/http" + "os" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "golang.org/x/crypto/bcrypt" + "gophergate/internal/db" + "gophergate/internal/models" + "gophergate/internal/utils" + "log/slog" + + "github.com/shirou/gopsutil/v3/cpu" + +func (s *Server) handleGetClients(c *gin.Context) { + var clients []db.Client + err := s.database.Select(&clients, "SELECT * FROM clients ORDER BY created_at DESC") + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + type UIClient struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + CreatedAt time.Time `json:"created_at"` + LastUsed *time.Time `json:"last_used"` + RequestsCount int `json:"requests_count"` + TokensCount int `json:"tokens_count"` + Status string `json:"status"` + RateLimitPerMinute int `json:"rate_limit_per_minute"` + } + + uiClients := make([]UIClient, len(clients)) + for i, cl := range clients { + status := "active" + if !cl.IsActive { + status = "disabled" + } + + name := "" + if cl.Name != nil { + name = *cl.Name + } + desc := "" + if cl.Description != nil { + desc = *cl.Description + } + + var lastUsedTime sql.NullTime + _ = s.database.Get(&lastUsedTime, "SELECT MAX(last_used_at) FROM client_tokens WHERE client_id = ?", cl.ClientID) + + var lastUsed *time.Time + if lastUsedTime.Valid && !lastUsedTime.Time.IsZero() { + t := lastUsedTime.Time + lastUsed = &t + } + + uiClients[i] = UIClient{ + ID: cl.ClientID, + Name: name, + Description: desc, + CreatedAt: cl.CreatedAt, + LastUsed: lastUsed, + RequestsCount: cl.TotalRequests, + TokensCount: cl.TotalTokens, + Status: status, + RateLimitPerMinute: cl.RateLimitPerMinute, + } + } + + c.JSON(http.StatusOK, SuccessResponse(uiClients)) +} + +func (s *Server) handleGetClient(c *gin.Context) { + id := c.Param("id") + var cl db.Client + err := s.database.Get(&cl, "SELECT * FROM clients WHERE client_id = ?", id) + if err != nil { + c.JSON(http.StatusNotFound, ErrorResponse("Client not found")) + return + } + + name := "" + if cl.Name != nil { + name = *cl.Name + } + desc := "" + if cl.Description != nil { + desc = *cl.Description + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{ + "id": cl.ClientID, + "name": name, + "description": desc, + "is_active": cl.IsActive, + "rate_limit_per_minute": cl.RateLimitPerMinute, + "created_at": cl.CreatedAt, + })) +} + +type UpdateClientRequest struct { + Name string `json:"name"` + Description *string `json:"description"` + IsActive bool `json:"is_active"` + RateLimitPerMinute *int `json:"rate_limit_per_minute"` +} + +func (s *Server) handleUpdateClient(c *gin.Context) { + id := c.Param("id") + var req UpdateClientRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request")) + return + } + + _, err := s.database.Exec(` + UPDATE clients SET + name = ?, + description = ?, + is_active = ?, + rate_limit_per_minute = COALESCE(?, rate_limit_per_minute), + updated_at = CURRENT_TIMESTAMP + WHERE client_id = ? + `, req.Name, req.Description, req.IsActive, req.RateLimitPerMinute, id) + + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Client updated"})) +} + +type CreateClientRequest struct { + Name string `json:"name" binding:"required"` + ClientID *string `json:"client_id"` +} + +func (s *Server) handleCreateClient(c *gin.Context) { + var req CreateClientRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request")) + return + } + + clientID := "" + if req.ClientID != nil { + clientID = *req.ClientID + } else { + clientID = "client-" + uuid.New().String()[:8] + } + + _, err := s.database.Exec("INSERT INTO clients (client_id, name, is_active) VALUES (?, ?, 1)", clientID, req.Name) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + token := "sk-" + uuid.New().String() + uuid.New().String() + token = token[:51] + + _, err = s.database.Exec("INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, 'default')", clientID, token) + if err != nil { + // Log error + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{ + "id": clientID, + "name": req.Name, + "status": "active", + "token": token, + "created_at": time.Now(), + })) +} + +func (s *Server) handleDeleteClient(c *gin.Context) { + id := c.Param("id") + if id == "default" { + c.JSON(http.StatusBadRequest, ErrorResponse("Cannot delete default client")) + return + } + + _, err := s.database.Exec("DELETE FROM clients WHERE client_id = ?", id) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Client deleted"})) +} + +func (s *Server) handleGetClientTokens(c *gin.Context) { + id := c.Param("id") + var tokens []db.ClientToken + err := s.database.Select(&tokens, "SELECT * FROM client_tokens WHERE client_id = ? ORDER BY created_at DESC", id) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + type MaskedToken struct { + ID int `json:"id"` + TokenMasked string `json:"token_masked"` + Name string `json:"name"` + IsActive bool `json:"is_active"` + CreatedAt time.Time `json:"created_at"` + LastUsedAt *time.Time `json:"last_used_at"` + } + + masked := make([]MaskedToken, len(tokens)) + for i, t := range tokens { + maskedToken := "••••" + if len(t.Token) > 8 { + maskedToken = t.Token[:3] + "••••" + t.Token[len(t.Token)-8:] + } + masked[i] = MaskedToken{ + ID: t.ID, + TokenMasked: maskedToken, + Name: t.Name, + IsActive: t.IsActive, + CreatedAt: t.CreatedAt, + LastUsedAt: t.LastUsedAt, + } + } + + c.JSON(http.StatusOK, SuccessResponse(masked)) +} + +type CreateTokenRequest struct { + Name string `json:"name"` +} + +func (s *Server) handleCreateClientToken(c *gin.Context) { + clientID := c.Param("id") + var req CreateTokenRequest + if err := c.ShouldBindJSON(&req); err != nil { + // optional name + } + + name := "default" + if req.Name != "" { + name = req.Name + } + + token := "sk-" + uuid.New().String() + uuid.New().String() + token = token[:51] + + _, err := s.database.Exec("INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, ?)", clientID, token, name) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{ + "token": token, + "name": name, + "created_at": time.Now(), + })) +} + +func (s *Server) handleDeleteClientToken(c *gin.Context) { + tokenID := c.Param("token_id") + + _, err := s.database.Exec("DELETE FROM client_tokens WHERE id = ?", tokenID) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Token revoked"})) +} + diff --git a/internal/server/dashboard.go b/internal/server/dashboard.go index 361368b0..63de8733 100644 --- a/internal/server/dashboard.go +++ b/internal/server/dashboard.go @@ -532,951 +532,3 @@ func (s *Server) handleDetailedUsage(c *gin.Context) { "cost": cost, }) } - } - - c.JSON(http.StatusOK, SuccessResponse(results)) -} - -func (s *Server) handleGetClients(c *gin.Context) { - var clients []db.Client - err := s.database.Select(&clients, "SELECT * FROM clients ORDER BY created_at DESC") - if err != nil { - c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) - return - } - - type UIClient struct { - ID string `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - CreatedAt time.Time `json:"created_at"` - LastUsed *time.Time `json:"last_used"` - RequestsCount int `json:"requests_count"` - TokensCount int `json:"tokens_count"` - Status string `json:"status"` - RateLimitPerMinute int `json:"rate_limit_per_minute"` - } - - uiClients := make([]UIClient, len(clients)) - for i, cl := range clients { - status := "active" - if !cl.IsActive { - status = "disabled" - } - - name := "" - if cl.Name != nil { - name = *cl.Name - } - desc := "" - if cl.Description != nil { - desc = *cl.Description - } - - var lastUsedTime sql.NullTime - _ = s.database.Get(&lastUsedTime, "SELECT MAX(last_used_at) FROM client_tokens WHERE client_id = ?", cl.ClientID) - - var lastUsed *time.Time - if lastUsedTime.Valid && !lastUsedTime.Time.IsZero() { - t := lastUsedTime.Time - lastUsed = &t - } - - uiClients[i] = UIClient{ - ID: cl.ClientID, - Name: name, - Description: desc, - CreatedAt: cl.CreatedAt, - LastUsed: lastUsed, - RequestsCount: cl.TotalRequests, - TokensCount: cl.TotalTokens, - Status: status, - RateLimitPerMinute: cl.RateLimitPerMinute, - } - } - - c.JSON(http.StatusOK, SuccessResponse(uiClients)) -} - -func (s *Server) handleGetClient(c *gin.Context) { - id := c.Param("id") - var cl db.Client - err := s.database.Get(&cl, "SELECT * FROM clients WHERE client_id = ?", id) - if err != nil { - c.JSON(http.StatusNotFound, ErrorResponse("Client not found")) - return - } - - name := "" - if cl.Name != nil { - name = *cl.Name - } - desc := "" - if cl.Description != nil { - desc = *cl.Description - } - - c.JSON(http.StatusOK, SuccessResponse(gin.H{ - "id": cl.ClientID, - "name": name, - "description": desc, - "is_active": cl.IsActive, - "rate_limit_per_minute": cl.RateLimitPerMinute, - "created_at": cl.CreatedAt, - })) -} - -type UpdateClientRequest struct { - Name string `json:"name"` - Description *string `json:"description"` - IsActive bool `json:"is_active"` - RateLimitPerMinute *int `json:"rate_limit_per_minute"` -} - -func (s *Server) handleUpdateClient(c *gin.Context) { - id := c.Param("id") - var req UpdateClientRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request")) - return - } - - _, err := s.database.Exec(` - UPDATE clients SET - name = ?, - description = ?, - is_active = ?, - rate_limit_per_minute = COALESCE(?, rate_limit_per_minute), - updated_at = CURRENT_TIMESTAMP - WHERE client_id = ? - `, req.Name, req.Description, req.IsActive, req.RateLimitPerMinute, id) - - if err != nil { - c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) - return - } - - c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Client updated"})) -} - -type CreateClientRequest struct { - Name string `json:"name" binding:"required"` - ClientID *string `json:"client_id"` -} - -func (s *Server) handleCreateClient(c *gin.Context) { - var req CreateClientRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request")) - return - } - - clientID := "" - if req.ClientID != nil { - clientID = *req.ClientID - } else { - clientID = "client-" + uuid.New().String()[:8] - } - - _, err := s.database.Exec("INSERT INTO clients (client_id, name, is_active) VALUES (?, ?, 1)", clientID, req.Name) - if err != nil { - c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) - return - } - - token := "sk-" + uuid.New().String() + uuid.New().String() - token = token[:51] - - _, err = s.database.Exec("INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, 'default')", clientID, token) - if err != nil { - // Log error - } - - c.JSON(http.StatusOK, SuccessResponse(gin.H{ - "id": clientID, - "name": req.Name, - "status": "active", - "token": token, - "created_at": time.Now(), - })) -} - -func (s *Server) handleDeleteClient(c *gin.Context) { - id := c.Param("id") - if id == "default" { - c.JSON(http.StatusBadRequest, ErrorResponse("Cannot delete default client")) - return - } - - _, err := s.database.Exec("DELETE FROM clients WHERE client_id = ?", id) - if err != nil { - c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) - return - } - - c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Client deleted"})) -} - -func (s *Server) handleGetClientTokens(c *gin.Context) { - id := c.Param("id") - var tokens []db.ClientToken - err := s.database.Select(&tokens, "SELECT * FROM client_tokens WHERE client_id = ? ORDER BY created_at DESC", id) - if err != nil { - c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) - return - } - - type MaskedToken struct { - ID int `json:"id"` - TokenMasked string `json:"token_masked"` - Name string `json:"name"` - IsActive bool `json:"is_active"` - CreatedAt time.Time `json:"created_at"` - LastUsedAt *time.Time `json:"last_used_at"` - } - - masked := make([]MaskedToken, len(tokens)) - for i, t := range tokens { - maskedToken := "••••" - if len(t.Token) > 8 { - maskedToken = t.Token[:3] + "••••" + t.Token[len(t.Token)-8:] - } - masked[i] = MaskedToken{ - ID: t.ID, - TokenMasked: maskedToken, - Name: t.Name, - IsActive: t.IsActive, - CreatedAt: t.CreatedAt, - LastUsedAt: t.LastUsedAt, - } - } - - c.JSON(http.StatusOK, SuccessResponse(masked)) -} - -type CreateTokenRequest struct { - Name string `json:"name"` -} - -func (s *Server) handleCreateClientToken(c *gin.Context) { - clientID := c.Param("id") - var req CreateTokenRequest - if err := c.ShouldBindJSON(&req); err != nil { - // optional name - } - - name := "default" - if req.Name != "" { - name = req.Name - } - - token := "sk-" + uuid.New().String() + uuid.New().String() - token = token[:51] - - _, err := s.database.Exec("INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, ?)", clientID, token, name) - if err != nil { - c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) - return - } - - c.JSON(http.StatusOK, SuccessResponse(gin.H{ - "token": token, - "name": name, - "created_at": time.Now(), - })) -} - -func (s *Server) handleDeleteClientToken(c *gin.Context) { - tokenID := c.Param("token_id") - - _, err := s.database.Exec("DELETE FROM client_tokens WHERE id = ?", tokenID) - if err != nil { - c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) - return - } - - c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Token revoked"})) -} - -func (s *Server) handleGetProviders(c *gin.Context) { - var dbConfigs []db.ProviderConfig - err := s.database.Select(&dbConfigs, "SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs") - if err != nil { - // Log error - } - - dbMap := make(map[string]db.ProviderConfig) - for _, cfg := range dbConfigs { - dbMap[cfg.ID] = cfg - } - - providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama"} - var result []gin.H - - for _, id := range providerIDs { - var name string - var enabled bool - var baseURL string - - switch id { - case "openai": - name = "OpenAI" - enabled = s.cfg.Providers.OpenAI.Enabled - baseURL = s.cfg.Providers.OpenAI.BaseURL - case "gemini": - name = "Google Gemini" - enabled = s.cfg.Providers.Gemini.Enabled - baseURL = s.cfg.Providers.Gemini.BaseURL - case "deepseek": - name = "DeepSeek" - enabled = s.cfg.Providers.DeepSeek.Enabled - baseURL = s.cfg.Providers.DeepSeek.BaseURL - case "moonshot": - name = "Moonshot" - enabled = s.cfg.Providers.Moonshot.Enabled - baseURL = s.cfg.Providers.Moonshot.BaseURL - case "grok": - name = "xAI Grok" - enabled = s.cfg.Providers.Grok.Enabled - baseURL = s.cfg.Providers.Grok.BaseURL - case "ollama": - name = "Ollama" - enabled = s.cfg.Providers.Ollama.Enabled - baseURL = s.cfg.Providers.Ollama.BaseURL - } - - var balance float64 - var threshold float64 = 5.0 - var billingMode string - - if dbCfg, ok := dbMap[id]; ok { - enabled = dbCfg.Enabled - if dbCfg.BaseURL != nil { - baseURL = *dbCfg.BaseURL - } - balance = dbCfg.CreditBalance - threshold = dbCfg.LowCreditThreshold - if dbCfg.BillingMode != nil { - billingMode = *dbCfg.BillingMode - } - } - - status := "disabled" - if enabled { - if _, ok := s.providers[id]; ok { - status = "online" - } else { - status = "error" - } - } - - // Get last used for this provider - var lastUsedTime sql.NullTime - _ = s.database.Get(&lastUsedTime, "SELECT MAX(timestamp) FROM llm_requests WHERE provider = ?", id) - var lastUsed interface{} - if lastUsedTime.Valid && !lastUsedTime.Time.IsZero() { - lastUsed = lastUsedTime.Time - } - - // Get models for this provider from registry - var models []string - s.registryMu.RLock() - if s.registry != nil { - registryID := id - if id == "gemini" { - registryID = "google" - } - if id == "moonshot" { - registryID = "moonshot" - } - if id == "grok" { - registryID = "xai" - } - - if pInfo, ok := s.registry.Providers[registryID]; ok { - for mID := range pInfo.Models { - models = append(models, mID) - } - } - } - s.registryMu.RUnlock() - - // If it's ollama, also include models from config - if id == "ollama" { - models = append(models, s.cfg.Providers.Ollama.Models...) - } - - result = append(result, gin.H{ - "id": id, - "name": name, - "enabled": enabled, - "status": status, - "base_url": baseURL, - "credit_balance": balance, - "low_credit_threshold": threshold, - "billing_mode": billingMode, - "last_used": lastUsed, - "models": models, - }) - } - - c.JSON(http.StatusOK, SuccessResponse(result)) -} - -type UpdateProviderRequest struct { - Enabled bool `json:"enabled"` - BaseURL *string `json:"base_url"` - APIKey *string `json:"api_key"` - CreditBalance *float64 `json:"credit_balance"` - LowCreditThreshold *float64 `json:"low_credit_threshold"` - BillingMode *string `json:"billing_mode"` -} - -func (s *Server) handleUpdateProvider(c *gin.Context) { - name := c.Param("name") - var req UpdateProviderRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request")) - return - } - - apiKeyEncrypted := false - var apiKey *string = req.APIKey - if req.APIKey != nil && *req.APIKey != "" { - encrypted, err := utils.Encrypt(*req.APIKey, s.cfg.KeyBytes) - if err == nil { - apiKey = &encrypted - apiKeyEncrypted = true - } - } - - _, err := s.database.Exec(` - INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode, api_key_encrypted) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(id) DO UPDATE SET - enabled = excluded.enabled, - base_url = COALESCE(excluded.base_url, provider_configs.base_url), - api_key = COALESCE(excluded.api_key, provider_configs.api_key), - api_key_encrypted = excluded.api_key_encrypted, - credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance), - low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold), - billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode), - updated_at = CURRENT_TIMESTAMP - `, name, strings.ToUpper(name), req.Enabled, req.BaseURL, apiKey, req.CreditBalance, req.LowCreditThreshold, req.BillingMode, apiKeyEncrypted) - - if err != nil { - c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) - return - } - - // Refresh in-memory providers - if err := s.RefreshProviders(); err != nil { - fmt.Printf("Error refreshing providers: %v\n", err) - } - - c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Provider updated"})) -} - -func (s *Server) handleTestProvider(c *gin.Context) { - name := c.Param("name") - provider, ok := s.providers[name] - if !ok { - c.JSON(http.StatusNotFound, ErrorResponse(fmt.Sprintf("Provider %s not found or not enabled", name))) - return - } - - startTime := time.Now() - - // Prepare a simple test request - testReq := &models.UnifiedRequest{ - Model: "gpt-4o-mini", // Default cheap test model - Messages: []models.UnifiedMessage{ - { - Role: "user", - Content: []models.UnifiedContentPart{{Type: "text", Text: "Hi"}}, - }, - }, - MaxTokens: new(uint32), - } - *testReq.MaxTokens = 5 - - // Adjust model for non-openai providers - if name == "gemini" { - testReq.Model = "gemini-2.0-flash" - } else if name == "deepseek" { - testReq.Model = "deepseek-chat" - } else if name == "moonshot" { - testReq.Model = "kimi-k2.5" - } else if name == "grok" { - testReq.Model = "grok-4-1-fast-non-reasoning" - } - - _, err := provider.ChatCompletion(c.Request.Context(), testReq) - latency := time.Since(startTime).Milliseconds() - - if err != nil { - c.JSON(http.StatusOK, ErrorResponse(fmt.Sprintf("Provider test failed: %v", err))) - return - } - - c.JSON(http.StatusOK, SuccessResponse(gin.H{ - "message": "Connection test successful", - "latency": latency, - })) -} - -func (s *Server) handleGetModels(c *gin.Context) { - usedOnly := c.Query("used_only") == "true" - - // Registry provider normalized name -> Proxy-internal provider ID - allowedRegistryProviders := map[string]string{ - "openai": "openai", - "google": "gemini", - "deepseek": "deepseek", - "xai": "grok", - "ollama": "ollama", - } - - // Merge registry models with DB overrides - var dbModels []db.ModelConfig - _ = s.database.Select(&dbModels, "SELECT * FROM model_configs") - - dbMap := make(map[string]db.ModelConfig) - for _, m := range dbModels { - dbMap[m.ID] = m - } - - // Fetch specific (model, provider) combinations that have been used - type modelProvider struct { - Model string `db:"model"` - Provider string `db:"provider"` - } - usedPairs := make(map[string]bool) - if usedOnly { - var pairs []modelProvider - err := s.database.Select(&pairs, "SELECT DISTINCT model, provider FROM llm_requests WHERE status = 'success'") - if err == nil { - for _, p := range pairs { - usedPairs[fmt.Sprintf("%s:%s", p.Model, p.Provider)] = true - } - } - } - - var result []gin.H - s.registryMu.RLock() - if s.registry != nil { - for pID, pInfo := range s.registry.Providers { - proxyProvider, allowed := allowedRegistryProviders[pID] - if !allowed { - continue - } - - for mID, mMeta := range pInfo.Models { - if usedOnly && !usedPairs[fmt.Sprintf("%s:%s", mID, proxyProvider)] { - continue - } - - enabled := true - promptCost := 0.0 - completionCost := 0.0 - var cacheReadCost *float64 - var cacheWriteCost *float64 - var mapping *string - contextLimit := uint32(0) - - if mMeta.Cost != nil { - promptCost = mMeta.Cost.Input - completionCost = mMeta.Cost.Output - cacheReadCost = mMeta.Cost.CacheRead - cacheWriteCost = mMeta.Cost.CacheWrite - } - if mMeta.Limit != nil { - contextLimit = mMeta.Limit.Context - } - - // Override from DB - if dbCfg, ok := dbMap[mID]; ok { - enabled = dbCfg.Enabled - if dbCfg.PromptCostPerM != nil { - promptCost = *dbCfg.PromptCostPerM - } - if dbCfg.CompletionCostPerM != nil { - completionCost = *dbCfg.CompletionCostPerM - } - if dbCfg.CacheReadCostPerM != nil { - cacheReadCost = dbCfg.CacheReadCostPerM - } - if dbCfg.CacheWriteCostPerM != nil { - cacheWriteCost = dbCfg.CacheWriteCostPerM - } - mapping = dbCfg.Mapping - } - - result = append(result, gin.H{ - "id": mID, - "name": mMeta.Name, - "provider": proxyProvider, - "enabled": enabled, - "prompt_cost": promptCost, - "completion_cost": completionCost, - "cache_read_cost": cacheReadCost, - "cache_write_cost": cacheWriteCost, - "context_limit": contextLimit, - "mapping": mapping, - "tool_call": mMeta.ToolCall != nil && *mMeta.ToolCall, - "reasoning": mMeta.Reasoning != nil && *mMeta.Reasoning, - "modalities": mMeta.Modalities, - }) - } - } - } - - // Add configured Ollama models if they aren't in registry - if s.cfg.Providers.Ollama.Enabled { - for _, mID := range s.cfg.Providers.Ollama.Models { - // Check if already added from registry - exists := false - for _, r := range result { - if r["id"] == mID { - exists = true - break - } - } - if exists { - continue - } - - if usedOnly && !usedPairs[fmt.Sprintf("%s:ollama", mID)] { - continue - } - - enabled := true - promptCost := 0.0 - completionCost := 0.0 - var cacheReadCost *float64 - var cacheWriteCost *float64 - var mapping *string - contextLimit := uint32(0) - - // Override from DB - if dbCfg, ok := dbMap[mID]; ok { - enabled = dbCfg.Enabled - if dbCfg.PromptCostPerM != nil { - promptCost = *dbCfg.PromptCostPerM - } - if dbCfg.CompletionCostPerM != nil { - completionCost = *dbCfg.CompletionCostPerM - } - if dbCfg.CacheReadCostPerM != nil { - cacheReadCost = dbCfg.CacheReadCostPerM - } - if dbCfg.CacheWriteCostPerM != nil { - cacheWriteCost = dbCfg.CacheWriteCostPerM - } - mapping = dbCfg.Mapping - } - - result = append(result, gin.H{ - "id": mID, - "name": mID, - "provider": "ollama", - "enabled": enabled, - "prompt_cost": promptCost, - "completion_cost": completionCost, - "cache_read_cost": cacheReadCost, - "cache_write_cost": cacheWriteCost, - "context_limit": contextLimit, - "modalities": gin.H{"input": []string{"text"}, "output": []string{"text"}}, - "tool_call": false, - "reasoning": false, - "mapping": mapping, - }) - } - } - - c.JSON(http.StatusOK, SuccessResponse(result)) -} - -func (s *Server) handleUpdateModel(c *gin.Context) { - id := c.Param("id") - var req struct { - Enabled bool `json:"enabled"` - PromptCost float64 `json:"prompt_cost"` - CompletionCost float64 `json:"completion_cost"` - CacheReadCost *float64 `json:"cache_read_cost"` - CacheWriteCost *float64 `json:"cache_write_cost"` - Mapping *string `json:"mapping"` - } - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request")) - return - } - - // Find provider for this model - providerID := "unknown" - s.registryMu.RLock() - if s.registry != nil { - for pID, pInfo := range s.registry.Providers { - if _, ok := pInfo.Models[id]; ok { - providerID = pID - break - } - } - } - - _, err := s.database.Exec(` - INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, cache_read_cost_per_m, cache_write_cost_per_m, mapping) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(id) DO UPDATE SET - enabled = excluded.enabled, - prompt_cost_per_m = excluded.prompt_cost_per_m, - completion_cost_per_m = excluded.completion_cost_per_m, - cache_read_cost_per_m = excluded.cache_read_cost_per_m, - cache_write_cost_per_m = excluded.cache_write_cost_per_m, - mapping = excluded.mapping, - updated_at = CURRENT_TIMESTAMP - `, id, providerID, req.Enabled, req.PromptCost, req.CompletionCost, req.CacheReadCost, req.CacheWriteCost, req.Mapping) - - if err != nil { - c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) - return - } - - c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Model updated"})) -} - -func (s *Server) handleGetUsers(c *gin.Context) { - var users []db.User - err := s.database.Select(&users, "SELECT id, username, display_name, role, must_change_password, created_at FROM users") - if err != nil { - c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) - return - } - c.JSON(http.StatusOK, SuccessResponse(users)) -} - -type CreateUserRequest struct { - Username string `json:"username" binding:"required"` - Password string `json:"password" binding:"required"` - DisplayName *string `json:"display_name"` - Role *string `json:"role"` -} - -func (s *Server) handleCreateUser(c *gin.Context) { - var req CreateUserRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request")) - return - } - - hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), 12) - if err != nil { - c.JSON(http.StatusInternalServerError, ErrorResponse("Failed to hash password")) - return - } - - role := "viewer" - if req.Role != nil { - role = *req.Role - } - - _, err = s.database.Exec("INSERT INTO users (username, password_hash, display_name, role, must_change_password) VALUES (?, ?, ?, ?, 1)", - req.Username, string(hash), req.DisplayName, role) - if err != nil { - c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) - return - } - - c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User created"})) -} - -type UpdateUserRequest struct { - DisplayName *string `json:"display_name"` - Role *string `json:"role"` - Password *string `json:"password"` - MustChangePassword *bool `json:"must_change_password"` -} - -func (s *Server) handleUpdateUser(c *gin.Context) { - id := c.Param("id") - var req UpdateUserRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request")) - return - } - - if req.DisplayName != nil { - s.database.Exec("UPDATE users SET display_name = ? WHERE id = ?", req.DisplayName, id) - } - if req.Role != nil { - s.database.Exec("UPDATE users SET role = ? WHERE id = ?", req.Role, id) - } - if req.MustChangePassword != nil { - s.database.Exec("UPDATE users SET must_change_password = ? WHERE id = ?", req.MustChangePassword, id) - } - if req.Password != nil { - hash, _ := bcrypt.GenerateFromPassword([]byte(*req.Password), 12) - s.database.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hash), id) - } - - c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User updated"})) -} - -func (s *Server) handleDeleteUser(c *gin.Context) { - id := c.Param("id") - - session, _ := c.Get("session") - if sess, ok := session.(*Session); ok { - var username string - s.database.Get(&username, "SELECT username FROM users WHERE id = ?", id) - if username == sess.Username { - c.JSON(http.StatusBadRequest, ErrorResponse("Cannot delete your own account")) - return - } - } - - _, err := s.database.Exec("DELETE FROM users WHERE id = ?", id) - if err != nil { - c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) - return - } - - c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User deleted"})) -} - -func (s *Server) handleSystemHealth(c *gin.Context) { - c.JSON(http.StatusOK, SuccessResponse(gin.H{ - "status": "ok", - "components": gin.H{ - "database": "online", - "proxy": "online", - }, - })) -} - -func (s *Server) handleSystemMetrics(c *gin.Context) { - v, _ := mem.VirtualMemory() - c_usage, _ := cpu.Percent(time.Second, false) - d, _ := disk.Usage("/") - l, _ := load.Avg() - p, _ := process.NewProcess(int32(os.Getpid())) - rss, _ := p.MemoryInfo() - - cpuPercent := 0.0 - if len(c_usage) > 0 { - cpuPercent = c_usage[0] - } - - c.JSON(http.StatusOK, SuccessResponse(gin.H{ - "cpu": gin.H{ - "usage_percent": fmt.Sprintf("%.1f", cpuPercent), - "load_average": []float64{l.Load1, l.Load5, l.Load15}, - }, - "memory": gin.H{ - "used_mb": v.Used / 1024 / 1024, - "total_mb": v.Total / 1024 / 1024, - "usage_percent": fmt.Sprintf("%.1f", v.UsedPercent), - "process_rss_mb": rss.RSS / 1024 / 1024, - }, - "disk": gin.H{ - "used_gb": float64(d.Used) / 1024 / 1024 / 1024, - "total_gb": float64(d.Total) / 1024 / 1024 / 1024, - "usage_percent": fmt.Sprintf("%.1f", d.UsedPercent), - }, - "connections": gin.H{ - "db_active": s.database.Stats().OpenConnections, - "websocket_listeners": s.hub.GetClientCount(), - }, - })) -} - -func (s *Server) handleGetSettings(c *gin.Context) { - providerCount := 0 - modelCount := 0 - s.registryMu.RLock() - if s.registry != nil { - providerCount = len(s.registry.Providers) - for _, p := range s.registry.Providers { - modelCount += len(p.Models) - } - } - - c.JSON(http.StatusOK, SuccessResponse(gin.H{ - "server": gin.H{ - "version": "1.0.0-go", - "auth_tokens": s.cfg.Server.AuthTokens, - }, - "database": gin.H{ - "type": "sqlite", - "path": s.cfg.Database.Path, - }, - "registry": gin.H{ - "provider_count": providerCount, - "model_count": modelCount, - }, - })) -} - -func (s *Server) handleCreateBackup(c *gin.Context) { - // Simplified backup response - c.JSON(http.StatusOK, SuccessResponse(gin.H{ - "backup_id": fmt.Sprintf("backup-%d.db", time.Now().Unix()), - "status": "created", - })) -} - -func (s *Server) handleGetLogs(c *gin.Context) { - var logs []db.LLMRequest - err := s.database.Select(&logs, "SELECT * FROM llm_requests ORDER BY timestamp DESC LIMIT 100") - if err != nil { - c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) - return - } - - // Format for UI - type UILog struct { - Timestamp string `json:"timestamp"` - ClientID string `json:"client_id"` - Provider string `json:"provider"` - Model string `json:"model"` - Tokens int `json:"tokens"` - Status string `json:"status"` - Duration int `json:"duration"` - } - - uiLogs := make([]UILog, len(logs)) - for i, l := range logs { - clientID := "unknown" - if l.ClientID != nil { - clientID = *l.ClientID - } - provider := "unknown" - if l.Provider != nil { - provider = *l.Provider - } - model := "unknown" - if l.Model != nil { - model = *l.Model - } - tokens := 0 - if l.TotalTokens != nil { - tokens = *l.TotalTokens - } - duration := 0 - if l.DurationMS != nil { - duration = *l.DurationMS - } - - uiLogs[i] = UILog{ - Timestamp: l.Timestamp.Format(time.RFC3339), - ClientID: clientID, - Provider: provider, - Model: model, - Tokens: tokens, - Status: l.Status, - Duration: duration, - } - } - - c.JSON(http.StatusOK, SuccessResponse(uiLogs)) -} diff --git a/internal/server/providers_admin.go b/internal/server/providers_admin.go new file mode 100644 index 00000000..bf86969a --- /dev/null +++ b/internal/server/providers_admin.go @@ -0,0 +1,247 @@ +package server + +import ( + "database/sql" + "fmt" + "net/http" + "os" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "golang.org/x/crypto/bcrypt" + "gophergate/internal/db" + "gophergate/internal/models" + "gophergate/internal/utils" + "log/slog" + + "github.com/shirou/gopsutil/v3/cpu" + +func (s *Server) handleGetProviders(c *gin.Context) { + var dbConfigs []db.ProviderConfig + err := s.database.Select(&dbConfigs, "SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs") + if err != nil { + // Log error + } + + dbMap := make(map[string]db.ProviderConfig) + for _, cfg := range dbConfigs { + dbMap[cfg.ID] = cfg + } + + providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok", "ollama"} + var result []gin.H + + for _, id := range providerIDs { + var name string + var enabled bool + var baseURL string + + switch id { + case "openai": + name = "OpenAI" + enabled = s.cfg.Providers.OpenAI.Enabled + baseURL = s.cfg.Providers.OpenAI.BaseURL + case "gemini": + name = "Google Gemini" + enabled = s.cfg.Providers.Gemini.Enabled + baseURL = s.cfg.Providers.Gemini.BaseURL + case "deepseek": + name = "DeepSeek" + enabled = s.cfg.Providers.DeepSeek.Enabled + baseURL = s.cfg.Providers.DeepSeek.BaseURL + case "moonshot": + name = "Moonshot" + enabled = s.cfg.Providers.Moonshot.Enabled + baseURL = s.cfg.Providers.Moonshot.BaseURL + case "grok": + name = "xAI Grok" + enabled = s.cfg.Providers.Grok.Enabled + baseURL = s.cfg.Providers.Grok.BaseURL + case "ollama": + name = "Ollama" + enabled = s.cfg.Providers.Ollama.Enabled + baseURL = s.cfg.Providers.Ollama.BaseURL + } + + var balance float64 + var threshold float64 = 5.0 + var billingMode string + + if dbCfg, ok := dbMap[id]; ok { + enabled = dbCfg.Enabled + if dbCfg.BaseURL != nil { + baseURL = *dbCfg.BaseURL + } + balance = dbCfg.CreditBalance + threshold = dbCfg.LowCreditThreshold + if dbCfg.BillingMode != nil { + billingMode = *dbCfg.BillingMode + } + } + + status := "disabled" + if enabled { + if _, ok := s.providers[id]; ok { + status = "online" + } else { + status = "error" + } + } + + // Get last used for this provider + var lastUsedTime sql.NullTime + _ = s.database.Get(&lastUsedTime, "SELECT MAX(timestamp) FROM llm_requests WHERE provider = ?", id) + var lastUsed interface{} + if lastUsedTime.Valid && !lastUsedTime.Time.IsZero() { + lastUsed = lastUsedTime.Time + } + + // Get models for this provider from registry + var models []string + s.registryMu.RLock() + if s.registry != nil { + registryID := id + if id == "gemini" { + registryID = "google" + } + if id == "moonshot" { + registryID = "moonshot" + } + if id == "grok" { + registryID = "xai" + } + + if pInfo, ok := s.registry.Providers[registryID]; ok { + for mID := range pInfo.Models { + models = append(models, mID) + } + } + } + s.registryMu.RUnlock() + + // If it's ollama, also include models from config + if id == "ollama" { + models = append(models, s.cfg.Providers.Ollama.Models...) + } + + result = append(result, gin.H{ + "id": id, + "name": name, + "enabled": enabled, + "status": status, + "base_url": baseURL, + "credit_balance": balance, + "low_credit_threshold": threshold, + "billing_mode": billingMode, + "last_used": lastUsed, + "models": models, + }) + } + + c.JSON(http.StatusOK, SuccessResponse(result)) +} + +type UpdateProviderRequest struct { + Enabled bool `json:"enabled"` + BaseURL *string `json:"base_url"` + APIKey *string `json:"api_key"` + CreditBalance *float64 `json:"credit_balance"` + LowCreditThreshold *float64 `json:"low_credit_threshold"` + BillingMode *string `json:"billing_mode"` +} + +func (s *Server) handleUpdateProvider(c *gin.Context) { + name := c.Param("name") + var req UpdateProviderRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request")) + return + } + + apiKeyEncrypted := false + var apiKey *string = req.APIKey + if req.APIKey != nil && *req.APIKey != "" { + encrypted, err := utils.Encrypt(*req.APIKey, s.cfg.KeyBytes) + if err == nil { + apiKey = &encrypted + apiKeyEncrypted = true + } + } + + _, err := s.database.Exec(` + INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode, api_key_encrypted) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + enabled = excluded.enabled, + base_url = COALESCE(excluded.base_url, provider_configs.base_url), + api_key = COALESCE(excluded.api_key, provider_configs.api_key), + api_key_encrypted = excluded.api_key_encrypted, + credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance), + low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold), + billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode), + updated_at = CURRENT_TIMESTAMP + `, name, strings.ToUpper(name), req.Enabled, req.BaseURL, apiKey, req.CreditBalance, req.LowCreditThreshold, req.BillingMode, apiKeyEncrypted) + + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + // Refresh in-memory providers + if err := s.RefreshProviders(); err != nil { + fmt.Printf("Error refreshing providers: %v\n", err) + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Provider updated"})) +} + +func (s *Server) handleTestProvider(c *gin.Context) { + name := c.Param("name") + provider, ok := s.providers[name] + if !ok { + c.JSON(http.StatusNotFound, ErrorResponse(fmt.Sprintf("Provider %s not found or not enabled", name))) + return + } + + startTime := time.Now() + + // Prepare a simple test request + testReq := &models.UnifiedRequest{ + Model: "gpt-4o-mini", // Default cheap test model + Messages: []models.UnifiedMessage{ + { + Role: "user", + Content: []models.UnifiedContentPart{{Type: "text", Text: "Hi"}}, + }, + }, + MaxTokens: new(uint32), + } + *testReq.MaxTokens = 5 + + // Adjust model for non-openai providers + if name == "gemini" { + testReq.Model = "gemini-2.0-flash" + } else if name == "deepseek" { + testReq.Model = "deepseek-chat" + } else if name == "moonshot" { + testReq.Model = "kimi-k2.5" + } else if name == "grok" { + testReq.Model = "grok-4-1-fast-non-reasoning" + } + + _, err := provider.ChatCompletion(c.Request.Context(), testReq) + latency := time.Since(startTime).Milliseconds() + + if err != nil { + c.JSON(http.StatusOK, ErrorResponse(fmt.Sprintf("Provider test failed: %v", err))) + return + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{ + "message": "Connection test successful", + "latency": latency, + })) +} + diff --git a/internal/server/system.go b/internal/server/system.go new file mode 100644 index 00000000..67f00f8f --- /dev/null +++ b/internal/server/system.go @@ -0,0 +1,156 @@ +package server + +import ( + "database/sql" + "fmt" + "net/http" + "os" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "golang.org/x/crypto/bcrypt" + "gophergate/internal/db" + "gophergate/internal/models" + "gophergate/internal/utils" + "log/slog" + + "github.com/shirou/gopsutil/v3/cpu" + +func (s *Server) handleSystemHealth(c *gin.Context) { + c.JSON(http.StatusOK, SuccessResponse(gin.H{ + "status": "ok", + "components": gin.H{ + "database": "online", + "proxy": "online", + }, + })) +} + +func (s *Server) handleSystemMetrics(c *gin.Context) { + v, _ := mem.VirtualMemory() + c_usage, _ := cpu.Percent(time.Second, false) + d, _ := disk.Usage("/") + l, _ := load.Avg() + p, _ := process.NewProcess(int32(os.Getpid())) + rss, _ := p.MemoryInfo() + + cpuPercent := 0.0 + if len(c_usage) > 0 { + cpuPercent = c_usage[0] + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{ + "cpu": gin.H{ + "usage_percent": fmt.Sprintf("%.1f", cpuPercent), + "load_average": []float64{l.Load1, l.Load5, l.Load15}, + }, + "memory": gin.H{ + "used_mb": v.Used / 1024 / 1024, + "total_mb": v.Total / 1024 / 1024, + "usage_percent": fmt.Sprintf("%.1f", v.UsedPercent), + "process_rss_mb": rss.RSS / 1024 / 1024, + }, + "disk": gin.H{ + "used_gb": float64(d.Used) / 1024 / 1024 / 1024, + "total_gb": float64(d.Total) / 1024 / 1024 / 1024, + "usage_percent": fmt.Sprintf("%.1f", d.UsedPercent), + }, + "connections": gin.H{ + "db_active": s.database.Stats().OpenConnections, + "websocket_listeners": s.hub.GetClientCount(), + }, + })) +} + +func (s *Server) handleGetSettings(c *gin.Context) { + providerCount := 0 + modelCount := 0 + s.registryMu.RLock() + if s.registry != nil { + providerCount = len(s.registry.Providers) + for _, p := range s.registry.Providers { + modelCount += len(p.Models) + } + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{ + "server": gin.H{ + "version": "1.0.0-go", + "auth_tokens": s.cfg.Server.AuthTokens, + }, + "database": gin.H{ + "type": "sqlite", + "path": s.cfg.Database.Path, + }, + "registry": gin.H{ + "provider_count": providerCount, + "model_count": modelCount, + }, + })) +} + +func (s *Server) handleCreateBackup(c *gin.Context) { + // Simplified backup response + c.JSON(http.StatusOK, SuccessResponse(gin.H{ + "backup_id": fmt.Sprintf("backup-%d.db", time.Now().Unix()), + "status": "created", + })) +} + +func (s *Server) handleGetLogs(c *gin.Context) { + var logs []db.LLMRequest + err := s.database.Select(&logs, "SELECT * FROM llm_requests ORDER BY timestamp DESC LIMIT 100") + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + // Format for UI + type UILog struct { + Timestamp string `json:"timestamp"` + ClientID string `json:"client_id"` + Provider string `json:"provider"` + Model string `json:"model"` + Tokens int `json:"tokens"` + Status string `json:"status"` + Duration int `json:"duration"` + } + + uiLogs := make([]UILog, len(logs)) + for i, l := range logs { + clientID := "unknown" + if l.ClientID != nil { + clientID = *l.ClientID + } + provider := "unknown" + if l.Provider != nil { + provider = *l.Provider + } + model := "unknown" + if l.Model != nil { + model = *l.Model + } + tokens := 0 + if l.TotalTokens != nil { + tokens = *l.TotalTokens + } + duration := 0 + if l.DurationMS != nil { + duration = *l.DurationMS + } + + uiLogs[i] = UILog{ + Timestamp: l.Timestamp.Format(time.RFC3339), + ClientID: clientID, + Provider: provider, + Model: model, + Tokens: tokens, + Status: l.Status, + Duration: duration, + } + } + + c.JSON(http.StatusOK, SuccessResponse(uiLogs)) +} diff --git a/internal/server/users.go b/internal/server/users.go new file mode 100644 index 00000000..ecd7568d --- /dev/null +++ b/internal/server/users.go @@ -0,0 +1,119 @@ +package server + +import ( + "database/sql" + "fmt" + "net/http" + "os" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "golang.org/x/crypto/bcrypt" + "gophergate/internal/db" + "gophergate/internal/models" + "gophergate/internal/utils" + "log/slog" + + "github.com/shirou/gopsutil/v3/cpu" + +func (s *Server) handleGetUsers(c *gin.Context) { + var users []db.User + err := s.database.Select(&users, "SELECT id, username, display_name, role, must_change_password, created_at FROM users") + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + c.JSON(http.StatusOK, SuccessResponse(users)) +} + +type CreateUserRequest struct { + Username string `json:"username" binding:"required"` + Password string `json:"password" binding:"required"` + DisplayName *string `json:"display_name"` + Role *string `json:"role"` +} + +func (s *Server) handleCreateUser(c *gin.Context) { + var req CreateUserRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request")) + return + } + + hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), 12) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse("Failed to hash password")) + return + } + + role := "viewer" + if req.Role != nil { + role = *req.Role + } + + _, err = s.database.Exec("INSERT INTO users (username, password_hash, display_name, role, must_change_password) VALUES (?, ?, ?, ?, 1)", + req.Username, string(hash), req.DisplayName, role) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User created"})) +} + +type UpdateUserRequest struct { + DisplayName *string `json:"display_name"` + Role *string `json:"role"` + Password *string `json:"password"` + MustChangePassword *bool `json:"must_change_password"` +} + +func (s *Server) handleUpdateUser(c *gin.Context) { + id := c.Param("id") + var req UpdateUserRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request")) + return + } + + if req.DisplayName != nil { + s.database.Exec("UPDATE users SET display_name = ? WHERE id = ?", req.DisplayName, id) + } + if req.Role != nil { + s.database.Exec("UPDATE users SET role = ? WHERE id = ?", req.Role, id) + } + if req.MustChangePassword != nil { + s.database.Exec("UPDATE users SET must_change_password = ? WHERE id = ?", req.MustChangePassword, id) + } + if req.Password != nil { + hash, _ := bcrypt.GenerateFromPassword([]byte(*req.Password), 12) + s.database.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hash), id) + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User updated"})) +} + +func (s *Server) handleDeleteUser(c *gin.Context) { + id := c.Param("id") + + session, _ := c.Get("session") + if sess, ok := session.(*Session); ok { + var username string + s.database.Get(&username, "SELECT username FROM users WHERE id = ?", id) + if username == sess.Username { + c.JSON(http.StatusBadRequest, ErrorResponse("Cannot delete your own account")) + return + } + } + + _, err := s.database.Exec("DELETE FROM users WHERE id = ?", id) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User deleted"})) +} + diff --git a/internal/utils/registry_test.go b/internal/utils/registry_test.go new file mode 100644 index 00000000..8de92616 --- /dev/null +++ b/internal/utils/registry_test.go @@ -0,0 +1,40 @@ +package utils + +import ( + "testing" + + "gophergate/internal/models" +) + +func TestCalculateCost_NotFound(t *testing.T) { + r := &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)} + cost := CalculateCost(r, "unknown-model", 100, 50, 0, 0, 0) + if cost != 0.0 { + t.Fatalf("expected 0 cost for unknown model, got %f", cost) + } +} + +func TestCalculateCost_KnownModel(t *testing.T) { + inputCost := 2.5 // $2.50 per 1M tokens + outputCost := 10.0 // $10.00 per 1M tokens + r := &models.ModelRegistry{ + Providers: map[string]models.ProviderInfo{ + "openai": { + Models: map[string]models.ModelMetadata{ + "gpt-4o": { + Cost: &models.ModelCost{ + Input: inputCost, + Output: outputCost, + }, + }, + }, + }, + }, + } + + cost := CalculateCost(r, "gpt-4o", 1000, 500, 0, 0, 0) + expected := (1000 * inputCost / 1000000.0) + (500 * outputCost / 1000000.0) + if cost != expected { + t.Fatalf("expected %f, got %f", expected, cost) + } +}