fix: restore models page functionality
Updated handleGetModels to merge registry data with DB overrides and implemented handleUpdateModel. Verified API response format matches frontend requirements.
This commit is contained in:
@@ -676,7 +676,7 @@ func (s *Server) handleTestProvider(c *gin.Context) {
|
||||
|
||||
// Prepare a simple test request
|
||||
testReq := &models.UnifiedRequest{
|
||||
Model: "gpt-4o", // Default test model, might need dynamic selection
|
||||
Model: "gpt-4o", // Default test model
|
||||
Messages: []models.UnifiedMessage{
|
||||
{
|
||||
Role: "user",
|
||||
@@ -711,13 +711,122 @@ func (s *Server) handleTestProvider(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (s *Server) handleGetModels(c *gin.Context) {
|
||||
var models []db.ModelConfig
|
||||
err := s.database.Select(&models, "SELECT * FROM model_configs")
|
||||
// Merge registry models with DB overrides
|
||||
var dbModels []db.ModelConfig
|
||||
_ = s.database.Select(&dbModels, "SELECT * FROM model_configs")
|
||||
|
||||
dbMap := make(map[string]db.ModelConfig)
|
||||
for _, m := range dbModels {
|
||||
dbMap[m.ID] = m
|
||||
}
|
||||
|
||||
var result []gin.H
|
||||
if s.registry != nil {
|
||||
for pID, pInfo := range s.registry.Providers {
|
||||
for mID, mMeta := range pInfo.Models {
|
||||
enabled := true
|
||||
promptCost := 0.0
|
||||
completionCost := 0.0
|
||||
var cacheReadCost *float64
|
||||
var cacheWriteCost *float64
|
||||
var mapping *string
|
||||
contextLimit := uint32(0)
|
||||
|
||||
if mMeta.Cost != nil {
|
||||
promptCost = mMeta.Cost.Input
|
||||
completionCost = mMeta.Cost.Output
|
||||
cacheReadCost = mMeta.Cost.CacheRead
|
||||
cacheWriteCost = mMeta.Cost.CacheWrite
|
||||
}
|
||||
if mMeta.Limit != nil {
|
||||
contextLimit = mMeta.Limit.Context
|
||||
}
|
||||
|
||||
// Override from DB
|
||||
if dbCfg, ok := dbMap[mID]; ok {
|
||||
enabled = dbCfg.Enabled
|
||||
if dbCfg.PromptCostPerM != nil {
|
||||
promptCost = *dbCfg.PromptCostPerM
|
||||
}
|
||||
if dbCfg.CompletionCostPerM != nil {
|
||||
completionCost = *dbCfg.CompletionCostPerM
|
||||
}
|
||||
if dbCfg.CacheReadCostPerM != nil {
|
||||
cacheReadCost = dbCfg.CacheReadCostPerM
|
||||
}
|
||||
if dbCfg.CacheWriteCostPerM != nil {
|
||||
cacheWriteCost = dbCfg.CacheWriteCostPerM
|
||||
}
|
||||
mapping = dbCfg.Mapping
|
||||
}
|
||||
|
||||
result = append(result, gin.H{
|
||||
"id": mID,
|
||||
"name": mMeta.Name,
|
||||
"provider": pID,
|
||||
"enabled": enabled,
|
||||
"prompt_cost": promptCost,
|
||||
"completion_cost": completionCost,
|
||||
"cache_read_cost": cacheReadCost,
|
||||
"cache_write_cost": cacheWriteCost,
|
||||
"context_limit": contextLimit,
|
||||
"mapping": mapping,
|
||||
"tool_call": mMeta.ToolCall != nil && *mMeta.ToolCall,
|
||||
"reasoning": mMeta.Reasoning != nil && *mMeta.Reasoning,
|
||||
"modalities": mMeta.Modalities,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(result))
|
||||
}
|
||||
|
||||
func (s *Server) handleUpdateModel(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var req struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
PromptCost float64 `json:"prompt_cost"`
|
||||
CompletionCost float64 `json:"completion_cost"`
|
||||
CacheReadCost *float64 `json:"cache_read_cost"`
|
||||
CacheWriteCost *float64 `json:"cache_write_cost"`
|
||||
Mapping *string `json:"mapping"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
// Find provider for this model
|
||||
providerID := "unknown"
|
||||
if s.registry != nil {
|
||||
for pID, pInfo := range s.registry.Providers {
|
||||
if _, ok := pInfo.Models[id]; ok {
|
||||
providerID = pID
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, err := s.database.Exec(`
|
||||
INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, cache_read_cost_per_m, cache_write_cost_per_m, mapping)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
enabled = excluded.enabled,
|
||||
prompt_cost_per_m = excluded.prompt_cost_per_m,
|
||||
completion_cost_per_m = excluded.completion_cost_per_m,
|
||||
cache_read_cost_per_m = excluded.cache_read_cost_per_m,
|
||||
cache_write_cost_per_m = excluded.cache_write_cost_per_m,
|
||||
mapping = excluded.mapping,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
`, id, providerID, req.Enabled, req.PromptCost, req.CompletionCost, req.CacheReadCost, req.CacheWriteCost, req.Mapping)
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, SuccessResponse(models))
|
||||
|
||||
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Model updated"}))
|
||||
}
|
||||
|
||||
func (s *Server) handleGetUsers(c *gin.Context) {
|
||||
|
||||
@@ -119,7 +119,9 @@ func (s *Server) setupRoutes() {
|
||||
admin.GET("/providers", s.handleGetProviders)
|
||||
admin.PUT("/providers/:name", s.handleUpdateProvider)
|
||||
admin.POST("/providers/:name/test", s.handleTestProvider)
|
||||
|
||||
admin.GET("/models", s.handleGetModels)
|
||||
admin.PUT("/models/:id", s.handleUpdateModel)
|
||||
|
||||
admin.GET("/users", s.handleGetUsers)
|
||||
admin.POST("/users", s.handleCreateUser)
|
||||
|
||||
Reference in New Issue
Block a user