Add client_tokens table with auto-generated sk-{hex} tokens so clients
created in the dashboard get working API keys. Auth flow: DB token lookup
first, then env token fallback, then permissive mode. Includes token
management CRUD endpoints and copy-once reveal modal in the frontend.
370 lines
12 KiB
Rust
370 lines
12 KiB
Rust
//! Rate limiting and circuit breaking for LLM proxy
|
|
//!
|
|
//! This module provides:
|
|
//! 1. Per-client rate limiting using governor crate
|
|
//! 2. Provider circuit breaking to handle API failures
|
|
//! 3. Global rate limiting for overall system protection
|
|
|
|
use anyhow::Result;
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
use tokio::sync::RwLock;
|
|
use tracing::{info, warn};
|
|
|
|
/// Rate limiter configuration
|
|
#[derive(Debug, Clone)]
|
|
pub struct RateLimiterConfig {
|
|
/// Requests per minute per client
|
|
pub requests_per_minute: u32,
|
|
/// Burst size (maximum burst capacity)
|
|
pub burst_size: u32,
|
|
/// Global requests per minute (across all clients)
|
|
pub global_requests_per_minute: u32,
|
|
}
|
|
|
|
impl Default for RateLimiterConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
requests_per_minute: 60, // 1 request per second per client
|
|
burst_size: 10, // Allow bursts of up to 10 requests
|
|
global_requests_per_minute: 600, // 10 requests per second globally
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Circuit breaker state
|
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
|
pub enum CircuitState {
|
|
Closed, // Normal operation
|
|
Open, // Circuit is open, requests fail fast
|
|
HalfOpen, // Testing if service has recovered
|
|
}
|
|
|
|
/// Circuit breaker configuration
|
|
#[derive(Debug, Clone)]
|
|
pub struct CircuitBreakerConfig {
|
|
/// Failure threshold to open circuit
|
|
pub failure_threshold: u32,
|
|
/// Time window for failure counting (seconds)
|
|
pub failure_window_secs: u64,
|
|
/// Time to wait before trying half-open state (seconds)
|
|
pub reset_timeout_secs: u64,
|
|
/// Success threshold to close circuit
|
|
pub success_threshold: u32,
|
|
}
|
|
|
|
impl Default for CircuitBreakerConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
failure_threshold: 5, // 5 failures
|
|
failure_window_secs: 60, // within 60 seconds
|
|
reset_timeout_secs: 30, // wait 30 seconds before half-open
|
|
success_threshold: 3, // 3 successes to close circuit
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Simple token bucket rate limiter for a single client
|
|
#[derive(Debug)]
|
|
struct TokenBucket {
|
|
tokens: f64,
|
|
capacity: f64,
|
|
refill_rate: f64, // tokens per second
|
|
last_refill: Instant,
|
|
}
|
|
|
|
impl TokenBucket {
|
|
fn new(capacity: f64, refill_rate: f64) -> Self {
|
|
Self {
|
|
tokens: capacity,
|
|
capacity,
|
|
refill_rate,
|
|
last_refill: Instant::now(),
|
|
}
|
|
}
|
|
|
|
fn refill(&mut self) {
|
|
let now = Instant::now();
|
|
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
|
|
let new_tokens = elapsed * self.refill_rate;
|
|
|
|
self.tokens = (self.tokens + new_tokens).min(self.capacity);
|
|
self.last_refill = now;
|
|
}
|
|
|
|
fn try_acquire(&mut self, tokens: f64) -> bool {
|
|
self.refill();
|
|
|
|
if self.tokens >= tokens {
|
|
self.tokens -= tokens;
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Circuit breaker for a provider
|
|
#[derive(Debug)]
|
|
pub struct ProviderCircuitBreaker {
|
|
state: CircuitState,
|
|
failure_count: u32,
|
|
success_count: u32,
|
|
last_failure_time: Option<std::time::Instant>,
|
|
last_state_change: std::time::Instant,
|
|
config: CircuitBreakerConfig,
|
|
}
|
|
|
|
impl ProviderCircuitBreaker {
|
|
pub fn new(config: CircuitBreakerConfig) -> Self {
|
|
Self {
|
|
state: CircuitState::Closed,
|
|
failure_count: 0,
|
|
success_count: 0,
|
|
last_failure_time: None,
|
|
last_state_change: std::time::Instant::now(),
|
|
config,
|
|
}
|
|
}
|
|
|
|
/// Check if request is allowed
|
|
pub fn allow_request(&mut self) -> bool {
|
|
match self.state {
|
|
CircuitState::Closed => true,
|
|
CircuitState::Open => {
|
|
// Check if reset timeout has passed
|
|
let elapsed = self.last_state_change.elapsed();
|
|
if elapsed.as_secs() >= self.config.reset_timeout_secs {
|
|
self.state = CircuitState::HalfOpen;
|
|
self.last_state_change = std::time::Instant::now();
|
|
info!("Circuit breaker transitioning to half-open state");
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
CircuitState::HalfOpen => true,
|
|
}
|
|
}
|
|
|
|
/// Record a successful request
|
|
pub fn record_success(&mut self) {
|
|
match self.state {
|
|
CircuitState::Closed => {
|
|
// Reset failure count on success
|
|
self.failure_count = 0;
|
|
self.last_failure_time = None;
|
|
}
|
|
CircuitState::HalfOpen => {
|
|
self.success_count += 1;
|
|
if self.success_count >= self.config.success_threshold {
|
|
self.state = CircuitState::Closed;
|
|
self.success_count = 0;
|
|
self.failure_count = 0;
|
|
self.last_state_change = std::time::Instant::now();
|
|
info!("Circuit breaker closed after successful requests");
|
|
}
|
|
}
|
|
CircuitState::Open => {
|
|
// Should not happen, but handle gracefully
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Record a failed request
|
|
pub fn record_failure(&mut self) {
|
|
let now = std::time::Instant::now();
|
|
|
|
// Check if failure window has expired
|
|
if let Some(last_failure) = self.last_failure_time
|
|
&& now.duration_since(last_failure).as_secs() > self.config.failure_window_secs
|
|
{
|
|
// Reset failure count if window expired
|
|
self.failure_count = 0;
|
|
}
|
|
|
|
self.failure_count += 1;
|
|
self.last_failure_time = Some(now);
|
|
|
|
if self.failure_count >= self.config.failure_threshold && self.state == CircuitState::Closed {
|
|
self.state = CircuitState::Open;
|
|
self.last_state_change = now;
|
|
warn!("Circuit breaker opened due to {} failures", self.failure_count);
|
|
} else if self.state == CircuitState::HalfOpen {
|
|
// Failure in half-open state, go back to open
|
|
self.state = CircuitState::Open;
|
|
self.success_count = 0;
|
|
self.last_state_change = now;
|
|
warn!("Circuit breaker re-opened after failure in half-open state");
|
|
}
|
|
}
|
|
|
|
/// Get current state
|
|
pub fn state(&self) -> CircuitState {
|
|
self.state
|
|
}
|
|
}
|
|
|
|
/// Rate limiting and circuit breaking manager
|
|
#[derive(Debug)]
|
|
pub struct RateLimitManager {
|
|
client_buckets: Arc<RwLock<HashMap<String, TokenBucket>>>,
|
|
global_bucket: Arc<RwLock<TokenBucket>>,
|
|
circuit_breakers: Arc<RwLock<HashMap<String, ProviderCircuitBreaker>>>,
|
|
config: RateLimiterConfig,
|
|
circuit_config: CircuitBreakerConfig,
|
|
}
|
|
|
|
impl RateLimitManager {
|
|
pub fn new(config: RateLimiterConfig, circuit_config: CircuitBreakerConfig) -> Self {
|
|
// Convert requests per minute to tokens per second
|
|
let global_refill_rate = config.global_requests_per_minute as f64 / 60.0;
|
|
|
|
Self {
|
|
client_buckets: Arc::new(RwLock::new(HashMap::new())),
|
|
global_bucket: Arc::new(RwLock::new(TokenBucket::new(
|
|
config.burst_size as f64,
|
|
global_refill_rate,
|
|
))),
|
|
circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
|
|
config,
|
|
circuit_config,
|
|
}
|
|
}
|
|
|
|
/// Check if a client request is allowed
|
|
pub async fn check_client_request(&self, client_id: &str) -> Result<bool> {
|
|
// Check global rate limit first (1 token per request)
|
|
{
|
|
let mut global_bucket = self.global_bucket.write().await;
|
|
if !global_bucket.try_acquire(1.0) {
|
|
warn!("Global rate limit exceeded");
|
|
return Ok(false);
|
|
}
|
|
}
|
|
|
|
// Check client-specific rate limit
|
|
let mut buckets = self.client_buckets.write().await;
|
|
let bucket = buckets.entry(client_id.to_string()).or_insert_with(|| {
|
|
TokenBucket::new(
|
|
self.config.burst_size as f64,
|
|
self.config.requests_per_minute as f64 / 60.0,
|
|
)
|
|
});
|
|
|
|
Ok(bucket.try_acquire(1.0))
|
|
}
|
|
|
|
/// Check if provider requests are allowed (circuit breaker)
|
|
pub async fn check_provider_request(&self, provider_name: &str) -> Result<bool> {
|
|
let mut breakers = self.circuit_breakers.write().await;
|
|
let breaker = breakers
|
|
.entry(provider_name.to_string())
|
|
.or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone()));
|
|
|
|
Ok(breaker.allow_request())
|
|
}
|
|
|
|
/// Record provider success
|
|
pub async fn record_provider_success(&self, provider_name: &str) {
|
|
let mut breakers = self.circuit_breakers.write().await;
|
|
if let Some(breaker) = breakers.get_mut(provider_name) {
|
|
breaker.record_success();
|
|
}
|
|
}
|
|
|
|
/// Record provider failure
|
|
pub async fn record_provider_failure(&self, provider_name: &str) {
|
|
let mut breakers = self.circuit_breakers.write().await;
|
|
let breaker = breakers
|
|
.entry(provider_name.to_string())
|
|
.or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone()));
|
|
|
|
breaker.record_failure();
|
|
}
|
|
|
|
/// Get provider circuit state
|
|
pub async fn get_provider_state(&self, provider_name: &str) -> CircuitState {
|
|
let breakers = self.circuit_breakers.read().await;
|
|
breakers
|
|
.get(provider_name)
|
|
.map(|b| b.state())
|
|
.unwrap_or(CircuitState::Closed)
|
|
}
|
|
}
|
|
|
|
/// Axum middleware for rate limiting
|
|
pub mod middleware {
|
|
use super::*;
|
|
use crate::errors::AppError;
|
|
use crate::state::AppState;
|
|
use axum::{
|
|
extract::{Request, State},
|
|
middleware::Next,
|
|
response::Response,
|
|
};
|
|
use sqlx;
|
|
|
|
/// Rate limiting middleware
|
|
pub async fn rate_limit_middleware(
|
|
State(state): State<AppState>,
|
|
request: Request,
|
|
next: Next,
|
|
) -> Result<Response, AppError> {
|
|
// Extract token synchronously from headers (avoids holding &Request across await)
|
|
let token = extract_bearer_token(&request);
|
|
|
|
// Resolve client_id: DB token lookup, then prefix fallback
|
|
let client_id = resolve_client_id(token, &state).await;
|
|
|
|
// Check rate limits
|
|
if !state.rate_limit_manager.check_client_request(&client_id).await? {
|
|
return Err(AppError::RateLimitError("Rate limit exceeded".to_string()));
|
|
}
|
|
|
|
Ok(next.run(request).await)
|
|
}
|
|
|
|
/// Synchronously extract bearer token from request headers
|
|
fn extract_bearer_token(request: &Request) -> Option<String> {
|
|
request.headers().get("Authorization")
|
|
.and_then(|v| v.to_str().ok())
|
|
.and_then(|s| s.strip_prefix("Bearer "))
|
|
.map(|t| t.to_string())
|
|
}
|
|
|
|
/// Resolve client ID: try DB token first, then fall back to token-prefix derivation
|
|
async fn resolve_client_id(token: Option<String>, state: &AppState) -> String {
|
|
if let Some(token) = token {
|
|
// Try DB token lookup first
|
|
if let Ok(Some(cid)) = sqlx::query_scalar::<_, String>(
|
|
"SELECT client_id FROM client_tokens WHERE token = ? AND is_active = TRUE",
|
|
)
|
|
.bind(&token)
|
|
.fetch_optional(&state.db_pool)
|
|
.await
|
|
{
|
|
return cid;
|
|
}
|
|
|
|
// Fallback to token-prefix derivation (env tokens / permissive mode)
|
|
return format!("client_{}", &token[..8.min(token.len())]);
|
|
}
|
|
|
|
// No token — anonymous
|
|
"anonymous".to_string()
|
|
}
|
|
|
|
/// Circuit breaker middleware for provider requests
|
|
pub async fn circuit_breaker_middleware(provider_name: &str, state: &AppState) -> Result<(), AppError> {
|
|
if !state.rate_limit_manager.check_provider_request(provider_name).await? {
|
|
return Err(AppError::ProviderError(format!(
|
|
"Provider {} is currently unavailable (circuit breaker open)",
|
|
provider_name
|
|
)));
|
|
}
|
|
Ok(())
|
|
}
|
|
}
|