feat(security): implement AES-256-GCM encryption for API keys and HMAC-signed session tokens
This commit introduces: - AES-256-GCM encryption for LLM provider API keys in the database. - HMAC-SHA256 signed session tokens with activity-based refresh logic. - Standardized frontend XSS protection using a global escapeHtml utility. - Hardened security headers and request body size limits. - Improved database integrity with foreign key enforcement and atomic transactions. - Integration tests for the full encrypted key storage and proxy usage lifecycle.
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
use axum::{extract::State, response::Json};
|
||||
use axum::{extract::State, http::{HeaderMap, HeaderValue}, response::{Json, IntoResponse}};
|
||||
use bcrypt;
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
@@ -64,14 +64,14 @@ pub(super) async fn handle_login(
|
||||
pub(super) async fn handle_auth_status(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
) -> impl IntoResponse {
|
||||
let token = headers
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
if let Some(token) = token
|
||||
&& let Some(session) = state.session_manager.validate_session(token).await
|
||||
&& let Some((session, new_token)) = state.session_manager.validate_session_with_refresh(token).await
|
||||
{
|
||||
// Look up display_name from DB
|
||||
let display_name = sqlx::query_scalar::<_, Option<String>>(
|
||||
@@ -85,17 +85,23 @@ pub(super) async fn handle_auth_status(
|
||||
.flatten()
|
||||
.unwrap_or_else(|| session.username.clone());
|
||||
|
||||
return Json(ApiResponse::success(serde_json::json!({
|
||||
let mut headers = HeaderMap::new();
|
||||
if let Some(refreshed_token) = new_token {
|
||||
if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) {
|
||||
headers.insert("X-Refreshed-Token", header_value);
|
||||
}
|
||||
}
|
||||
return (headers, Json(ApiResponse::success(serde_json::json!({
|
||||
"authenticated": true,
|
||||
"user": {
|
||||
"username": session.username,
|
||||
"name": display_name,
|
||||
"role": session.role
|
||||
}
|
||||
})));
|
||||
}))));
|
||||
}
|
||||
|
||||
Json(ApiResponse::error("Not authenticated".to_string()))
|
||||
(HeaderMap::new(), Json(ApiResponse::error("Not authenticated".to_string())))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -108,7 +114,7 @@ pub(super) async fn handle_change_password(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(payload): Json<ChangePasswordRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
) -> impl IntoResponse {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Extract the authenticated user from the session token
|
||||
@@ -117,14 +123,24 @@ pub(super) async fn handle_change_password(
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
let session = match token {
|
||||
Some(t) => state.session_manager.validate_session(t).await,
|
||||
None => None,
|
||||
let (session, new_token) = match token {
|
||||
Some(t) => match state.session_manager.validate_session_with_refresh(t).await {
|
||||
Some((session, new_token)) => (Some(session), new_token),
|
||||
None => (None, None),
|
||||
},
|
||||
None => (None, None),
|
||||
};
|
||||
|
||||
let mut response_headers = HeaderMap::new();
|
||||
if let Some(refreshed_token) = new_token {
|
||||
if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) {
|
||||
response_headers.insert("X-Refreshed-Token", header_value);
|
||||
}
|
||||
}
|
||||
|
||||
let username = match session {
|
||||
Some(s) => s.username,
|
||||
None => return Json(ApiResponse::error("Not authenticated".to_string())),
|
||||
None => return (response_headers, Json(ApiResponse::error("Not authenticated".to_string()))),
|
||||
};
|
||||
|
||||
let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = ?")
|
||||
@@ -138,7 +154,7 @@ pub(super) async fn handle_change_password(
|
||||
if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) {
|
||||
let new_hash = match bcrypt::hash(&payload.new_password, 12) {
|
||||
Ok(h) => h,
|
||||
Err(_) => return Json(ApiResponse::error("Failed to hash new password".to_string())),
|
||||
Err(_) => return (response_headers, Json(ApiResponse::error("Failed to hash new password".to_string()))),
|
||||
};
|
||||
|
||||
let update_result = sqlx::query(
|
||||
@@ -150,16 +166,16 @@ pub(super) async fn handle_change_password(
|
||||
.await;
|
||||
|
||||
match update_result {
|
||||
Ok(_) => Json(ApiResponse::success(
|
||||
Ok(_) => (response_headers, Json(ApiResponse::success(
|
||||
serde_json::json!({ "message": "Password updated successfully" }),
|
||||
)),
|
||||
Err(e) => Json(ApiResponse::error(format!("Failed to update database: {}", e))),
|
||||
))),
|
||||
Err(e) => (response_headers, Json(ApiResponse::error(format!("Failed to update database: {}", e)))),
|
||||
}
|
||||
} else {
|
||||
Json(ApiResponse::error("Current password incorrect".to_string()))
|
||||
(response_headers, Json(ApiResponse::error("Current password incorrect".to_string())))
|
||||
}
|
||||
}
|
||||
Err(e) => Json(ApiResponse::error(format!("User not found: {}", e))),
|
||||
Err(e) => (response_headers, Json(ApiResponse::error(format!("User not found: {}", e)))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,19 +196,19 @@ pub(super) async fn handle_logout(
|
||||
}
|
||||
|
||||
/// Helper: Extract and validate a session from the Authorization header.
|
||||
/// Returns the Session if valid, or an error response.
|
||||
/// Returns the Session and optional new token if refreshed, or an error response.
|
||||
pub(super) async fn extract_session(
|
||||
state: &DashboardState,
|
||||
headers: &axum::http::HeaderMap,
|
||||
) -> Result<super::sessions::Session, Json<ApiResponse<serde_json::Value>>> {
|
||||
) -> Result<(super::sessions::Session, Option<String>), Json<ApiResponse<serde_json::Value>>> {
|
||||
let token = headers
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
match token {
|
||||
Some(t) => match state.session_manager.validate_session(t).await {
|
||||
Some(session) => Ok(session),
|
||||
Some(t) => match state.session_manager.validate_session_with_refresh(t).await {
|
||||
Some((session, new_token)) => Ok((session, new_token)),
|
||||
None => Err(Json(ApiResponse::error("Session expired or invalid".to_string()))),
|
||||
},
|
||||
None => Err(Json(ApiResponse::error("Not authenticated".to_string()))),
|
||||
@@ -200,13 +216,14 @@ pub(super) async fn extract_session(
|
||||
}
|
||||
|
||||
/// Helper: Extract session and require admin role.
|
||||
/// Returns session and optional new token if refreshed.
|
||||
pub(super) async fn require_admin(
|
||||
state: &DashboardState,
|
||||
headers: &axum::http::HeaderMap,
|
||||
) -> Result<super::sessions::Session, Json<ApiResponse<serde_json::Value>>> {
|
||||
let session = extract_session(state, headers).await?;
|
||||
) -> Result<(super::sessions::Session, Option<String>), Json<ApiResponse<serde_json::Value>>> {
|
||||
let (session, new_token) = extract_session(state, headers).await?;
|
||||
if session.role != "admin" {
|
||||
return Err(Json(ApiResponse::error("Admin access required".to_string())));
|
||||
}
|
||||
Ok(session)
|
||||
Ok((session, new_token))
|
||||
}
|
||||
|
||||
@@ -88,9 +88,10 @@ pub(super) async fn handle_create_client(
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(payload): Json<CreateClientRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
@@ -198,9 +199,10 @@ pub(super) async fn handle_update_client(
|
||||
Path(id): Path<String>,
|
||||
Json(payload): Json<UpdateClientPayload>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
@@ -294,9 +296,10 @@ pub(super) async fn handle_delete_client(
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
@@ -437,9 +440,10 @@ pub(super) async fn handle_create_client_token(
|
||||
Path(id): Path<String>,
|
||||
Json(payload): Json<CreateTokenRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
@@ -485,9 +489,10 @@ pub(super) async fn handle_delete_client_token(
|
||||
headers: axum::http::HeaderMap,
|
||||
Path((client_id, token_id)): Path<(String, i64)>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
|
||||
@@ -11,10 +11,18 @@ mod users;
|
||||
mod websocket;
|
||||
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
Router,
|
||||
routing::{delete, get, post, put},
|
||||
};
|
||||
use axum::http::{header, HeaderValue};
|
||||
use serde::Serialize;
|
||||
use tower_http::{
|
||||
limit::RequestBodyLimitLayer,
|
||||
set_header::SetResponseHeaderLayer,
|
||||
};
|
||||
|
||||
use crate::state::AppState;
|
||||
use sessions::SessionManager;
|
||||
@@ -52,6 +60,21 @@ impl<T> ApiResponse<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Rate limiting middleware for dashboard routes that extracts AppState from DashboardState.
|
||||
async fn dashboard_rate_limit_middleware(
|
||||
State(dashboard_state): State<DashboardState>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, crate::errors::AppError> {
|
||||
// Delegate to the existing rate limit middleware with AppState
|
||||
crate::rate_limiting::middleware::rate_limit_middleware(
|
||||
State(dashboard_state.app_state),
|
||||
request,
|
||||
next,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
// Dashboard routes
|
||||
pub fn router(state: AppState) -> Router {
|
||||
let session_manager = SessionManager::new(24); // 24-hour session TTL
|
||||
@@ -60,6 +83,26 @@ pub fn router(state: AppState) -> Router {
|
||||
session_manager,
|
||||
};
|
||||
|
||||
// Security headers
|
||||
let csp_header: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
||||
header::CONTENT_SECURITY_POLICY,
|
||||
"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws:;"
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
let x_frame_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
||||
header::X_FRAME_OPTIONS,
|
||||
"DENY".parse().unwrap(),
|
||||
);
|
||||
let x_content_type_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
||||
header::X_CONTENT_TYPE_OPTIONS,
|
||||
"nosniff".parse().unwrap(),
|
||||
);
|
||||
let strict_transport_security: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
||||
header::STRICT_TRANSPORT_SECURITY,
|
||||
"max-age=31536000; includeSubDomains".parse().unwrap(),
|
||||
);
|
||||
|
||||
Router::new()
|
||||
// Static file serving
|
||||
.fallback_service(tower_http::services::ServeDir::new("static"))
|
||||
@@ -119,5 +162,16 @@ pub fn router(state: AppState) -> Router {
|
||||
"/api/system/settings",
|
||||
get(system::handle_get_settings).post(system::handle_update_settings),
|
||||
)
|
||||
// Security layers
|
||||
.layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit
|
||||
.layer(csp_header)
|
||||
.layer(x_frame_options)
|
||||
.layer(x_content_type_options)
|
||||
.layer(strict_transport_security)
|
||||
// Rate limiting middleware
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
dashboard_state.clone(),
|
||||
dashboard_rate_limit_middleware,
|
||||
))
|
||||
.with_state(dashboard_state)
|
||||
}
|
||||
|
||||
@@ -156,9 +156,10 @@ pub(super) async fn handle_update_model(
|
||||
Path(id): Path<String>,
|
||||
Json(payload): Json<UpdateModelRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ use std::collections::HashMap;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
use crate::utils::crypto;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct UpdateProviderRequest {
|
||||
@@ -265,21 +266,44 @@ pub(super) async fn handle_update_provider(
|
||||
Path(name): Path<String>,
|
||||
Json(payload): Json<UpdateProviderRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Update or insert into database (include billing_mode)
|
||||
// Prepare API key encryption if provided
|
||||
let (api_key_to_store, api_key_encrypted_flag) = match &payload.api_key {
|
||||
Some(key) if !key.is_empty() => {
|
||||
match crypto::encrypt(key) {
|
||||
Ok(encrypted) => (Some(encrypted), Some(true)),
|
||||
Err(e) => {
|
||||
warn!("Failed to encrypt API key for provider {}: {}", name, e);
|
||||
return Json(ApiResponse::error(format!("Failed to encrypt API key: {}", e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(_) => {
|
||||
// Empty string means clear the key
|
||||
(None, Some(false))
|
||||
}
|
||||
None => {
|
||||
// Keep existing key, we'll rely on COALESCE in SQL
|
||||
(None, None)
|
||||
}
|
||||
};
|
||||
|
||||
// Update or insert into database (include billing_mode and api_key_encrypted)
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold, billing_mode)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
enabled = excluded.enabled,
|
||||
base_url = excluded.base_url,
|
||||
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
|
||||
api_key_encrypted = COALESCE(excluded.api_key_encrypted, provider_configs.api_key_encrypted),
|
||||
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
|
||||
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
|
||||
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
|
||||
@@ -290,7 +314,8 @@ pub(super) async fn handle_update_provider(
|
||||
.bind(name.to_uppercase())
|
||||
.bind(payload.enabled)
|
||||
.bind(&payload.base_url)
|
||||
.bind(&payload.api_key)
|
||||
.bind(&api_key_to_store)
|
||||
.bind(api_key_encrypted_flag)
|
||||
.bind(payload.credit_balance)
|
||||
.bind(payload.low_credit_threshold)
|
||||
.bind(payload.billing_mode)
|
||||
|
||||
@@ -1,7 +1,17 @@
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use hmac::{Hmac, Mac};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Sha256, digest::generic_array::GenericArray};
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
|
||||
|
||||
const TOKEN_VERSION: &str = "v2";
|
||||
const REFRESH_WINDOW_MINUTES: i64 = 15; // refresh if token expires within 15 minutes
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Session {
|
||||
@@ -9,51 +19,136 @@ pub struct Session {
|
||||
pub role: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
pub session_id: String, // unique identifier for the session (UUID)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SessionManager {
|
||||
sessions: Arc<RwLock<HashMap<String, Session>>>,
|
||||
sessions: Arc<RwLock<HashMap<String, Session>>>, // key = session_id
|
||||
ttl_hours: i64,
|
||||
secret: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct SessionPayload {
|
||||
session_id: String,
|
||||
username: String,
|
||||
role: String,
|
||||
iat: i64, // issued at (Unix timestamp)
|
||||
exp: i64, // expiry (Unix timestamp)
|
||||
version: String,
|
||||
}
|
||||
|
||||
impl SessionManager {
|
||||
pub fn new(ttl_hours: i64) -> Self {
|
||||
let secret = load_session_secret();
|
||||
Self {
|
||||
sessions: Arc::new(RwLock::new(HashMap::new())),
|
||||
ttl_hours,
|
||||
secret,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new session and return the session token.
|
||||
/// Create a new session and return a signed session token.
|
||||
pub async fn create_session(&self, username: String, role: String) -> String {
|
||||
let token = format!("session-{}", uuid::Uuid::new_v4());
|
||||
let session_id = Uuid::new_v4().to_string();
|
||||
let now = Utc::now();
|
||||
let expires_at = now + Duration::hours(self.ttl_hours);
|
||||
let session = Session {
|
||||
username,
|
||||
role,
|
||||
username: username.clone(),
|
||||
role: role.clone(),
|
||||
created_at: now,
|
||||
expires_at: now + Duration::hours(self.ttl_hours),
|
||||
expires_at,
|
||||
session_id: session_id.clone(),
|
||||
};
|
||||
self.sessions.write().await.insert(token.clone(), session);
|
||||
token
|
||||
// Store session by session_id
|
||||
self.sessions.write().await.insert(session_id.clone(), session);
|
||||
// Create signed token
|
||||
self.create_signed_token(&session_id, &username, &role, now.timestamp(), expires_at.timestamp())
|
||||
}
|
||||
|
||||
/// Validate a session token and return the session if valid and not expired.
|
||||
/// If the token is within the refresh window, returns a new token as well.
|
||||
pub async fn validate_session(&self, token: &str) -> Option<Session> {
|
||||
self.validate_session_with_refresh(token).await.map(|(session, _)| session)
|
||||
}
|
||||
|
||||
/// Validate a session token and return (session, optional new token if refreshed).
|
||||
pub async fn validate_session_with_refresh(&self, token: &str) -> Option<(Session, Option<String>)> {
|
||||
// Legacy token format (UUID)
|
||||
if token.starts_with("session-") {
|
||||
let sessions = self.sessions.read().await;
|
||||
return sessions.get(token).and_then(|s| {
|
||||
if s.expires_at > Utc::now() {
|
||||
Some((s.clone(), None))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Signed token format
|
||||
let payload = match verify_signed_token(token, &self.secret) {
|
||||
Ok(p) => p,
|
||||
Err(_) => return None,
|
||||
};
|
||||
|
||||
// Check expiry
|
||||
let now = Utc::now().timestamp();
|
||||
if payload.exp <= now {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Look up session by session_id
|
||||
let sessions = self.sessions.read().await;
|
||||
sessions.get(token).and_then(|s| {
|
||||
if s.expires_at > Utc::now() {
|
||||
Some(s.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
let session = match sessions.get(&payload.session_id) {
|
||||
Some(s) => s.clone(),
|
||||
None => return None, // session revoked or not found
|
||||
};
|
||||
|
||||
// Ensure session username/role matches (should always match)
|
||||
if session.username != payload.username || session.role != payload.role {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if token is within refresh window (last REFRESH_WINDOW_MINUTES of validity)
|
||||
let refresh_threshold = payload.exp - REFRESH_WINDOW_MINUTES * 60;
|
||||
let new_token = if now >= refresh_threshold {
|
||||
// Generate a new token with same session data but updated iat/exp?
|
||||
// According to activity-based refresh, we should extend the session expiry.
|
||||
// We'll extend from now by ttl_hours (or keep original expiry?).
|
||||
// Let's extend from now by ttl_hours (sliding window).
|
||||
let new_exp = Utc::now() + Duration::hours(self.ttl_hours);
|
||||
// Update session expiry in store
|
||||
drop(sessions); // release read lock before acquiring write lock
|
||||
self.update_session_expiry(&payload.session_id, new_exp).await;
|
||||
// Create new token with updated iat/exp
|
||||
let new_token = self.create_signed_token(
|
||||
&payload.session_id,
|
||||
&payload.username,
|
||||
&payload.role,
|
||||
now,
|
||||
new_exp.timestamp(),
|
||||
);
|
||||
Some(new_token)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Some((session, new_token))
|
||||
}
|
||||
|
||||
/// Revoke (delete) a session by token.
|
||||
/// Supports both legacy tokens (token is key) and signed tokens (extract session_id).
|
||||
pub async fn revoke_session(&self, token: &str) {
|
||||
self.sessions.write().await.remove(token);
|
||||
if token.starts_with("session-") {
|
||||
self.sessions.write().await.remove(token);
|
||||
return;
|
||||
}
|
||||
// For signed token, try to extract session_id
|
||||
if let Ok(payload) = verify_signed_token(token, &self.secret) {
|
||||
self.sessions.write().await.remove(&payload.session_id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove all expired sessions from the store.
|
||||
@@ -61,4 +156,156 @@ impl SessionManager {
|
||||
let now = Utc::now();
|
||||
self.sessions.write().await.retain(|_, s| s.expires_at > now);
|
||||
}
|
||||
|
||||
// --- Private helpers ---
|
||||
|
||||
fn create_signed_token(&self, session_id: &str, username: &str, role: &str, iat: i64, exp: i64) -> String {
|
||||
let payload = SessionPayload {
|
||||
session_id: session_id.to_string(),
|
||||
username: username.to_string(),
|
||||
role: role.to_string(),
|
||||
iat,
|
||||
exp,
|
||||
version: TOKEN_VERSION.to_string(),
|
||||
};
|
||||
sign_token(&payload, &self.secret)
|
||||
}
|
||||
|
||||
async fn update_session_expiry(&self, session_id: &str, new_expires_at: DateTime<Utc>) {
|
||||
let mut sessions = self.sessions.write().await;
|
||||
if let Some(session) = sessions.get_mut(session_id) {
|
||||
session.expires_at = new_expires_at;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Load session secret from environment variable SESSION_SECRET (hex or base64 encoded).
|
||||
/// If not set, generates a random 32-byte secret and logs a warning.
|
||||
fn load_session_secret() -> Vec<u8> {
|
||||
let secret_str = env::var("SESSION_SECRET").unwrap_or_else(|_| {
|
||||
// Also check LLM_PROXY__SESSION_SECRET for consistency with config prefix
|
||||
env::var("LLM_PROXY__SESSION_SECRET").unwrap_or_else(|_| {
|
||||
// Generate a random secret (32 bytes) and encode as hex
|
||||
use rand::RngCore;
|
||||
let mut bytes = [0u8; 32];
|
||||
rand::rng().fill_bytes(&mut bytes);
|
||||
let hex_secret = hex::encode(bytes);
|
||||
tracing::warn!(
|
||||
"SESSION_SECRET environment variable not set. Using a randomly generated secret. \
|
||||
This will invalidate all sessions on restart. Set SESSION_SECRET to a fixed hex or base64 encoded 32-byte value."
|
||||
);
|
||||
hex_secret
|
||||
})
|
||||
});
|
||||
|
||||
// Decode hex or base64
|
||||
hex::decode(&secret_str)
|
||||
.or_else(|_| URL_SAFE.decode(&secret_str))
|
||||
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&secret_str))
|
||||
.unwrap_or_else(|_| {
|
||||
panic!("SESSION_SECRET must be hex or base64 encoded (32 bytes)");
|
||||
})
|
||||
}
|
||||
|
||||
/// Sign a session payload and return a token string in format base64_url(payload).base64_url(signature).
|
||||
fn sign_token(payload: &SessionPayload, secret: &[u8]) -> String {
|
||||
let json = serde_json::to_vec(payload).expect("Failed to serialize payload");
|
||||
let payload_b64 = URL_SAFE.encode(&json);
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
|
||||
mac.update(&json);
|
||||
let signature = mac.finalize().into_bytes();
|
||||
let signature_b64 = URL_SAFE.encode(signature);
|
||||
format!("{}.{}", payload_b64, signature_b64)
|
||||
}
|
||||
|
||||
/// Verify a signed token and return the decoded payload if valid.
|
||||
fn verify_signed_token(token: &str, secret: &[u8]) -> Result<SessionPayload, TokenError> {
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err(TokenError::InvalidFormat);
|
||||
}
|
||||
let payload_b64 = parts[0];
|
||||
let signature_b64 = parts[1];
|
||||
|
||||
let json = URL_SAFE.decode(payload_b64).map_err(|_| TokenError::InvalidFormat)?;
|
||||
let signature = URL_SAFE.decode(signature_b64).map_err(|_| TokenError::InvalidFormat)?;
|
||||
|
||||
// Verify HMAC
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
|
||||
mac.update(&json);
|
||||
// Convert signature slice to GenericArray
|
||||
let tag = GenericArray::from_slice(&signature);
|
||||
mac.verify(tag).map_err(|_| TokenError::InvalidSignature)?;
|
||||
|
||||
// Deserialize payload
|
||||
let payload: SessionPayload = serde_json::from_slice(&json).map_err(|_| TokenError::InvalidPayload)?;
|
||||
Ok(payload)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum TokenError {
|
||||
InvalidFormat,
|
||||
InvalidSignature,
|
||||
InvalidPayload,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::env;
|
||||
|
||||
#[test]
|
||||
fn test_sign_and_verify_token() {
|
||||
let secret = b"test-secret-must-be-32-bytes-long!";
|
||||
let payload = SessionPayload {
|
||||
session_id: "test-session".to_string(),
|
||||
username: "testuser".to_string(),
|
||||
role: "user".to_string(),
|
||||
iat: 1000,
|
||||
exp: 2000,
|
||||
version: TOKEN_VERSION.to_string(),
|
||||
};
|
||||
let token = sign_token(&payload, secret);
|
||||
let verified = verify_signed_token(&token, secret).unwrap();
|
||||
assert_eq!(verified.session_id, payload.session_id);
|
||||
assert_eq!(verified.username, payload.username);
|
||||
assert_eq!(verified.role, payload.role);
|
||||
assert_eq!(verified.iat, payload.iat);
|
||||
assert_eq!(verified.exp, payload.exp);
|
||||
assert_eq!(verified.version, payload.version);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tampered_token() {
|
||||
let secret = b"test-secret-must-be-32-bytes-long!";
|
||||
let payload = SessionPayload {
|
||||
session_id: "test-session".to_string(),
|
||||
username: "testuser".to_string(),
|
||||
role: "user".to_string(),
|
||||
iat: 1000,
|
||||
exp: 2000,
|
||||
version: TOKEN_VERSION.to_string(),
|
||||
};
|
||||
let mut token = sign_token(&payload, secret);
|
||||
// Tamper with payload part
|
||||
let mut parts: Vec<&str> = token.split('.').collect();
|
||||
let mut payload_bytes = URL_SAFE.decode(parts[0]).unwrap();
|
||||
payload_bytes[0] ^= 0xFF; // flip some bits
|
||||
let tampered_payload = URL_SAFE.encode(payload_bytes);
|
||||
parts[0] = &tampered_payload;
|
||||
token = parts.join(".");
|
||||
assert!(verify_signed_token(&token, secret).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_session_secret_from_env() {
|
||||
unsafe {
|
||||
env::set_var("SESSION_SECRET", hex::encode([0xAA; 32]));
|
||||
}
|
||||
let secret = load_session_secret();
|
||||
assert_eq!(secret, vec![0xAA; 32]);
|
||||
unsafe {
|
||||
env::remove_var("SESSION_SECRET");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -279,9 +279,10 @@ pub(super) async fn handle_system_backup(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
let backup_id = format!("backup-{}", chrono::Utc::now().timestamp());
|
||||
@@ -341,9 +342,10 @@ pub(super) async fn handle_update_settings(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
Json(ApiResponse::error(
|
||||
"Changing settings at runtime is not yet supported. Please update your config file and restart the server."
|
||||
|
||||
@@ -14,9 +14,10 @@ pub(super) async fn handle_get_users(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (session, _) = match auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
@@ -66,9 +67,10 @@ pub(super) async fn handle_create_user(
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(payload): Json<CreateUserRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (session, _) = match auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
@@ -147,9 +149,10 @@ pub(super) async fn handle_update_user(
|
||||
Path(id): Path<i64>,
|
||||
Json(payload): Json<UpdateUserRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (session, _) = match auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
@@ -249,8 +252,8 @@ pub(super) async fn handle_delete_user(
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<i64>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let session = match auth::require_admin(&state, &headers).await {
|
||||
Ok(s) => s,
|
||||
let (session, _) = match auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user