chore: consolidate env files and update gitignore
Removed .env and .env.backup from git tracking and consolidated configuration into .env.example. Updated .gitignore to robustly prevent accidental inclusion of sensitive files.
This commit is contained in:
58
internal/models/registry.go
Normal file
58
internal/models/registry.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package models
|
||||
|
||||
type ModelRegistry struct {
|
||||
Providers map[string]ProviderInfo `json:"-"`
|
||||
}
|
||||
|
||||
type ProviderInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Models map[string]ModelMetadata `json:"models"`
|
||||
}
|
||||
|
||||
type ModelMetadata struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Cost *ModelCost `json:"cost,omitempty"`
|
||||
Limit *ModelLimit `json:"limit,omitempty"`
|
||||
Modalities *ModelModalities `json:"modalities,omitempty"`
|
||||
ToolCall *bool `json:"tool_call,omitempty"`
|
||||
Reasoning *bool `json:"reasoning,omitempty"`
|
||||
}
|
||||
|
||||
type ModelCost struct {
|
||||
Input float64 `json:"input"`
|
||||
Output float64 `json:"output"`
|
||||
CacheRead *float64 `json:"cache_read,omitempty"`
|
||||
CacheWrite *float64 `json:"cache_write,omitempty"`
|
||||
}
|
||||
|
||||
type ModelLimit struct {
|
||||
Context uint32 `json:"context"`
|
||||
Output uint32 `json:"output"`
|
||||
}
|
||||
|
||||
type ModelModalities struct {
|
||||
Input []string `json:"input"`
|
||||
Output []string `json:"output"`
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) FindModel(modelID string) *ModelMetadata {
|
||||
// First try exact match in models map
|
||||
for _, provider := range r.Providers {
|
||||
if model, ok := provider.Models[modelID]; ok {
|
||||
return &model
|
||||
}
|
||||
}
|
||||
|
||||
// Try searching by ID in metadata
|
||||
for _, provider := range r.Providers {
|
||||
for _, model := range provider.Models {
|
||||
if model.ID == modelID {
|
||||
return &model
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -26,12 +26,20 @@ type Server struct {
|
||||
sessions *SessionManager
|
||||
hub *Hub
|
||||
logger *RequestLogger
|
||||
registry *models.ModelRegistry
|
||||
}
|
||||
|
||||
func NewServer(cfg *config.Config, database *db.DB) *Server {
|
||||
router := gin.Default()
|
||||
hub := NewHub()
|
||||
|
||||
// Fetch registry (non-blocking for startup if it fails, but we'll try once)
|
||||
registry, err := utils.FetchRegistry()
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Failed to fetch initial model registry: %v\n", err)
|
||||
registry = &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)}
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
router: router,
|
||||
cfg: cfg,
|
||||
@@ -40,6 +48,7 @@ func NewServer(cfg *config.Config, database *db.DB) *Server {
|
||||
sessions: NewSessionManager(cfg.KeyBytes, 24*time.Hour),
|
||||
hub: hub,
|
||||
logger: NewRequestLogger(database, hub),
|
||||
registry: registry,
|
||||
}
|
||||
|
||||
// Initialize providers
|
||||
@@ -311,8 +320,9 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model string, u
|
||||
if usage.CacheWriteTokens != nil {
|
||||
entry.CacheWriteTokens = *usage.CacheWriteTokens
|
||||
}
|
||||
// TODO: Calculate cost properly based on pricing
|
||||
entry.Cost = 0.0
|
||||
|
||||
// Calculate cost using registry
|
||||
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
|
||||
}
|
||||
|
||||
s.logger.LogRequest(entry)
|
||||
@@ -321,6 +331,18 @@ func (s *Server) logRequest(start time.Time, clientID, provider, model string, u
|
||||
func (s *Server) Run() error {
|
||||
go s.hub.Run()
|
||||
s.logger.Start()
|
||||
|
||||
// Start registry refresher
|
||||
go func() {
|
||||
ticker := time.NewTicker(24 * time.Hour)
|
||||
for range ticker.C {
|
||||
newRegistry, err := utils.FetchRegistry()
|
||||
if err == nil {
|
||||
s.registry = newRegistry
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
|
||||
return s.router.Run(addr)
|
||||
}
|
||||
|
||||
54
internal/utils/registry.go
Normal file
54
internal/utils/registry.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"llm-proxy/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
const ModelsDevURL = "https://models.dev/api.json"
|
||||
|
||||
func FetchRegistry() (*models.ModelRegistry, error) {
|
||||
log.Printf("Fetching model registry from %s", ModelsDevURL)
|
||||
|
||||
client := resty.New().SetTimeout(10 * time.Second)
|
||||
resp, err := client.R().Get(ModelsDevURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch registry: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("failed to fetch registry: HTTP %d", resp.StatusCode())
|
||||
}
|
||||
|
||||
var providers map[string]models.ProviderInfo
|
||||
if err := json.Unmarshal(resp.Body(), &providers); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal registry: %w", err)
|
||||
}
|
||||
|
||||
log.Println("Successfully loaded model registry")
|
||||
return &models.ModelRegistry{Providers: providers}, nil
|
||||
}
|
||||
|
||||
func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens, completionTokens, cacheRead, cacheWrite uint32) float64 {
|
||||
meta := registry.FindModel(modelID)
|
||||
if meta == nil || meta.Cost == nil {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
cost := (float64(promptTokens) * meta.Cost.Input / 1000000.0) +
|
||||
(float64(completionTokens) * meta.Cost.Output / 1000000.0)
|
||||
|
||||
if meta.Cost.CacheRead != nil {
|
||||
cost += float64(cacheRead) * (*meta.Cost.CacheRead) / 1000000.0
|
||||
}
|
||||
if meta.Cost.CacheWrite != nil {
|
||||
cost += float64(cacheWrite) * (*meta.Cost.CacheWrite) / 1000000.0
|
||||
}
|
||||
|
||||
return cost
|
||||
}
|
||||
Reference in New Issue
Block a user