//! Rate limiting and circuit breaking for LLM proxy //! //! This module provides: //! 1. Per-client rate limiting using governor crate //! 2. Provider circuit breaking to handle API failures //! 3. Global rate limiting for overall system protection use anyhow::Result; use governor::{Quota, RateLimiter, DefaultDirectRateLimiter}; use std::collections::HashMap; use std::num::NonZeroU32; use std::sync::Arc; use tokio::sync::RwLock; use tracing::{info, warn}; type GovRateLimiter = DefaultDirectRateLimiter; /// Rate limiter configuration #[derive(Debug, Clone)] pub struct RateLimiterConfig { /// Requests per minute per client pub requests_per_minute: u32, /// Burst size (maximum burst capacity) pub burst_size: u32, /// Global requests per minute (across all clients) pub global_requests_per_minute: u32, } impl Default for RateLimiterConfig { fn default() -> Self { Self { requests_per_minute: 60, // 1 request per second per client burst_size: 10, // Allow bursts of up to 10 requests global_requests_per_minute: 600, // 10 requests per second globally } } } /// Circuit breaker state #[derive(Debug, Clone, Copy, PartialEq)] pub enum CircuitState { Closed, // Normal operation Open, // Circuit is open, requests fail fast HalfOpen, // Testing if service has recovered } /// Circuit breaker configuration #[derive(Debug, Clone)] pub struct CircuitBreakerConfig { /// Failure threshold to open circuit pub failure_threshold: u32, /// Time window for failure counting (seconds) pub failure_window_secs: u64, /// Time to wait before trying half-open state (seconds) pub reset_timeout_secs: u64, /// Success threshold to close circuit pub success_threshold: u32, } impl Default for CircuitBreakerConfig { fn default() -> Self { Self { failure_threshold: 5, // 5 failures failure_window_secs: 60, // within 60 seconds reset_timeout_secs: 30, // wait 30 seconds before half-open success_threshold: 3, // 3 successes to close circuit } } } /// Circuit breaker for a provider #[derive(Debug)] pub struct ProviderCircuitBreaker { state: CircuitState, failure_count: u32, success_count: u32, last_failure_time: Option, last_state_change: std::time::Instant, config: CircuitBreakerConfig, } impl ProviderCircuitBreaker { pub fn new(config: CircuitBreakerConfig) -> Self { Self { state: CircuitState::Closed, failure_count: 0, success_count: 0, last_failure_time: None, last_state_change: std::time::Instant::now(), config, } } /// Check if request is allowed pub fn allow_request(&mut self) -> bool { match self.state { CircuitState::Closed => true, CircuitState::Open => { // Check if reset timeout has passed let elapsed = self.last_state_change.elapsed(); if elapsed.as_secs() >= self.config.reset_timeout_secs { self.state = CircuitState::HalfOpen; self.last_state_change = std::time::Instant::now(); info!("Circuit breaker transitioning to half-open state"); true } else { false } } CircuitState::HalfOpen => true, } } /// Record a successful request pub fn record_success(&mut self) { match self.state { CircuitState::Closed => { // Reset failure count on success self.failure_count = 0; self.last_failure_time = None; } CircuitState::HalfOpen => { self.success_count += 1; if self.success_count >= self.config.success_threshold { self.state = CircuitState::Closed; self.success_count = 0; self.failure_count = 0; self.last_state_change = std::time::Instant::now(); info!("Circuit breaker closed after successful requests"); } } CircuitState::Open => { // Should not happen, but handle gracefully } } } /// Record a failed request pub fn record_failure(&mut self) { let now = std::time::Instant::now(); // Check if failure window has expired if let Some(last_failure) = self.last_failure_time && now.duration_since(last_failure).as_secs() > self.config.failure_window_secs { // Reset failure count if window expired self.failure_count = 0; } self.failure_count += 1; self.last_failure_time = Some(now); if self.failure_count >= self.config.failure_threshold && self.state == CircuitState::Closed { self.state = CircuitState::Open; self.last_state_change = now; warn!("Circuit breaker opened due to {} failures", self.failure_count); } else if self.state == CircuitState::HalfOpen { // Failure in half-open state, go back to open self.state = CircuitState::Open; self.success_count = 0; self.last_state_change = now; warn!("Circuit breaker re-opened after failure in half-open state"); } } /// Get current state pub fn state(&self) -> CircuitState { self.state } } /// Rate limiting and circuit breaking manager #[derive(Debug)] pub struct RateLimitManager { client_buckets: Arc>>, global_bucket: Arc, circuit_breakers: Arc>>, config: RateLimiterConfig, circuit_config: CircuitBreakerConfig, } impl RateLimitManager { pub fn new(config: RateLimiterConfig, circuit_config: CircuitBreakerConfig) -> Self { // Create global rate limiter quota // Use a much larger burst size for the global bucket to handle concurrent dashboard load let global_burst = config.global_requests_per_minute / 6; // e.g., 100 for 600 req/min let global_quota = Quota::per_minute( NonZeroU32::new(config.global_requests_per_minute).expect("global_requests_per_minute must be positive") ) .allow_burst(NonZeroU32::new(global_burst).expect("global_burst must be positive")); let global_bucket = RateLimiter::direct(global_quota); Self { client_buckets: Arc::new(RwLock::new(HashMap::new())), global_bucket: Arc::new(global_bucket), circuit_breakers: Arc::new(RwLock::new(HashMap::new())), config, circuit_config, } } /// Check if a client request is allowed pub async fn check_client_request(&self, client_id: &str) -> Result { // Check global rate limit first (1 token per request) if self.global_bucket.check().is_err() { warn!("Global rate limit exceeded"); return Ok(false); } // Check client-specific rate limit let mut buckets = self.client_buckets.write().await; let bucket = buckets.entry(client_id.to_string()).or_insert_with(|| { let quota = Quota::per_minute( NonZeroU32::new(self.config.requests_per_minute).expect("requests_per_minute must be positive") ) .allow_burst(NonZeroU32::new(self.config.burst_size).expect("burst_size must be positive")); RateLimiter::direct(quota) }); Ok(bucket.check().is_ok()) } /// Check if provider requests are allowed (circuit breaker) pub async fn check_provider_request(&self, provider_name: &str) -> Result { let mut breakers = self.circuit_breakers.write().await; let breaker = breakers .entry(provider_name.to_string()) .or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone())); Ok(breaker.allow_request()) } /// Record provider success pub async fn record_provider_success(&self, provider_name: &str) { let mut breakers = self.circuit_breakers.write().await; if let Some(breaker) = breakers.get_mut(provider_name) { breaker.record_success(); } } /// Record provider failure pub async fn record_provider_failure(&self, provider_name: &str) { let mut breakers = self.circuit_breakers.write().await; let breaker = breakers .entry(provider_name.to_string()) .or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone())); breaker.record_failure(); } /// Get provider circuit state pub async fn get_provider_state(&self, provider_name: &str) -> CircuitState { let breakers = self.circuit_breakers.read().await; breakers .get(provider_name) .map(|b| b.state()) .unwrap_or(CircuitState::Closed) } } /// Axum middleware for rate limiting pub mod middleware { use super::*; use crate::errors::AppError; use crate::state::AppState; use crate::auth::AuthInfo; use axum::{ extract::{Request, State}, middleware::Next, response::Response, }; use sqlx; /// Rate limiting middleware pub async fn rate_limit_middleware( State(state): State, mut request: Request, next: Next, ) -> Result { // Extract token synchronously from headers (avoids holding &Request across await) let token = extract_bearer_token(&request); // Resolve client_id and populate AuthInfo: DB token lookup, then prefix fallback let auth_info = resolve_auth_info(token, &state).await; let client_id = auth_info.client_id.clone(); // Check rate limits if !state.rate_limit_manager.check_client_request(&client_id).await? { return Err(AppError::RateLimitError("Rate limit exceeded".to_string())); } // Store AuthInfo in request extensions for extractors and downstream handlers request.extensions_mut().insert(auth_info); Ok(next.run(request).await) } /// Synchronously extract bearer token from request headers fn extract_bearer_token(request: &Request) -> Option { request.headers().get("Authorization") .and_then(|v| v.to_str().ok()) .and_then(|s| s.strip_prefix("Bearer ")) .map(|t| t.to_string()) } /// Resolve auth info: try DB token first, then fall back to token-prefix derivation async fn resolve_auth_info(token: Option, state: &AppState) -> AuthInfo { if let Some(token) = token { // Try DB token lookup first match sqlx::query_scalar::<_, String>( "UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = TRUE RETURNING client_id", ) .bind(&token) .fetch_optional(&state.db_pool) .await { Ok(Some(cid)) => { return AuthInfo { token, client_id: cid, }; } Err(e) => { warn!("DB error during token lookup: {}", e); } _ => {} } // Fallback to token-prefix derivation (env tokens / permissive mode) let client_id = format!("client_{}", &token[..8.min(token.len())]); return AuthInfo { token, client_id }; } // No token — anonymous AuthInfo { token: String::new(), client_id: "anonymous".to_string(), } } /// Circuit breaker middleware for provider requests pub async fn circuit_breaker_middleware(provider_name: &str, state: &AppState) -> Result<(), AppError> { if !state.rate_limit_manager.check_provider_request(provider_name).await? { return Err(AppError::ProviderError(format!( "Provider {} is currently unavailable (circuit breaker open)", provider_name ))); } Ok(()) } }