refactor: comprehensive audit — fix bugs, harden security, deduplicate providers, add CI/Docker
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled

Phase 1: Fix compilation (config_path Option<PathBuf>, streaming test, stale test cleanup)
Phase 2: Fix critical bugs (remove block_on deadlocks in 4 providers, fix broken SQL query builder)
Phase 3: Security hardening (session manager, real auth, token masking, Gemini key to header, password policy)
Phase 4: Implement stubs (real provider test, /proc health metrics, client/provider/backup endpoints, has_images)
Phase 5: Code quality (shared provider helpers, explicit re-exports, all Clippy warnings fixed, unwrap removal, 6 unused deps removed, dashboard split into 7 sub-modules)
Phase 6: Infrastructure (GitHub Actions CI, multi-stage Dockerfile, rustfmt.toml, clippy.toml, script fixes)
This commit is contained in:
2026-03-02 00:35:45 -05:00
parent ba643dd2b0
commit 2cdc49d7f2
42 changed files with 2800 additions and 2747 deletions

130
src/dashboard/auth.rs Normal file
View File

@@ -0,0 +1,130 @@
use axum::{extract::State, response::Json};
use bcrypt;
use serde::Deserialize;
use sqlx::Row;
use tracing::warn;
use super::{ApiResponse, DashboardState};
// Authentication handlers
#[derive(Deserialize)]
pub(super) struct LoginRequest {
pub(super) username: String,
pub(super) password: String,
}
pub(super) async fn handle_login(
State(state): State<DashboardState>,
Json(payload): Json<LoginRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let user_result =
sqlx::query("SELECT username, password_hash, role, must_change_password FROM users WHERE username = ?")
.bind(&payload.username)
.fetch_optional(pool)
.await;
match user_result {
Ok(Some(row)) => {
let hash = row.get::<String, _>("password_hash");
if bcrypt::verify(&payload.password, &hash).unwrap_or(false) {
let username = row.get::<String, _>("username");
let role = row.get::<String, _>("role");
let must_change_password = row.get::<bool, _>("must_change_password");
let token = state
.session_manager
.create_session(username.clone(), role.clone())
.await;
Json(ApiResponse::success(serde_json::json!({
"token": token,
"must_change_password": must_change_password,
"user": {
"username": username,
"name": "Administrator",
"role": role
}
})))
} else {
Json(ApiResponse::error("Invalid username or password".to_string()))
}
}
Ok(None) => Json(ApiResponse::error("Invalid username or password".to_string())),
Err(e) => {
warn!("Database error during login: {}", e);
Json(ApiResponse::error("Login failed due to system error".to_string()))
}
}
}
pub(super) async fn handle_auth_status(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let token = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
if let Some(token) = token
&& let Some(session) = state.session_manager.validate_session(token).await
{
return Json(ApiResponse::success(serde_json::json!({
"authenticated": true,
"user": {
"username": session.username,
"name": "Administrator",
"role": session.role
}
})));
}
Json(ApiResponse::error("Not authenticated".to_string()))
}
#[derive(Deserialize)]
pub(super) struct ChangePasswordRequest {
pub(super) current_password: String,
pub(super) new_password: String,
}
pub(super) async fn handle_change_password(
State(state): State<DashboardState>,
Json(payload): Json<ChangePasswordRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
// For now, always change 'admin' user
let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = 'admin'")
.fetch_one(pool)
.await;
match user_result {
Ok(row) => {
let hash = row.get::<String, _>("password_hash");
if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) {
let new_hash = match bcrypt::hash(&payload.new_password, 12) {
Ok(h) => h,
Err(_) => return Json(ApiResponse::error("Failed to hash new password".to_string())),
};
let update_result = sqlx::query(
"UPDATE users SET password_hash = ?, must_change_password = FALSE WHERE username = 'admin'",
)
.bind(new_hash)
.execute(pool)
.await;
match update_result {
Ok(_) => Json(ApiResponse::success(
serde_json::json!({ "message": "Password updated successfully" }),
)),
Err(e) => Json(ApiResponse::error(format!("Failed to update database: {}", e))),
}
} else {
Json(ApiResponse::error("Current password incorrect".to_string()))
}
}
Err(e) => Json(ApiResponse::error(format!("User not found: {}", e))),
}
}

227
src/dashboard/clients.rs Normal file
View File

