feat: enforce master token authentication and reasoning support
- Added strict token validation against LLM_PROXY__SERVER__AUTH_TOKENS. - Integrated 'reasoning_content' support into providers and server responses. - Updated AppState to carry valid auth tokens for request-time validation.
This commit is contained in:
@@ -17,6 +17,10 @@ where
|
||||
type Rejection = AppError;
|
||||
|
||||
fn from_request_parts(parts: &mut Parts, state: &S) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
|
||||
// We need access to the AppState to get valid tokens
|
||||
// Since state is generic here, we try to cast it or assume it's available via extensions
|
||||
// In this project, AppState is cloned into Axum state.
|
||||
|
||||
async move {
|
||||
// Extract bearer token from Authorization header
|
||||
let TypedHeader(Authorization(bearer)) =
|
||||
@@ -26,13 +30,15 @@ where
|
||||
|
||||
let token = bearer.token().to_string();
|
||||
|
||||
// In a real implementation, we would:
|
||||
// 1. Validate token against database or config
|
||||
// 2. Look up client_id associated with token
|
||||
// 3. Check token permissions/rate limits
|
||||
|
||||
// For now, use token hash as client_id
|
||||
let client_id = format!("client_{}", &token[..8]);
|
||||
// For a proxy, we want to check if this token is in our allowed list
|
||||
// The list is stored in AppState which is available in Parts extensions
|
||||
let client_id = {
|
||||
// In main.rs, we set up the router with State(state).
|
||||
// However, in from_request_parts, we usually look in extensions or use the state if S is AppState.
|
||||
// For now, let's derive the client_id and allow the server logic to handle the lookup if needed,
|
||||
// but a better way is to validate here.
|
||||
format!("client_{}", &token[..8.min(token.len())])
|
||||
};
|
||||
|
||||
Ok(AuthenticatedClient { token, client_id })
|
||||
}
|
||||
|
||||
@@ -69,13 +69,17 @@ pub mod test_utils {
|
||||
providers: std::collections::HashMap::new(),
|
||||
};
|
||||
|
||||
let (dashboard_tx, _) = tokio::sync::broadcast::channel(100);
|
||||
|
||||
Arc::new(AppState {
|
||||
provider_manager,
|
||||
db_pool: pool.clone(),
|
||||
rate_limit_manager: Arc::new(rate_limit_manager),
|
||||
client_manager,
|
||||
request_logger: Arc::new(crate::logging::RequestLogger::new(pool.clone())),
|
||||
request_logger: Arc::new(crate::logging::RequestLogger::new(pool.clone(), dashboard_tx.clone())),
|
||||
model_registry: Arc::new(model_registry),
|
||||
dashboard_tx,
|
||||
auth_tokens: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ async fn main() -> Result<()> {
|
||||
};
|
||||
|
||||
// Create application state
|
||||
let state = AppState::new(provider_manager, db_pool, rate_limit_manager, model_registry);
|
||||
let state = AppState::new(provider_manager, db_pool, rate_limit_manager, model_registry, config.server.auth_tokens.clone());
|
||||
|
||||
// Create application router
|
||||
let app = Router::new()
|
||||
|
||||
@@ -32,6 +32,11 @@ async fn chat_completions(
|
||||
auth: AuthenticatedClient,
|
||||
Json(request): Json<ChatCompletionRequest>,
|
||||
) -> Result<axum::response::Response, AppError> {
|
||||
// Validate token against configured auth tokens
|
||||
if !state.auth_tokens.is_empty() && !state.auth_tokens.contains(&auth.token) {
|
||||
return Err(AppError::AuthError("Invalid authentication token".to_string()));
|
||||
}
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
let client_id = auth.client_id.clone();
|
||||
let model = request.model.clone();
|
||||
|
||||
@@ -17,6 +17,7 @@ pub struct AppState {
|
||||
pub request_logger: Arc<RequestLogger>,
|
||||
pub model_registry: Arc<ModelRegistry>,
|
||||
pub dashboard_tx: broadcast::Sender<serde_json::Value>,
|
||||
pub auth_tokens: Vec<String>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
@@ -25,6 +26,7 @@ impl AppState {
|
||||
db_pool: DbPool,
|
||||
rate_limit_manager: RateLimitManager,
|
||||
model_registry: ModelRegistry,
|
||||
auth_tokens: Vec<String>,
|
||||
) -> Self {
|
||||
let client_manager = Arc::new(ClientManager::new(db_pool.clone()));
|
||||
let (dashboard_tx, _) = broadcast::channel(100);
|
||||
@@ -38,6 +40,7 @@ impl AppState {
|
||||
request_logger,
|
||||
model_registry: Arc::new(model_registry),
|
||||
dashboard_tx,
|
||||
auth_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user