refactor: comprehensive audit — fix bugs, harden security, deduplicate providers, add CI/Docker
Phase 1: Fix compilation (config_path Option<PathBuf>, streaming test, stale test cleanup) Phase 2: Fix critical bugs (remove block_on deadlocks in 4 providers, fix broken SQL query builder) Phase 3: Security hardening (session manager, real auth, token masking, Gemini key to header, password policy) Phase 4: Implement stubs (real provider test, /proc health metrics, client/provider/backup endpoints, has_images) Phase 5: Code quality (shared provider helpers, explicit re-exports, all Clippy warnings fixed, unwrap removal, 6 unused deps removed, dashboard split into 7 sub-modules) Phase 6: Infrastructure (GitHub Actions CI, multi-stage Dockerfile, rustfmt.toml, clippy.toml, script fixes)
This commit is contained in:
130
src/dashboard/auth.rs
Normal file
130
src/dashboard/auth.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
use axum::{extract::State, response::Json};
|
||||
use bcrypt;
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
|
||||
// Authentication handlers
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct LoginRequest {
|
||||
pub(super) username: String,
|
||||
pub(super) password: String,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_login(
|
||||
State(state): State<DashboardState>,
|
||||
Json(payload): Json<LoginRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let user_result =
|
||||
sqlx::query("SELECT username, password_hash, role, must_change_password FROM users WHERE username = ?")
|
||||
.bind(&payload.username)
|
||||
.fetch_optional(pool)
|
||||
.await;
|
||||
|
||||
match user_result {
|
||||
Ok(Some(row)) => {
|
||||
let hash = row.get::<String, _>("password_hash");
|
||||
if bcrypt::verify(&payload.password, &hash).unwrap_or(false) {
|
||||
let username = row.get::<String, _>("username");
|
||||
let role = row.get::<String, _>("role");
|
||||
let must_change_password = row.get::<bool, _>("must_change_password");
|
||||
let token = state
|
||||
.session_manager
|
||||
.create_session(username.clone(), role.clone())
|
||||
.await;
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"token": token,
|
||||
"must_change_password": must_change_password,
|
||||
"user": {
|
||||
"username": username,
|
||||
"name": "Administrator",
|
||||
"role": role
|
||||
}
|
||||
})))
|
||||
} else {
|
||||
Json(ApiResponse::error("Invalid username or password".to_string()))
|
||||
}
|
||||
}
|
||||
Ok(None) => Json(ApiResponse::error("Invalid username or password".to_string())),
|
||||
Err(e) => {
|
||||
warn!("Database error during login: {}", e);
|
||||
Json(ApiResponse::error("Login failed due to system error".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_auth_status(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let token = headers
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
if let Some(token) = token
|
||||
&& let Some(session) = state.session_manager.validate_session(token).await
|
||||
{
|
||||
return Json(ApiResponse::success(serde_json::json!({
|
||||
"authenticated": true,
|
||||
"user": {
|
||||
"username": session.username,
|
||||
"name": "Administrator",
|
||||
"role": session.role
|
||||
}
|
||||
})));
|
||||
}
|
||||
|
||||
Json(ApiResponse::error("Not authenticated".to_string()))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct ChangePasswordRequest {
|
||||
pub(super) current_password: String,
|
||||
pub(super) new_password: String,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_change_password(
|
||||
State(state): State<DashboardState>,
|
||||
Json(payload): Json<ChangePasswordRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// For now, always change 'admin' user
|
||||
let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = 'admin'")
|
||||
.fetch_one(pool)
|
||||
.await;
|
||||
|
||||
match user_result {
|
||||
Ok(row) => {
|
||||
let hash = row.get::<String, _>("password_hash");
|
||||
if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) {
|
||||
let new_hash = match bcrypt::hash(&payload.new_password, 12) {
|
||||
Ok(h) => h,
|
||||
Err(_) => return Json(ApiResponse::error("Failed to hash new password".to_string())),
|
||||
};
|
||||
|
||||
let update_result = sqlx::query(
|
||||
"UPDATE users SET password_hash = ?, must_change_password = FALSE WHERE username = 'admin'",
|
||||
)
|
||||
.bind(new_hash)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match update_result {
|
||||
Ok(_) => Json(ApiResponse::success(
|
||||
serde_json::json!({ "message": "Password updated successfully" }),
|
||||
)),
|
||||
Err(e) => Json(ApiResponse::error(format!("Failed to update database: {}", e))),
|
||||
}
|
||||
} else {
|
||||
Json(ApiResponse::error("Current password incorrect".to_string()))
|
||||
}
|
||||
}
|
||||
Err(e) => Json(ApiResponse::error(format!("User not found: {}", e))),
|
||||
}
|
||||
}
|
||||
227
src/dashboard/clients.rs
Normal file
227
src/dashboard/clients.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::Json,
|
||||
};
|
||||
use chrono;
|
||||
use serde::Deserialize;
|
||||
use serde_json;
|
||||
use sqlx::Row;
|
||||
use tracing::warn;
|
||||
use uuid;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct CreateClientRequest {
|
||||
pub(super) name: String,
|
||||
pub(super) client_id: Option<String>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_clients(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
client_id as id,
|
||||
name,
|
||||
created_at,
|
||||
total_requests,
|
||||
total_tokens,
|
||||
total_cost,
|
||||
is_active
|
||||
FROM clients
|
||||
ORDER BY created_at DESC
|
||||
"#,
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let clients: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
serde_json::json!({
|
||||
"id": row.get::<String, _>("id"),
|
||||
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
"requests_count": row.get::<i64, _>("total_requests"),
|
||||
"total_tokens": row.get::<i64, _>("total_tokens"),
|
||||
"total_cost": row.get::<f64, _>("total_cost"),
|
||||
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(clients)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch clients: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch clients".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_create_client(
|
||||
State(state): State<DashboardState>,
|
||||
Json(payload): Json<CreateClientRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let client_id = payload
|
||||
.client_id
|
||||
.unwrap_or_else(|| format!("client-{}", &uuid::Uuid::new_v4().to_string()[..8]));
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO clients (client_id, name, is_active)
|
||||
VALUES (?, ?, TRUE)
|
||||
RETURNING *
|
||||
"#,
|
||||
)
|
||||
.bind(&client_id)
|
||||
.bind(&payload.name)
|
||||
.fetch_one(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(row) => Json(ApiResponse::success(serde_json::json!({
|
||||
"id": row.get::<String, _>("client_id"),
|
||||
"name": row.get::<Option<String>, _>("name"),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
"status": "active",
|
||||
}))),
|
||||
Err(e) => {
|
||||
warn!("Failed to create client: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to create client: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_client(
|
||||
State(state): State<DashboardState>,
|
||||
Path(id): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
c.client_id as id,
|
||||
c.name,
|
||||
c.is_active,
|
||||
c.created_at,
|
||||
COALESCE(c.total_tokens, 0) as total_tokens,
|
||||
COALESCE(c.total_cost, 0.0) as total_cost,
|
||||
COUNT(r.id) as total_requests,
|
||||
MAX(r.timestamp) as last_request
|
||||
FROM clients c
|
||||
LEFT JOIN llm_requests r ON c.client_id = r.client_id
|
||||
WHERE c.client_id = ?
|
||||
GROUP BY c.client_id
|
||||
"#,
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_optional(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Some(row)) => Json(ApiResponse::success(serde_json::json!({
|
||||
"id": row.get::<String, _>("id"),
|
||||
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
|
||||
"is_active": row.get::<bool, _>("is_active"),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
"total_tokens": row.get::<i64, _>("total_tokens"),
|
||||
"total_cost": row.get::<f64, _>("total_cost"),
|
||||
"total_requests": row.get::<i64, _>("total_requests"),
|
||||
"last_request": row.get::<Option<chrono::DateTime<chrono::Utc>>, _>("last_request"),
|
||||
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
|
||||
}))),
|
||||
Ok(None) => Json(ApiResponse::error(format!("Client '{}' not found", id))),
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch client: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to fetch client: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_delete_client(
|
||||
State(state): State<DashboardState>,
|
||||
Path(id): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Don't allow deleting the default client
|
||||
if id == "default" {
|
||||
return Json(ApiResponse::error("Cannot delete default client".to_string()));
|
||||
}
|
||||
|
||||
let result = sqlx::query("DELETE FROM clients WHERE client_id = ?")
|
||||
.bind(id)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Client deleted" }))),
|
||||
Err(e) => Json(ApiResponse::error(format!("Failed to delete client: {}", e))),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_client_usage(
|
||||
State(state): State<DashboardState>,
|
||||
Path(id): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Get per-model breakdown for this client
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
model,
|
||||
provider,
|
||||
COUNT(*) as request_count,
|
||||
SUM(prompt_tokens) as prompt_tokens,
|
||||
SUM(completion_tokens) as completion_tokens,
|
||||
SUM(total_tokens) as total_tokens,
|
||||
SUM(cost) as total_cost,
|
||||
AVG(duration_ms) as avg_duration_ms
|
||||
FROM llm_requests
|
||||
WHERE client_id = ?
|
||||
GROUP BY model, provider
|
||||
ORDER BY total_cost DESC
|
||||
"#,
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let breakdown: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
serde_json::json!({
|
||||
"model": row.get::<String, _>("model"),
|
||||
"provider": row.get::<String, _>("provider"),
|
||||
"request_count": row.get::<i64, _>("request_count"),
|
||||
"prompt_tokens": row.get::<i64, _>("prompt_tokens"),
|
||||
"completion_tokens": row.get::<i64, _>("completion_tokens"),
|
||||
"total_tokens": row.get::<i64, _>("total_tokens"),
|
||||
"total_cost": row.get::<f64, _>("total_cost"),
|
||||
"avg_duration_ms": row.get::<f64, _>("avg_duration_ms"),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"client_id": id,
|
||||
"breakdown": breakdown,
|
||||
})))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch client usage: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to fetch client usage: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
1077
src/dashboard/mod.rs
1077
src/dashboard/mod.rs
File diff suppressed because it is too large
Load Diff
116
src/dashboard/models.rs
Normal file
116
src/dashboard/models.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::Json,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json;
|
||||
use sqlx::Row;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct UpdateModelRequest {
|
||||
pub(super) enabled: bool,
|
||||
pub(super) prompt_cost: Option<f64>,
|
||||
pub(super) completion_cost: Option<f64>,
|
||||
pub(super) mapping: Option<String>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_models(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let registry = &state.app_state.model_registry;
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Load overrides from database
|
||||
let db_models_result =
|
||||
sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
let mut db_models = HashMap::new();
|
||||
if let Ok(rows) = db_models_result {
|
||||
for row in rows {
|
||||
let id: String = row.get("id");
|
||||
db_models.insert(id, row);
|
||||
}
|
||||
}
|
||||
|
||||
let mut models_json = Vec::new();
|
||||
|
||||
for (p_id, p_info) in ®istry.providers {
|
||||
for (m_id, m_meta) in &p_info.models {
|
||||
let mut enabled = true;
|
||||
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
||||
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
||||
let mut mapping = None::<String>;
|
||||
|
||||
if let Some(row) = db_models.get(m_id) {
|
||||
enabled = row.get("enabled");
|
||||
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") {
|
||||
prompt_cost = p;
|
||||
}
|
||||
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") {
|
||||
completion_cost = c;
|
||||
}
|
||||
mapping = row.get("mapping");
|
||||
}
|
||||
|
||||
models_json.push(serde_json::json!({
|
||||
"id": m_id,
|
||||
"provider": p_id,
|
||||
"name": m_meta.name,
|
||||
"enabled": enabled,
|
||||
"prompt_cost": prompt_cost,
|
||||
"completion_cost": completion_cost,
|
||||
"mapping": mapping,
|
||||
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(models_json)))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_update_model(
|
||||
State(state): State<DashboardState>,
|
||||
Path(id): Path<String>,
|
||||
Json(payload): Json<UpdateModelRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Find provider_id for this model in registry
|
||||
let provider_id = state
|
||||
.app_state
|
||||
.model_registry
|
||||
.providers
|
||||
.iter()
|
||||
.find(|(_, p)| p.models.contains_key(&id))
|
||||
.map(|(id, _)| id.clone())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
enabled = excluded.enabled,
|
||||
prompt_cost_per_m = excluded.prompt_cost_per_m,
|
||||
completion_cost_per_m = excluded.completion_cost_per_m,
|
||||
mapping = excluded.mapping,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"#,
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(provider_id)
|
||||
.bind(payload.enabled)
|
||||
.bind(payload.prompt_cost)
|
||||
.bind(payload.completion_cost)
|
||||
.bind(payload.mapping)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" }))),
|
||||
Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))),
|
||||
}
|
||||
}
|
||||
346
src/dashboard/providers.rs
Normal file
346
src/dashboard/providers.rs
Normal file
@@ -0,0 +1,346 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::Json,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json;
|
||||
use sqlx::Row;
|
||||
use std::collections::HashMap;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct UpdateProviderRequest {
|
||||
pub(super) enabled: bool,
|
||||
pub(super) base_url: Option<String>,
|
||||
pub(super) api_key: Option<String>,
|
||||
pub(super) credit_balance: Option<f64>,
|
||||
pub(super) low_credit_threshold: Option<f64>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_providers(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let registry = &state.app_state.model_registry;
|
||||
let config = &state.app_state.config;
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Load all overrides from database
|
||||
let db_configs_result =
|
||||
sqlx::query("SELECT id, enabled, base_url, credit_balance, low_credit_threshold FROM provider_configs")
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
let mut db_configs = HashMap::new();
|
||||
if let Ok(rows) = db_configs_result {
|
||||
for row in rows {
|
||||
let id: String = row.get("id");
|
||||
let enabled: bool = row.get("enabled");
|
||||
let base_url: Option<String> = row.get("base_url");
|
||||
let balance: f64 = row.get("credit_balance");
|
||||
let threshold: f64 = row.get("low_credit_threshold");
|
||||
db_configs.insert(id, (enabled, base_url, balance, threshold));
|
||||
}
|
||||
}
|
||||
|
||||
let mut providers_json = Vec::new();
|
||||
|
||||
// Define the list of providers we support
|
||||
let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
|
||||
|
||||
for id in provider_ids {
|
||||
// Get base config
|
||||
let (mut enabled, mut base_url, display_name) = match id {
|
||||
"openai" => (
|
||||
config.providers.openai.enabled,
|
||||
config.providers.openai.base_url.clone(),
|
||||
"OpenAI",
|
||||
),
|
||||
"gemini" => (
|
||||
config.providers.gemini.enabled,
|
||||
config.providers.gemini.base_url.clone(),
|
||||
"Google Gemini",
|
||||
),
|
||||
"deepseek" => (
|
||||
config.providers.deepseek.enabled,
|
||||
config.providers.deepseek.base_url.clone(),
|
||||
"DeepSeek",
|
||||
),
|
||||
"grok" => (
|
||||
config.providers.grok.enabled,
|
||||
config.providers.grok.base_url.clone(),
|
||||
"xAI Grok",
|
||||
),
|
||||
"ollama" => (
|
||||
config.providers.ollama.enabled,
|
||||
config.providers.ollama.base_url.clone(),
|
||||
"Ollama",
|
||||
),
|
||||
_ => (false, "".to_string(), "Unknown"),
|
||||
};
|
||||
|
||||
let mut balance = 0.0;
|
||||
let mut threshold = 5.0;
|
||||
|
||||
// Apply database overrides
|
||||
if let Some((db_enabled, db_url, db_balance, db_threshold)) = db_configs.get(id) {
|
||||
enabled = *db_enabled;
|
||||
if let Some(url) = db_url {
|
||||
base_url = url.clone();
|
||||
}
|
||||
balance = *db_balance;
|
||||
threshold = *db_threshold;
|
||||
}
|
||||
|
||||
// Find models for this provider in registry
|
||||
let mut models = Vec::new();
|
||||
if let Some(p_info) = registry.providers.get(id) {
|
||||
models = p_info.models.keys().cloned().collect();
|
||||
} else if id == "ollama" {
|
||||
models = config.providers.ollama.models.clone();
|
||||
}
|
||||
|
||||
// Determine status
|
||||
let status = if !enabled {
|
||||
"disabled"
|
||||
} else {
|
||||
// Check if it's actually initialized in the provider manager
|
||||
if state.app_state.provider_manager.get_provider(id).await.is_some() {
|
||||
// Check circuit breaker
|
||||
if state
|
||||
.app_state
|
||||
.rate_limit_manager
|
||||
.check_provider_request(id)
|
||||
.await
|
||||
.unwrap_or(true)
|
||||
{
|
||||
"online"
|
||||
} else {
|
||||
"degraded"
|
||||
}
|
||||
} else {
|
||||
"error" // Enabled but failed to initialize (e.g. missing API key)
|
||||
}
|
||||
};
|
||||
|
||||
providers_json.push(serde_json::json!({
|
||||
"id": id,
|
||||
"name": display_name,
|
||||
"enabled": enabled,
|
||||
"status": status,
|
||||
"models": models,
|
||||
"base_url": base_url,
|
||||
"credit_balance": balance,
|
||||
"low_credit_threshold": threshold,
|
||||
"last_used": None::<String>,
|
||||
}));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(providers_json)))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_provider(
|
||||
State(state): State<DashboardState>,
|
||||
Path(name): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let registry = &state.app_state.model_registry;
|
||||
let config = &state.app_state.config;
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Validate provider name
|
||||
let (mut enabled, mut base_url, display_name) = match name.as_str() {
|
||||
"openai" => (
|
||||
config.providers.openai.enabled,
|
||||
config.providers.openai.base_url.clone(),
|
||||
"OpenAI",
|
||||
),
|
||||
"gemini" => (
|
||||
config.providers.gemini.enabled,
|
||||
config.providers.gemini.base_url.clone(),
|
||||
"Google Gemini",
|
||||
),
|
||||
"deepseek" => (
|
||||
config.providers.deepseek.enabled,
|
||||
config.providers.deepseek.base_url.clone(),
|
||||
"DeepSeek",
|
||||
),
|
||||
"grok" => (
|
||||
config.providers.grok.enabled,
|
||||
config.providers.grok.base_url.clone(),
|
||||
"xAI Grok",
|
||||
),
|
||||
"ollama" => (
|
||||
config.providers.ollama.enabled,
|
||||
config.providers.ollama.base_url.clone(),
|
||||
"Ollama",
|
||||
),
|
||||
_ => return Json(ApiResponse::error(format!("Unknown provider '{}'", name))),
|
||||
};
|
||||
|
||||
let mut balance = 0.0;
|
||||
let mut threshold = 5.0;
|
||||
|
||||
// Apply database overrides
|
||||
let db_config = sqlx::query(
|
||||
"SELECT enabled, base_url, credit_balance, low_credit_threshold FROM provider_configs WHERE id = ?",
|
||||
)
|
||||
.bind(&name)
|
||||
.fetch_optional(pool)
|
||||
.await;
|
||||
|
||||
if let Ok(Some(row)) = db_config {
|
||||
enabled = row.get::<bool, _>("enabled");
|
||||
if let Some(url) = row.get::<Option<String>, _>("base_url") {
|
||||
base_url = url;
|
||||
}
|
||||
balance = row.get::<f64, _>("credit_balance");
|
||||
threshold = row.get::<f64, _>("low_credit_threshold");
|
||||
}
|
||||
|
||||
// Find models for this provider
|
||||
let mut models = Vec::new();
|
||||
if let Some(p_info) = registry.providers.get(name.as_str()) {
|
||||
models = p_info.models.keys().cloned().collect();
|
||||
} else if name == "ollama" {
|
||||
models = config.providers.ollama.models.clone();
|
||||
}
|
||||
|
||||
// Determine status
|
||||
let status = if !enabled {
|
||||
"disabled"
|
||||
} else if state.app_state.provider_manager.get_provider(&name).await.is_some() {
|
||||
if state
|
||||
.app_state
|
||||
.rate_limit_manager
|
||||
.check_provider_request(&name)
|
||||
.await
|
||||
.unwrap_or(true)
|
||||
{
|
||||
"online"
|
||||
} else {
|
||||
"degraded"
|
||||
}
|
||||
} else {
|
||||
"error"
|
||||
};
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"id": name,
|
||||
"name": display_name,
|
||||
"enabled": enabled,
|
||||
"status": status,
|
||||
"models": models,
|
||||
"base_url": base_url,
|
||||
"credit_balance": balance,
|
||||
"low_credit_threshold": threshold,
|
||||
"last_used": None::<String>,
|
||||
})))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_update_provider(
|
||||
State(state): State<DashboardState>,
|
||||
Path(name): Path<String>,
|
||||
Json(payload): Json<UpdateProviderRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Update or insert into database
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
enabled = excluded.enabled,
|
||||
base_url = excluded.base_url,
|
||||
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
|
||||
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
|
||||
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"#
|
||||
)
|
||||
.bind(&name)
|
||||
.bind(name.to_uppercase())
|
||||
.bind(payload.enabled)
|
||||
.bind(&payload.base_url)
|
||||
.bind(&payload.api_key)
|
||||
.bind(payload.credit_balance)
|
||||
.bind(payload.low_credit_threshold)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
// Re-initialize provider in manager
|
||||
if let Err(e) = state
|
||||
.app_state
|
||||
.provider_manager
|
||||
.initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool)
|
||||
.await
|
||||
{
|
||||
warn!("Failed to re-initialize provider {}: {}", name, e);
|
||||
return Json(ApiResponse::error(format!(
|
||||
"Provider settings saved but initialization failed: {}",
|
||||
e
|
||||
)));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(
|
||||
serde_json::json!({ "message": "Provider updated and re-initialized" }),
|
||||
))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to update provider config: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to update provider: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_test_provider(
|
||||
State(state): State<DashboardState>,
|
||||
Path(name): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let provider = match state.app_state.provider_manager.get_provider(&name).await {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
return Json(ApiResponse::error(format!(
|
||||
"Provider '{}' not found or not enabled",
|
||||
name
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// Pick a real model for this provider from the registry
|
||||
let test_model = state
|
||||
.app_state
|
||||
.model_registry
|
||||
.providers
|
||||
.get(&name)
|
||||
.and_then(|p| p.models.keys().next().cloned())
|
||||
.unwrap_or_else(|| name.clone());
|
||||
|
||||
let test_request = crate::models::UnifiedRequest {
|
||||
client_id: "system-test".to_string(),
|
||||
model: test_model,
|
||||
messages: vec![crate::models::UnifiedMessage {
|
||||
role: "user".to_string(),
|
||||
content: vec![crate::models::ContentPart::Text { text: "Hi".to_string() }],
|
||||
}],
|
||||
temperature: None,
|
||||
max_tokens: Some(5),
|
||||
stream: false,
|
||||
has_images: false,
|
||||
};
|
||||
|
||||
match provider.chat_completion(test_request).await {
|
||||
Ok(_) => {
|
||||
let latency = start.elapsed().as_millis();
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"success": true,
|
||||
"latency": latency,
|
||||
"message": "Connection test successful"
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::error(format!("Provider test failed: {}", e))),
|
||||
}
|
||||
}
|
||||
64
src/dashboard/sessions.rs
Normal file
64
src/dashboard/sessions.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Session {
|
||||
pub username: String,
|
||||
pub role: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SessionManager {
|
||||
sessions: Arc<RwLock<HashMap<String, Session>>>,
|
||||
ttl_hours: i64,
|
||||
}
|
||||
|
||||
impl SessionManager {
|
||||
pub fn new(ttl_hours: i64) -> Self {
|
||||
Self {
|
||||
sessions: Arc::new(RwLock::new(HashMap::new())),
|
||||
ttl_hours,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new session and return the session token.
|
||||
pub async fn create_session(&self, username: String, role: String) -> String {
|
||||
let token = format!("session-{}", uuid::Uuid::new_v4());
|
||||
let now = Utc::now();
|
||||
let session = Session {
|
||||
username,
|
||||
role,
|
||||
created_at: now,
|
||||
expires_at: now + Duration::hours(self.ttl_hours),
|
||||
};
|
||||
self.sessions.write().await.insert(token.clone(), session);
|
||||
token
|
||||
}
|
||||
|
||||
/// Validate a session token and return the session if valid and not expired.
|
||||
pub async fn validate_session(&self, token: &str) -> Option<Session> {
|
||||
let sessions = self.sessions.read().await;
|
||||
sessions.get(token).and_then(|s| {
|
||||
if s.expires_at > Utc::now() {
|
||||
Some(s.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Revoke (delete) a session by token.
|
||||
pub async fn revoke_session(&self, token: &str) {
|
||||
self.sessions.write().await.remove(token);
|
||||
}
|
||||
|
||||
/// Remove all expired sessions from the store.
|
||||
pub async fn cleanup_expired(&self) {
|
||||
let now = Utc::now();
|
||||
self.sessions.write().await.retain(|_, s| s.expires_at > now);
|
||||
}
|
||||
}
|
||||
193
src/dashboard/system.rs
Normal file
193
src/dashboard/system.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
use axum::{extract::State, response::Json};
|
||||
use chrono;
|
||||
use serde_json;
|
||||
use sqlx::Row;
|
||||
use std::collections::HashMap;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
|
||||
pub(super) async fn handle_system_health(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let mut components = HashMap::new();
|
||||
components.insert("api_server".to_string(), "online".to_string());
|
||||
components.insert("database".to_string(), "online".to_string());
|
||||
|
||||
// Check provider health via circuit breakers
|
||||
let provider_ids: Vec<String> = state
|
||||
.app_state
|
||||
.provider_manager
|
||||
.get_all_providers()
|
||||
.await
|
||||
.iter()
|
||||
.map(|p| p.name().to_string())
|
||||
.collect();
|
||||
|
||||
for p_id in provider_ids {
|
||||
if state
|
||||
.app_state
|
||||
.rate_limit_manager
|
||||
.check_provider_request(&p_id)
|
||||
.await
|
||||
.unwrap_or(true)
|
||||
{
|
||||
components.insert(p_id, "online".to_string());
|
||||
} else {
|
||||
components.insert(p_id, "degraded".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Read real memory usage from /proc/self/status
|
||||
let memory_mb = std::fs::read_to_string("/proc/self/status")
|
||||
.ok()
|
||||
.and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string()))
|
||||
.and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::<f64>().ok()))
|
||||
.map(|kb| kb / 1024.0)
|
||||
.unwrap_or(0.0);
|
||||
|
||||
// Get real database pool stats
|
||||
let db_pool_size = state.app_state.db_pool.size();
|
||||
let db_pool_idle = state.app_state.db_pool.num_idle();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"status": "healthy",
|
||||
"timestamp": chrono::Utc::now().to_rfc3339(),
|
||||
"components": components,
|
||||
"metrics": {
|
||||
"memory_usage_mb": (memory_mb * 10.0).round() / 10.0,
|
||||
"db_connections_active": db_pool_size - db_pool_idle as u32,
|
||||
"db_connections_idle": db_pool_idle,
|
||||
}
|
||||
})))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_system_logs(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
id,
|
||||
timestamp,
|
||||
client_id,
|
||||
provider,
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cost,
|
||||
status,
|
||||
error_message,
|
||||
duration_ms
|
||||
FROM llm_requests
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 100
|
||||
"#,
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let logs: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
serde_json::json!({
|
||||
"id": row.get::<i64, _>("id"),
|
||||
"timestamp": row.get::<chrono::DateTime<chrono::Utc>, _>("timestamp"),
|
||||
"client_id": row.get::<String, _>("client_id"),
|
||||
"provider": row.get::<String, _>("provider"),
|
||||
"model": row.get::<String, _>("model"),
|
||||
"tokens": row.get::<i64, _>("total_tokens"),
|
||||
"cost": row.get::<f64, _>("cost"),
|
||||
"status": row.get::<String, _>("status"),
|
||||
"error": row.get::<Option<String>, _>("error_message"),
|
||||
"duration": row.get::<i64, _>("duration_ms"),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(logs)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch system logs: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch system logs".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_system_backup(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
let backup_id = format!("backup-{}", chrono::Utc::now().timestamp());
|
||||
let backup_path = format!("data/{}.db", backup_id);
|
||||
|
||||
// Ensure the data directory exists
|
||||
if let Err(e) = std::fs::create_dir_all("data") {
|
||||
return Json(ApiResponse::error(format!("Failed to create backup directory: {}", e)));
|
||||
}
|
||||
|
||||
// Use SQLite VACUUM INTO for a consistent backup
|
||||
let result = sqlx::query(&format!("VACUUM INTO '{}'", backup_path))
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
// Get backup file size
|
||||
let size_bytes = std::fs::metadata(&backup_path).map(|m| m.len()).unwrap_or(0);
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"success": true,
|
||||
"message": "Backup completed successfully",
|
||||
"backup_id": backup_id,
|
||||
"backup_path": backup_path,
|
||||
"size_bytes": size_bytes,
|
||||
})))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Database backup failed: {}", e);
|
||||
Json(ApiResponse::error(format!("Backup failed: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_settings(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let registry = &state.app_state.model_registry;
|
||||
let provider_count = registry.providers.len();
|
||||
let model_count: usize = registry.providers.values().map(|p| p.models.len()).sum();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"server": {
|
||||
"auth_tokens": state.app_state.auth_tokens.iter().map(|t| mask_token(t)).collect::<Vec<_>>(),
|
||||
"version": env!("CARGO_PKG_VERSION"),
|
||||
},
|
||||
"registry": {
|
||||
"provider_count": provider_count,
|
||||
"model_count": model_count,
|
||||
},
|
||||
"database": {
|
||||
"type": "SQLite",
|
||||
}
|
||||
})))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_update_settings(
|
||||
State(_state): State<DashboardState>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
Json(ApiResponse::error(
|
||||
"Changing settings at runtime is not yet supported. Please update your config file and restart the server."
|
||||
.to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
fn mask_token(token: &str) -> String {
|
||||
if token.len() <= 8 {
|
||||
return "*****".to_string();
|
||||
}
|
||||
|
||||
let masked_len = token.len().min(12);
|
||||
let visible_len = 4;
|
||||
let mask_len = masked_len - visible_len;
|
||||
|
||||
format!("{}{}", "*".repeat(mask_len), &token[token.len() - visible_len..])
|
||||
}
|
||||
330
src/dashboard/usage.rs
Normal file
330
src/dashboard/usage.rs
Normal file
@@ -0,0 +1,330 @@
|
||||
use axum::{extract::State, response::Json};
|
||||
use chrono;
|
||||
use serde_json;
|
||||
use sqlx::Row;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
|
||||
pub(super) async fn handle_usage_summary(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Total stats
|
||||
let total_stats = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(total_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(cost), 0.0) as total_cost,
|
||||
COUNT(DISTINCT client_id) as active_clients
|
||||
FROM llm_requests
|
||||
"#,
|
||||
)
|
||||
.fetch_one(pool);
|
||||
|
||||
// Today's stats
|
||||
let today = chrono::Utc::now().format("%Y-%m-%d").to_string();
|
||||
let today_stats = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
COUNT(*) as today_requests,
|
||||
COALESCE(SUM(total_tokens), 0) as today_tokens,
|
||||
COALESCE(SUM(cost), 0.0) as today_cost
|
||||
FROM llm_requests
|
||||
WHERE strftime('%Y-%m-%d', timestamp) = ?
|
||||
"#,
|
||||
)
|
||||
.bind(today)
|
||||
.fetch_one(pool);
|
||||
|
||||
// Error stats
|
||||
let error_stats = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as errors
|
||||
FROM llm_requests
|
||||
"#,
|
||||
)
|
||||
.fetch_one(pool);
|
||||
|
||||
// Average response time
|
||||
let avg_response = sqlx::query(
|
||||
r#"
|
||||
SELECT COALESCE(AVG(duration_ms), 0.0) as avg_duration
|
||||
FROM llm_requests
|
||||
WHERE status = 'success'
|
||||
"#,
|
||||
)
|
||||
.fetch_one(pool);
|
||||
|
||||
match tokio::join!(total_stats, today_stats, error_stats, avg_response) {
|
||||
(Ok(t), Ok(d), Ok(e), Ok(a)) => {
|
||||
let total_requests: i64 = t.get("total_requests");
|
||||
let total_tokens: i64 = t.get("total_tokens");
|
||||
let total_cost: f64 = t.get("total_cost");
|
||||
let active_clients: i64 = t.get("active_clients");
|
||||
|
||||
let today_requests: i64 = d.get("today_requests");
|
||||
let today_cost: f64 = d.get("today_cost");
|
||||
|
||||
let total_count: i64 = e.get("total");
|
||||
let error_count: i64 = e.get("errors");
|
||||
let error_rate = if total_count > 0 {
|
||||
(error_count as f64 / total_count as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let avg_response_time: f64 = a.get("avg_duration");
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"total_requests": total_requests,
|
||||
"total_tokens": total_tokens,
|
||||
"total_cost": total_cost,
|
||||
"active_clients": active_clients,
|
||||
"today_requests": today_requests,
|
||||
"today_cost": today_cost,
|
||||
"error_rate": error_rate,
|
||||
"avg_response_time": avg_response_time,
|
||||
})))
|
||||
}
|
||||
_ => Json(ApiResponse::error("Failed to fetch usage statistics".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_time_series(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let twenty_four_hours_ago = now - chrono::Duration::hours(24);
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
strftime('%H:00', timestamp) as hour,
|
||||
COUNT(*) as requests,
|
||||
SUM(total_tokens) as tokens,
|
||||
SUM(cost) as cost
|
||||
FROM llm_requests
|
||||
WHERE timestamp >= ?
|
||||
GROUP BY hour
|
||||
ORDER BY hour
|
||||
"#,
|
||||
)
|
||||
.bind(twenty_four_hours_ago)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let mut series = Vec::new();
|
||||
|
||||
for row in rows {
|
||||
let hour: String = row.get("hour");
|
||||
let requests: i64 = row.get("requests");
|
||||
let tokens: i64 = row.get("tokens");
|
||||
let cost: f64 = row.get("cost");
|
||||
|
||||
series.push(serde_json::json!({
|
||||
"time": hour,
|
||||
"requests": requests,
|
||||
"tokens": tokens,
|
||||
"cost": cost,
|
||||
}));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"series": series,
|
||||
"period": "24h"
|
||||
})))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch time series data: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch time series data".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_clients_usage(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
// Query database for client usage statistics
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
client_id,
|
||||
COUNT(*) as requests,
|
||||
SUM(total_tokens) as tokens,
|
||||
SUM(cost) as cost,
|
||||
MAX(timestamp) as last_request
|
||||
FROM llm_requests
|
||||
GROUP BY client_id
|
||||
ORDER BY requests DESC
|
||||
"#,
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let mut client_usage = Vec::new();
|
||||
|
||||
for row in rows {
|
||||
let client_id: String = row.get("client_id");
|
||||
let requests: i64 = row.get("requests");
|
||||
let tokens: i64 = row.get("tokens");
|
||||
let cost: f64 = row.get("cost");
|
||||
let last_request: Option<chrono::DateTime<chrono::Utc>> = row.get("last_request");
|
||||
|
||||
client_usage.push(serde_json::json!({
|
||||
"client_id": client_id,
|
||||
"client_name": client_id,
|
||||
"requests": requests,
|
||||
"tokens": tokens,
|
||||
"cost": cost,
|
||||
"last_request": last_request,
|
||||
}));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(client_usage)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch client usage data: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch client usage data".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_providers_usage(
|
||||
State(state): State<DashboardState>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
// Query database for provider usage statistics
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
provider,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(total_tokens), 0) as tokens,
|
||||
COALESCE(SUM(cost), 0.0) as cost
|
||||
FROM llm_requests
|
||||
GROUP BY provider
|
||||
ORDER BY requests DESC
|
||||
"#,
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let mut provider_usage = Vec::new();
|
||||
|
||||
for row in rows {
|
||||
let provider: String = row.get("provider");
|
||||
let requests: i64 = row.get("requests");
|
||||
let tokens: i64 = row.get("tokens");
|
||||
let cost: f64 = row.get("cost");
|
||||
|
||||
provider_usage.push(serde_json::json!({
|
||||
"provider": provider,
|
||||
"requests": requests,
|
||||
"tokens": tokens,
|
||||
"cost": cost,
|
||||
}));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(provider_usage)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch provider usage data: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch provider usage data".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_detailed_usage(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
strftime('%Y-%m-%d', timestamp) as date,
|
||||
client_id,
|
||||
provider,
|
||||
model,
|
||||
COUNT(*) as requests,
|
||||
SUM(total_tokens) as tokens,
|
||||
SUM(cost) as cost
|
||||
FROM llm_requests
|
||||
GROUP BY date, client_id, provider, model
|
||||
ORDER BY date DESC
|
||||
LIMIT 100
|
||||
"#,
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let usage: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
serde_json::json!({
|
||||
"date": row.get::<String, _>("date"),
|
||||
"client": row.get::<String, _>("client_id"),
|
||||
"provider": row.get::<String, _>("provider"),
|
||||
"model": row.get::<String, _>("model"),
|
||||
"requests": row.get::<i64, _>("requests"),
|
||||
"tokens": row.get::<i64, _>("tokens"),
|
||||
"cost": row.get::<f64, _>("cost"),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(usage)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch detailed usage: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch detailed usage".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_analytics_breakdown(
|
||||
State(state): State<DashboardState>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Model breakdown
|
||||
let models =
|
||||
sqlx::query("SELECT model as label, COUNT(*) as value FROM llm_requests GROUP BY model ORDER BY value DESC")
|
||||
.fetch_all(pool);
|
||||
|
||||
// Client breakdown
|
||||
let clients = sqlx::query(
|
||||
"SELECT client_id as label, COUNT(*) as value FROM llm_requests GROUP BY client_id ORDER BY value DESC",
|
||||
)
|
||||
.fetch_all(pool);
|
||||
|
||||
match tokio::join!(models, clients) {
|
||||
(Ok(m_rows), Ok(c_rows)) => {
|
||||
let model_breakdown: Vec<serde_json::Value> = m_rows
|
||||
.into_iter()
|
||||
.map(|r| serde_json::json!({ "label": r.get::<String, _>("label"), "value": r.get::<i64, _>("value") }))
|
||||
.collect();
|
||||
|
||||
let client_breakdown: Vec<serde_json::Value> = c_rows
|
||||
.into_iter()
|
||||
.map(|r| serde_json::json!({ "label": r.get::<String, _>("label"), "value": r.get::<i64, _>("value") }))
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"models": model_breakdown,
|
||||
"clients": client_breakdown
|
||||
})))
|
||||
}
|
||||
_ => Json(ApiResponse::error("Failed to fetch analytics breakdown".to_string())),
|
||||
}
|
||||
}
|
||||
75
src/dashboard/websocket.rs
Normal file
75
src/dashboard/websocket.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
use axum::{
|
||||
extract::{
|
||||
State,
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use serde_json;
|
||||
use tracing::info;
|
||||
|
||||
use super::DashboardState;
|
||||
|
||||
// WebSocket handler
|
||||
pub(super) async fn handle_websocket(ws: WebSocketUpgrade, State(state): State<DashboardState>) -> impl IntoResponse {
|
||||
ws.on_upgrade(|socket| handle_websocket_connection(socket, state))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_websocket_connection(mut socket: WebSocket, state: DashboardState) {
|
||||
info!("WebSocket connection established");
|
||||
|
||||
// Subscribe to events from the global bus
|
||||
let mut rx = state.app_state.dashboard_tx.subscribe();
|
||||
|
||||
// Send initial connection message
|
||||
let _ = socket
|
||||
.send(Message::Text(
|
||||
serde_json::json!({
|
||||
"type": "connected",
|
||||
"message": "Connected to LLM Proxy Dashboard"
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await;
|
||||
|
||||
// Handle incoming messages and broadcast events
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Receive broadcast events
|
||||
Ok(event) = rx.recv() => {
|
||||
let Ok(json_str) = serde_json::to_string(&event) else {
|
||||
continue;
|
||||
};
|
||||
let message = Message::Text(json_str.into());
|
||||
if socket.send(message).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Receive WebSocket messages
|
||||
result = socket.recv() => {
|
||||
match result {
|
||||
Some(Ok(Message::Text(text))) => {
|
||||
handle_websocket_message(&text, &state).await;
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("WebSocket connection closed");
|
||||
}
|
||||
|
||||
pub(super) async fn handle_websocket_message(text: &str, state: &DashboardState) {
|
||||
// Parse and handle WebSocket messages
|
||||
if let Ok(data) = serde_json::from_str::<serde_json::Value>(text)
|
||||
&& let Some("ping") = data.get("type").and_then(|v| v.as_str())
|
||||
{
|
||||
let _ = state.app_state.dashboard_tx.send(serde_json::json!({
|
||||
"type": "pong",
|
||||
"payload": {}
|
||||
}));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user