chore: initial clean commit

This commit is contained in:
2026-02-26 13:56:21 -05:00
commit 1755075657
53 changed files with 18068 additions and 0 deletions

359
src/rate_limiting/mod.rs Normal file
View File

@@ -0,0 +1,359 @@
//! Rate limiting and circuit breaking for LLM proxy
//!
//! This module provides:
//! 1. Per-client rate limiting using governor crate
//! 2. Provider circuit breaking to handle API failures
//! 3. Global rate limiting for overall system protection
use std::sync::Arc;
use std::collections::HashMap;
use std::time::Instant;
use tokio::sync::RwLock;
use tracing::{info, warn};
use anyhow::Result;
/// Rate limiter configuration
#[derive(Debug, Clone)]
pub struct RateLimiterConfig {
/// Requests per minute per client
pub requests_per_minute: u32,
/// Burst size (maximum burst capacity)
pub burst_size: u32,
/// Global requests per minute (across all clients)
pub global_requests_per_minute: u32,
}
impl Default for RateLimiterConfig {
fn default() -> Self {
Self {
requests_per_minute: 60, // 1 request per second per client
burst_size: 10, // Allow bursts of up to 10 requests
global_requests_per_minute: 600, // 10 requests per second globally
}
}
}
/// Circuit breaker state
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CircuitState {
Closed, // Normal operation
Open, // Circuit is open, requests fail fast
HalfOpen, // Testing if service has recovered
}
/// Circuit breaker configuration
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
/// Failure threshold to open circuit
pub failure_threshold: u32,
/// Time window for failure counting (seconds)
pub failure_window_secs: u64,
/// Time to wait before trying half-open state (seconds)
pub reset_timeout_secs: u64,
/// Success threshold to close circuit
pub success_threshold: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5, // 5 failures
failure_window_secs: 60, // within 60 seconds
reset_timeout_secs: 30, // wait 30 seconds before half-open
success_threshold: 3, // 3 successes to close circuit
}
}
}
/// Simple token bucket rate limiter for a single client
#[derive(Debug)]
struct TokenBucket {
tokens: f64,
capacity: f64,
refill_rate: f64, // tokens per second
last_refill: Instant,
}
impl TokenBucket {
fn new(capacity: f64, refill_rate: f64) -> Self {
Self {
tokens: capacity,
capacity,
refill_rate,
last_refill: Instant::now(),
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
let new_tokens = elapsed * self.refill_rate;
self.tokens = (self.tokens + new_tokens).min(self.capacity);
self.last_refill = now;
}
fn try_acquire(&mut self, tokens: f64) -> bool {
self.refill();
if self.tokens >= tokens {
self.tokens -= tokens;
true
} else {
false
}
}
}
/// Circuit breaker for a provider
#[derive(Debug)]
pub struct ProviderCircuitBreaker {
state: CircuitState,
failure_count: u32,
success_count: u32,
last_failure_time: Option<std::time::Instant>,
last_state_change: std::time::Instant,
config: CircuitBreakerConfig,
}
impl ProviderCircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
state: CircuitState::Closed,
failure_count: 0,
success_count: 0,
last_failure_time: None,
last_state_change: std::time::Instant::now(),
config,
}
}
/// Check if request is allowed
pub fn allow_request(&mut self) -> bool {
match self.state {
CircuitState::Closed => true,
CircuitState::Open => {
// Check if reset timeout has passed
let elapsed = self.last_state_change.elapsed();
if elapsed.as_secs() >= self.config.reset_timeout_secs {
self.state = CircuitState::HalfOpen;
self.last_state_change = std::time::Instant::now();
info!("Circuit breaker transitioning to half-open state");
true
} else {
false
}
}
CircuitState::HalfOpen => true,
}
}
/// Record a successful request
pub fn record_success(&mut self) {
match self.state {
CircuitState::Closed => {
// Reset failure count on success
self.failure_count = 0;
self.last_failure_time = None;
}
CircuitState::HalfOpen => {
self.success_count += 1;
if self.success_count >= self.config.success_threshold {
self.state = CircuitState::Closed;
self.success_count = 0;
self.failure_count = 0;
self.last_state_change = std::time::Instant::now();
info!("Circuit breaker closed after successful requests");
}
}
CircuitState::Open => {
// Should not happen, but handle gracefully
}
}
}
/// Record a failed request
pub fn record_failure(&mut self) {
let now = std::time::Instant::now();
// Check if failure window has expired
if let Some(last_failure) = self.last_failure_time {
if now.duration_since(last_failure).as_secs() > self.config.failure_window_secs {
// Reset failure count if window expired
self.failure_count = 0;
}
}
self.failure_count += 1;
self.last_failure_time = Some(now);
if self.failure_count >= self.config.failure_threshold && self.state == CircuitState::Closed {
self.state = CircuitState::Open;
self.last_state_change = now;
warn!("Circuit breaker opened due to {} failures", self.failure_count);
} else if self.state == CircuitState::HalfOpen {
// Failure in half-open state, go back to open
self.state = CircuitState::Open;
self.success_count = 0;
self.last_state_change = now;
warn!("Circuit breaker re-opened after failure in half-open state");
}
}
/// Get current state
pub fn state(&self) -> CircuitState {
self.state
}
}
/// Rate limiting and circuit breaking manager
#[derive(Debug)]
pub struct RateLimitManager {
client_buckets: Arc<RwLock<HashMap<String, TokenBucket>>>,
global_bucket: Arc<RwLock<TokenBucket>>,
circuit_breakers: Arc<RwLock<HashMap<String, ProviderCircuitBreaker>>>,
config: RateLimiterConfig,
circuit_config: CircuitBreakerConfig,
}
impl RateLimitManager {
pub fn new(config: RateLimiterConfig, circuit_config: CircuitBreakerConfig) -> Self {
// Convert requests per minute to tokens per second
let global_refill_rate = config.global_requests_per_minute as f64 / 60.0;
Self {
client_buckets: Arc::new(RwLock::new(HashMap::new())),
global_bucket: Arc::new(RwLock::new(TokenBucket::new(
config.burst_size as f64,
global_refill_rate,
))),
circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
config,
circuit_config,
}
}
/// Check if a client request is allowed
pub async fn check_client_request(&self, client_id: &str) -> Result<bool> {
// Check global rate limit first (1 token per request)
{
let mut global_bucket = self.global_bucket.write().await;
if !global_bucket.try_acquire(1.0) {
warn!("Global rate limit exceeded");
return Ok(false);
}
}
// Check client-specific rate limit
let mut buckets = self.client_buckets.write().await;
let bucket = buckets
.entry(client_id.to_string())
.or_insert_with(|| {
TokenBucket::new(
self.config.burst_size as f64,
self.config.requests_per_minute as f64 / 60.0,
)
});
Ok(bucket.try_acquire(1.0))
}
/// Check if provider requests are allowed (circuit breaker)
pub async fn check_provider_request(&self, provider_name: &str) -> Result<bool> {
let mut breakers = self.circuit_breakers.write().await;
let breaker = breakers
.entry(provider_name.to_string())
.or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone()));
Ok(breaker.allow_request())
}
/// Record provider success
pub async fn record_provider_success(&self, provider_name: &str) {
let mut breakers = self.circuit_breakers.write().await;
if let Some(breaker) = breakers.get_mut(provider_name) {
breaker.record_success();
}
}
/// Record provider failure
pub async fn record_provider_failure(&self, provider_name: &str) {
let mut breakers = self.circuit_breakers.write().await;
let breaker = breakers
.entry(provider_name.to_string())
.or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone()));
breaker.record_failure();
}
/// Get provider circuit state
pub async fn get_provider_state(&self, provider_name: &str) -> CircuitState {
let breakers = self.circuit_breakers.read().await;
breakers
.get(provider_name)
.map(|b| b.state())
.unwrap_or(CircuitState::Closed)
}
}
/// Axum middleware for rate limiting
pub mod middleware {
use super::*;
use axum::{
extract::{Request, State},
middleware::Next,
response::Response,
};
use crate::errors::AppError;
use crate::state::AppState;
/// Rate limiting middleware
pub async fn rate_limit_middleware(
State(state): State<AppState>,
request: Request,
next: Next,
) -> Result<Response, AppError> {
// Extract client ID from authentication header
let client_id = extract_client_id_from_request(&request);
// Check rate limits
if !state.rate_limit_manager.check_client_request(&client_id).await? {
return Err(AppError::RateLimitError(
"Rate limit exceeded".to_string()
));
}
Ok(next.run(request).await)
}
/// Extract client ID from request (helper function)
fn extract_client_id_from_request(request: &Request) -> String {
// Try to extract from Authorization header
if let Some(auth_header) = request.headers().get("Authorization") {
if let Ok(auth_str) = auth_header.to_str() {
if auth_str.starts_with("Bearer ") {
let token = &auth_str[7..];
// Use token hash as client ID (same logic as auth module)
return format!("client_{}", &token[..8.min(token.len())]);
}
}
}
// Fallback to anonymous
"anonymous".to_string()
}
/// Circuit breaker middleware for provider requests
pub async fn circuit_breaker_middleware(
provider_name: &str,
state: &AppState,
) -> Result<(), AppError> {
if !state.rate_limit_manager.check_provider_request(provider_name).await? {
return Err(AppError::ProviderError(
format!("Provider {} is currently unavailable (circuit breaker open)", provider_name)
));
}
Ok(())
}
}