@@ -0,0 +1,227 @@
use axum::{
extract::{Path, State},
response::Json,
};
use chrono;
use serde::Deserialize;
use serde_json;
use sqlx::Row;
use tracing::warn;
use uuid;
use super::{ApiResponse, DashboardState};
#[derive(Deserialize)]
pub(super) struct CreateClientRequest {
pub(super) name: String,
pub(super) client_id: Option<String>,
}
pub(super) async fn handle_get_clients(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
client_id as id,
name,
created_at,
total_requests,
total_tokens,
total_cost,
is_active
FROM clients
ORDER BY created_at DESC
"#,
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let clients: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
serde_json::json!({
"id": row.get::<String, _>("id"),
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
"requests_count": row.get::<i64, _>("total_requests"),
"total_tokens": row.get::<i64, _>("total_tokens"),
"total_cost": row.get::<f64, _>("total_cost"),
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
})
})
.collect();
Json(ApiResponse::success(serde_json::json!(clients)))
}
Err(e) => {
warn!("Failed to fetch clients: {}", e);
Json(ApiResponse::error("Failed to fetch clients".to_string()))
}
}
}
pub(super) async fn handle_create_client(
State(state): State<DashboardState>,
Json(payload): Json<CreateClientRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let client_id = payload
.client_id
.unwrap_or_else(|| format!("client-{}", &uuid::Uuid::new_v4().to_string()[..8]));
let result = sqlx::query(
r#"
INSERT INTO clients (client_id, name, is_active)
VALUES (?, ?, TRUE)
RETURNING *
"#,
)
.bind(&client_id)
.bind(&payload.name)
.fetch_one(pool)
.await;
match result {
Ok(row) => Json(ApiResponse::success(serde_json::json!({
"id": row.get::<String, _>("client_id"),
"name": row.get::<Option<String>, _>("name"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
"status": "active",
}))),
Err(e) => {
warn!("Failed to create client: {}", e);
Json(ApiResponse::error(format!("Failed to create client: {}", e)))
}
}
}
pub(super) async fn handle_get_client(
State(state): State<DashboardState>,
Path(id): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
c.client_id as id,
c.name,
c.is_active,
c.created_at,
COALESCE(c.total_tokens, 0) as total_tokens,
COALESCE(c.total_cost, 0.0) as total_cost,
COUNT(r.id) as total_requests,
MAX(r.timestamp) as last_request
FROM clients c
LEFT JOIN llm_requests r ON c.client_id = r.client_id
WHERE c.client_id = ?
GROUP BY c.client_id
"#,
)
.bind(&id)
.fetch_optional(pool)
.await;
match result {
Ok(Some(row)) => Json(ApiResponse::success(serde_json::json!({
"id": row.get::<String, _>("id"),
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
"is_active": row.get::<bool, _>("is_active"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
"total_tokens": row.get::<i64, _>("total_tokens"),
"total_cost": row.get::<f64, _>("total_cost"),
"total_requests": row.get::<i64, _>("total_requests"),
"last_request": row.get::<Option<chrono::DateTime<chrono::Utc>>, _>("last_request"),
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
}))),
Ok(None) => Json(ApiResponse::error(format!("Client '{}' not found", id))),
Err(e) => {
warn!("Failed to fetch client: {}", e);
Json(ApiResponse::error(format!("Failed to fetch client: {}", e)))
}
}
}
pub(super) async fn handle_delete_client(
State(state): State<DashboardState>,
Path(id): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
// Don't allow deleting the default client
if id == "default" {
return Json(ApiResponse::error("Cannot delete default client".to_string()));
}
let result = sqlx::query("DELETE FROM clients WHERE client_id = ?")
.bind(id)
.execute(pool)
.await;
match result {
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Client deleted" }))),
Err(e) => Json(ApiResponse::error(format!("Failed to delete client: {}", e))),
}
}
pub(super) async fn handle_client_usage(
State(state): State<DashboardState>,
Path(id): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
// Get per-model breakdown for this client
let result = sqlx::query(
r#"
SELECT
model,
provider,
COUNT(*) as request_count,
SUM(prompt_tokens) as prompt_tokens,
SUM(completion_tokens) as completion_tokens,
SUM(total_tokens) as total_tokens,
SUM(cost) as total_cost,
AVG(duration_ms) as avg_duration_ms
FROM llm_requests
WHERE client_id = ?
GROUP BY model, provider
ORDER BY total_cost DESC
"#,
)
.bind(&id)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let breakdown: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
serde_json::json!({
"model": row.get::<String, _>("model"),
"provider": row.get::<String, _>("provider"),
"request_count": row.get::<i64, _>("request_count"),
"prompt_tokens": row.get::<i64, _>("prompt_tokens"),
"completion_tokens": row.get::<i64, _>("completion_tokens"),
"total_tokens": row.get::<i64, _>("total_tokens"),
"total_cost": row.get::<f64, _>("total_cost"),
"avg_duration_ms": row.get::<f64, _>("avg_duration_ms"),
})
})
.collect();
Json(ApiResponse::success(serde_json::json!({
"client_id": id,
"breakdown": breakdown,
})))
}
Err(e) => {
warn!("Failed to fetch client usage: {}", e);
Json(ApiResponse::error(format!("Failed to fetch client usage: {}", e)))
}
}
}

