diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 8bb3ceef..e370cec1 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -1,33 +1,40 @@ use axum::{extract::FromRequestParts, http::request::Parts}; -use axum_extra::TypedHeader; -use axum_extra::headers::Authorization; -use headers::authorization::Bearer; use crate::errors::AppError; -pub struct AuthenticatedClient { +#[derive(Debug, Clone)] +pub struct AuthInfo { pub token: String, pub client_id: String, } +pub struct AuthenticatedClient { + pub info: AuthInfo, +} + impl FromRequestParts for AuthenticatedClient where S: Send + Sync, { type Rejection = AppError; - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - // Extract bearer token from Authorization header - let TypedHeader(Authorization(bearer)) = TypedHeader::>::from_request_parts(parts, state) - .await - .map_err(|_| AppError::AuthError("Missing or invalid bearer token".to_string()))?; + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + // Retrieve AuthInfo from request extensions, where it was placed by rate_limit_middleware + let info = parts + .extensions + .get::() + .cloned() + .ok_or_else(|| AppError::AuthError("Authentication info not found in request".to_string()))?; - let token = bearer.token().to_string(); + Ok(AuthenticatedClient { info }) + } +} - // Derive client_id from the token prefix - let client_id = format!("client_{}", &token[..8.min(token.len())]); +impl std::ops::Deref for AuthenticatedClient { + type Target = AuthInfo; - Ok(AuthenticatedClient { token, client_id }) + fn deref(&self) -> &Self::Target { + &self.info } } diff --git a/src/rate_limiting/mod.rs b/src/rate_limiting/mod.rs index fbdc1fcb..c5e22e14 100644 --- a/src/rate_limiting/mod.rs +++ b/src/rate_limiting/mod.rs @@ -299,6 +299,7 @@ pub mod middleware { use super::*; use crate::errors::AppError; use crate::state::AppState; + use crate::auth::AuthInfo; use axum::{ extract::{Request, State}, middleware::Next, @@ -309,20 +310,24 @@ pub mod middleware { /// Rate limiting middleware pub async fn rate_limit_middleware( State(state): State, - request: Request, + mut request: Request, next: Next, ) -> Result { // Extract token synchronously from headers (avoids holding &Request across await) let token = extract_bearer_token(&request); - // Resolve client_id: DB token lookup, then prefix fallback - let client_id = resolve_client_id(token, &state).await; + // Resolve client_id and populate AuthInfo: DB token lookup, then prefix fallback + let auth_info = resolve_auth_info(token, &state).await; + let client_id = auth_info.client_id.clone(); // Check rate limits if !state.rate_limit_manager.check_client_request(&client_id).await? { return Err(AppError::RateLimitError("Rate limit exceeded".to_string())); } + // Store AuthInfo in request extensions for extractors and downstream handlers + request.extensions_mut().insert(auth_info); + Ok(next.run(request).await) } @@ -334,26 +339,39 @@ pub mod middleware { .map(|t| t.to_string()) } - /// Resolve client ID: try DB token first, then fall back to token-prefix derivation - async fn resolve_client_id(token: Option, state: &AppState) -> String { + /// Resolve auth info: try DB token first, then fall back to token-prefix derivation + async fn resolve_auth_info(token: Option, state: &AppState) -> AuthInfo { if let Some(token) = token { // Try DB token lookup first - if let Ok(Some(cid)) = sqlx::query_scalar::<_, String>( - "SELECT client_id FROM client_tokens WHERE token = ? AND is_active = TRUE", + match sqlx::query_scalar::<_, String>( + "UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = TRUE RETURNING client_id", ) .bind(&token) .fetch_optional(&state.db_pool) .await { - return cid; + Ok(Some(cid)) => { + return AuthInfo { + token, + client_id: cid, + }; + } + Err(e) => { + warn!("DB error during token lookup: {}", e); + } + _ => {} } // Fallback to token-prefix derivation (env tokens / permissive mode) - return format!("client_{}", &token[..8.min(token.len())]); + let client_id = format!("client_{}", &token[..8.min(token.len())]); + return AuthInfo { token, client_id }; } // No token — anonymous - "anonymous".to_string() + AuthInfo { + token: String::new(), + client_id: "anonymous".to_string(), + } } /// Circuit breaker middleware for provider requests diff --git a/src/server/mod.rs b/src/server/mod.rs index 2b6b68e1..b00095a1 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -122,32 +122,16 @@ async fn chat_completions( auth: AuthenticatedClient, Json(mut request): Json, ) -> Result { - // Resolve client_id: try DB token first, then env tokens, then permissive fallback - let db_client_id: Option = sqlx::query_scalar::<_, String>( - "SELECT client_id FROM client_tokens WHERE token = ? AND is_active = TRUE", - ) - .bind(&auth.token) - .fetch_optional(&state.db_pool) - .await - .unwrap_or(None); + let client_id = auth.client_id.clone(); + let token = auth.token.clone(); - let client_id = if let Some(cid) = db_client_id { - // Update last_used_at in background (fire-and-forget) - let pool = state.db_pool.clone(); - let token = auth.token.clone(); - tokio::spawn(async move { - let _ = sqlx::query("UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ?") - .bind(&token) - .execute(&pool) - .await; - }); - cid - } else if state.auth_tokens.is_empty() || state.auth_tokens.contains(&auth.token) { - // Env token match or permissive mode (no env tokens configured) - auth.client_id.clone() - } else { - return Err(AppError::AuthError("Invalid authentication token".to_string())); - }; + // Verify token if env tokens are configured + if !state.auth_tokens.is_empty() && !state.auth_tokens.contains(&token) { + // If not in env tokens, check if it was a DB token (client_id wouldn't be client_XXXX prefix) + if client_id.starts_with("client_") { + return Err(AppError::AuthError("Invalid authentication token".to_string())); + } + } let start_time = std::time::Instant::now(); let model = request.model.clone();