diff --git a/Cargo.lock b/Cargo.lock index 4cff23de..04509cd7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1444,10 +1444,12 @@ dependencies = [ "dotenvy", "futures", "headers", + "hex", "image", "insta", "mime", "mockito", + "rand 0.9.2", "reqwest", "reqwest-eventsource", "serde", diff --git a/Cargo.toml b/Cargo.toml index f7160e35..98b043a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,6 +52,8 @@ futures = "0.3" async-trait = "0.1" async-stream = "0.3" reqwest-eventsource = "0.6" +rand = "0.9" +hex = "0.4" [dev-dependencies] tokio-test = "0.4" diff --git a/src/dashboard/clients.rs b/src/dashboard/clients.rs index c9d16cc2..7d8747d0 100644 --- a/src/dashboard/clients.rs +++ b/src/dashboard/clients.rs @@ -3,6 +3,7 @@ use axum::{ response::Json, }; use chrono; +use rand::Rng; use serde::Deserialize; use serde_json; use sqlx::Row; @@ -11,6 +12,13 @@ use uuid; use super::{ApiResponse, DashboardState}; +/// Generate a random API token: sk-{48 hex chars} +fn generate_token() -> String { + let mut rng = rand::rng(); + let bytes: Vec = (0..24).map(|_| rng.random::()).collect(); + format!("sk-{}", hex::encode(bytes)) +} + #[derive(Deserialize)] pub(super) struct CreateClientRequest { pub(super) name: String, @@ -98,12 +106,29 @@ pub(super) async fn handle_create_client( .await; match result { - Ok(row) => Json(ApiResponse::success(serde_json::json!({ - "id": row.get::("client_id"), - "name": row.get::, _>("name"), - "created_at": row.get::, _>("created_at"), - "status": "active", - }))), + Ok(row) => { + // Auto-generate a token for the new client + let token = generate_token(); + let token_result = sqlx::query( + "INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, 'default')", + ) + .bind(&client_id) + .bind(&token) + .execute(pool) + .await; + + if let Err(e) = token_result { + warn!("Client created but failed to generate token: {}", e); + } + + Json(ApiResponse::success(serde_json::json!({ + "id": row.get::("client_id"), + "name": row.get::, _>("name"), + "created_at": row.get::, _>("created_at"), + "status": "active", + "token": token, + }))) + } Err(e) => { warn!("Failed to create client: {}", e); Json(ApiResponse::error(format!("Failed to create client: {}", e))) @@ -333,3 +358,131 @@ pub(super) async fn handle_client_usage( } } } + +// ── Token management endpoints ────────────────────────────────────── + +pub(super) async fn handle_get_client_tokens( + State(state): State, + Path(id): Path, +) -> Json> { + let pool = &state.app_state.db_pool; + + let result = sqlx::query( + r#" + SELECT id, token, name, is_active, created_at, last_used_at + FROM client_tokens + WHERE client_id = ? + ORDER BY created_at DESC + "#, + ) + .bind(&id) + .fetch_all(pool) + .await; + + match result { + Ok(rows) => { + let tokens: Vec = rows + .into_iter() + .map(|row| { + let token: String = row.get("token"); + // Mask all but last 8 chars: sk-••••abcd1234 + let masked = if token.len() > 8 { + format!("{}••••{}", &token[..3], &token[token.len() - 8..]) + } else { + "••••".to_string() + }; + serde_json::json!({ + "id": row.get::("id"), + "token_masked": masked, + "name": row.get::, _>("name").unwrap_or_else(|| "default".to_string()), + "is_active": row.get::("is_active"), + "created_at": row.get::, _>("created_at"), + "last_used_at": row.get::>, _>("last_used_at"), + }) + }) + .collect(); + + Json(ApiResponse::success(serde_json::json!(tokens))) + } + Err(e) => { + warn!("Failed to fetch client tokens: {}", e); + Json(ApiResponse::error(format!("Failed to fetch client tokens: {}", e))) + } + } +} + +#[derive(Deserialize)] +pub(super) struct CreateTokenRequest { + pub(super) name: Option, +} + +pub(super) async fn handle_create_client_token( + State(state): State, + Path(id): Path, + Json(payload): Json, +) -> Json> { + let pool = &state.app_state.db_pool; + + // Verify client exists + let exists: Option<(i64,)> = sqlx::query_as("SELECT 1 as x FROM clients WHERE client_id = ?") + .bind(&id) + .fetch_optional(pool) + .await + .unwrap_or(None); + + if exists.is_none() { + return Json(ApiResponse::error(format!("Client '{}' not found", id))); + } + + let token = generate_token(); + let token_name = payload.name.unwrap_or_else(|| "default".to_string()); + + let result = sqlx::query( + "INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, ?) RETURNING id, created_at", + ) + .bind(&id) + .bind(&token) + .bind(&token_name) + .fetch_one(pool) + .await; + + match result { + Ok(row) => Json(ApiResponse::success(serde_json::json!({ + "id": row.get::("id"), + "token": token, + "name": token_name, + "created_at": row.get::, _>("created_at"), + }))), + Err(e) => { + warn!("Failed to create client token: {}", e); + Json(ApiResponse::error(format!("Failed to create token: {}", e))) + } + } +} + +pub(super) async fn handle_delete_client_token( + State(state): State, + Path((client_id, token_id)): Path<(String, i64)>, +) -> Json> { + let pool = &state.app_state.db_pool; + + let result = sqlx::query("DELETE FROM client_tokens WHERE id = ? AND client_id = ?") + .bind(token_id) + .bind(&client_id) + .execute(pool) + .await; + + match result { + Ok(r) => { + if r.rows_affected() == 0 { + Json(ApiResponse::error("Token not found".to_string())) + } else { + Json(ApiResponse::success(serde_json::json!({ "message": "Token revoked" }))) + } + } + Err(e) => { + warn!("Failed to delete client token: {}", e); + Json(ApiResponse::error(format!("Failed to revoke token: {}", e))) + } + } +} diff --git a/src/dashboard/mod.rs b/src/dashboard/mod.rs index 95d1672d..c8e4b375 100644 --- a/src/dashboard/mod.rs +++ b/src/dashboard/mod.rs @@ -11,7 +11,7 @@ mod websocket; use axum::{ Router, - routing::{get, post, put}, + routing::{delete, get, post, put}, }; use serde::Serialize; @@ -87,6 +87,14 @@ pub fn router(state: AppState) -> Router { .delete(clients::handle_delete_client), ) .route("/api/clients/{id}/usage", get(clients::handle_client_usage)) + .route( + "/api/clients/{id}/tokens", + get(clients::handle_get_client_tokens).post(clients::handle_create_client_token), + ) + .route( + "/api/clients/{id}/tokens/{token_id}", + delete(clients::handle_delete_client_token), + ) .route("/api/providers", get(providers::handle_get_providers)) .route( "/api/providers/{name}", diff --git a/src/database/mod.rs b/src/database/mod.rs index 1c796a40..bce8b016 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -130,6 +130,24 @@ async fn run_migrations(pool: &DbPool) -> Result<()> { .execute(pool) .await?; + // Create client_tokens table for DB-based token auth + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS client_tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + client_id TEXT NOT NULL, + token TEXT NOT NULL UNIQUE, + name TEXT DEFAULT 'default', + is_active BOOLEAN DEFAULT TRUE, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + last_used_at DATETIME, + FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE CASCADE + ) + "#, + ) + .execute(pool) + .await?; + // Add must_change_password column if it doesn't exist (migration for existing DBs) let _ = sqlx::query("ALTER TABLE users ADD COLUMN must_change_password BOOLEAN DEFAULT FALSE") .execute(pool) @@ -184,6 +202,14 @@ async fn run_migrations(pool: &DbPool) -> Result<()> { .execute(pool) .await?; + sqlx::query("CREATE UNIQUE INDEX IF NOT EXISTS idx_client_tokens_token ON client_tokens(token)") + .execute(pool) + .await?; + + sqlx::query("CREATE INDEX IF NOT EXISTS idx_client_tokens_client_id ON client_tokens(client_id)") + .execute(pool) + .await?; + // Insert default client if none exists sqlx::query( r#" diff --git a/src/rate_limiting/mod.rs b/src/rate_limiting/mod.rs index 2b40e672..fbdc1fcb 100644 --- a/src/rate_limiting/mod.rs +++ b/src/rate_limiting/mod.rs @@ -304,6 +304,7 @@ pub mod middleware { middleware::Next, response::Response, }; + use sqlx; /// Rate limiting middleware pub async fn rate_limit_middleware( @@ -311,8 +312,11 @@ pub mod middleware { request: Request, next: Next, ) -> Result { - // Extract client ID from authentication header - let client_id = extract_client_id_from_request(&request); + // Extract token synchronously from headers (avoids holding &Request across await) + let token = extract_bearer_token(&request); + + // Resolve client_id: DB token lookup, then prefix fallback + let client_id = resolve_client_id(token, &state).await; // Check rate limits if !state.rate_limit_manager.check_client_request(&client_id).await? { @@ -322,18 +326,33 @@ pub mod middleware { Ok(next.run(request).await) } - /// Extract client ID from request (helper function) - fn extract_client_id_from_request(request: &Request) -> String { - // Try to extract from Authorization header - if let Some(auth_header) = request.headers().get("Authorization") - && let Ok(auth_str) = auth_header.to_str() - && let Some(token) = auth_str.strip_prefix("Bearer ") - { - // Use token hash as client ID (same logic as auth module) + /// Synchronously extract bearer token from request headers + fn extract_bearer_token(request: &Request) -> Option { + request.headers().get("Authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.strip_prefix("Bearer ")) + .map(|t| t.to_string()) + } + + /// Resolve client ID: try DB token first, then fall back to token-prefix derivation + async fn resolve_client_id(token: Option, state: &AppState) -> String { + if let Some(token) = token { + // Try DB token lookup first + if let Ok(Some(cid)) = sqlx::query_scalar::<_, String>( + "SELECT client_id FROM client_tokens WHERE token = ? AND is_active = TRUE", + ) + .bind(&token) + .fetch_optional(&state.db_pool) + .await + { + return cid; + } + + // Fallback to token-prefix derivation (env tokens / permissive mode) return format!("client_{}", &token[..8.min(token.len())]); } - // Fallback to anonymous + // No token — anonymous "anonymous".to_string() } diff --git a/src/server/mod.rs b/src/server/mod.rs index 962131d0..a89f0996 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -6,6 +6,7 @@ use axum::{ routing::{get, post}, }; use futures::stream::StreamExt; +use sqlx; use std::sync::Arc; use tracing::{info, warn}; use uuid::Uuid; @@ -119,13 +120,34 @@ async fn chat_completions( auth: AuthenticatedClient, Json(mut request): Json, ) -> Result { - // Validate token against configured auth tokens - if !state.auth_tokens.is_empty() && !state.auth_tokens.contains(&auth.token) { + // Resolve client_id: try DB token first, then env tokens, then permissive fallback + let db_client_id: Option = sqlx::query_scalar::<_, String>( + "SELECT client_id FROM client_tokens WHERE token = ? AND is_active = TRUE", + ) + .bind(&auth.token) + .fetch_optional(&state.db_pool) + .await + .unwrap_or(None); + + let client_id = if let Some(cid) = db_client_id { + // Update last_used_at in background (fire-and-forget) + let pool = state.db_pool.clone(); + let token = auth.token.clone(); + tokio::spawn(async move { + let _ = sqlx::query("UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ?") + .bind(&token) + .execute(&pool) + .await; + }); + cid + } else if state.auth_tokens.is_empty() || state.auth_tokens.contains(&auth.token) { + // Env token match or permissive mode (no env tokens configured) + auth.client_id.clone() + } else { return Err(AppError::AuthError("Invalid authentication token".to_string())); - } + }; let start_time = std::time::Instant::now(); - let client_id = auth.client_id.clone(); let model = request.model.clone(); info!("Chat completion request from client {} for model {}", client_id, model); diff --git a/static/js/pages/clients.js b/static/js/pages/clients.js index 26d2a35c..20f0a2c3 100644 --- a/static/js/pages/clients.js +++ b/static/js/pages/clients.js @@ -167,16 +167,61 @@ class ClientsPage { } try { - await window.api.post('/clients', { name, client_id: id || null }); - window.authManager.showToast(`Client "${name}" created`, 'success'); + const result = await window.api.post('/clients', { name, client_id: id || null }); modal.remove(); this.loadClients(); + + // Show the generated token (copy-once dialog) + if (result.token) { + this.showTokenRevealModal(name, result.token); + } else { + window.authManager.showToast(`Client "${name}" created`, 'success'); + } } catch (error) { window.authManager.showToast(error.message, 'error'); } }; } + showTokenRevealModal(clientName, token) { + const modal = document.createElement('div'); + modal.className = 'modal active'; + modal.innerHTML = ` + + `; + + document.body.appendChild(modal); + + modal.querySelector('#copy-token-btn').onclick = () => { + navigator.clipboard.writeText(token).then(() => { + window.authManager.showToast('Token copied to clipboard', 'success'); + }); + }; + } + async deleteClient(id) { if (!confirm(`Are you sure you want to delete client ${id}? This cannot be undone.`)) return; @@ -228,6 +273,18 @@ class ClientsPage { Active +
+
+ +
+
Loading tokens...
+
+