File diff suppressed because it is too large Load Diff

116
src/dashboard/models.rs Normal file
View File

@@ -0,0 +1,116 @@
use axum::{
extract::{Path, State},
response::Json,
};
use serde::Deserialize;
use serde_json;
use sqlx::Row;
use std::collections::HashMap;
use super::{ApiResponse, DashboardState};
#[derive(Deserialize)]
pub(super) struct UpdateModelRequest {
pub(super) enabled: bool,
pub(super) prompt_cost: Option<f64>,
pub(super) completion_cost: Option<f64>,
pub(super) mapping: Option<String>,
}
pub(super) async fn handle_get_models(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let registry = &state.app_state.model_registry;
let pool = &state.app_state.db_pool;
// Load overrides from database
let db_models_result =
sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
.fetch_all(pool)
.await;
let mut db_models = HashMap::new();
if let Ok(rows) = db_models_result {
for row in rows {
let id: String = row.get("id");
db_models.insert(id, row);
}
}
let mut models_json = Vec::new();
for (p_id, p_info) in &registry.providers {
for (m_id, m_meta) in &p_info.models {
let mut enabled = true;
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
let mut mapping = None::<String>;
if let Some(row) = db_models.get(m_id) {
enabled = row.get("enabled");
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") {
prompt_cost = p;
}
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") {
completion_cost = c;
}
mapping = row.get("mapping");
}
models_json.push(serde_json::json!({
"id": m_id,
"provider": p_id,
"name": m_meta.name,
"enabled": enabled,
"prompt_cost": prompt_cost,
"completion_cost": completion_cost,
"mapping": mapping,
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
}));
}
}
Json(ApiResponse::success(serde_json::json!(models_json)))
}
pub(super) async fn handle_update_model(
State(state): State<DashboardState>,
Path(id): Path<String>,
Json(payload): Json<UpdateModelRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
// Find provider_id for this model in registry
let provider_id = state
.app_state
.model_registry
.providers
.iter()
.find(|(_, p)| p.models.contains_key(&id))
.map(|(id, _)| id.clone())
.unwrap_or_else(|| "unknown".to_string());
let result = sqlx::query(
r#"
INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
enabled = excluded.enabled,
prompt_cost_per_m = excluded.prompt_cost_per_m,
completion_cost_per_m = excluded.completion_cost_per_m,
mapping = excluded.mapping,
updated_at = CURRENT_TIMESTAMP
"#,
)
.bind(&id)
.bind(provider_id)
.bind(payload.enabled)
.bind(payload.prompt_cost)
.bind(payload.completion_cost)
.bind(payload.mapping)
.execute(pool)
.await;
match result {
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" }))),
Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))),
}
}

346
src/dashboard/providers.rs Normal file
View File

