Refined CalculateCost to correctly handle cached token discounts. Added fuzzy matching to model lookup. Robustified SQL date extraction using SUBSTR and LIKE for better SQLite compatibility.
66 lines
1.7 KiB
Go
66 lines
1.7 KiB
Go
package utils
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"time"
|
|
|
|
"llm-proxy/internal/models"
|
|
"github.com/go-resty/resty/v2"
|
|
)
|
|
|
|
const ModelsDevURL = "https://models.dev/api.json"
|
|
|
|
func FetchRegistry() (*models.ModelRegistry, error) {
|
|
log.Printf("Fetching model registry from %s", ModelsDevURL)
|
|
|
|
client := resty.New().SetTimeout(10 * time.Second)
|
|
resp, err := client.R().Get(ModelsDevURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch registry: %w", err)
|
|
}
|
|
|
|
if !resp.IsSuccess() {
|
|
return nil, fmt.Errorf("failed to fetch registry: HTTP %d", resp.StatusCode())
|
|
}
|
|
|
|
var providers map[string]models.ProviderInfo
|
|
if err := json.Unmarshal(resp.Body(), &providers); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal registry: %w", err)
|
|
}
|
|
|
|
log.Println("Successfully loaded model registry")
|
|
return &models.ModelRegistry{Providers: providers}, nil
|
|
}
|
|
|
|
func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens, completionTokens, cacheRead, cacheWrite uint32) float64 {
|
|
meta := registry.FindModel(modelID)
|
|
if meta == nil || meta.Cost == nil {
|
|
return 0.0
|
|
}
|
|
|
|
// promptTokens is usually the TOTAL prompt size.
|
|
// We subtract cacheRead from it to get the uncached part.
|
|
uncachedTokens := promptTokens
|
|
if cacheRead > 0 {
|
|
if cacheRead > promptTokens {
|
|
uncachedTokens = 0
|
|
} else {
|
|
uncachedTokens = promptTokens - cacheRead
|
|
}
|
|
}
|
|
|
|
cost := (float64(uncachedTokens) * meta.Cost.Input / 1000000.0) +
|
|
(float64(completionTokens) * meta.Cost.Output / 1000000.0)
|
|
|
|
if meta.Cost.CacheRead != nil {
|
|
cost += float64(cacheRead) * (*meta.Cost.CacheRead) / 1000000.0
|
|
}
|
|
if meta.Cost.CacheWrite != nil {
|
|
cost += float64(cacheWrite) * (*meta.Cost.CacheWrite) / 1000000.0
|
|
}
|
|
|
|
return cost
|
|
}
|