Updated handleGetModels to merge registry data with DB overrides and implemented handleUpdateModel. Verified API response format matches frontend requirements.
1024 lines
28 KiB
Go
1024 lines
28 KiB
Go
package server
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"llm-proxy/internal/db"
|
|
"llm-proxy/internal/models"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
type ApiResponse struct {
|
|
Success bool `json:"success"`
|
|
Data interface{} `json:"data,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
func SuccessResponse(data interface{}) ApiResponse {
|
|
return ApiResponse{Success: true, Data: data}
|
|
}
|
|
|
|
func ErrorResponse(err string) ApiResponse {
|
|
return ApiResponse{Success: false, Error: err}
|
|
}
|
|
|
|
func (s *Server) adminAuthMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
|
|
if token == "" {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse("Not authenticated"))
|
|
return
|
|
}
|
|
|
|
session, _, err := s.sessions.ValidateSession(token)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse("Session expired or invalid"))
|
|
return
|
|
}
|
|
|
|
if session.Role != "admin" {
|
|
c.AbortWithStatusJSON(http.StatusForbidden, ErrorResponse("Admin access required"))
|
|
return
|
|
}
|
|
|
|
c.Set("session", session)
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
type LoginRequest struct {
|
|
Username string `json:"username" binding:"required"`
|
|
Password string `json:"password" binding:"required"`
|
|
}
|
|
|
|
func (s *Server) handleLogin(c *gin.Context) {
|
|
var req LoginRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
|
return
|
|
}
|
|
|
|
var user db.User
|
|
err := s.database.Get(&user, "SELECT username, password_hash, display_name, role, must_change_password FROM users WHERE username = ?", req.Username)
|
|
if err != nil {
|
|
c.JSON(http.StatusUnauthorized, ErrorResponse("Invalid username or password"))
|
|
return
|
|
}
|
|
|
|
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
|
|
c.JSON(http.StatusUnauthorized, ErrorResponse("Invalid username or password"))
|
|
return
|
|
}
|
|
|
|
token, err := s.sessions.CreateSession(user.Username, user.Role)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse("Failed to create session"))
|
|
return
|
|
}
|
|
|
|
displayName := user.Username
|
|
if user.DisplayName != nil {
|
|
displayName = *user.DisplayName
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
|
"token": token,
|
|
"must_change_password": user.MustChangePassword,
|
|
"user": gin.H{
|
|
"username": user.Username,
|
|
"name": displayName,
|
|
"role": user.Role,
|
|
},
|
|
}))
|
|
}
|
|
|
|
func (s *Server) handleAuthStatus(c *gin.Context) {
|
|
token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
|
|
session, _, err := s.sessions.ValidateSession(token)
|
|
if err != nil {
|
|
c.JSON(http.StatusUnauthorized, ErrorResponse("Not authenticated"))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
|
"authenticated": true,
|
|
"user": gin.H{
|
|
"username": session.Username,
|
|
"role": session.Role,
|
|
},
|
|
}))
|
|
}
|
|
|
|
type ChangePasswordRequest struct {
|
|
CurrentPassword string `json:"current_password" binding:"required"`
|
|
NewPassword string `json:"new_password" binding:"required"`
|
|
}
|
|
|
|
func (s *Server) handleChangePassword(c *gin.Context) {
|
|
token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
|
|
session, _, err := s.sessions.ValidateSession(token)
|
|
if err != nil {
|
|
c.JSON(http.StatusUnauthorized, ErrorResponse("Not authenticated"))
|
|
return
|
|
}
|
|
|
|
var req ChangePasswordRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
|
return
|
|
}
|
|
|
|
var user db.User
|
|
err = s.database.Get(&user, "SELECT password_hash FROM users WHERE username = ?", session.Username)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse("User not found"))
|
|
return
|
|
}
|
|
|
|
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.CurrentPassword)); err != nil {
|
|
c.JSON(http.StatusUnauthorized, ErrorResponse("Current password incorrect"))
|
|
return
|
|
}
|
|
|
|
newHash, err := bcrypt.GenerateFromPassword([]byte(req.NewPassword), 12)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse("Failed to hash new password"))
|
|
return
|
|
}
|
|
|
|
_, err = s.database.Exec("UPDATE users SET password_hash = ?, must_change_password = 0 WHERE username = ?", string(newHash), session.Username)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse("Failed to update password"))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Password updated successfully"}))
|
|
}
|
|
|
|
func (s *Server) handleLogout(c *gin.Context) {
|
|
token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
|
|
s.sessions.RevokeSession(token)
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Logged out"}))
|
|
}
|
|
|
|
type UsagePeriodFilter struct {
|
|
Period string `form:"period"`
|
|
From string `form:"from"`
|
|
To string `form:"to"`
|
|
}
|
|
|
|
func (f *UsagePeriodFilter) ToSQL() (string, []interface{}) {
|
|
period := f.Period
|
|
if period == "" {
|
|
period = "all"
|
|
}
|
|
|
|
if period == "custom" {
|
|
var clauses []string
|
|
var binds []interface{}
|
|
if f.From != "" {
|
|
clauses = append(clauses, "timestamp >= ?")
|
|
binds = append(binds, f.From)
|
|
}
|
|
if f.To != "" {
|
|
clauses = append(clauses, "timestamp <= ?")
|
|
binds = append(binds, f.To)
|
|
}
|
|
if len(clauses) > 0 {
|
|
return " AND " + strings.Join(clauses, " AND "), binds
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
var cutoff time.Time
|
|
switch period {
|
|
case "today":
|
|
cutoff = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
|
|
case "24h":
|
|
cutoff = now.Add(-24 * time.Hour)
|
|
case "7d":
|
|
cutoff = now.Add(-7 * 24 * time.Hour)
|
|
case "30d":
|
|
cutoff = now.Add(-30 * 24 * time.Hour)
|
|
default:
|
|
return "", nil
|
|
}
|
|
|
|
return " AND timestamp >= ?", []interface{}{cutoff.Format(time.RFC3339)}
|
|
}
|
|
|
|
func (s *Server) handleUsageSummary(c *gin.Context) {
|
|
var filter UsagePeriodFilter
|
|
if err := c.ShouldBindQuery(&filter); err != nil {
|
|
// ignore
|
|
}
|
|
|
|
clause, binds := filter.ToSQL()
|
|
|
|
// Total stats
|
|
var totalStats struct {
|
|
TotalRequests int `db:"total_requests"`
|
|
TotalTokens int `db:"total_tokens"`
|
|
TotalCost float64 `db:"total_cost"`
|
|
ActiveClients int `db:"active_clients"`
|
|
}
|
|
err := s.database.Get(&totalStats, fmt.Sprintf(`
|
|
SELECT
|
|
COUNT(*) as total_requests,
|
|
COALESCE(SUM(total_tokens), 0) as total_tokens,
|
|
COALESCE(SUM(cost), 0.0) as total_cost,
|
|
COUNT(DISTINCT client_id) as active_clients
|
|
FROM llm_requests
|
|
WHERE 1=1 %s
|
|
`, clause), binds...)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
|
return
|
|
}
|
|
|
|
// Today stats
|
|
var todayStats struct {
|
|
TodayRequests int `db:"today_requests"`
|
|
TodayCost float64 `db:"today_cost"`
|
|
}
|
|
today := time.Now().UTC().Format("2006-01-02")
|
|
_ = s.database.Get(&todayStats, `
|
|
SELECT
|
|
COUNT(*) as today_requests,
|
|
COALESCE(SUM(cost), 0.0) as today_cost
|
|
FROM llm_requests
|
|
WHERE strftime('%Y-%m-%d', timestamp) = ?
|
|
`, today)
|
|
|
|
// Error rate & Avg response time
|
|
var miscStats struct {
|
|
ErrorRate float64 `db:"error_rate"`
|
|
AvgResponseTime float64 `db:"avg_response_time"`
|
|
}
|
|
_ = s.database.Get(&miscStats, `
|
|
SELECT
|
|
(CAST(SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*)) * 100.0 as error_rate,
|
|
COALESCE(AVG(duration_ms), 0.0) as avg_response_time
|
|
FROM llm_requests
|
|
`)
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
|
"total_requests": totalStats.TotalRequests,
|
|
"total_tokens": totalStats.TotalTokens,
|
|
"total_cost": totalStats.TotalCost,
|
|
"active_clients": totalStats.ActiveClients,
|
|
"today_requests": todayStats.TodayRequests,
|
|
"today_cost": todayStats.TodayCost,
|
|
"error_rate": miscStats.ErrorRate,
|
|
"avg_response_time": miscStats.AvgResponseTime,
|
|
}))
|
|
}
|
|
|
|
func (s *Server) handleTimeSeries(c *gin.Context) {
|
|
var filter UsagePeriodFilter
|
|
if err := c.ShouldBindQuery(&filter); err != nil {
|
|
// ignore
|
|
}
|
|
|
|
clause, binds := filter.ToSQL()
|
|
|
|
if clause == "" {
|
|
cutoff := time.Now().UTC().Add(-30 * 24 * time.Hour)
|
|
clause = " AND timestamp >= ?"
|
|
binds = []interface{}{cutoff.Format(time.RFC3339)}
|
|
}
|
|
|
|
query := fmt.Sprintf(`
|
|
SELECT
|
|
strftime('%%Y-%%m-%%d', timestamp) as bucket,
|
|
COUNT(*) as requests,
|
|
COALESCE(SUM(total_tokens), 0) as tokens,
|
|
COALESCE(SUM(cost), 0.0) as cost
|
|
FROM llm_requests
|
|
WHERE 1=1 %s
|
|
GROUP BY bucket
|
|
ORDER BY bucket
|
|
`, clause)
|
|
|
|
var rows []struct {
|
|
Bucket string `db:"bucket"`
|
|
Requests int `db:"requests"`
|
|
Tokens int `db:"tokens"`
|
|
Cost float64 `db:"cost"`
|
|
}
|
|
|
|
err := s.database.Select(&rows, query, binds...)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
|
return
|
|
}
|
|
|
|
series := make([]gin.H, len(rows))
|
|
for i, r := range rows {
|
|
series[i] = gin.H{
|
|
"time": r.Bucket,
|
|
"requests": r.Requests,
|
|
"tokens": r.Tokens,
|
|
"cost": r.Cost,
|
|
}
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
|
"series": series,
|
|
}))
|
|
}
|
|
|
|
func (s *Server) handleProvidersUsage(c *gin.Context) {
|
|
var filter UsagePeriodFilter
|
|
if err := c.ShouldBindQuery(&filter); err != nil {
|
|
// ignore
|
|
}
|
|
|
|
clause, binds := filter.ToSQL()
|
|
|
|
var rows []struct {
|
|
Provider string `db:"provider"`
|
|
Requests int `db:"requests"`
|
|
}
|
|
err := s.database.Select(&rows, fmt.Sprintf(`
|
|
SELECT provider, COUNT(*) as requests
|
|
FROM llm_requests
|
|
WHERE 1=1 %s
|
|
GROUP BY provider
|
|
`, clause), binds...)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(rows))
|
|
}
|
|
|
|
func (s *Server) handleAnalyticsBreakdown(c *gin.Context) {
|
|
var filter UsagePeriodFilter
|
|
if err := c.ShouldBindQuery(&filter); err != nil {
|
|
// ignore
|
|
}
|
|
|
|
clause, binds := filter.ToSQL()
|
|
|
|
var models []struct {
|
|
Label string `db:"label"`
|
|
Value int `db:"value"`
|
|
}
|
|
err := s.database.Select(&models, fmt.Sprintf("SELECT model as label, COUNT(*) as value FROM llm_requests WHERE 1=1 %s GROUP BY model ORDER BY value DESC", clause), binds...)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
|
return
|
|
}
|
|
|
|
var clients []struct {
|
|
Label string `db:"label"`
|
|
Value int `db:"value"`
|
|
}
|
|
err = s.database.Select(&clients, fmt.Sprintf("SELECT client_id as label, COUNT(*) as value FROM llm_requests WHERE 1=1 %s GROUP BY client_id ORDER BY value DESC", clause), binds...)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
|
"models": models,
|
|
"clients": clients,
|
|
}))
|
|
}
|
|
|
|
func (s *Server) handleGetClients(c *gin.Context) {
|
|
var clients []db.Client
|
|
err := s.database.Select(&clients, "SELECT * FROM clients ORDER BY created_at DESC")
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, SuccessResponse(clients))
|
|
}
|
|
|
|
type CreateClientRequest struct {
|
|
Name string `json:"name" binding:"required"`
|
|
ClientID *string `json:"client_id"`
|
|
}
|
|
|
|
func (s *Server) handleCreateClient(c *gin.Context) {
|
|
var req CreateClientRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
|
return
|
|
}
|
|
|
|
clientID := ""
|
|
if req.ClientID != nil {
|
|
clientID = *req.ClientID
|
|
} else {
|
|
clientID = "client-" + uuid.New().String()[:8]
|
|
}
|
|
|
|
_, err := s.database.Exec("INSERT INTO clients (client_id, name, is_active) VALUES (?, ?, 1)", clientID, req.Name)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
|
return
|
|
}
|
|
|
|
token := "sk-" + uuid.New().String() + uuid.New().String()
|
|
token = token[:51]
|
|
|
|
_, err = s.database.Exec("INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, 'default')", clientID, token)
|
|
if err != nil {
|
|
// Log error
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
|
"id": clientID,
|
|
"name": req.Name,
|
|
"status": "active",
|
|
"token": token,
|
|
"created_at": time.Now(),
|
|
}))
|
|
}
|
|
|
|
func (s *Server) handleDeleteClient(c *gin.Context) {
|
|
id := c.Param("id")
|
|
if id == "default" {
|
|
c.JSON(http.StatusBadRequest, ErrorResponse("Cannot delete default client"))
|
|
return
|
|
}
|
|
|
|
_, err := s.database.Exec("DELETE FROM clients WHERE client_id = ?", id)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Client deleted"}))
|
|
}
|
|
|
|
func (s *Server) handleGetClientTokens(c *gin.Context) {
|
|
id := c.Param("id")
|
|
var tokens []db.ClientToken
|
|
err := s.database.Select(&tokens, "SELECT * FROM client_tokens WHERE client_id = ? ORDER BY created_at DESC", id)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
|
return
|
|
}
|
|
|
|
type MaskedToken struct {
|
|
ID int `json:"id"`
|
|
TokenMasked string `json:"token_masked"`
|
|
Name string `json:"name"`
|
|
IsActive bool `json:"is_active"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
LastUsedAt *time.Time `json:"last_used_at"`
|
|
}
|
|
|
|
masked := make([]MaskedToken, len(tokens))
|
|
for i, t := range tokens {
|
|
maskedToken := "••••"
|
|
if len(t.Token) > 8 {
|
|
maskedToken = t.Token[:3] + "••••" + t.Token[len(t.Token)-8:]
|
|
}
|
|
masked[i] = MaskedToken{
|
|
ID: t.ID,
|
|
TokenMasked: maskedToken,
|
|
Name: t.Name,
|
|
IsActive: t.IsActive,
|
|
CreatedAt: t.CreatedAt,
|
|
LastUsedAt: t.LastUsedAt,
|
|
}
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(masked))
|
|
}
|
|
|
|
type CreateTokenRequest struct {
|
|
Name string `json:"name"`
|
|
}
|
|
|
|
func (s *Server) handleCreateClientToken(c *gin.Context) {
|
|
clientID := c.Param("id")
|
|
var req CreateTokenRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
// optional name
|
|
}
|
|
|
|
name := "default"
|
|
if req.Name != "" {
|
|
name = req.Name
|
|
}
|
|
|
|
token := "sk-" + uuid.New().String() + uuid.New().String()
|
|
token = token[:51]
|
|
|
|
_, err := s.database.Exec("INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, ?)", clientID, token, name)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
|
"token": token,
|
|
"name": name,
|
|
"created_at": time.Now(),
|
|
}))
|
|
}
|
|
|
|
func (s *Server) handleDeleteClientToken(c *gin.Context) {
|
|
tokenID := c.Param("token_id")
|
|
|
|
_, err := s.database.Exec("DELETE FROM client_tokens WHERE id = ?", tokenID)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Token revoked"}))
|
|
}
|
|
|
|
func (s *Server) handleGetProviders(c *gin.Context) {
|
|
var dbConfigs []db.ProviderConfig
|
|
err := s.database.Select(&dbConfigs, "SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs")
|
|
if err != nil {
|
|
// Log error
|
|
}
|
|
|
|
dbMap := make(map[string]db.ProviderConfig)
|
|
for _, cfg := range dbConfigs {
|
|
dbMap[cfg.ID] = cfg
|
|
}
|
|
|
|
providerIDs := []string{"openai", "gemini", "deepseek", "grok", "ollama"}
|
|
var result []gin.H
|
|
|
|
for _, id := range providerIDs {
|
|
var name string
|
|
var enabled bool
|
|
var baseURL string
|
|
|
|
switch id {
|
|
case "openai":
|
|
name = "OpenAI"
|
|
enabled = s.cfg.Providers.OpenAI.Enabled
|
|
baseURL = s.cfg.Providers.OpenAI.BaseURL
|
|
case "gemini":
|
|
name = "Google Gemini"
|
|
enabled = s.cfg.Providers.Gemini.Enabled
|
|
baseURL = s.cfg.Providers.Gemini.BaseURL
|
|
case "deepseek":
|
|
name = "DeepSeek"
|
|
enabled = s.cfg.Providers.DeepSeek.Enabled
|
|
baseURL = s.cfg.Providers.DeepSeek.BaseURL
|
|
case "grok":
|
|
name = "xAI Grok"
|
|
enabled = s.cfg.Providers.Grok.Enabled
|
|
baseURL = s.cfg.Providers.Grok.BaseURL
|
|
case "ollama":
|
|
name = "Ollama"
|
|
enabled = s.cfg.Providers.Ollama.Enabled
|
|
baseURL = s.cfg.Providers.Ollama.BaseURL
|
|
}
|
|
|
|
var balance float64
|
|
var threshold float64 = 5.0
|
|
var billingMode string
|
|
|
|
if dbCfg, ok := dbMap[id]; ok {
|
|
enabled = dbCfg.Enabled
|
|
if dbCfg.BaseURL != nil {
|
|
baseURL = *dbCfg.BaseURL
|
|
}
|
|
balance = dbCfg.CreditBalance
|
|
threshold = dbCfg.LowCreditThreshold
|
|
if dbCfg.BillingMode != nil {
|
|
billingMode = *dbCfg.BillingMode
|
|
}
|
|
}
|
|
|
|
status := "disabled"
|
|
if enabled {
|
|
if _, ok := s.providers[id]; ok {
|
|
status = "online"
|
|
} else {
|
|
status = "error"
|
|
}
|
|
}
|
|
|
|
result = append(result, gin.H{
|
|
"id": id,
|
|
"name": name,
|
|
"enabled": enabled,
|
|
"status": status,
|
|
"base_url": baseURL,
|
|
"credit_balance": balance,
|
|
"low_credit_threshold": threshold,
|
|
"billing_mode": billingMode,
|
|
})
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(result))
|
|
}
|
|
|
|
type UpdateProviderRequest struct {
|
|
Enabled bool `json:"enabled"`
|
|
BaseURL *string `json:"base_url"`
|
|
APIKey *string `json:"api_key"`
|
|
CreditBalance *float64 `json:"credit_balance"`
|
|
LowCreditThreshold *float64 `json:"low_credit_threshold"`
|
|
BillingMode *string `json:"billing_mode"`
|
|
}
|
|
|
|
func (s *Server) handleUpdateProvider(c *gin.Context) {
|
|
name := c.Param("name")
|
|
var req UpdateProviderRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
|
return
|
|
}
|
|
|
|
_, err := s.database.Exec(`
|
|
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(id) DO UPDATE SET
|
|
enabled = excluded.enabled,
|
|
base_url = COALESCE(excluded.base_url, provider_configs.base_url),
|
|
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
|
|
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
|
|
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
|
|
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
|
|
updated_at = CURRENT_TIMESTAMP
|
|
`, name, strings.ToUpper(name), req.Enabled, req.BaseURL, req.APIKey, req.CreditBalance, req.LowCreditThreshold, req.BillingMode)
|
|
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Provider updated"}))
|
|
}
|
|
|
|
func (s *Server) handleTestProvider(c *gin.Context) {
|
|
name := c.Param("name")
|
|
provider, ok := s.providers[name]
|
|
if !ok {
|
|
c.JSON(http.StatusNotFound, ErrorResponse(fmt.Sprintf("Provider %s not found or not enabled", name)))
|
|
return
|
|
}
|
|
|
|
startTime := time.Now()
|
|
|
|
// Prepare a simple test request
|
|
testReq := &models.UnifiedRequest{
|
|
Model: "gpt-4o", // Default test model
|
|
Messages: []models.UnifiedMessage{
|
|
{
|
|
Role: "user",
|
|
Content: []models.UnifiedContentPart{{Type: "text", Text: "Hi"}},
|
|
},
|
|
},
|
|
MaxTokens: new(uint32),
|
|
}
|
|
*testReq.MaxTokens = 5
|
|
|
|
// Adjust model for non-openai providers
|
|
if name == "gemini" {
|
|
testReq.Model = "gemini-2.0-flash"
|
|
} else if name == "deepseek" {
|
|
testReq.Model = "deepseek-chat"
|
|
} else if name == "grok" {
|
|
testReq.Model = "grok-2"
|
|
}
|
|
|
|
_, err := provider.ChatCompletion(c.Request.Context(), testReq)
|
|
latency := time.Since(startTime).Milliseconds()
|
|
|
|
if err != nil {
|
|
c.JSON(http.StatusOK, ErrorResponse(fmt.Sprintf("Provider test failed: %v", err)))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
|
"message": "Connection test successful",
|
|
"latency": latency,
|
|
}))
|
|
}
|
|
|
|
func (s *Server) handleGetModels(c *gin.Context) {
|
|
// Merge registry models with DB overrides
|
|
var dbModels []db.ModelConfig
|
|
_ = s.database.Select(&dbModels, "SELECT * FROM model_configs")
|
|
|
|
dbMap := make(map[string]db.ModelConfig)
|
|
for _, m := range dbModels {
|
|
dbMap[m.ID] = m
|
|
}
|
|
|
|
var result []gin.H
|
|
if s.registry != nil {
|
|
for pID, pInfo := range s.registry.Providers {
|
|
for mID, mMeta := range pInfo.Models {
|
|
enabled := true
|
|
promptCost := 0.0
|
|
completionCost := 0.0
|
|
var cacheReadCost *float64
|
|
var cacheWriteCost *float64
|
|
var mapping *string
|
|
contextLimit := uint32(0)
|
|
|
|
if mMeta.Cost != nil {
|
|
promptCost = mMeta.Cost.Input
|
|
completionCost = mMeta.Cost.Output
|
|
cacheReadCost = mMeta.Cost.CacheRead
|
|
cacheWriteCost = mMeta.Cost.CacheWrite
|
|
}
|
|
if mMeta.Limit != nil {
|
|
contextLimit = mMeta.Limit.Context
|
|
}
|
|
|
|
// Override from DB
|
|
if dbCfg, ok := dbMap[mID]; ok {
|
|
enabled = dbCfg.Enabled
|
|
if dbCfg.PromptCostPerM != nil {
|
|
promptCost = *dbCfg.PromptCostPerM
|
|
}
|
|
if dbCfg.CompletionCostPerM != nil {
|
|
completionCost = *dbCfg.CompletionCostPerM
|
|
}
|
|
if dbCfg.CacheReadCostPerM != nil {
|
|
cacheReadCost = dbCfg.CacheReadCostPerM
|
|
}
|
|
if dbCfg.CacheWriteCostPerM != nil {
|
|
cacheWriteCost = dbCfg.CacheWriteCostPerM
|
|
}
|
|
mapping = dbCfg.Mapping
|
|
}
|
|
|
|
result = append(result, gin.H{
|
|
"id": mID,
|
|
"name": mMeta.Name,
|
|
"provider": pID,
|
|
"enabled": enabled,
|
|
"prompt_cost": promptCost,
|
|
"completion_cost": completionCost,
|
|
"cache_read_cost": cacheReadCost,
|
|
"cache_write_cost": cacheWriteCost,
|
|
"context_limit": contextLimit,
|
|
"mapping": mapping,
|
|
"tool_call": mMeta.ToolCall != nil && *mMeta.ToolCall,
|
|
"reasoning": mMeta.Reasoning != nil && *mMeta.Reasoning,
|
|
"modalities": mMeta.Modalities,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(result))
|
|
}
|
|
|
|
func (s *Server) handleUpdateModel(c *gin.Context) {
|
|
id := c.Param("id")
|
|
var req struct {
|
|
Enabled bool `json:"enabled"`
|
|
PromptCost float64 `json:"prompt_cost"`
|
|
CompletionCost float64 `json:"completion_cost"`
|
|
CacheReadCost *float64 `json:"cache_read_cost"`
|
|
CacheWriteCost *float64 `json:"cache_write_cost"`
|
|
Mapping *string `json:"mapping"`
|
|
}
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
|
return
|
|
}
|
|
|
|
// Find provider for this model
|
|
providerID := "unknown"
|
|
if s.registry != nil {
|
|
for pID, pInfo := range s.registry.Providers {
|
|
if _, ok := pInfo.Models[id]; ok {
|
|
providerID = pID
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
_, err := s.database.Exec(`
|
|
INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, cache_read_cost_per_m, cache_write_cost_per_m, mapping)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(id) DO UPDATE SET
|
|
enabled = excluded.enabled,
|
|
prompt_cost_per_m = excluded.prompt_cost_per_m,
|
|
completion_cost_per_m = excluded.completion_cost_per_m,
|
|
cache_read_cost_per_m = excluded.cache_read_cost_per_m,
|
|
cache_write_cost_per_m = excluded.cache_write_cost_per_m,
|
|
mapping = excluded.mapping,
|
|
updated_at = CURRENT_TIMESTAMP
|
|
`, id, providerID, req.Enabled, req.PromptCost, req.CompletionCost, req.CacheReadCost, req.CacheWriteCost, req.Mapping)
|
|
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Model updated"}))
|
|
}
|
|
|
|
func (s *Server) handleGetUsers(c *gin.Context) {
|
|
var users []db.User
|
|
err := s.database.Select(&users, "SELECT id, username, display_name, role, must_change_password, created_at FROM users")
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, SuccessResponse(users))
|
|
}
|
|
|
|
type CreateUserRequest struct {
|
|
Username string `json:"username" binding:"required"`
|
|
Password string `json:"password" binding:"required"`
|
|
DisplayName *string `json:"display_name"`
|
|
Role *string `json:"role"`
|
|
}
|
|
|
|
func (s *Server) handleCreateUser(c *gin.Context) {
|
|
var req CreateUserRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
|
return
|
|
}
|
|
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), 12)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse("Failed to hash password"))
|
|
return
|
|
}
|
|
|
|
role := "viewer"
|
|
if req.Role != nil {
|
|
role = *req.Role
|
|
}
|
|
|
|
_, err = s.database.Exec("INSERT INTO users (username, password_hash, display_name, role, must_change_password) VALUES (?, ?, ?, ?, 1)",
|
|
req.Username, string(hash), req.DisplayName, role)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User created"}))
|
|
}
|
|
|
|
type UpdateUserRequest struct {
|
|
DisplayName *string `json:"display_name"`
|
|
Role *string `json:"role"`
|
|
Password *string `json:"password"`
|
|
MustChangePassword *bool `json:"must_change_password"`
|
|
}
|
|
|
|
func (s *Server) handleUpdateUser(c *gin.Context) {
|
|
id := c.Param("id")
|
|
var req UpdateUserRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
|
return
|
|
}
|
|
|
|
if req.DisplayName != nil {
|
|
s.database.Exec("UPDATE users SET display_name = ? WHERE id = ?", req.DisplayName, id)
|
|
}
|
|
if req.Role != nil {
|
|
s.database.Exec("UPDATE users SET role = ? WHERE id = ?", req.Role, id)
|
|
}
|
|
if req.MustChangePassword != nil {
|
|
s.database.Exec("UPDATE users SET must_change_password = ? WHERE id = ?", req.MustChangePassword, id)
|
|
}
|
|
if req.Password != nil {
|
|
hash, _ := bcrypt.GenerateFromPassword([]byte(*req.Password), 12)
|
|
s.database.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hash), id)
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User updated"}))
|
|
}
|
|
|
|
func (s *Server) handleDeleteUser(c *gin.Context) {
|
|
id := c.Param("id")
|
|
|
|
session, _ := c.Get("session")
|
|
if sess, ok := session.(*Session); ok {
|
|
var username string
|
|
s.database.Get(&username, "SELECT username FROM users WHERE id = ?", id)
|
|
if username == sess.Username {
|
|
c.JSON(http.StatusBadRequest, ErrorResponse("Cannot delete your own account"))
|
|
return
|
|
}
|
|
}
|
|
|
|
_, err := s.database.Exec("DELETE FROM users WHERE id = ?", id)
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
|
return
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User deleted"}))
|
|
}
|
|
|
|
func (s *Server) handleSystemHealth(c *gin.Context) {
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
|
"status": "ok",
|
|
"components": gin.H{
|
|
"database": "online",
|
|
"proxy": "online",
|
|
},
|
|
}))
|
|
}
|
|
|
|
func (s *Server) handleGetSettings(c *gin.Context) {
|
|
providerCount := 0
|
|
modelCount := 0
|
|
if s.registry != nil {
|
|
providerCount = len(s.registry.Providers)
|
|
for _, p := range s.registry.Providers {
|
|
modelCount += len(p.Models)
|
|
}
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
|
"server": gin.H{
|
|
"version": "1.0.0-go",
|
|
"auth_tokens": s.cfg.Server.AuthTokens,
|
|
},
|
|
"database": gin.H{
|
|
"type": "sqlite",
|
|
"path": s.cfg.Database.Path,
|
|
},
|
|
"registry": gin.H{
|
|
"provider_count": providerCount,
|
|
"model_count": modelCount,
|
|
},
|
|
}))
|
|
}
|
|
|
|
func (s *Server) handleCreateBackup(c *gin.Context) {
|
|
// Simplified backup response
|
|
c.JSON(http.StatusOK, SuccessResponse(gin.H{
|
|
"backup_id": fmt.Sprintf("backup-%d.db", time.Now().Unix()),
|
|
"status": "created",
|
|
}))
|
|
}
|
|
|
|
func (s *Server) handleGetLogs(c *gin.Context) {
|
|
var logs []db.LLMRequest
|
|
err := s.database.Select(&logs, "SELECT * FROM llm_requests ORDER BY timestamp DESC LIMIT 100")
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
|
return
|
|
}
|
|
|
|
// Format for UI
|
|
type UILog struct {
|
|
Timestamp string `json:"timestamp"`
|
|
ClientID string `json:"client_id"`
|
|
Provider string `json:"provider"`
|
|
Model string `json:"model"`
|
|
Tokens int `json:"tokens"`
|
|
Status string `json:"status"`
|
|
}
|
|
|
|
uiLogs := make([]UILog, len(logs))
|
|
for i, l := range logs {
|
|
clientID := "unknown"
|
|
if l.ClientID != nil {
|
|
clientID = *l.ClientID
|
|
}
|
|
provider := "unknown"
|
|
if l.Provider != nil {
|
|
provider = *l.Provider
|
|
}
|
|
model := "unknown"
|
|
if l.Model != nil {
|
|
model = *l.Model
|
|
}
|
|
tokens := 0
|
|
if l.TotalTokens != nil {
|
|
tokens = *l.TotalTokens
|
|
}
|
|
|
|
uiLogs[i] = UILog{
|
|
Timestamp: l.Timestamp.Format(time.RFC3339),
|
|
ClientID: clientID,
|
|
Provider: provider,
|
|
Model: model,
|
|
Tokens: tokens,
|
|
Status: l.Status,
|
|
}
|
|
}
|
|
|
|
c.JSON(http.StatusOK, SuccessResponse(uiLogs))
|
|
}
|