@@ -0,0 +1,346 @@
use axum::{
extract::{Path, State},
response::Json,
};
use serde::Deserialize;
use serde_json;
use sqlx::Row;
use std::collections::HashMap;
use tracing::warn;
use super::{ApiResponse, DashboardState};
#[derive(Deserialize)]
pub(super) struct UpdateProviderRequest {
pub(super) enabled: bool,
pub(super) base_url: Option<String>,
pub(super) api_key: Option<String>,
pub(super) credit_balance: Option<f64>,
pub(super) low_credit_threshold: Option<f64>,
}
pub(super) async fn handle_get_providers(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let registry = &state.app_state.model_registry;
let config = &state.app_state.config;
let pool = &state.app_state.db_pool;
// Load all overrides from database
let db_configs_result =
sqlx::query("SELECT id, enabled, base_url, credit_balance, low_credit_threshold FROM provider_configs")
.fetch_all(pool)
.await;
let mut db_configs = HashMap::new();
if let Ok(rows) = db_configs_result {
for row in rows {
let id: String = row.get("id");
let enabled: bool = row.get("enabled");
let base_url: Option<String> = row.get("base_url");
let balance: f64 = row.get("credit_balance");
let threshold: f64 = row.get("low_credit_threshold");
db_configs.insert(id, (enabled, base_url, balance, threshold));
}
}
let mut providers_json = Vec::new();
// Define the list of providers we support
let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
for id in provider_ids {
// Get base config
let (mut enabled, mut base_url, display_name) = match id {
"openai" => (
config.providers.openai.enabled,
config.providers.openai.base_url.clone(),
"OpenAI",
),
"gemini" => (
config.providers.gemini.enabled,
config.providers.gemini.base_url.clone(),
"Google Gemini",
),
"deepseek" => (
config.providers.deepseek.enabled,
config.providers.deepseek.base_url.clone(),
"DeepSeek",
),
"grok" => (
config.providers.grok.enabled,
config.providers.grok.base_url.clone(),
"xAI Grok",
),
"ollama" => (
config.providers.ollama.enabled,
config.providers.ollama.base_url.clone(),
"Ollama",
),
_ => (false, "".to_string(), "Unknown"),
};
let mut balance = 0.0;
let mut threshold = 5.0;
// Apply database overrides
if let Some((db_enabled, db_url, db_balance, db_threshold)) = db_configs.get(id) {
enabled = *db_enabled;
if let Some(url) = db_url {
base_url = url.clone();
}
balance = *db_balance;
threshold = *db_threshold;
}
// Find models for this provider in registry
let mut models = Vec::new();
if let Some(p_info) = registry.providers.get(id) {
models = p_info.models.keys().cloned().collect();
} else if id == "ollama" {
models = config.providers.ollama.models.clone();
}
// Determine status
let status = if !enabled {
"disabled"
} else {
// Check if it's actually initialized in the provider manager
if state.app_state.provider_manager.get_provider(id).await.is_some() {
// Check circuit breaker
if state
.app_state
.rate_limit_manager
.check_provider_request(id)
.await
.unwrap_or(true)
{
"online"
} else {
"degraded"
}
} else {
"error" // Enabled but failed to initialize (e.g. missing API key)
}
};
providers_json.push(serde_json::json!({
"id": id,
"name": display_name,
"enabled": enabled,
"status": status,
"models": models,
"base_url": base_url,
"credit_balance": balance,
"low_credit_threshold": threshold,
"last_used": None::<String>,
}));
}
Json(ApiResponse::success(serde_json::json!(providers_json)))
}
pub(super) async fn handle_get_provider(
State(state): State<DashboardState>,
Path(name): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let registry = &state.app_state.model_registry;
let config = &state.app_state.config;
let pool = &state.app_state.db_pool;
// Validate provider name
let (mut enabled, mut base_url, display_name) = match name.as_str() {
"openai" => (
config.providers.openai.enabled,
config.providers.openai.base_url.clone(),
"OpenAI",
),
"gemini" => (
config.providers.gemini.enabled,
config.providers.gemini.base_url.clone(),
"Google Gemini",
),
"deepseek" => (
config.providers.deepseek.enabled,
config.providers.deepseek.base_url.clone(),
"DeepSeek",
),
"grok" => (
config.providers.grok.enabled,
config.providers.grok.base_url.clone(),
"xAI Grok",
),
"ollama" => (
config.providers.ollama.enabled,
config.providers.ollama.base_url.clone(),
"Ollama",
),
_ => return Json(ApiResponse::error(format!("Unknown provider '{}'", name))),
};
let mut balance = 0.0;
let mut threshold = 5.0;
// Apply database overrides
let db_config = sqlx::query(
"SELECT enabled, base_url, credit_balance, low_credit_threshold FROM provider_configs WHERE id = ?",
)
.bind(&name)
.fetch_optional(pool)
.await;
if let Ok(Some(row)) = db_config {
enabled = row.get::<bool, _>("enabled");
if let Some(url) = row.get::<Option<String>, _>("base_url") {
base_url = url;
}
balance = row.get::<f64, _>("credit_balance");
threshold = row.get::<f64, _>("low_credit_threshold");
}
// Find models for this provider
let mut models = Vec::new();
if let Some(p_info) = registry.providers.get(name.as_str()) {
models = p_info.models.keys().cloned().collect();
} else if name == "ollama" {
models = config.providers.ollama.models.clone();
}
// Determine status
let status = if !enabled {
"disabled"
} else if state.app_state.provider_manager.get_provider(&name).await.is_some() {
if state
.app_state
.rate_limit_manager
.check_provider_request(&name)
.await
.unwrap_or(true)
{
"online"
} else {
"degraded"
}
} else {
"error"
};
Json(ApiResponse::success(serde_json::json!({
"id": name,
"name": display_name,
"enabled": enabled,
"status": status,
"models": models,
"base_url": base_url,
"credit_balance": balance,
"low_credit_threshold": threshold,
"last_used": None::<String>,
})))
}
pub(super) async fn handle_update_provider(
State(state): State<DashboardState>,
Path(name): Path<String>,
Json(payload): Json<UpdateProviderRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
// Update or insert into database
let result = sqlx::query(
r#"
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold)
VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
enabled = excluded.enabled,
base_url = excluded.base_url,
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
updated_at = CURRENT_TIMESTAMP
"#
)
.bind(&name)
.bind(name.to_uppercase())
.bind(payload.enabled)
.bind(&payload.base_url)
.bind(&payload.api_key)
.bind(payload.credit_balance)
.bind(payload.low_credit_threshold)
.execute(pool)
.await;
match result {
Ok(_) => {
// Re-initialize provider in manager
if let Err(e) = state
.app_state
.provider_manager
.initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool)
.await
{
warn!("Failed to re-initialize provider {}: {}", name, e);
return Json(ApiResponse::error(format!(
"Provider settings saved but initialization failed: {}",
e
)));
}
Json(ApiResponse::success(
serde_json::json!({ "message": "Provider updated and re-initialized" }),
))
}
Err(e) => {
warn!("Failed to update provider config: {}", e);
Json(ApiResponse::error(format!("Failed to update provider: {}", e)))
}
}
}
pub(super) async fn handle_test_provider(
State(state): State<DashboardState>,
Path(name): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let start = std::time::Instant::now();
let provider = match state.app_state.provider_manager.get_provider(&name).await {
Some(p) => p,
None => {
return Json(ApiResponse::error(format!(
"Provider '{}' not found or not enabled",
name
)));
}
};
// Pick a real model for this provider from the registry
let test_model = state
.app_state
.model_registry
.providers
.get(&name)
.and_then(|p| p.models.keys().next().cloned())
.unwrap_or_else(|| name.clone());
let test_request = crate::models::UnifiedRequest {
client_id: "system-test".to_string(),
model: test_model,
messages: vec![crate::models::UnifiedMessage {
role: "user".to_string(),
content: vec![crate::models::ContentPart::Text { text: "Hi".to_string() }],
}],
temperature: None,
max_tokens: Some(5),
stream: false,
has_images: false,
};
match provider.chat_completion(test_request).await {
Ok(_) => {
let latency = start.elapsed().as_millis();
Json(ApiResponse::success(serde_json::json!({
"success": true,
"latency": latency,
"message": "Connection test successful"
})))
}
Err(e) => Json(ApiResponse::error(format!("Provider test failed: {}", e))),
}
}

