diff --git a/internal/server/dashboard.go b/internal/server/dashboard.go index 34505f27..0a7f518d 100644 --- a/internal/server/dashboard.go +++ b/internal/server/dashboard.go @@ -676,7 +676,7 @@ func (s *Server) handleTestProvider(c *gin.Context) { // Prepare a simple test request testReq := &models.UnifiedRequest{ - Model: "gpt-4o", // Default test model, might need dynamic selection + Model: "gpt-4o", // Default test model Messages: []models.UnifiedMessage{ { Role: "user", @@ -711,13 +711,122 @@ func (s *Server) handleTestProvider(c *gin.Context) { } func (s *Server) handleGetModels(c *gin.Context) { - var models []db.ModelConfig - err := s.database.Select(&models, "SELECT * FROM model_configs") + // Merge registry models with DB overrides + var dbModels []db.ModelConfig + _ = s.database.Select(&dbModels, "SELECT * FROM model_configs") + + dbMap := make(map[string]db.ModelConfig) + for _, m := range dbModels { + dbMap[m.ID] = m + } + + var result []gin.H + if s.registry != nil { + for pID, pInfo := range s.registry.Providers { + for mID, mMeta := range pInfo.Models { + enabled := true + promptCost := 0.0 + completionCost := 0.0 + var cacheReadCost *float64 + var cacheWriteCost *float64 + var mapping *string + contextLimit := uint32(0) + + if mMeta.Cost != nil { + promptCost = mMeta.Cost.Input + completionCost = mMeta.Cost.Output + cacheReadCost = mMeta.Cost.CacheRead + cacheWriteCost = mMeta.Cost.CacheWrite + } + if mMeta.Limit != nil { + contextLimit = mMeta.Limit.Context + } + + // Override from DB + if dbCfg, ok := dbMap[mID]; ok { + enabled = dbCfg.Enabled + if dbCfg.PromptCostPerM != nil { + promptCost = *dbCfg.PromptCostPerM + } + if dbCfg.CompletionCostPerM != nil { + completionCost = *dbCfg.CompletionCostPerM + } + if dbCfg.CacheReadCostPerM != nil { + cacheReadCost = dbCfg.CacheReadCostPerM + } + if dbCfg.CacheWriteCostPerM != nil { + cacheWriteCost = dbCfg.CacheWriteCostPerM + } + mapping = dbCfg.Mapping + } + + result = append(result, gin.H{ + "id": mID, + "name": mMeta.Name, + "provider": pID, + "enabled": enabled, + "prompt_cost": promptCost, + "completion_cost": completionCost, + "cache_read_cost": cacheReadCost, + "cache_write_cost": cacheWriteCost, + "context_limit": contextLimit, + "mapping": mapping, + "tool_call": mMeta.ToolCall != nil && *mMeta.ToolCall, + "reasoning": mMeta.Reasoning != nil && *mMeta.Reasoning, + "modalities": mMeta.Modalities, + }) + } + } + } + + c.JSON(http.StatusOK, SuccessResponse(result)) +} + +func (s *Server) handleUpdateModel(c *gin.Context) { + id := c.Param("id") + var req struct { + Enabled bool `json:"enabled"` + PromptCost float64 `json:"prompt_cost"` + CompletionCost float64 `json:"completion_cost"` + CacheReadCost *float64 `json:"cache_read_cost"` + CacheWriteCost *float64 `json:"cache_write_cost"` + Mapping *string `json:"mapping"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request")) + return + } + + // Find provider for this model + providerID := "unknown" + if s.registry != nil { + for pID, pInfo := range s.registry.Providers { + if _, ok := pInfo.Models[id]; ok { + providerID = pID + break + } + } + } + + _, err := s.database.Exec(` + INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, cache_read_cost_per_m, cache_write_cost_per_m, mapping) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + enabled = excluded.enabled, + prompt_cost_per_m = excluded.prompt_cost_per_m, + completion_cost_per_m = excluded.completion_cost_per_m, + cache_read_cost_per_m = excluded.cache_read_cost_per_m, + cache_write_cost_per_m = excluded.cache_write_cost_per_m, + mapping = excluded.mapping, + updated_at = CURRENT_TIMESTAMP + `, id, providerID, req.Enabled, req.PromptCost, req.CompletionCost, req.CacheReadCost, req.CacheWriteCost, req.Mapping) + if err != nil { c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) return } - c.JSON(http.StatusOK, SuccessResponse(models)) + + c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Model updated"})) } func (s *Server) handleGetUsers(c *gin.Context) { diff --git a/internal/server/server.go b/internal/server/server.go index 96d7750f..facce6ce 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -119,7 +119,9 @@ func (s *Server) setupRoutes() { admin.GET("/providers", s.handleGetProviders) admin.PUT("/providers/:name", s.handleUpdateProvider) admin.POST("/providers/:name/test", s.handleTestProvider) + admin.GET("/models", s.handleGetModels) + admin.PUT("/models/:id", s.handleUpdateModel) admin.GET("/users", s.handleGetUsers) admin.POST("/users", s.handleCreateUser)