Phase 1: Fix compilation (config_path Option<PathBuf>, streaming test, stale test cleanup) Phase 2: Fix critical bugs (remove block_on deadlocks in 4 providers, fix broken SQL query builder) Phase 3: Security hardening (session manager, real auth, token masking, Gemini key to header, password policy) Phase 4: Implement stubs (real provider test, /proc health metrics, client/provider/backup endpoints, has_images) Phase 5: Code quality (shared provider helpers, explicit re-exports, all Clippy warnings fixed, unwrap removal, 6 unused deps removed, dashboard split into 7 sub-modules) Phase 6: Infrastructure (GitHub Actions CI, multi-stage Dockerfile, rustfmt.toml, clippy.toml, script fixes)
305 lines
9.2 KiB
Rust
305 lines
9.2 KiB
Rust
//! Client management for LLM proxy
|
|
//!
|
|
//! This module handles:
|
|
//! 1. Client registration and management
|
|
//! 2. Client usage tracking
|
|
//! 3. Client rate limit configuration
|
|
|
|
use anyhow::Result;
|
|
use chrono::{DateTime, Utc};
|
|
use serde::{Deserialize, Serialize};
|
|
use sqlx::{Row, SqlitePool};
|
|
use tracing::{info, warn};
|
|
|
|
/// Client information
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct Client {
|
|
pub id: i64,
|
|
pub client_id: String,
|
|
pub name: String,
|
|
pub description: String,
|
|
pub created_at: DateTime<Utc>,
|
|
pub updated_at: DateTime<Utc>,
|
|
pub is_active: bool,
|
|
pub rate_limit_per_minute: i64,
|
|
pub total_requests: i64,
|
|
pub total_tokens: i64,
|
|
pub total_cost: f64,
|
|
}
|
|
|
|
/// Client creation request
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct CreateClientRequest {
|
|
pub client_id: String,
|
|
pub name: String,
|
|
pub description: String,
|
|
pub rate_limit_per_minute: Option<i64>,
|
|
}
|
|
|
|
/// Client update request
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct UpdateClientRequest {
|
|
pub name: Option<String>,
|
|
pub description: Option<String>,
|
|
pub is_active: Option<bool>,
|
|
pub rate_limit_per_minute: Option<i64>,
|
|
}
|
|
|
|
/// Client manager for database operations
|
|
pub struct ClientManager {
|
|
db_pool: SqlitePool,
|
|
}
|
|
|
|
impl ClientManager {
|
|
pub fn new(db_pool: SqlitePool) -> Self {
|
|
Self { db_pool }
|
|
}
|
|
|
|
/// Create a new client
|
|
pub async fn create_client(&self, request: CreateClientRequest) -> Result<Client> {
|
|
let rate_limit = request.rate_limit_per_minute.unwrap_or(60);
|
|
|
|
// First insert the client
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO clients (client_id, name, description, rate_limit_per_minute)
|
|
VALUES (?, ?, ?, ?)
|
|
"#,
|
|
)
|
|
.bind(&request.client_id)
|
|
.bind(&request.name)
|
|
.bind(&request.description)
|
|
.bind(rate_limit)
|
|
.execute(&self.db_pool)
|
|
.await?;
|
|
|
|
// Then fetch the created client
|
|
let client = self
|
|
.get_client(&request.client_id)
|
|
.await?
|
|
.ok_or_else(|| anyhow::anyhow!("Failed to retrieve created client"))?;
|
|
|
|
info!("Created client: {} ({})", client.name, client.client_id);
|
|
Ok(client)
|
|
}
|
|
|
|
/// Get a client by ID
|
|
pub async fn get_client(&self, client_id: &str) -> Result<Option<Client>> {
|
|
let row = sqlx::query(
|
|
r#"
|
|
SELECT
|
|
id, client_id, name, description,
|
|
created_at, updated_at, is_active,
|
|
rate_limit_per_minute, total_requests, total_tokens, total_cost
|
|
FROM clients
|
|
WHERE client_id = ?
|
|
"#,
|
|
)
|
|
.bind(client_id)
|
|
.fetch_optional(&self.db_pool)
|
|
.await?;
|
|
|
|
if let Some(row) = row {
|
|
let client = Client {
|
|
id: row.get("id"),
|
|
client_id: row.get("client_id"),
|
|
name: row.get("name"),
|
|
description: row.get("description"),
|
|
created_at: row.get("created_at"),
|
|
updated_at: row.get("updated_at"),
|
|
is_active: row.get("is_active"),
|
|
rate_limit_per_minute: row.get("rate_limit_per_minute"),
|
|
total_requests: row.get("total_requests"),
|
|
total_tokens: row.get("total_tokens"),
|
|
total_cost: row.get("total_cost"),
|
|
};
|
|
Ok(Some(client))
|
|
} else {
|
|
Ok(None)
|
|
}
|
|
}
|
|
|
|
/// Update a client
|
|
pub async fn update_client(&self, client_id: &str, request: UpdateClientRequest) -> Result<Option<Client>> {
|
|
// First, get the current client to check if it exists
|
|
let current_client = self.get_client(client_id).await?;
|
|
if current_client.is_none() {
|
|
return Ok(None);
|
|
}
|
|
|
|
// Build update query dynamically based on provided fields
|
|
let mut query_builder = sqlx::QueryBuilder::new("UPDATE clients SET ");
|
|
let mut has_updates = false;
|
|
|
|
if let Some(name) = &request.name {
|
|
query_builder.push("name = ");
|
|
query_builder.push_bind(name);
|
|
has_updates = true;
|
|
}
|
|
|
|
if let Some(description) = &request.description {
|
|
if has_updates {
|
|
query_builder.push(", ");
|
|
}
|
|
query_builder.push("description = ");
|
|
query_builder.push_bind(description);
|
|
has_updates = true;
|
|
}
|
|
|
|
if let Some(is_active) = request.is_active {
|
|
if has_updates {
|
|
query_builder.push(", ");
|
|
}
|
|
query_builder.push("is_active = ");
|
|
query_builder.push_bind(is_active);
|
|
has_updates = true;
|
|
}
|
|
|
|
if let Some(rate_limit) = request.rate_limit_per_minute {
|
|
if has_updates {
|
|
query_builder.push(", ");
|
|
}
|
|
query_builder.push("rate_limit_per_minute = ");
|
|
query_builder.push_bind(rate_limit);
|
|
has_updates = true;
|
|
}
|
|
|
|
// Always update the updated_at timestamp
|
|
if has_updates {
|
|
query_builder.push(", ");
|
|
}
|
|
query_builder.push("updated_at = CURRENT_TIMESTAMP");
|
|
|
|
if !has_updates {
|
|
// No updates to make
|
|
return self.get_client(client_id).await;
|
|
}
|
|
|
|
query_builder.push(" WHERE client_id = ");
|
|
query_builder.push_bind(client_id);
|
|
|
|
let query = query_builder.build();
|
|
query.execute(&self.db_pool).await?;
|
|
|
|
// Fetch the updated client
|
|
let updated_client = self.get_client(client_id).await?;
|
|
|
|
if updated_client.is_some() {
|
|
info!("Updated client: {}", client_id);
|
|
}
|
|
|
|
Ok(updated_client)
|
|
}
|
|
|
|
/// List all clients
|
|
pub async fn list_clients(&self, limit: Option<i64>, offset: Option<i64>) -> Result<Vec<Client>> {
|
|
let limit = limit.unwrap_or(100);
|
|
let offset = offset.unwrap_or(0);
|
|
|
|
let rows = sqlx::query(
|
|
r#"
|
|
SELECT
|
|
id, client_id, name, description,
|
|
created_at, updated_at, is_active,
|
|
rate_limit_per_minute, total_requests, total_tokens, total_cost
|
|
FROM clients
|
|
ORDER BY created_at DESC
|
|
LIMIT ? OFFSET ?
|
|
"#,
|
|
)
|
|
.bind(limit)
|
|
.bind(offset)
|
|
.fetch_all(&self.db_pool)
|
|
.await?;
|
|
|
|
let mut clients = Vec::new();
|
|
for row in rows {
|
|
let client = Client {
|
|
id: row.get("id"),
|
|
client_id: row.get("client_id"),
|
|
name: row.get("name"),
|
|
description: row.get("description"),
|
|
created_at: row.get("created_at"),
|
|
updated_at: row.get("updated_at"),
|
|
is_active: row.get("is_active"),
|
|
rate_limit_per_minute: row.get("rate_limit_per_minute"),
|
|
total_requests: row.get("total_requests"),
|
|
total_tokens: row.get("total_tokens"),
|
|
total_cost: row.get("total_cost"),
|
|
};
|
|
clients.push(client);
|
|
}
|
|
|
|
Ok(clients)
|
|
}
|
|
|
|
/// Delete a client
|
|
pub async fn delete_client(&self, client_id: &str) -> Result<bool> {
|
|
let result = sqlx::query("DELETE FROM clients WHERE client_id = ?")
|
|
.bind(client_id)
|
|
.execute(&self.db_pool)
|
|
.await?;
|
|
|
|
let deleted = result.rows_affected() > 0;
|
|
|
|
if deleted {
|
|
info!("Deleted client: {}", client_id);
|
|
} else {
|
|
warn!("Client not found for deletion: {}", client_id);
|
|
}
|
|
|
|
Ok(deleted)
|
|
}
|
|
|
|
/// Update client usage statistics after a request
|
|
pub async fn update_client_usage(&self, client_id: &str, tokens: i64, cost: f64) -> Result<()> {
|
|
sqlx::query(
|
|
r#"
|
|
UPDATE clients
|
|
SET
|
|
total_requests = total_requests + 1,
|
|
total_tokens = total_tokens + ?,
|
|
total_cost = total_cost + ?,
|
|
updated_at = CURRENT_TIMESTAMP
|
|
WHERE client_id = ?
|
|
"#,
|
|
)
|
|
.bind(tokens)
|
|
.bind(cost)
|
|
.bind(client_id)
|
|
.execute(&self.db_pool)
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Get client usage statistics
|
|
pub async fn get_client_usage(&self, client_id: &str) -> Result<Option<(i64, i64, f64)>> {
|
|
let row = sqlx::query(
|
|
r#"
|
|
SELECT total_requests, total_tokens, total_cost
|
|
FROM clients
|
|
WHERE client_id = ?
|
|
"#,
|
|
)
|
|
.bind(client_id)
|
|
.fetch_optional(&self.db_pool)
|
|
.await?;
|
|
|
|
if let Some(row) = row {
|
|
let total_requests: i64 = row.get("total_requests");
|
|
let total_tokens: i64 = row.get("total_tokens");
|
|
let total_cost: f64 = row.get("total_cost");
|
|
Ok(Some((total_requests, total_tokens, total_cost)))
|
|
} else {
|
|
Ok(None)
|
|
}
|
|
}
|
|
|
|
/// Check if a client exists and is active
|
|
pub async fn validate_client(&self, client_id: &str) -> Result<bool> {
|
|
let client = self.get_client(client_id).await?;
|
|
Ok(client.map(|c| c.is_active).unwrap_or(false))
|
|
}
|
|
}
|