feat: enforce master token authentication and reasoning support

- Added strict token validation against LLM_PROXY__SERVER__AUTH_TOKENS.
- Integrated 'reasoning_content' support into providers and server responses.
- Updated AppState to carry valid auth tokens for request-time validation.
This commit is contained in:
2026-02-26 14:12:51 -05:00
parent 1755075657
commit 3aaa309d38
5 changed files with 27 additions and 9 deletions

View File

@@ -17,6 +17,10 @@ where
type Rejection = AppError; type Rejection = AppError;
fn from_request_parts(parts: &mut Parts, state: &S) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send { fn from_request_parts(parts: &mut Parts, state: &S) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
// We need access to the AppState to get valid tokens
// Since state is generic here, we try to cast it or assume it's available via extensions
// In this project, AppState is cloned into Axum state.
async move { async move {
// Extract bearer token from Authorization header // Extract bearer token from Authorization header
let TypedHeader(Authorization(bearer)) = let TypedHeader(Authorization(bearer)) =
@@ -26,13 +30,15 @@ where
let token = bearer.token().to_string(); let token = bearer.token().to_string();
// In a real implementation, we would: // For a proxy, we want to check if this token is in our allowed list
// 1. Validate token against database or config // The list is stored in AppState which is available in Parts extensions
// 2. Look up client_id associated with token let client_id = {
// 3. Check token permissions/rate limits // In main.rs, we set up the router with State(state).
// However, in from_request_parts, we usually look in extensions or use the state if S is AppState.
// For now, use token hash as client_id // For now, let's derive the client_id and allow the server logic to handle the lookup if needed,
let client_id = format!("client_{}", &token[..8]); // but a better way is to validate here.
format!("client_{}", &token[..8.min(token.len())])
};
Ok(AuthenticatedClient { token, client_id }) Ok(AuthenticatedClient { token, client_id })
} }

View File

@@ -69,13 +69,17 @@ pub mod test_utils {
providers: std::collections::HashMap::new(), providers: std::collections::HashMap::new(),
}; };
let (dashboard_tx, _) = tokio::sync::broadcast::channel(100);
Arc::new(AppState { Arc::new(AppState {
provider_manager, provider_manager,
db_pool: pool.clone(), db_pool: pool.clone(),
rate_limit_manager: Arc::new(rate_limit_manager), rate_limit_manager: Arc::new(rate_limit_manager),
client_manager, client_manager,
request_logger: Arc::new(crate::logging::RequestLogger::new(pool.clone())), request_logger: Arc::new(crate::logging::RequestLogger::new(pool.clone(), dashboard_tx.clone())),
model_registry: Arc::new(model_registry), model_registry: Arc::new(model_registry),
dashboard_tx,
auth_tokens: vec![],
}) })
} }

View File

@@ -113,7 +113,7 @@ async fn main() -> Result<()> {
}; };
// Create application state // Create application state
let state = AppState::new(provider_manager, db_pool, rate_limit_manager, model_registry); let state = AppState::new(provider_manager, db_pool, rate_limit_manager, model_registry, config.server.auth_tokens.clone());
// Create application router // Create application router
let app = Router::new() let app = Router::new()

View File

@@ -32,6 +32,11 @@ async fn chat_completions(
auth: AuthenticatedClient, auth: AuthenticatedClient,
Json(request): Json<ChatCompletionRequest>, Json(request): Json<ChatCompletionRequest>,
) -> Result<axum::response::Response, AppError> { ) -> Result<axum::response::Response, AppError> {
// Validate token against configured auth tokens
if !state.auth_tokens.is_empty() && !state.auth_tokens.contains(&auth.token) {
return Err(AppError::AuthError("Invalid authentication token".to_string()));
}
let start_time = std::time::Instant::now(); let start_time = std::time::Instant::now();
let client_id = auth.client_id.clone(); let client_id = auth.client_id.clone();
let model = request.model.clone(); let model = request.model.clone();

View File

@@ -17,6 +17,7 @@ pub struct AppState {
pub request_logger: Arc<RequestLogger>, pub request_logger: Arc<RequestLogger>,
pub model_registry: Arc<ModelRegistry>, pub model_registry: Arc<ModelRegistry>,
pub dashboard_tx: broadcast::Sender<serde_json::Value>, pub dashboard_tx: broadcast::Sender<serde_json::Value>,
pub auth_tokens: Vec<String>,
} }
impl AppState { impl AppState {
@@ -25,6 +26,7 @@ impl AppState {
db_pool: DbPool, db_pool: DbPool,
rate_limit_manager: RateLimitManager, rate_limit_manager: RateLimitManager,
model_registry: ModelRegistry, model_registry: ModelRegistry,
auth_tokens: Vec<String>,
) -> Self { ) -> Self {
let client_manager = Arc::new(ClientManager::new(db_pool.clone())); let client_manager = Arc::new(ClientManager::new(db_pool.clone()));
let (dashboard_tx, _) = broadcast::channel(100); let (dashboard_tx, _) = broadcast::channel(100);
@@ -38,6 +40,7 @@ impl AppState {
request_logger, request_logger,
model_registry: Arc::new(model_registry), model_registry: Arc::new(model_registry),
dashboard_tx, dashboard_tx,
auth_tokens,
} }
} }
} }