64
src/dashboard/sessions.rs Normal file
View File

@@ -0,0 +1,64 @@
use chrono::{DateTime, Duration, Utc};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Clone, Debug)]
pub struct Session {
pub username: String,
pub role: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
}
#[derive(Clone)]
pub struct SessionManager {
sessions: Arc<RwLock<HashMap<String, Session>>>,
ttl_hours: i64,
}
impl SessionManager {
pub fn new(ttl_hours: i64) -> Self {
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
ttl_hours,
}
}
/// Create a new session and return the session token.
pub async fn create_session(&self, username: String, role: String) -> String {
let token = format!("session-{}", uuid::Uuid::new_v4());
let now = Utc::now();
let session = Session {
username,
role,
created_at: now,
expires_at: now + Duration::hours(self.ttl_hours),
};
self.sessions.write().await.insert(token.clone(), session);
token
}
/// Validate a session token and return the session if valid and not expired.
pub async fn validate_session(&self, token: &str) -> Option<Session> {
let sessions = self.sessions.read().await;
sessions.get(token).and_then(|s| {
if s.expires_at > Utc::now() {
Some(s.clone())
} else {
None
}
})
}
/// Revoke (delete) a session by token.
pub async fn revoke_session(&self, token: &str) {
self.sessions.write().await.remove(token);
}
/// Remove all expired sessions from the store.
pub async fn cleanup_expired(&self) {
let now = Utc::now();
self.sessions.write().await.retain(|_, s| s.expires_at > now);
}
}

193
src/dashboard/system.rs Normal file
View File

