Track cache_read_tokens and cache_write_tokens end-to-end: parse from provider responses (OpenAI, DeepSeek, Grok, Gemini), persist to SQLite, apply cache-aware pricing from the model registry, and surface in API responses and the dashboard. - Add cache fields to ProviderResponse, StreamUsage, RequestLog structs - Parse cached_tokens (OpenAI/Grok), prompt_cache_hit/miss (DeepSeek), cachedContentTokenCount (Gemini) from provider responses - Send stream_options.include_usage for streaming; capture real usage from final SSE chunk in AggregatingStream - ALTER TABLE migration for cache_read_tokens/cache_write_tokens columns - Cache-aware cost formula using registry cache_read/cache_write rates - Update Provider trait calculate_cost signature across all providers - Add cache_read_tokens/cache_write_tokens to Usage API response - Dashboard: cache hit rate card, cache columns in pricing and usage tables, cache token aggregation in SQL queries - Remove API debug panel and verbose console logging from api.js - Bump static asset cache-bust to v5
174 lines
5.8 KiB
Rust
174 lines
5.8 KiB
Rust
use axum::{
|
|
extract::{Path, Query, State},
|
|
response::Json,
|
|
};
|
|
use serde::Deserialize;
|
|
use serde_json;
|
|
use sqlx::Row;
|
|
use std::collections::HashMap;
|
|
|
|
use super::{ApiResponse, DashboardState};
|
|
use crate::models::registry::{ModelFilter, ModelSortBy, SortOrder};
|
|
|
|
#[derive(Deserialize)]
|
|
pub(super) struct UpdateModelRequest {
|
|
pub(super) enabled: bool,
|
|
pub(super) prompt_cost: Option<f64>,
|
|
pub(super) completion_cost: Option<f64>,
|
|
pub(super) mapping: Option<String>,
|
|
}
|
|
|
|
/// Query parameters for `GET /api/models`.
|
|
#[derive(Debug, Deserialize, Default)]
|
|
pub(super) struct ModelListParams {
|
|
/// Filter by provider ID.
|
|
pub provider: Option<String>,
|
|
/// Text search on model ID or name.
|
|
pub search: Option<String>,
|
|
/// Filter by input modality (e.g. "image").
|
|
pub modality: Option<String>,
|
|
/// Only models that support tool calling.
|
|
pub tool_call: Option<bool>,
|
|
/// Only models that support reasoning.
|
|
pub reasoning: Option<bool>,
|
|
/// Only models that have pricing data.
|
|
pub has_cost: Option<bool>,
|
|
/// Sort field (name, id, provider, context_limit, input_cost, output_cost).
|
|
pub sort_by: Option<ModelSortBy>,
|
|
/// Sort direction (asc, desc).
|
|
pub sort_order: Option<SortOrder>,
|
|
}
|
|
|
|
pub(super) async fn handle_get_models(
|
|
State(state): State<DashboardState>,
|
|
Query(params): Query<ModelListParams>,
|
|
) -> Json<ApiResponse<serde_json::Value>> {
|
|
let registry = &state.app_state.model_registry;
|
|
let pool = &state.app_state.db_pool;
|
|
|
|
// Build filter from query params
|
|
let filter = ModelFilter {
|
|
provider: params.provider,
|
|
search: params.search,
|
|
modality: params.modality,
|
|
tool_call: params.tool_call,
|
|
reasoning: params.reasoning,
|
|
has_cost: params.has_cost,
|
|
};
|
|
let sort_by = params.sort_by.unwrap_or_default();
|
|
let sort_order = params.sort_order.unwrap_or_default();
|
|
|
|
// Get filtered and sorted model entries
|
|
let entries = registry.list_models(&filter, &sort_by, &sort_order);
|
|
|
|
// Load overrides from database
|
|
let db_models_result =
|
|
sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
|
|
.fetch_all(pool)
|
|
.await;
|
|
|
|
let mut db_models = HashMap::new();
|
|
if let Ok(rows) = db_models_result {
|
|
for row in rows {
|
|
let id: String = row.get("id");
|
|
db_models.insert(id, row);
|
|
}
|
|
}
|
|
|
|
let mut models_json = Vec::new();
|
|
|
|
for entry in &entries {
|
|
let m_key = entry.model_key;
|
|
let m_meta = entry.metadata;
|
|
|
|
let mut enabled = true;
|
|
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
|
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
|
let cache_read_cost = m_meta.cost.as_ref().and_then(|c| c.cache_read);
|
|
let cache_write_cost = m_meta.cost.as_ref().and_then(|c| c.cache_write);
|
|
let mut mapping = None::<String>;
|
|
|
|
if let Some(row) = db_models.get(m_key) {
|
|
enabled = row.get("enabled");
|
|
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") {
|
|
prompt_cost = p;
|
|
}
|
|
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") {
|
|
completion_cost = c;
|
|
}
|
|
mapping = row.get("mapping");
|
|
}
|
|
|
|
models_json.push(serde_json::json!({
|
|
"id": m_key,
|
|
"provider": entry.provider_id,
|
|
"provider_name": entry.provider_name,
|
|
"name": m_meta.name,
|
|
"enabled": enabled,
|
|
"prompt_cost": prompt_cost,
|
|
"completion_cost": completion_cost,
|
|
"cache_read_cost": cache_read_cost,
|
|
"cache_write_cost": cache_write_cost,
|
|
"mapping": mapping,
|
|
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
|
|
"output_limit": m_meta.limit.as_ref().map(|l| l.output).unwrap_or(0),
|
|
"modalities": m_meta.modalities.as_ref().map(|m| serde_json::json!({
|
|
"input": m.input,
|
|
"output": m.output,
|
|
})),
|
|
"tool_call": m_meta.tool_call,
|
|
"reasoning": m_meta.reasoning,
|
|
}));
|
|
}
|
|
|
|
Json(ApiResponse::success(serde_json::json!(models_json)))
|
|
}
|
|
|
|
pub(super) async fn handle_update_model(
|
|
State(state): State<DashboardState>,
|
|
Path(id): Path<String>,
|
|
Json(payload): Json<UpdateModelRequest>,
|
|
) -> Json<ApiResponse<serde_json::Value>> {
|
|
let pool = &state.app_state.db_pool;
|
|
|
|
// Find provider_id for this model in registry
|
|
let provider_id = state
|
|
.app_state
|
|
.model_registry
|
|
.providers
|
|
.iter()
|
|
.find(|(_, p)| p.models.contains_key(&id))
|
|
.map(|(id, _)| id.clone())
|
|
.unwrap_or_else(|| "unknown".to_string());
|
|
|
|
let result = sqlx::query(
|
|
r#"
|
|
INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(id) DO UPDATE SET
|
|
enabled = excluded.enabled,
|
|
prompt_cost_per_m = excluded.prompt_cost_per_m,
|
|
completion_cost_per_m = excluded.completion_cost_per_m,
|
|
mapping = excluded.mapping,
|
|
updated_at = CURRENT_TIMESTAMP
|
|
"#,
|
|
)
|
|
.bind(&id)
|
|
.bind(provider_id)
|
|
.bind(payload.enabled)
|
|
.bind(payload.prompt_cost)
|
|
.bind(payload.completion_cost)
|
|
.bind(payload.mapping)
|
|
.execute(pool)
|
|
.await;
|
|
|
|
match result {
|
|
Ok(_) => {
|
|
// Invalidate the in-memory cache so the proxy picks up the change immediately
|
|
state.app_state.model_config_cache.invalidate().await;
|
|
Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" })))
|
|
}
|
|
Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))),
|
|
}
|
|
}
|