//! Rate limiting and circuit breaking for LLM proxy //! //! This module provides: //! 1. Per-client rate limiting using governor crate //! 2. Provider circuit breaking to handle API failures //! 3. Global rate limiting for overall system protection use anyhow::Result; use std::collections::HashMap; use std::sync::Arc; use std::time::Instant; use tokio::sync::RwLock; use tracing::{info, warn}; /// Rate limiter configuration #[derive(Debug, Clone)] pub struct RateLimiterConfig { /// Requests per minute per client pub requests_per_minute: u32, /// Burst size (maximum burst capacity) pub burst_size: u32, /// Global requests per minute (across all clients) pub global_requests_per_minute: u32, } impl Default for RateLimiterConfig { fn default() -> Self { Self { requests_per_minute: 60, // 1 request per second per client burst_size: 10, // Allow bursts of up to 10 requests global_requests_per_minute: 600, // 10 requests per second globally } } } /// Circuit breaker state #[derive(Debug, Clone, Copy, PartialEq)] pub enum CircuitState { Closed, // Normal operation Open, // Circuit is open, requests fail fast HalfOpen, // Testing if service has recovered } /// Circuit breaker configuration #[derive(Debug, Clone)] pub struct CircuitBreakerConfig { /// Failure threshold to open circuit pub failure_threshold: u32, /// Time window for failure counting (seconds) pub failure_window_secs: u64, /// Time to wait before trying half-open state (seconds) pub reset_timeout_secs: u64, /// Success threshold to close circuit pub success_threshold: u32, } impl Default for CircuitBreakerConfig { fn default() -> Self { Self { failure_threshold: 5, // 5 failures failure_window_secs: 60, // within 60 seconds reset_timeout_secs: 30, // wait 30 seconds before half-open success_threshold: 3, // 3 successes to close circuit } } } /// Simple token bucket rate limiter for a single client #[derive(Debug)] struct TokenBucket { tokens: f64, capacity: f64, refill_rate: f64, // tokens per second last_refill: Instant, } impl TokenBucket { fn new(capacity: f64, refill_rate: f64) -> Self { Self { tokens: capacity, capacity, refill_rate, last_refill: Instant::now(), } } fn refill(&mut self) { let now = Instant::now(); let elapsed = now.duration_since(self.last_refill).as_secs_f64(); let new_tokens = elapsed * self.refill_rate; self.tokens = (self.tokens + new_tokens).min(self.capacity); self.last_refill = now; } fn try_acquire(&mut self, tokens: f64) -> bool { self.refill(); if self.tokens >= tokens { self.tokens -= tokens; true } else { false } } } /// Circuit breaker for a provider #[derive(Debug)] pub struct ProviderCircuitBreaker { state: CircuitState, failure_count: u32, success_count: u32, last_failure_time: Option, last_state_change: std::time::Instant, config: CircuitBreakerConfig, } impl ProviderCircuitBreaker { pub fn new(config: CircuitBreakerConfig) -> Self { Self { state: CircuitState::Closed, failure_count: 0, success_count: 0, last_failure_time: None, last_state_change: std::time::Instant::now(), config, } } /// Check if request is allowed pub fn allow_request(&mut self) -> bool { match self.state { CircuitState::Closed => true, CircuitState::Open => { // Check if reset timeout has passed let elapsed = self.last_state_change.elapsed(); if elapsed.as_secs() >= self.config.reset_timeout_secs { self.state = CircuitState::HalfOpen; self.last_state_change = std::time::Instant::now(); info!("Circuit breaker transitioning to half-open state"); true } else { false } } CircuitState::HalfOpen => true, } } /// Record a successful request pub fn record_success(&mut self) { match self.state { CircuitState::Closed => { // Reset failure count on success self.failure_count = 0; self.last_failure_time = None; } CircuitState::HalfOpen => { self.success_count += 1; if self.success_count >= self.config.success_threshold { self.state = CircuitState::Closed; self.success_count = 0; self.failure_count = 0; self.last_state_change = std::time::Instant::now(); info!("Circuit breaker closed after successful requests"); } } CircuitState::Open => { // Should not happen, but handle gracefully } } } /// Record a failed request pub fn record_failure(&mut self) { let now = std::time::Instant::now(); // Check if failure window has expired if let Some(last_failure) = self.last_failure_time && now.duration_since(last_failure).as_secs() > self.config.failure_window_secs { // Reset failure count if window expired self.failure_count = 0; } self.failure_count += 1; self.last_failure_time = Some(now); if self.failure_count >= self.config.failure_threshold && self.state == CircuitState::Closed { self.state = CircuitState::Open; self.last_state_change = now; warn!("Circuit breaker opened due to {} failures", self.failure_count); } else if self.state == CircuitState::HalfOpen { // Failure in half-open state, go back to open self.state = CircuitState::Open; self.success_count = 0; self.last_state_change = now; warn!("Circuit breaker re-opened after failure in half-open state"); } } /// Get current state pub fn state(&self) -> CircuitState { self.state } } /// Rate limiting and circuit breaking manager #[derive(Debug)] pub struct RateLimitManager { client_buckets: Arc>>, global_bucket: Arc>, circuit_breakers: Arc>>, config: RateLimiterConfig, circuit_config: CircuitBreakerConfig, } impl RateLimitManager { pub fn new(config: RateLimiterConfig, circuit_config: CircuitBreakerConfig) -> Self { // Convert requests per minute to tokens per second let global_refill_rate = config.global_requests_per_minute as f64 / 60.0; Self { client_buckets: Arc::new(RwLock::new(HashMap::new())), global_bucket: Arc::new(RwLock::new(TokenBucket::new( config.burst_size as f64, global_refill_rate, ))), circuit_breakers: Arc::new(RwLock::new(HashMap::new())), config, circuit_config, } } /// Check if a client request is allowed pub async fn check_client_request(&self, client_id: &str) -> Result { // Check global rate limit first (1 token per request) { let mut global_bucket = self.global_bucket.write().await; if !global_bucket.try_acquire(1.0) { warn!("Global rate limit exceeded"); return Ok(false); } } // Check client-specific rate limit let mut buckets = self.client_buckets.write().await; let bucket = buckets.entry(client_id.to_string()).or_insert_with(|| { TokenBucket::new( self.config.burst_size as f64, self.config.requests_per_minute as f64 / 60.0, ) }); Ok(bucket.try_acquire(1.0)) } /// Check if provider requests are allowed (circuit breaker) pub async fn check_provider_request(&self, provider_name: &str) -> Result { let mut breakers = self.circuit_breakers.write().await; let breaker = breakers .entry(provider_name.to_string()) .or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone())); Ok(breaker.allow_request()) } /// Record provider success pub async fn record_provider_success(&self, provider_name: &str) { let mut breakers = self.circuit_breakers.write().await; if let Some(breaker) = breakers.get_mut(provider_name) { breaker.record_success(); } } /// Record provider failure pub async fn record_provider_failure(&self, provider_name: &str) { let mut breakers = self.circuit_breakers.write().await; let breaker = breakers .entry(provider_name.to_string()) .or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone())); breaker.record_failure(); } /// Get provider circuit state pub async fn get_provider_state(&self, provider_name: &str) -> CircuitState { let breakers = self.circuit_breakers.read().await; breakers .get(provider_name) .map(|b| b.state()) .unwrap_or(CircuitState::Closed) } } /// Axum middleware for rate limiting pub mod middleware { use super::*; use crate::errors::AppError; use crate::state::AppState; use axum::{ extract::{Request, State}, middleware::Next, response::Response, }; use sqlx; /// Rate limiting middleware pub async fn rate_limit_middleware( State(state): State, request: Request, next: Next, ) -> Result { // Extract token synchronously from headers (avoids holding &Request across await) let token = extract_bearer_token(&request); // Resolve client_id: DB token lookup, then prefix fallback let client_id = resolve_client_id(token, &state).await; // Check rate limits if !state.rate_limit_manager.check_client_request(&client_id).await? { return Err(AppError::RateLimitError("Rate limit exceeded".to_string())); } Ok(next.run(request).await) } /// Synchronously extract bearer token from request headers fn extract_bearer_token(request: &Request) -> Option { request.headers().get("Authorization") .and_then(|v| v.to_str().ok()) .and_then(|s| s.strip_prefix("Bearer ")) .map(|t| t.to_string()) } /// Resolve client ID: try DB token first, then fall back to token-prefix derivation async fn resolve_client_id(token: Option, state: &AppState) -> String { if let Some(token) = token { // Try DB token lookup first if let Ok(Some(cid)) = sqlx::query_scalar::<_, String>( "SELECT client_id FROM client_tokens WHERE token = ? AND is_active = TRUE", ) .bind(&token) .fetch_optional(&state.db_pool) .await { return cid; } // Fallback to token-prefix derivation (env tokens / permissive mode) return format!("client_{}", &token[..8.min(token.len())]); } // No token — anonymous "anonymous".to_string() } /// Circuit breaker middleware for provider requests pub async fn circuit_breaker_middleware(provider_name: &str, state: &AppState) -> Result<(), AppError> { if !state.rate_limit_manager.check_provider_request(provider_name).await? { return Err(AppError::ProviderError(format!( "Provider {} is currently unavailable (circuit breaker open)", provider_name ))); } Ok(()) } }