@@ -0,0 +1,193 @@
use axum::{extract::State, response::Json};
use chrono;
use serde_json;
use sqlx::Row;
use std::collections::HashMap;
use tracing::warn;
use super::{ApiResponse, DashboardState};
pub(super) async fn handle_system_health(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let mut components = HashMap::new();
components.insert("api_server".to_string(), "online".to_string());
components.insert("database".to_string(), "online".to_string());
// Check provider health via circuit breakers
let provider_ids: Vec<String> = state
.app_state
.provider_manager
.get_all_providers()
.await
.iter()
.map(|p| p.name().to_string())
.collect();
for p_id in provider_ids {
if state
.app_state
.rate_limit_manager
.check_provider_request(&p_id)
.await
.unwrap_or(true)
{
components.insert(p_id, "online".to_string());
} else {
components.insert(p_id, "degraded".to_string());
}
}
// Read real memory usage from /proc/self/status
let memory_mb = std::fs::read_to_string("/proc/self/status")
.ok()
.and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string()))
.and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::<f64>().ok()))
.map(|kb| kb / 1024.0)
.unwrap_or(0.0);
// Get real database pool stats
let db_pool_size = state.app_state.db_pool.size();
let db_pool_idle = state.app_state.db_pool.num_idle();
Json(ApiResponse::success(serde_json::json!({
"status": "healthy",
"timestamp": chrono::Utc::now().to_rfc3339(),
"components": components,
"metrics": {
"memory_usage_mb": (memory_mb * 10.0).round() / 10.0,
"db_connections_active": db_pool_size - db_pool_idle as u32,
"db_connections_idle": db_pool_idle,
}
})))
}
pub(super) async fn handle_system_logs(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
id,
timestamp,
client_id,
provider,
model,
prompt_tokens,
completion_tokens,
total_tokens,
cost,
status,
error_message,
duration_ms
FROM llm_requests
ORDER BY timestamp DESC
LIMIT 100
"#,
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let logs: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
serde_json::json!({
"id": row.get::<i64, _>("id"),
"timestamp": row.get::<chrono::DateTime<chrono::Utc>, _>("timestamp"),
"client_id": row.get::<String, _>("client_id"),
"provider": row.get::<String, _>("provider"),
"model": row.get::<String, _>("model"),
"tokens": row.get::<i64, _>("total_tokens"),
"cost": row.get::<f64, _>("cost"),
"status": row.get::<String, _>("status"),
"error": row.get::<Option<String>, _>("error_message"),
"duration": row.get::<i64, _>("duration_ms"),
})
})
.collect();
Json(ApiResponse::success(serde_json::json!(logs)))
}
Err(e) => {
warn!("Failed to fetch system logs: {}", e);
Json(ApiResponse::error("Failed to fetch system logs".to_string()))
}
}
}
pub(super) async fn handle_system_backup(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let backup_id = format!("backup-{}", chrono::Utc::now().timestamp());
let backup_path = format!("data/{}.db", backup_id);
// Ensure the data directory exists
if let Err(e) = std::fs::create_dir_all("data") {
return Json(ApiResponse::error(format!("Failed to create backup directory: {}", e)));
}
// Use SQLite VACUUM INTO for a consistent backup
let result = sqlx::query(&format!("VACUUM INTO '{}'", backup_path))
.execute(pool)
.await;
match result {
Ok(_) => {
// Get backup file size
let size_bytes = std::fs::metadata(&backup_path).map(|m| m.len()).unwrap_or(0);
Json(ApiResponse::success(serde_json::json!({
"success": true,
"message": "Backup completed successfully",
"backup_id": backup_id,
"backup_path": backup_path,
"size_bytes": size_bytes,
})))
}
Err(e) => {
warn!("Database backup failed: {}", e);
Json(ApiResponse::error(format!("Backup failed: {}", e)))
}
}
}
pub(super) async fn handle_get_settings(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let registry = &state.app_state.model_registry;
let provider_count = registry.providers.len();
let model_count: usize = registry.providers.values().map(|p| p.models.len()).sum();
Json(ApiResponse::success(serde_json::json!({
"server": {
"auth_tokens": state.app_state.auth_tokens.iter().map(|t| mask_token(t)).collect::<Vec<_>>(),
"version": env!("CARGO_PKG_VERSION"),
},
"registry": {
"provider_count": provider_count,
"model_count": model_count,
},
"database": {
"type": "SQLite",
}
})))
}
pub(super) async fn handle_update_settings(
State(_state): State<DashboardState>,
) -> Json<ApiResponse<serde_json::Value>> {
Json(ApiResponse::error(
"Changing settings at runtime is not yet supported. Please update your config file and restart the server."
.to_string(),
))
}
// Helper functions
fn mask_token(token: &str) -> String {
if token.len() <= 8 {
return "*****".to_string();
}
let masked_len = token.len().min(12);
let visible_len = 4;
let mask_len = masked_len - visible_len;
format!("{}{}", "*".repeat(mask_len), &token[token.len() - visible_len..])
}

330
src/dashboard/usage.rs Normal file
View File

