//! Client management for LLM proxy //! //! This module handles: //! 1. Client registration and management //! 2. Client usage tracking //! 3. Client rate limit configuration use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use sqlx::{SqlitePool, Row}; use anyhow::Result; use tracing::{info, warn}; /// Client information #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Client { pub id: i64, pub client_id: String, pub name: String, pub description: String, pub created_at: DateTime, pub updated_at: DateTime, pub is_active: bool, pub rate_limit_per_minute: i64, pub total_requests: i64, pub total_tokens: i64, pub total_cost: f64, } /// Client creation request #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CreateClientRequest { pub client_id: String, pub name: String, pub description: String, pub rate_limit_per_minute: Option, } /// Client update request #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UpdateClientRequest { pub name: Option, pub description: Option, pub is_active: Option, pub rate_limit_per_minute: Option, } /// Client manager for database operations pub struct ClientManager { db_pool: SqlitePool, } impl ClientManager { pub fn new(db_pool: SqlitePool) -> Self { Self { db_pool } } /// Create a new client pub async fn create_client(&self, request: CreateClientRequest) -> Result { let rate_limit = request.rate_limit_per_minute.unwrap_or(60); // First insert the client sqlx::query( r#" INSERT INTO clients (client_id, name, description, rate_limit_per_minute) VALUES (?, ?, ?, ?) "#, ) .bind(&request.client_id) .bind(&request.name) .bind(&request.description) .bind(rate_limit) .execute(&self.db_pool) .await?; // Then fetch the created client let client = self.get_client(&request.client_id).await? .ok_or_else(|| anyhow::anyhow!("Failed to retrieve created client"))?; info!("Created client: {} ({})", client.name, client.client_id); Ok(client) } /// Get a client by ID pub async fn get_client(&self, client_id: &str) -> Result> { let row = sqlx::query( r#" SELECT id, client_id, name, description, created_at, updated_at, is_active, rate_limit_per_minute, total_requests, total_tokens, total_cost FROM clients WHERE client_id = ? "#, ) .bind(client_id) .fetch_optional(&self.db_pool) .await?; if let Some(row) = row { let client = Client { id: row.get("id"), client_id: row.get("client_id"), name: row.get("name"), description: row.get("description"), created_at: row.get("created_at"), updated_at: row.get("updated_at"), is_active: row.get("is_active"), rate_limit_per_minute: row.get("rate_limit_per_minute"), total_requests: row.get("total_requests"), total_tokens: row.get("total_tokens"), total_cost: row.get("total_cost"), }; Ok(Some(client)) } else { Ok(None) } } /// Update a client pub async fn update_client(&self, client_id: &str, request: UpdateClientRequest) -> Result> { // First, get the current client to check if it exists let current_client = self.get_client(client_id).await?; if current_client.is_none() { return Ok(None); } // Build update query dynamically based on provided fields let mut updates = Vec::new(); let mut query_builder = sqlx::QueryBuilder::new("UPDATE clients SET "); let mut has_updates = false; if let Some(name) = &request.name { updates.push("name = "); query_builder.push_bind(name); has_updates = true; } if let Some(description) = &request.description { if has_updates { query_builder.push(", "); } updates.push("description = "); query_builder.push_bind(description); has_updates = true; } if let Some(is_active) = request.is_active { if has_updates { query_builder.push(", "); } updates.push("is_active = "); query_builder.push_bind(is_active); has_updates = true; } if let Some(rate_limit) = request.rate_limit_per_minute { if has_updates { query_builder.push(", "); } updates.push("rate_limit_per_minute = "); query_builder.push_bind(rate_limit); has_updates = true; } // Always update the updated_at timestamp if has_updates { query_builder.push(", "); } query_builder.push("updated_at = CURRENT_TIMESTAMP"); if !has_updates { // No updates to make return self.get_client(client_id).await; } query_builder.push(" WHERE client_id = "); query_builder.push_bind(client_id); let query = query_builder.build(); query.execute(&self.db_pool).await?; // Fetch the updated client let updated_client = self.get_client(client_id).await?; if updated_client.is_some() { info!("Updated client: {}", client_id); } Ok(updated_client) } /// List all clients pub async fn list_clients(&self, limit: Option, offset: Option) -> Result> { let limit = limit.unwrap_or(100); let offset = offset.unwrap_or(0); let rows = sqlx::query( r#" SELECT id, client_id, name, description, created_at, updated_at, is_active, rate_limit_per_minute, total_requests, total_tokens, total_cost FROM clients ORDER BY created_at DESC LIMIT ? OFFSET ? "# ) .bind(limit) .bind(offset) .fetch_all(&self.db_pool) .await?; let mut clients = Vec::new(); for row in rows { let client = Client { id: row.get("id"), client_id: row.get("client_id"), name: row.get("name"), description: row.get("description"), created_at: row.get("created_at"), updated_at: row.get("updated_at"), is_active: row.get("is_active"), rate_limit_per_minute: row.get("rate_limit_per_minute"), total_requests: row.get("total_requests"), total_tokens: row.get("total_tokens"), total_cost: row.get("total_cost"), }; clients.push(client); } Ok(clients) } /// Delete a client pub async fn delete_client(&self, client_id: &str) -> Result { let result = sqlx::query( "DELETE FROM clients WHERE client_id = ?" ) .bind(client_id) .execute(&self.db_pool) .await?; let deleted = result.rows_affected() > 0; if deleted { info!("Deleted client: {}", client_id); } else { warn!("Client not found for deletion: {}", client_id); } Ok(deleted) } /// Update client usage statistics after a request pub async fn update_client_usage( &self, client_id: &str, tokens: i64, cost: f64, ) -> Result<()> { sqlx::query( r#" UPDATE clients SET total_requests = total_requests + 1, total_tokens = total_tokens + ?, total_cost = total_cost + ?, updated_at = CURRENT_TIMESTAMP WHERE client_id = ? "# ) .bind(tokens) .bind(cost) .bind(client_id) .execute(&self.db_pool) .await?; Ok(()) } /// Get client usage statistics pub async fn get_client_usage(&self, client_id: &str) -> Result> { let row = sqlx::query( r#" SELECT total_requests, total_tokens, total_cost FROM clients WHERE client_id = ? "# ) .bind(client_id) .fetch_optional(&self.db_pool) .await?; if let Some(row) = row { let total_requests: i64 = row.get("total_requests"); let total_tokens: i64 = row.get("total_tokens"); let total_cost: f64 = row.get("total_cost"); Ok(Some((total_requests, total_tokens, total_cost))) } else { Ok(None) } } /// Check if a client exists and is active pub async fn validate_client(&self, client_id: &str) -> Result { let client = self.get_client(client_id).await?; Ok(client.map(|c| c.is_active).unwrap_or(false)) } }