From 5ddf284b8f3d9909f9e1e65dcbacecc46913c51b Mon Sep 17 00:00:00 2001 From: hobokenchicken Date: Thu, 5 Mar 2026 14:44:45 -0500 Subject: [PATCH] feat(auth): refactor token resolution into shared TokenResolution and centralize in middleware; simplify AuthenticatedClient to carry resolved DB ID --- .gitignore | 1 + src/auth/mod.rs | 95 ++++++++++++++++++++++++++++++++++++++++ src/rate_limiting/mod.rs | 16 +++++-- src/server/mod.rs | 23 ++++++++++ 4 files changed, 132 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index fdb9268f..8bf0d3dd 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ /*.db /*.db-shm /*.db-wal +/data/ diff --git a/src/auth/mod.rs b/src/auth/mod.rs index e370cec1..44885fb7 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -1,11 +1,76 @@ use axum::{extract::FromRequestParts, http::request::Parts}; +<<<<<<< HEAD +======= +use axum_extra::TypedHeader; +use axum_extra::headers::Authorization; +use headers::authorization::Bearer; +use sqlx; +>>>>>>> 76e5b9f (perf(auth): eliminate duplicate token resolution database queries) use crate::errors::AppError; +use crate::state::AppState; + +/// Token resolution result stored in request extensions +/// This avoids duplicate database queries for token resolution +#[derive(Debug, Clone)] +pub struct TokenResolution { + /// The raw bearer token from Authorization header + pub token: String, + /// Client ID for rate limiting (from DB lookup or token prefix derivation) + pub client_id_for_rate_limit: String, + /// Client ID from database if token was found in client_tokens table + pub db_client_id: Option, +} + +impl TokenResolution { + /// Resolve a token to client ID, checking database first, then falling back to token prefix + pub async fn resolve(token: Option, state: &AppState) -> Self { + match token { + Some(token) => { + // Try DB token lookup first + let db_client_id = match sqlx::query_scalar::<_, String>( + "SELECT client_id FROM client_tokens WHERE token = ? AND is_active = TRUE", + ) + .bind(&token) + .fetch_optional(&state.db_pool) + .await + { + Ok(Some(cid)) => Some(cid), + Ok(None) => None, + Err(_) => None, // Log error? For now, treat as not found + }; + + let client_id_for_rate_limit = if let Some(ref cid) = db_client_id { + cid.clone() + } else { + // Fallback to token-prefix derivation + format!("client_{}", &token[..8.min(token.len())]) + }; + + Self { + token, + client_id_for_rate_limit, + db_client_id, + } + } + None => { + // No token — anonymous + Self { + token: String::new(), + client_id_for_rate_limit: "anonymous".to_string(), + db_client_id: None, + } + } + } + } +} #[derive(Debug, Clone)] pub struct AuthInfo { pub token: String, pub client_id: String, + /// Client ID from database if token was found in client_tokens table + pub db_client_id: Option, } pub struct AuthenticatedClient { @@ -18,6 +83,7 @@ where { type Rejection = AppError; +<<<<<<< HEAD async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { // Retrieve AuthInfo from request extensions, where it was placed by rate_limit_middleware let info = parts @@ -25,6 +91,27 @@ where .get::() .cloned() .ok_or_else(|| AppError::AuthError("Authentication info not found in request".to_string()))?; +======= + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + // Check if TokenResolution is already in extensions (set by rate limit middleware) + if let Some(resolution) = parts.extensions.get::() { + // Use the resolved token and client_id + let token = resolution.token.clone(); + let client_id = resolution.client_id_for_rate_limit.clone(); + + return Ok(AuthenticatedClient { + token, + client_id, + db_client_id: resolution.db_client_id.clone(), + }); + } + + // Fallback: extract token from Authorization header directly + // (this shouldn't happen if rate limit middleware is applied) + let TypedHeader(Authorization(bearer)) = TypedHeader::>::from_request_parts(parts, state) + .await + .map_err(|_| AppError::AuthError("Missing or invalid bearer token".to_string()))?; +>>>>>>> 76e5b9f (perf(auth): eliminate duplicate token resolution database queries) Ok(AuthenticatedClient { info }) } @@ -33,8 +120,16 @@ where impl std::ops::Deref for AuthenticatedClient { type Target = AuthInfo; +<<<<<<< HEAD fn deref(&self) -> &Self::Target { &self.info +======= + Ok(AuthenticatedClient { + token, + client_id, + db_client_id: None, + }) +>>>>>>> 76e5b9f (perf(auth): eliminate duplicate token resolution database queries) } } diff --git a/src/rate_limiting/mod.rs b/src/rate_limiting/mod.rs index c5e22e14..b6025989 100644 --- a/src/rate_limiting/mod.rs +++ b/src/rate_limiting/mod.rs @@ -305,7 +305,6 @@ pub mod middleware { middleware::Next, response::Response, }; - use sqlx; /// Rate limiting middleware pub async fn rate_limit_middleware( @@ -316,12 +315,20 @@ pub mod middleware { // Extract token synchronously from headers (avoids holding &Request across await) let token = extract_bearer_token(&request); +<<<<<<< HEAD // Resolve client_id and populate AuthInfo: DB token lookup, then prefix fallback let auth_info = resolve_auth_info(token, &state).await; let client_id = auth_info.client_id.clone(); +======= + // Resolve token to client ID (with DB lookup if applicable) + let resolution = crate::auth::TokenResolution::resolve(token, &state).await; +>>>>>>> 76e5b9f (perf(auth): eliminate duplicate token resolution database queries) - // Check rate limits - if !state.rate_limit_manager.check_client_request(&client_id).await? { + // Store resolution in request extensions for downstream handlers + request.extensions_mut().insert(resolution.clone()); + + // Check rate limits using the rate-limit client ID + if !state.rate_limit_manager.check_client_request(&resolution.client_id_for_rate_limit).await? { return Err(AppError::RateLimitError("Rate limit exceeded".to_string())); } @@ -339,6 +346,7 @@ pub mod middleware { .map(|t| t.to_string()) } +<<<<<<< HEAD /// Resolve auth info: try DB token first, then fall back to token-prefix derivation async fn resolve_auth_info(token: Option, state: &AppState) -> AuthInfo { if let Some(token) = token { @@ -374,6 +382,8 @@ pub mod middleware { } } +======= +>>>>>>> 76e5b9f (perf(auth): eliminate duplicate token resolution database queries) /// Circuit breaker middleware for provider requests pub async fn circuit_breaker_middleware(provider_name: &str, state: &AppState) -> Result<(), AppError> { if !state.rate_limit_manager.check_provider_request(provider_name).await? { diff --git a/src/server/mod.rs b/src/server/mod.rs index d91520ff..8c21a8ae 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -121,6 +121,7 @@ async fn chat_completions( auth: AuthenticatedClient, Json(mut request): Json, ) -> Result { +<<<<<<< HEAD let client_id = auth.client_id.clone(); let token = auth.token.clone(); @@ -131,6 +132,28 @@ async fn chat_completions( return Err(AppError::AuthError("Invalid authentication token".to_string())); } } +======= + // Use the db_client_id from the AuthenticatedClient (already resolved by middleware) + let db_client_id = auth.db_client_id.clone(); + + let client_id = if let Some(cid) = db_client_id { + // Update last_used_at in background (fire-and-forget) for DB tokens + let pool = state.db_pool.clone(); + let token = auth.token.clone(); + tokio::spawn(async move { + let _ = sqlx::query("UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ?") + .bind(&token) + .execute(&pool) + .await; + }); + cid + } else if state.auth_tokens.is_empty() || state.auth_tokens.contains(&auth.token) { + // Env token match or permissive mode (no env tokens configured) + auth.client_id.clone() + } else { + return Err(AppError::AuthError("Invalid authentication token".to_string())); + }; +>>>>>>> 76e5b9f (perf(auth): eliminate duplicate token resolution database queries) let start_time = std::time::Instant::now(); let model = request.model.clone();