@@ -0,0 +1,330 @@
use axum::{extract::State, response::Json};
use chrono;
use serde_json;
use sqlx::Row;
use tracing::warn;
use super::{ApiResponse, DashboardState};
pub(super) async fn handle_usage_summary(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
// Total stats
let total_stats = sqlx::query(
r#"
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(total_tokens), 0) as total_tokens,
COALESCE(SUM(cost), 0.0) as total_cost,
COUNT(DISTINCT client_id) as active_clients
FROM llm_requests
"#,
)
.fetch_one(pool);
// Today's stats
let today = chrono::Utc::now().format("%Y-%m-%d").to_string();
let today_stats = sqlx::query(
r#"
SELECT
COUNT(*) as today_requests,
COALESCE(SUM(total_tokens), 0) as today_tokens,
COALESCE(SUM(cost), 0.0) as today_cost
FROM llm_requests
WHERE strftime('%Y-%m-%d', timestamp) = ?
"#,
)
.bind(today)
.fetch_one(pool);
// Error stats
let error_stats = sqlx::query(
r#"
SELECT
COUNT(*) as total,
SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as errors
FROM llm_requests
"#,
)
.fetch_one(pool);
// Average response time
let avg_response = sqlx::query(
r#"
SELECT COALESCE(AVG(duration_ms), 0.0) as avg_duration
FROM llm_requests
WHERE status = 'success'
"#,
)
.fetch_one(pool);
match tokio::join!(total_stats, today_stats, error_stats, avg_response) {
(Ok(t), Ok(d), Ok(e), Ok(a)) => {
let total_requests: i64 = t.get("total_requests");
let total_tokens: i64 = t.get("total_tokens");
let total_cost: f64 = t.get("total_cost");
let active_clients: i64 = t.get("active_clients");
let today_requests: i64 = d.get("today_requests");
let today_cost: f64 = d.get("today_cost");
let total_count: i64 = e.get("total");
let error_count: i64 = e.get("errors");
let error_rate = if total_count > 0 {
(error_count as f64 / total_count as f64) * 100.0
} else {
0.0
};
let avg_response_time: f64 = a.get("avg_duration");
Json(ApiResponse::success(serde_json::json!({
"total_requests": total_requests,
"total_tokens": total_tokens,
"total_cost": total_cost,
"active_clients": active_clients,
"today_requests": today_requests,
"today_cost": today_cost,
"error_rate": error_rate,
"avg_response_time": avg_response_time,
})))
}
_ => Json(ApiResponse::error("Failed to fetch usage statistics".to_string())),
}
}
pub(super) async fn handle_time_series(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let now = chrono::Utc::now();
let twenty_four_hours_ago = now - chrono::Duration::hours(24);
let result = sqlx::query(
r#"
SELECT
strftime('%H:00', timestamp) as hour,
COUNT(*) as requests,
SUM(total_tokens) as tokens,
SUM(cost) as cost
FROM llm_requests
WHERE timestamp >= ?
GROUP BY hour
ORDER BY hour
"#,
)
.bind(twenty_four_hours_ago)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let mut series = Vec::new();
for row in rows {
let hour: String = row.get("hour");
let requests: i64 = row.get("requests");
let tokens: i64 = row.get("tokens");
let cost: f64 = row.get("cost");
series.push(serde_json::json!({
"time": hour,
"requests": requests,
"tokens": tokens,
"cost": cost,
}));
}
Json(ApiResponse::success(serde_json::json!({
"series": series,
"period": "24h"
})))
}
Err(e) => {
warn!("Failed to fetch time series data: {}", e);
Json(ApiResponse::error("Failed to fetch time series data".to_string()))
}
}
}
pub(super) async fn handle_clients_usage(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
// Query database for client usage statistics
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
client_id,
COUNT(*) as requests,
SUM(total_tokens) as tokens,
SUM(cost) as cost,
MAX(timestamp) as last_request
FROM llm_requests
GROUP BY client_id
ORDER BY requests DESC
"#,
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let mut client_usage = Vec::new();
for row in rows {
let client_id: String = row.get("client_id");
let requests: i64 = row.get("requests");
let tokens: i64 = row.get("tokens");
let cost: f64 = row.get("cost");
let last_request: Option<chrono::DateTime<chrono::Utc>> = row.get("last_request");
client_usage.push(serde_json::json!({
"client_id": client_id,
"client_name": client_id,
"requests": requests,
"tokens": tokens,
"cost": cost,
"last_request": last_request,
}));
}
Json(ApiResponse::success(serde_json::json!(client_usage)))
}
Err(e) => {
warn!("Failed to fetch client usage data: {}", e);
Json(ApiResponse::error("Failed to fetch client usage data".to_string()))
}
}
}
pub(super) async fn handle_providers_usage(
State(state): State<DashboardState>,
) -> Json<ApiResponse<serde_json::Value>> {
// Query database for provider usage statistics
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
provider,
COUNT(*) as requests,
COALESCE(SUM(total_tokens), 0) as tokens,
COALESCE(SUM(cost), 0.0) as cost
FROM llm_requests
GROUP BY provider
ORDER BY requests DESC
"#,
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let mut provider_usage = Vec::new();
for row in rows {
let provider: String = row.get("provider");
let requests: i64 = row.get("requests");
let tokens: i64 = row.get("tokens");
let cost: f64 = row.get("cost");
provider_usage.push(serde_json::json!({
"provider": provider,
"requests": requests,
"tokens": tokens,
"cost": cost,
}));
}
Json(ApiResponse::success(serde_json::json!(provider_usage)))
}
Err(e) => {
warn!("Failed to fetch provider usage data: {}", e);
Json(ApiResponse::error("Failed to fetch provider usage data".to_string()))
}
}
}
pub(super) async fn handle_detailed_usage(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
strftime('%Y-%m-%d', timestamp) as date,
client_id,
provider,
model,
COUNT(*) as requests,
SUM(total_tokens) as tokens,
SUM(cost) as cost
FROM llm_requests
GROUP BY date, client_id, provider, model
ORDER BY date DESC
LIMIT 100
"#,
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let usage: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
serde_json::json!({
"date": row.get::<String, _>("date"),
"client": row.get::<String, _>("client_id"),
"provider": row.get::<String, _>("provider"),
"model": row.get::<String, _>("model"),
"requests": row.get::<i64, _>("requests"),
"tokens": row.get::<i64, _>("tokens"),
"cost": row.get::<f64, _>("cost"),
})
})
.collect();
Json(ApiResponse::success(serde_json::json!(usage)))
}
Err(e) => {
warn!("Failed to fetch detailed usage: {}", e);
Json(ApiResponse::error("Failed to fetch detailed usage".to_string()))
}
}
}
pub(super) async fn handle_analytics_breakdown(
State(state): State<DashboardState>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
// Model breakdown
let models =
sqlx::query("SELECT model as label, COUNT(*) as value FROM llm_requests GROUP BY model ORDER BY value DESC")
.fetch_all(pool);
// Client breakdown
let clients = sqlx::query(
"SELECT client_id as label, COUNT(*) as value FROM llm_requests GROUP BY client_id ORDER BY value DESC",
)
.fetch_all(pool);
match tokio::join!(models, clients) {
(Ok(m_rows), Ok(c_rows)) => {
let model_breakdown: Vec<serde_json::Value> = m_rows
.into_iter()
.map(|r| serde_json::json!({ "label": r.get::<String, _>("label"), "value": r.get::<i64, _>("value") }))
.collect();
let client_breakdown: Vec<serde_json::Value> = c_rows
.into_iter()
.map(|r| serde_json::json!({ "label": r.get::<String, _>("label"), "value": r.get::<i64, _>("value") }))
.collect();
Json(ApiResponse::success(serde_json::json!({
"models": model_breakdown,
"clients": client_breakdown
})))
}
_ => Json(ApiResponse::error("Failed to fetch analytics breakdown".to_string())),
}
}

View File

@@ -0,0 +1,75 @@
use axum::{
extract::{
State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
response::IntoResponse,
};
use serde_json;
use tracing::info;
use super::DashboardState;
// WebSocket handler
pub(super) async fn handle_websocket(ws: WebSocketUpgrade, State(state): State<DashboardState>) -> impl IntoResponse {
ws.on_upgrade(|socket| handle_websocket_connection(socket, state))
}
pub(super) async fn handle_websocket_connection(mut socket: WebSocket, state: DashboardState) {
info!("WebSocket connection established");
// Subscribe to events from the global bus
let mut rx = state.app_state.dashboard_tx.subscribe();
// Send initial connection message
let _ = socket
.send(Message::Text(
serde_json::json!({
"type": "connected",
"message": "Connected to LLM Proxy Dashboard"
})
.to_string()
.into(),
))
.await;
// Handle incoming messages and broadcast events
loop {
tokio::select! {
// Receive broadcast events
Ok(event) = rx.recv() => {
let Ok(json_str) = serde_json::to_string(&event) else {
continue;
};
let message = Message::Text(json_str.into());
if socket.send(message).await.is_err() {
break;
}
}
// Receive WebSocket messages
result = socket.recv() => {
match result {
Some(Ok(Message::Text(text))) => {
handle_websocket_message(&text, &state).await;
}
_ => break,
}
}
}
}
info!("WebSocket connection closed");
}
pub(super) async fn handle_websocket_message(text: &str, state: &DashboardState) {
// Parse and handle WebSocket messages
if let Ok(data) = serde_json::from_str::<serde_json::Value>(text)
&& let Some("ping") = data.get("type").and_then(|v| v.as_str())
{
let _ = state.app_state.dashboard_tx.send(serde_json::json!({
"type": "pong",
"payload": {}
}));
}
}