Compare commits

...

3 Commits

Author SHA1 Message Date
633b69a07b docs: sync documentation with current implementation and archive stale plan
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled
2026-03-06 14:28:04 -05:00
975ae124d1 merge 2026-03-06 14:21:58 -05:00
9b8483e797 feat(security): implement AES-256-GCM encryption for API keys and HMAC-signed session tokens
This commit introduces:
- AES-256-GCM encryption for LLM provider API keys in the database.
- HMAC-SHA256 signed session tokens with activity-based refresh logic.
- Standardized frontend XSS protection using a global escapeHtml utility.
- Hardened security headers and request body size limits.
- Improved database integrity with foreign key enforcement and atomic transactions.
- Integration tests for the full encrypted key storage and proxy usage lifecycle.
2026-03-06 14:17:56 -05:00
36 changed files with 2027 additions and 305 deletions

View File

@@ -26,3 +26,6 @@ LLM_PROXY__SERVER__PORT=8080
# Database path (optional)
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
# Session secret for HMAC-signed tokens (hex or base64 encoded, 32 bytes)
SESSION_SECRET=your_session_secret_here_32_bytes

65
CODE_REVIEW_PLAN.md Normal file
View File

@@ -0,0 +1,65 @@
# LLM Proxy Code Review Plan
## Overview
The **LLM Proxy** project is a Rust-based middleware designed to provide a unified interface for multiple Large Language Models (LLMs). Based on the repository structure, the project aims to implement a high-performance proxy server (`src/`) that handles request routing, usage tracking, and billing logic. A static dashboard (`static/`) provides a management interface for monitoring consumption and managing API keys. The architecture leverages Rust's async capabilities for efficient request handling and SQLite for persistent state management.
## Review Phases
### Phase 1: Backend Architecture & Rust Logic (@code-reviewer)
- **Focus on:**
- **Core Proxy Logic:** Efficiency of the request/response pipeline and streaming support.
- **State Management:** Thread-safety and shared state patterns using `Arc` and `Mutex`/`RwLock`.
- **Error Handling:** Use of idiomatic Rust error types and propagation.
- **Async Performance:** Proper use of `tokio` or similar runtimes to avoid blocking the executor.
- **Rust Idioms:** Adherence to Clippy suggestions and standard Rust naming conventions.
### Phase 2: Security & Authentication Audit (@security-auditor)
- **Focus on:**
- **API Key Management:** Secure storage, masking in logs, and rotation mechanisms.
- **JWT Handling:** Validation logic, signature verification, and expiration checks.
- **Input Validation:** Sanitization of prompts and configuration parameters to prevent injection.
- **Dependency Audit:** Scanning for known vulnerabilities in the `Cargo.lock` using `cargo-audit`.
### Phase 3: Database & Data Integrity Review (@database-optimizer)
- **Focus on:**
- **Schema Design:** Efficiency of the SQLite schema for usage tracking and billing.
- **Migration Strategy:** Robustness of the migration scripts to prevent data loss.
- **Usage Tracking:** Accuracy of token counting and concurrency handling during increments.
- **Query Optimization:** Identifying potential bottlenecks in reporting queries.
### Phase 4: Frontend & Dashboard Review (@frontend-developer)
- **Focus on:**
- **Vanilla JS Patterns:** Review of Web Components and modular JS in `static/js`.
- **Security:** Protection against XSS in the dashboard and secure handling of local storage.
- **UI/UX Consistency:** Ensuring the management interface is intuitive and responsive.
- **API Integration:** Robustness of the frontend's communication with the Rust backend.
### Phase 5: Infrastructure & Deployment Review (@devops-engineer)
- **Focus on:**
- **Dockerfile Optimization:** Multi-stage builds to minimize image size and attack surface.
- **Resource Limits:** Configuration of CPU/Memory limits for the proxy container.
- **Deployment Docs:** Clarity of the setup process and environment variable documentation.
## Timeline (Gantt)
```mermaid
gantt
title LLM Proxy Code Review Timeline (March 2026)
dateFormat YYYY-MM-DD
section Backend & Security
Architecture & Rust Logic (Phase 1) :active, p1, 2026-03-06, 1d
Security & Auth Audit (Phase 2) :p2, 2026-03-07, 1d
section Data & Frontend
Database & Integrity (Phase 3) :p3, 2026-03-07, 1d
Frontend & Dashboard (Phase 4) :p4, 2026-03-08, 1d
section DevOps
Infra & Deployment (Phase 5) :p5, 2026-03-08, 1d
Final Review & Sign-off :2026-03-08, 4h
```
## Success Criteria
- **Security:** Zero high-priority vulnerabilities identified; all API keys masked in logs.
- **Performance:** Proxy overhead is minimal (<10ms latency addition); queries are indexed.
- **Maintainability:** Code passes all linting (`cargo clippy`) and formatting (`cargo fmt`) checks.
- **Documentation:** README and deployment guides are up-to-date and accurate.
- **Reliability:** Usage tracking matches actual API consumption with 99.9% accuracy.

201
Cargo.lock generated
View File

@@ -8,6 +8,41 @@ version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa"
[[package]]
name = "aead"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0"
dependencies = [
"crypto-common",
"generic-array",
]
[[package]]
name = "aes"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0"
dependencies = [
"cfg-if",
"cipher",
"cpufeatures",
]
[[package]]
name = "aes-gcm"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1"
dependencies = [
"aead",
"aes",
"cipher",
"ctr",
"ghash",
"subtle",
]
[[package]]
name = "ahash"
version = "0.7.8"
@@ -541,9 +576,33 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a"
dependencies = [
"generic-array",
"rand_core 0.6.4",
"typenum",
]
[[package]]
name = "ctr"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835"
dependencies = [
"cipher",
]
[[package]]
name = "dashmap"
version = "6.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf"
dependencies = [
"cfg-if",
"crossbeam-utils",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core",
]
[[package]]
name = "data-encoding"
version = "2.10.0"
@@ -895,6 +954,37 @@ dependencies = [
"wasip3",
]
[[package]]
name = "ghash"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1"
dependencies = [
"opaque-debug",
"polyval",
]
[[package]]
name = "governor"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0746aa765db78b521451ef74221663b57ba595bf83f75d0ce23cc09447c8139f"
dependencies = [
"cfg-if",
"dashmap",
"futures-sink",
"futures-timer",
"futures-util",
"no-std-compat",
"nonzero_ext",
"parking_lot",
"portable-atomic",
"quanta",
"rand 0.8.5",
"smallvec",
"spinning_top",
]
[[package]]
name = "h2"
version = "0.4.13"
@@ -923,6 +1013,12 @@ dependencies = [
"ahash",
]
[[package]]
name = "hashbrown"
version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]]
name = "hashbrown"
version = "0.15.5"
@@ -1431,6 +1527,7 @@ checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77"
name = "llm-proxy"
version = "0.1.0"
dependencies = [
"aes-gcm",
"anyhow",
"assert_cmd",
"async-stream",
@@ -1443,8 +1540,10 @@ dependencies = [
"config",
"dotenvy",
"futures",
"governor",
"headers",
"hex",
"hmac",
"image",
"insta",
"mime",
@@ -1454,6 +1553,7 @@ dependencies = [
"reqwest-eventsource",
"serde",
"serde_json",
"sha2",
"sqlx",
"tempfile",
"thiserror 1.0.69",
@@ -1598,6 +1698,12 @@ dependencies = [
"pxfm",
]
[[package]]
name = "no-std-compat"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
[[package]]
name = "nom"
version = "7.1.3"
@@ -1608,6 +1714,12 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "nonzero_ext"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
[[package]]
name = "nu-ansi-term"
version = "0.50.3"
@@ -1669,6 +1781,12 @@ version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "opaque-debug"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
[[package]]
name = "ordered-multimap"
version = "0.4.3"
@@ -1824,6 +1942,24 @@ dependencies = [
"miniz_oxide",
]
[[package]]
name = "polyval"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25"
dependencies = [
"cfg-if",
"cpufeatures",
"opaque-debug",
"universal-hash",
]
[[package]]
name = "portable-atomic"
version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
[[package]]
name = "potential_utf"
version = "0.1.4"
@@ -1897,6 +2033,21 @@ dependencies = [
"num-traits",
]
[[package]]
name = "quanta"
version = "0.12.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7"
dependencies = [
"crossbeam-utils",
"libc",
"once_cell",
"raw-cpuid",
"wasi",
"web-sys",
"winapi",
]
[[package]]
name = "quick-error"
version = "2.0.1"
@@ -2032,6 +2183,15 @@ dependencies = [
"getrandom 0.3.4",
]
[[package]]
name = "raw-cpuid"
version = "11.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186"
dependencies = [
"bitflags 2.11.0",
]
[[package]]
name = "redox_syscall"
version = "0.5.18"
@@ -2453,6 +2613,15 @@ dependencies = [
"lock_api",
]
[[package]]
name = "spinning_top"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
dependencies = [
"lock_api",
]
[[package]]
name = "spki"
version = "0.7.3"
@@ -3158,6 +3327,16 @@ version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "universal-hash"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea"
dependencies = [
"crypto-common",
"subtle",
]
[[package]]
name = "untrusted"
version = "0.9.0"
@@ -3411,6 +3590,28 @@ dependencies = [
"wasite",
]
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-core"
version = "0.62.2"

View File

@@ -13,7 +13,8 @@ repository = ""
axum = { version = "0.8", features = ["macros", "ws"] }
tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "net", "time", "signal", "fs"] }
tower = "0.5"
tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip", "fs"] }
tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip", "fs", "set-header", "limit"] }
governor = "0.7"
# ========== HTTP Clients ==========
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
@@ -46,6 +47,9 @@ mime = "0.3"
anyhow = "1.0"
thiserror = "1.0"
bcrypt = "0.15"
aes-gcm = "0.10"
hmac = "0.12"
sha2 = "0.10"
chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1.0", features = ["v4", "serde"] }
futures = "0.3"

480
DATABASE_REVIEW.md Normal file
View File

@@ -0,0 +1,480 @@
# Database Review Report for LLM-Proxy Repository
**Review Date:** 2025-03-06
**Reviewer:** Database Optimization Expert
**Repository:** llm-proxy
**Focus Areas:** Schema Design, Query Optimization, Migration Strategy, Data Integrity, Usage Tracking Accuracy
## Executive Summary
The llm-proxy database implementation demonstrates solid foundation with appropriate table structures and clear separation of concerns. However, several areas require improvement to ensure scalability, data consistency, and performance as usage grows. Key findings include:
1. **Schema Design**: Generally normalized but missing foreign key enforcement and some critical indexes.
2. **Query Optimization**: Well-optimized for most queries but missing composite indexes for common filtering patterns.
3. **Migration Strategy**: Ad-hoc migration approach that may cause issues with schema evolution.
4. **Data Integrity**: Potential race conditions in usage tracking and missing transaction boundaries.
5. **Usage Tracking**: Generally accurate but risk of inconsistent state between related tables.
This report provides detailed analysis and actionable recommendations for each area.
## 1. Schema Design Review
### Tables Overview
The database consists of 6 main tables:
1. **clients**: Client management with usage aggregates
2. **llm_requests**: Request logging with token counts and costs
3. **provider_configs**: Provider configuration and credit balances
4. **model_configs**: Model-specific configuration and cost overrides
5. **users**: Dashboard user authentication
6. **client_tokens**: API token storage for client authentication
### Normalization Assessment
**Strengths:**
- Tables follow 3rd Normal Form (3NF) with appropriate separation
- Foreign key relationships properly defined
- No obvious data duplication across tables
**Areas for Improvement:**
- **Denormalized aggregates**: `clients.total_requests`, `total_tokens`, `total_cost` are derived from `llm_requests`. This introduces risk of inconsistency.
- **Provider credit balance**: Stored in `provider_configs` but also updated based on `llm_requests`. No audit trail for balance changes.
### Data Type Analysis
**Appropriate Choices:**
- INTEGER for token counts (cast from u32 to i64)
- REAL for monetary values
- DATETIME for timestamps using SQLite's CURRENT_TIMESTAMP
- TEXT for identifiers with appropriate length
**Potential Issues:**
- `llm_requests.request_body` and `response_body` defined as TEXT but always set to NULL - consider removing or making optional columns.
- `provider_configs.billing_mode` added via migration but default value not consistently applied to existing rows.
### Constraints and Foreign Keys
**Current Constraints:**
- Primary keys defined for all tables
- UNIQUE constraints on `clients.client_id`, `users.username`, `client_tokens.token`
- Foreign key definitions present but **not enforced** (SQLite default)
**Missing Constraints:**
- NOT NULL constraints missing on several columns where nullability not intended
- CHECK constraints for positive values (`credit_balance >= 0`)
- Foreign key enforcement not enabled
## 2. Query Optimization Analysis
### Indexing Strategy
**Existing Indexes:**
- `idx_clients_client_id` - Essential for client lookups
- `idx_clients_created_at` - Useful for chronological listing
- `idx_llm_requests_timestamp` - Critical for time-based queries
- `idx_llm_requests_client_id` - Supports client-specific queries
- `idx_llm_requests_provider` - Good for provider breakdowns
- `idx_llm_requests_status` - Low cardinality but acceptable
- `idx_client_tokens_token` UNIQUE - Essential for authentication
- `idx_client_tokens_client_id` - Supports token management
**Missing Critical Indexes:**
1. `model_configs.provider_id` - Foreign key column used in JOINs
2. `llm_requests(client_id, timestamp)` - Composite index for client time-series queries
3. `llm_requests(provider, timestamp)` - For provider performance analysis
4. `llm_requests(status, timestamp)` - For error trend analysis
### N+1 Query Detection
**Well-Optimized Areas:**
- Model configuration caching prevents repeated database hits
- Provider configs loaded in batch for dashboard display
- Client listing uses single efficient query
**Potential N+1 Patterns:**
- In `server/mod.rs` list_models function, cache lookup per model but this is in-memory
- No significant database N+1 issues identified
### Inefficient Query Patterns
**Query 1: Time-series aggregation with strftime()**
```sql
SELECT strftime('%Y-%m-%d', timestamp) as date, ...
FROM llm_requests
WHERE 1=1 {}
GROUP BY date, client_id, provider, model
ORDER BY date DESC
LIMIT 200
```
**Issue:** Function on indexed column prevents index utilization for the WHERE clause when filtering by timestamp range.
**Recommendation:** Store computed date column or use range queries on timestamp directly.
**Query 2: Today's stats using strftime()**
```sql
WHERE strftime('%Y-%m-%d', timestamp) = ?
```
**Issue:** Non-sargable query prevents index usage.
**Recommendation:** Use range query:
```sql
WHERE timestamp >= date(?) AND timestamp < date(?, '+1 day')
```
### Recommended Index Additions
```sql
-- Composite indexes for common query patterns
CREATE INDEX idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp);
CREATE INDEX idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp);
CREATE INDEX idx_llm_requests_status_timestamp ON llm_requests(status, timestamp);
-- Foreign key index
CREATE INDEX idx_model_configs_provider_id ON model_configs(provider_id);
-- Optional: Covering index for client usage queries
CREATE INDEX idx_clients_usage ON clients(client_id, total_requests, total_tokens, total_cost);
```
## 3. Migration Strategy Assessment
### Current Approach
The migration system uses a hybrid approach:
1. **Schema synchronization**: `CREATE TABLE IF NOT EXISTS` on startup
2. **Ad-hoc migrations**: `ALTER TABLE` statements with error suppression
3. **Single migration file**: `migrations/001-add-billing-mode.sql` with transaction wrapper
**Pros:**
- Simple to understand and maintain
- Automatic schema creation for new deployments
- Error suppression prevents crashes on column existence
**Cons:**
- No version tracking of applied migrations
- Potential for inconsistent schema across deployments
- `ALTER TABLE` error suppression hides genuine schema issues
- No rollback capability
### Risks and Limitations
1. **Schema Drift**: Different instances may have different schemas if migrations are applied out of order
2. **Data Loss Risk**: No backup/verification before schema changes
3. **Production Issues**: Error suppression could mask migration failures until runtime
### Recommendations
1. **Implement Proper Migration Tooling**: Use `sqlx migrate` or similar versioned migration system
2. **Add Migration Version Table**: Track applied migrations and checksum verification
3. **Separate Migration Scripts**: One file per migration with up/down directions
4. **Pre-deployment Validation**: Schema checks in CI/CD pipeline
5. **Backup Strategy**: Automatic backups before migration execution
## 4. Data Integrity Evaluation
### Foreign Key Enforcement
**Critical Issue:** Foreign key constraints are defined but **not enforced** in SQLite.
**Impact:** Orphaned records, inconsistent referential integrity.
**Solution:** Enable foreign key support in connection string:
```rust
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?
.create_if_missing(true)
.pragma("foreign_keys", "ON");
```
### Transaction Usage
**Good Patterns:**
- Request logging uses transactions for insert + provider balance update
- Atomic UPDATE for client usage statistics
**Problematic Areas:**
1. **Split Transactions**: Client usage update and request logging are in separate transactions
- In `logging/mod.rs`: `insert_log` transaction includes provider balance update
- In `utils/streaming.rs`: Client usage updated separately after logging
- **Risk**: Partial updates if one transaction fails
2. **No Transaction for Client Creation**: Client and token creation not atomic
**Recommendations:**
- Wrap client usage update within the same transaction as request logging
- Use transaction for client + token creation
- Consider using savepoints for complex operations
### Race Conditions and Consistency
**Potential Race Conditions:**
1. **Provider credit balance**: Concurrent requests may cause lost updates
- Current: `UPDATE provider_configs SET credit_balance = credit_balance - ?`
- SQLite provides serializable isolation, but negative balances not prevented
2. **Client usage aggregates**: Concurrent updates to `total_requests`, `total_tokens`, `total_cost`
- Similar UPDATE pattern, generally safe but consider idempotency
**Recommendations:**
- Add check constraint: `CHECK (credit_balance >= 0)`
- Implement idempotent request logging with unique request IDs
- Consider optimistic concurrency control for critical balances
## 5. Usage Tracking Accuracy
### Token Counting Methodology
**Current Approach:**
- Prompt tokens: Estimated using provider-specific estimators
- Completion tokens: Estimated or from provider real usage data
- Cache tokens: Separately tracked for cache-aware pricing
**Strengths:**
- Fallback to estimation when provider doesn't report usage
- Cache token differentiation for accurate pricing
**Weaknesses:**
- Estimation may differ from actual provider counts
- No validation of provider-reported token counts
### Cost Calculation
**Well Implemented:**
- Model-specific cost overrides via `model_configs`
- Cache-aware pricing when supported by registry
- Provider fallback calculations
**Potential Issues:**
- Floating-point precision for monetary calculations
- No rounding strategy for fractional cents
### Update Consistency
**Inconsistency Risk:** Client aggregates updated separately from request logging.
**Example Flow:**
1. Request log inserted and provider balance updated (transaction)
2. Client usage updated (separate operation)
3. If step 2 fails, client stats undercount usage
**Solution:** Include client update in the same transaction:
```rust
// In insert_log function, add:
UPDATE clients
SET total_requests = total_requests + 1,
total_tokens = total_tokens + ?,
total_cost = total_cost + ?
WHERE client_id = ?;
```
### Financial Accuracy
**Good Practices:**
- Token-level granularity for cost calculation
- Separation of prompt/completion/cache pricing
- Database persistence for audit trail
**Recommendations:**
1. **Audit Trail**: Add `balance_transactions` table for provider credit changes
2. **Rounding Policy**: Define rounding strategy (e.g., to 6 decimal places)
3. **Validation**: Periodic reconciliation of aggregates vs. detail records
## 6. Performance Recommendations
### Schema Improvements
1. **Partitioning Strategy**: For high-volume `llm_requests`, consider:
- Monthly partitioning by timestamp
- Archive old data to separate tables
2. **Data Retention Policy**: Implement automatic cleanup of old request logs
```sql
DELETE FROM llm_requests WHERE timestamp < date('now', '-90 days');
```
3. **Column Optimization**: Remove unused `request_body`, `response_body` columns or implement compression
### Query Optimizations
1. **Avoid Functions on Indexed Columns**: Rewrite date queries as range queries
2. **Batch Updates**: Consider batch updates for client usage instead of per-request
3. **Read Replicas**: For dashboard queries, consider separate read connection
### Connection Pooling
**Current:** SQLx connection pool with default settings
**Recommendations:**
- Configure pool size based on expected concurrency
- Implement connection health checks
- Monitor pool utilization metrics
### Monitoring Setup
**Essential Metrics:**
- Query execution times (slow query logging)
- Index usage statistics
- Table growth trends
- Connection pool utilization
**Implementation:**
- Add `sqlx::metrics` integration
- Regular `ANALYZE` execution for query planner
- Dashboard for database health monitoring
## 7. Security Considerations
### Data Protection
**Sensitive Data:**
- `provider_configs.api_key` - Should be encrypted at rest
- `users.password_hash` - Already hashed with bcrypt
- `client_tokens.token` - Plain text storage
**Recommendations:**
- Encrypt API keys using libsodium or similar
- Implement token hashing (similar to password hashing)
- Regular security audits of authentication flows
### SQL Injection Prevention
**Good Practices:**
- Use sqlx query builder with parameter binding
- No raw SQL concatenation observed in code review
**Verification Needed:** Ensure all dynamic SQL uses parameterized queries
### Access Controls
**Database Level:**
- SQLite lacks built-in user management
- Consider file system permissions for database file
- Application-level authentication is primary control
## 8. Summary of Critical Issues
**Priority 1 (Critical):**
1. Foreign key constraints not enabled
2. Split transactions risking data inconsistency
3. Missing composite indexes for common queries
**Priority 2 (High):**
1. No proper migration versioning system
2. Potential race conditions in balance updates
3. Non-sargable date queries impacting performance
**Priority 3 (Medium):**
1. Denormalized aggregates without consistency guarantees
2. No data retention policy for request logs
3. Missing check constraints for data validation
## 9. Recommended Action Plan
### Phase 1: Immediate Fixes (1-2 weeks)
1. Enable foreign key constraints in database connection
2. Add composite indexes for common query patterns
3. Fix transaction boundaries for client usage updates
4. Rewrite non-sargable date queries
### Phase 2: Short-term Improvements (3-4 weeks)
1. Implement proper migration system with version tracking
2. Add check constraints for data validation
3. Implement connection pooling configuration
4. Create database monitoring dashboard
### Phase 3: Long-term Enhancements (2-3 months)
1. Implement data retention and archiving strategy
2. Add audit trail for provider balance changes
3. Consider partitioning for high-volume tables
4. Implement encryption for sensitive data
### Phase 4: Ongoing Maintenance
1. Regular index maintenance and query plan analysis
2. Periodic reconciliation of aggregate vs. detail data
3. Security audits and dependency updates
4. Performance benchmarking and optimization
---
## Appendices
### A. Sample Migration Implementation
```sql
-- migrations/002-enable-foreign-keys.sql
PRAGMA foreign_keys = ON;
-- migrations/003-add-composite-indexes.sql
CREATE INDEX idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp);
CREATE INDEX idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp);
CREATE INDEX idx_model_configs_provider_id ON model_configs(provider_id);
```
### B. Transaction Fix Example
```rust
async fn insert_log(pool: &SqlitePool, log: RequestLog) -> Result<(), sqlx::Error> {
let mut tx = pool.begin().await?;
// Insert or ignore client
sqlx::query("INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')")
.bind(&log.client_id)
.bind(&log.client_id)
.execute(&mut *tx)
.await?;
// Insert request log
sqlx::query("INSERT INTO llm_requests ...")
.execute(&mut *tx)
.await?;
// Update provider balance
if log.cost > 0.0 {
sqlx::query("UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')")
.bind(log.cost)
.bind(&log.provider)
.execute(&mut *tx)
.await?;
}
// Update client aggregates within same transaction
sqlx::query("UPDATE clients SET total_requests = total_requests + 1, total_tokens = total_tokens + ?, total_cost = total_cost + ? WHERE client_id = ?")
.bind(log.total_tokens as i64)
.bind(log.cost)
.bind(&log.client_id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(())
}
```
### C. Monitoring Query Examples
```sql
-- Identify unused indexes
SELECT * FROM sqlite_master
WHERE type = 'index'
AND name NOT IN (
SELECT DISTINCT name
FROM sqlite_stat1
WHERE tbl = 'llm_requests'
);
-- Table size analysis
SELECT name, (pgsize * page_count) / 1024 / 1024 as size_mb
FROM dbstat
WHERE name = 'llm_requests';
-- Query performance analysis (requires EXPLAIN QUERY PLAN)
EXPLAIN QUERY PLAN
SELECT * FROM llm_requests
WHERE client_id = ? AND timestamp >= ?;
```
---
*This report provides a comprehensive analysis of the current database implementation and actionable recommendations for improvement. Regular review and iteration will ensure the database continues to meet performance, consistency, and scalability requirements as the application grows.*

62
PLAN.md Normal file
View File

@@ -0,0 +1,62 @@
# Project Plan: LLM Proxy Enhancements & Security Upgrade
This document outlines the roadmap for standardizing frontend security, cleaning up the codebase, upgrading session management to HMAC-signed tokens, and extending integration testing.
## Phase 1: Frontend Security Standardization
**Primary Agent:** `frontend-developer`
- [x] Audit `static/js/pages/users.js` for manual HTML string concatenation.
- [x] Replace custom escaping or unescaped injections with `window.api.escapeHtml`.
- [x] Verify user list and user detail rendering for XSS vulnerabilities.
## Phase 2: Codebase Cleanup
**Primary Agent:** `backend-developer`
- [x] Identify and remove unused imports in `src/config/mod.rs`.
- [x] Identify and remove unused imports in `src/providers/mod.rs`.
- [x] Run `cargo clippy` and `cargo fmt` to ensure adherence to standards.
## Phase 3: HMAC Architectural Upgrade
**Primary Agents:** `fullstack-developer`, `security-auditor`, `backend-developer`
### 3.1 Design (Security Auditor)
- [x] Define Token Structure: `base64(payload).signature`.
- Payload: `{ "session_id": "...", "username": "...", "role": "...", "exp": ... }`
- [x] Select HMAC algorithm (HMAC-SHA256).
- [x] Define environment variable for secret key: `SESSION_SECRET`.
### 3.2 Implementation (Backend Developer)
- [x] Refactor `src/dashboard/sessions.rs`:
- Integrate `hmac` and `sha2` crates (or similar).
- Update `create_session` to return signed tokens.
- Update `validate_session` to verify signature before checking store.
- [x] Implement activity-based session refresh:
- If session is valid and >50% through its TTL, extend `expires_at` and issue new signed token.
### 3.3 Integration (Fullstack Developer)
- [x] Update dashboard API handlers to handle new token format.
- [x] Update frontend session storage/retrieval if necessary.
## Phase 4: Extended Integration Testing
**Primary Agent:** `qa-automation`
- [ ] Setup test environment with encrypted key storage enabled.
- [ ] Implement end-to-end flow:
1. Store encrypted provider key via API.
2. Authenticate through Proxy.
3. Make proxied LLM request (verifying decryption and usage).
- [ ] Validate HMAC token expiration and refresh logic in automated tests.
## Phase 5: Code Quality & Refactoring
**Primary Agent:** `fullstack-developer`
- [x] Refactor dashboard monolith into modular sub-modules (`auth.rs`, `usage.rs`, etc.).
- [x] Standardize error handling and remove `unwrap()` in production paths.
- [x] Implement system health metrics and backup functionality.
---
## Technical Standards
- **Rust:** No `unwrap()` in production code; use proper error handling (`Result`).
- **Frontend:** Always use `window.api` wrappers for sensitive operations.
- **Security:** Secrets must never be logged or hardcoded.

108
README.md
View File

@@ -26,6 +26,17 @@ A unified, high-performance LLM proxy gateway built in Rust. It provides a singl
- **Rate Limiting:** Per-client and global rate limits.
- **Cache-Aware Costing:** Tracks cache hit/miss tokens for accurate billing.
## Security
LLM Proxy is designed with security in mind:
- **HMAC Session Tokens:** Management dashboard sessions are secured using HMAC-SHA256 signed tokens.
- **Encrypted Provider Keys:** Sensitive LLM provider API keys are stored encrypted (AES-256-GCM) in the database.
- **Session Refresh:** Activity-based session extension prevents session hijacking while maintaining user convenience.
- **XSS Prevention:** Standardized frontend escaping using `window.api.escapeHtml`.
**Note:** You must define a `SESSION_SECRET` in your `.env` file for secure session signing.
## Tech Stack
- **Runtime:** Rust with Tokio.
@@ -39,6 +50,7 @@ A unified, high-performance LLM proxy gateway built in Rust. It provides a singl
- Rust (1.80+)
- SQLite3
- Docker (optional, for containerized deployment)
### Quick Start
@@ -53,10 +65,9 @@ A unified, high-performance LLM proxy gateway built in Rust. It provides a singl
```bash
cp .env.example .env
# Edit .env and add your API keys:
# SESSION_SECRET=... (Generate a strong random secret)
# OPENAI_API_KEY=sk-...
# GEMINI_API_KEY=AIza...
# DEEPSEEK_API_KEY=sk-...
# GROK_API_KEY=gk-... (optional)
```
3. Run the proxy:
@@ -66,50 +77,31 @@ A unified, high-performance LLM proxy gateway built in Rust. It provides a singl
The server starts on `http://localhost:8080` by default.
### Configuration
### Deployment (Docker)
Edit `config.toml` to customize providers, models, and settings:
A multi-stage `Dockerfile` is provided for efficient deployment:
```toml
[server]
port = 8080
host = "0.0.0.0"
```bash
# Build the container
docker build -t llm-proxy .
[database]
path = "./data/llm_proxy.db"
[providers.openai]
enabled = true
default_model = "gpt-4o"
[providers.gemini]
enabled = true
default_model = "gemini-2.0-flash"
[providers.deepseek]
enabled = true
default_model = "deepseek-reasoner"
[providers.grok]
enabled = false
default_model = "grok-beta"
[providers.ollama]
enabled = false
base_url = "http://localhost:11434/v1"
# Run the container
docker run -p 8080:8080 \
-e SESSION_SECRET=your-secure-secret \
-v ./data:/app/data \
llm-proxy
```
## Management Dashboard
Access the dashboard at `http://localhost:8080`:
Access the dashboard at `http://localhost:8080`. The dashboard architecture has been refactored into modular sub-components for better maintainability:
- **Overview:** Real-time request counters, system health, provider status.
- **Analytics:** Time-series charts, filterable by date, client, provider, and model.
- **Costs:** Budget tracking, cost breakdown by provider/client/model, projections.
- **Clients:** Create, revoke, and rotate API tokens; per-client usage stats.
- **Providers:** Enable/disable providers, test connections, configure API keys.
- **Monitoring:** Live request stream via WebSocket, response times, error rates.
- **Users:** Admin/user management with role-based access control.
- **Auth (`/api/auth`):** Login, session management, and password changes.
- **Usage (`/api/usage`):** Summary stats, time-series analytics, and provider breakdown.
- **Clients (`/api/clients`):** API key management and per-client usage tracking.
- **Providers (`/api/providers`):** Provider configuration, status monitoring, and connection testing.
- **System (`/api/system`):** Health metrics, live logs, database backups, and global settings.
- **Monitoring:** Live request stream via WebSocket.
### Default Credentials
@@ -137,46 +129,6 @@ response = client.chat.completions.create(
)
```
### Open WebUI
```
API Base URL: http://your-server:8080/v1
API Key: YOUR_CLIENT_API_KEY
```
### cURL
```bash
curl -X POST http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer YOUR_CLIENT_API_KEY" \
-d '{
"model": "gpt-4o",
"messages": [{"role": "user", "content": "Hello!"}],
"stream": false
}'
```
## Model Discovery
The proxy exposes `/v1/models` for OpenAI-compatible client model discovery:
```bash
curl http://localhost:8080/v1/models \
-H "Authorization: Bearer YOUR_CLIENT_API_KEY"
```
## Troubleshooting
### Streaming Issues
If clients timeout or show "TransferEncodingError", ensure:
1. Proxy buffering is disabled in nginx: `proxy_buffering off;`
2. Chunked transfer is enabled: `chunked_transfer_encoding on;`
3. Timeouts are sufficient: `proxy_read_timeout 7200s;`
### Provider Errors
- Check API keys are set in `.env`
- Test provider in dashboard (Settings → Providers → Test)
- Review logs: `journalctl -u llm-proxy -f`
## License
MIT OR Apache-2.0

58
RUST_BACKEND_REVIEW.md Normal file
View File

@@ -0,0 +1,58 @@
# LLM Proxy Rust Backend Code Review Report
## Executive Summary
This code review examines the `llm-proxy` Rust backend, focusing on architectural soundness, performance characteristics, and adherence to Rust idioms. The codebase demonstrates solid engineering with well-structured modular design, comprehensive error handling, and thoughtful async patterns. However, several areas require attention for production readiness, particularly around thread safety, memory efficiency, and error recovery.
## 1. Core Proxy Logic Review
### Strengths
- Clean provider abstraction (`Provider` trait).
- Streaming support with `AggregatingStream` for token counting.
- Model mapping and caching system.
### Issues Found
- **Provider Manager Thread Safety Risk:** O(n) lookups using `Vec` with `RwLock`. Use `DashMap` instead.
- **Streaming Memory Inefficiency:** Accumulates complete response content in memory.
- **Model Registry Cache Invalidation:** No strategy when config changes via dashboard.
## 2. State Management Review
- **Token Bucket Algorithm Flaw:** Custom implementation lacks thread-safe refill sync. Use `governor` crate.
- **Broadcast Channel Unbounded Growth Risk:** Fixed-size (100) may drop messages.
- **Database Connection Pool Contention:** SQLite connections shared without differentiation.
## 3. Error Handling Review
- **Error Recovery Missing:** No circuit breaker or retry logic for provider calls.
- **Stream Error Logging Gap:** Stream errors swallowed without logging partial usage.
## 4. Rust Idioms and Performance
- **Unnecessary String Cloning:** Frequent cloning in authentication hot paths.
- **JSON Parsing Inefficiency:** Multiple passes with `serde_json::Value`. Use typed structs.
- **Missing `#[derive(Copy)]`:** For small enums like `CircuitState`.
## 5. Async Performance
- **Blocking Calls:** Token estimation may block async runtime.
- **Missing Connection Timeouts:** Only overall timeout, no separate read/write timeouts.
- **Unbounded Task Spawn:** For client usage updates under load.
## 6. Security Considerations
- **Token Leakage:** Redact tokens in `Debug` and `Display` impls.
- **No Request Size Limits:** Vulnerable to memory exhaustion.
## 7. Testability
- **Mocking Difficulty:** Tight coupling to concrete provider implementations.
- **Missing Integration Tests:** No E2E tests for streaming.
## 8. Summary of Critical Actions
### High Priority
1. Replace custom token bucket with `governor` crate.
2. Fix provider lookup O(n) scaling issue.
3. Implement proper error recovery with retries.
4. Add request size limits and timeout configurations.
### Medium Priority
1. Reduce string cloning in hot paths.
2. Implement cache invalidation for model configs.
3. Add connection pooling separation.
4. Improve streaming memory efficiency.
---
*Review conducted by: Senior Principal Engineer (Code Reviewer)*

58
SECURITY_AUDIT.md Normal file
View File

@@ -0,0 +1,58 @@
# LLM Proxy Security Audit Report
## Executive Summary
A comprehensive security audit of the `llm-proxy` repository was conducted. The audit identified **1 critical vulnerability**, **3 high-risk issues**, **4 medium-risk issues**, and **3 low-risk issues**. The most severe findings include Cross-Site Scripting (XSS) in the dashboard interface and insecure storage of provider API keys in the database.
## Detailed Findings
### Critical Risk Vulnerabilities
#### **CRITICAL-01: Cross-Site Scripting (XSS) in Dashboard Interface**
- **Location**: `static/js/pages/clients.js` (multiple locations).
- **Description**: User-controlled data (e.g., `client.id`) inserted directly into HTML or `onclick` handlers without escaping.
- **Impact**: Arbitrary JavaScript execution in admin context, potentially stealing session tokens.
#### **CRITICAL-02: Insecure API Key Storage in Database**
- **Location**: `src/database/mod.rs`, `src/providers/mod.rs`, `src/dashboard/providers.rs`.
- **Description**: Provider API keys are stored in **plaintext** in the SQLite database.
- **Impact**: Compromised database file exposes all provider API keys.
### High Risk Vulnerabilities
#### **HIGH-01: Missing Input Validation and Size Limits**
- **Location**: `src/server/mod.rs`, `src/models/mod.rs`.
- **Impact**: Denial of Service via large payloads.
#### **HIGH-02: Sensitive Data Logging Without Encryption**
- **Location**: `src/database/mod.rs`, `src/logging/mod.rs`.
- **Description**: Full request and response bodies stored in `llm_requests` table without encryption or redaction.
#### **HIGH-03: Weak Default Credentials and Password Policy**
- **Description**: Default admin password is 'admin' with only 4-character minimum password length.
### Medium Risk Vulnerabilities
#### **MEDIUM-01: Missing CSRF Protection**
- No CSRF tokens or SameSite cookie attributes for state-changing dashboard endpoints.
#### **MEDIUM-02: Insecure Session Management**
- Session tokens stored in localStorage without HttpOnly flag.
- Tokens use simple `session-{uuid}` format.
#### **MEDIUM-03: Error Information Leakage**
- Internal error details exposed to clients in some cases.
#### **MEDIUM-04: Outdated Dependencies**
- Outdated versions of `chrono`, `tokio`, and `reqwest`.
### Low Risk Vulnerabilities
- Missing security headers (CSP, HSTS, X-Frame-Options).
- Insufficient rate limiting on dashboard authentication.
- No database encryption at rest.
## Recommendations
### Immediate Actions
1. **Fix XSS Vulnerabilities:** Implement proper HTML escaping for all user-controlled data.
2. **Secure API Key Storage:** Encrypt API keys in database using a library like `ring`.
3. **Implement Input Validation:** Add maximum payload size limits (e.g., 10MB).
4. **Improve Data Protection:** Add option to disable request/response body logging.
---
*Report generated by Security Auditor Agent on March 6, 2026*

View File

@@ -0,0 +1,13 @@
-- Migration: add composite indexes for query performance
-- Adds three composite indexes:
-- 1. idx_llm_requests_client_timestamp on llm_requests(client_id, timestamp)
-- 2. idx_llm_requests_provider_timestamp on llm_requests(provider, timestamp)
-- 3. idx_model_configs_provider_id on model_configs(provider_id)
BEGIN TRANSACTION;
CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp);
CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp);
CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id);
COMMIT;

View File

@@ -1,5 +1,7 @@
use anyhow::Result;
use base64::{Engine as _};
use config::{Config, File, FileFormat};
use hex;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
@@ -96,6 +98,7 @@ pub struct AppConfig {
pub model_mapping: ModelMappingConfig,
pub pricing: PricingConfig,
pub config_path: Option<PathBuf>,
pub encryption_key: String,
}
impl AppConfig {
@@ -136,7 +139,8 @@ impl AppConfig {
.set_default("providers.grok.enabled", true)?
.set_default("providers.ollama.base_url", "http://localhost:11434/v1")?
.set_default("providers.ollama.enabled", false)?
.set_default("providers.ollama.models", Vec::<String>::new())?;
.set_default("providers.ollama.models", Vec::<String>::new())?
.set_default("encryption_key", "")?;
// Load from config file if exists
// Priority: explicit path arg > LLM_PROXY__CONFIG_PATH env var > ./config.toml
@@ -167,6 +171,19 @@ impl AppConfig {
let server: ServerConfig = config.get("server")?;
let database: DatabaseConfig = config.get("database")?;
let providers: ProviderConfig = config.get("providers")?;
let encryption_key: String = config.get("encryption_key")?;
// Validate encryption key length (must be 32 bytes after hex or base64 decoding)
if encryption_key.is_empty() {
anyhow::bail!("Encryption key is required (LLM_PROXY__ENCRYPTION_KEY environment variable)");
}
// Try hex decode first, then base64
let key_bytes = hex::decode(&encryption_key)
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&encryption_key))
.map_err(|e| anyhow::anyhow!("Encryption key must be hex or base64 encoded: {}", e))?;
if key_bytes.len() != 32 {
anyhow::bail!("Encryption key must be 32 bytes (256 bits), got {} bytes", key_bytes.len());
}
// For now, use empty model mapping and pricing (will be populated later)
let model_mapping = ModelMappingConfig { patterns: vec![] };
@@ -185,6 +202,7 @@ impl AppConfig {
model_mapping,
pricing,
config_path: Some(config_path),
encryption_key,
}))
}

View File

@@ -1,4 +1,4 @@
use axum::{extract::State, response::Json};
use axum::{extract::State, http::{HeaderMap, HeaderValue}, response::{Json, IntoResponse}};
use bcrypt;
use serde::Deserialize;
use sqlx::Row;
@@ -64,14 +64,14 @@ pub(super) async fn handle_login(
pub(super) async fn handle_auth_status(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
) -> impl IntoResponse {
let token = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
if let Some(token) = token
&& let Some(session) = state.session_manager.validate_session(token).await
&& let Some((session, new_token)) = state.session_manager.validate_session_with_refresh(token).await
{
// Look up display_name from DB
let display_name = sqlx::query_scalar::<_, Option<String>>(
@@ -85,17 +85,23 @@ pub(super) async fn handle_auth_status(
.flatten()
.unwrap_or_else(|| session.username.clone());
return Json(ApiResponse::success(serde_json::json!({
let mut headers = HeaderMap::new();
if let Some(refreshed_token) = new_token {
if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) {
headers.insert("X-Refreshed-Token", header_value);
}
}
return (headers, Json(ApiResponse::success(serde_json::json!({
"authenticated": true,
"user": {
"username": session.username,
"name": display_name,
"role": session.role
}
})));
}))));
}
Json(ApiResponse::error("Not authenticated".to_string()))
(HeaderMap::new(), Json(ApiResponse::error("Not authenticated".to_string())))
}
#[derive(Deserialize)]
@@ -108,7 +114,7 @@ pub(super) async fn handle_change_password(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Json(payload): Json<ChangePasswordRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
) -> impl IntoResponse {
let pool = &state.app_state.db_pool;
// Extract the authenticated user from the session token
@@ -117,14 +123,24 @@ pub(super) async fn handle_change_password(
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
let session = match token {
Some(t) => state.session_manager.validate_session(t).await,
None => None,
let (session, new_token) = match token {
Some(t) => match state.session_manager.validate_session_with_refresh(t).await {
Some((session, new_token)) => (Some(session), new_token),
None => (None, None),
},
None => (None, None),
};
let mut response_headers = HeaderMap::new();
if let Some(refreshed_token) = new_token {
if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) {
response_headers.insert("X-Refreshed-Token", header_value);
}
}
let username = match session {
Some(s) => s.username,
None => return Json(ApiResponse::error("Not authenticated".to_string())),
None => return (response_headers, Json(ApiResponse::error("Not authenticated".to_string()))),
};
let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = ?")
@@ -138,7 +154,7 @@ pub(super) async fn handle_change_password(
if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) {
let new_hash = match bcrypt::hash(&payload.new_password, 12) {
Ok(h) => h,
Err(_) => return Json(ApiResponse::error("Failed to hash new password".to_string())),
Err(_) => return (response_headers, Json(ApiResponse::error("Failed to hash new password".to_string()))),
};
let update_result = sqlx::query(
@@ -150,16 +166,16 @@ pub(super) async fn handle_change_password(
.await;
match update_result {
Ok(_) => Json(ApiResponse::success(
Ok(_) => (response_headers, Json(ApiResponse::success(
serde_json::json!({ "message": "Password updated successfully" }),
)),
Err(e) => Json(ApiResponse::error(format!("Failed to update database: {}", e))),
))),
Err(e) => (response_headers, Json(ApiResponse::error(format!("Failed to update database: {}", e)))),
}
} else {
Json(ApiResponse::error("Current password incorrect".to_string()))
(response_headers, Json(ApiResponse::error("Current password incorrect".to_string())))
}
}
Err(e) => Json(ApiResponse::error(format!("User not found: {}", e))),
Err(e) => (response_headers, Json(ApiResponse::error(format!("User not found: {}", e)))),
}
}
@@ -180,19 +196,19 @@ pub(super) async fn handle_logout(
}
/// Helper: Extract and validate a session from the Authorization header.
/// Returns the Session if valid, or an error response.
/// Returns the Session and optional new token if refreshed, or an error response.
pub(super) async fn extract_session(
state: &DashboardState,
headers: &axum::http::HeaderMap,
) -> Result<super::sessions::Session, Json<ApiResponse<serde_json::Value>>> {
) -> Result<(super::sessions::Session, Option<String>), Json<ApiResponse<serde_json::Value>>> {
let token = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
match token {
Some(t) => match state.session_manager.validate_session(t).await {
Some(session) => Ok(session),
Some(t) => match state.session_manager.validate_session_with_refresh(t).await {
Some((session, new_token)) => Ok((session, new_token)),
None => Err(Json(ApiResponse::error("Session expired or invalid".to_string()))),
},
None => Err(Json(ApiResponse::error("Not authenticated".to_string()))),
@@ -200,13 +216,14 @@ pub(super) async fn extract_session(
}
/// Helper: Extract session and require admin role.
/// Returns session and optional new token if refreshed.
pub(super) async fn require_admin(
state: &DashboardState,
headers: &axum::http::HeaderMap,
) -> Result<super::sessions::Session, Json<ApiResponse<serde_json::Value>>> {
let session = extract_session(state, headers).await?;
) -> Result<(super::sessions::Session, Option<String>), Json<ApiResponse<serde_json::Value>>> {
let (session, new_token) = extract_session(state, headers).await?;
if session.role != "admin" {
return Err(Json(ApiResponse::error("Admin access required".to_string())));
}
Ok(session)
Ok((session, new_token))
}

View File

@@ -88,9 +88,10 @@ pub(super) async fn handle_create_client(
headers: axum::http::HeaderMap,
Json(payload): Json<CreateClientRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
@@ -198,9 +199,10 @@ pub(super) async fn handle_update_client(
Path(id): Path<String>,
Json(payload): Json<UpdateClientPayload>,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
@@ -294,9 +296,10 @@ pub(super) async fn handle_delete_client(
headers: axum::http::HeaderMap,
Path(id): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
@@ -437,9 +440,10 @@ pub(super) async fn handle_create_client_token(
Path(id): Path<String>,
Json(payload): Json<CreateTokenRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
@@ -485,9 +489,10 @@ pub(super) async fn handle_delete_client_token(
headers: axum::http::HeaderMap,
Path((client_id, token_id)): Path<(String, i64)>,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;

View File

@@ -11,10 +11,18 @@ mod users;
mod websocket;
use axum::{
extract::{Request, State},
middleware::Next,
response::Response,
Router,
routing::{delete, get, post, put},
};
use axum::http::{header, HeaderValue};
use serde::Serialize;
use tower_http::{
limit::RequestBodyLimitLayer,
set_header::SetResponseHeaderLayer,
};
use crate::state::AppState;
use sessions::SessionManager;
@@ -52,6 +60,21 @@ impl<T> ApiResponse<T> {
}
}
/// Rate limiting middleware for dashboard routes that extracts AppState from DashboardState.
async fn dashboard_rate_limit_middleware(
State(dashboard_state): State<DashboardState>,
request: Request,
next: Next,
) -> Result<Response, crate::errors::AppError> {
// Delegate to the existing rate limit middleware with AppState
crate::rate_limiting::middleware::rate_limit_middleware(
State(dashboard_state.app_state),
request,
next,
)
.await
}
// Dashboard routes
pub fn router(state: AppState) -> Router {
let session_manager = SessionManager::new(24); // 24-hour session TTL
@@ -60,6 +83,26 @@ pub fn router(state: AppState) -> Router {
session_manager,
};
// Security headers
let csp_header: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::CONTENT_SECURITY_POLICY,
"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws:;"
.parse()
.unwrap(),
);
let x_frame_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_FRAME_OPTIONS,
"DENY".parse().unwrap(),
);
let x_content_type_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_CONTENT_TYPE_OPTIONS,
"nosniff".parse().unwrap(),
);
let strict_transport_security: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::STRICT_TRANSPORT_SECURITY,
"max-age=31536000; includeSubDomains".parse().unwrap(),
);
Router::new()
// Static file serving
.fallback_service(tower_http::services::ServeDir::new("static"))
@@ -119,5 +162,16 @@ pub fn router(state: AppState) -> Router {
"/api/system/settings",
get(system::handle_get_settings).post(system::handle_update_settings),
)
// Security layers
.layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit
.layer(csp_header)
.layer(x_frame_options)
.layer(x_content_type_options)
.layer(strict_transport_security)
// Rate limiting middleware
.layer(axum::middleware::from_fn_with_state(
dashboard_state.clone(),
dashboard_rate_limit_middleware,
))
.with_state(dashboard_state)
}

View File

@@ -156,9 +156,10 @@ pub(super) async fn handle_update_model(
Path(id): Path<String>,
Json(payload): Json<UpdateModelRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;

View File

@@ -9,6 +9,7 @@ use std::collections::HashMap;
use tracing::warn;
use super::{ApiResponse, DashboardState};
use crate::utils::crypto;
#[derive(Deserialize)]
pub(super) struct UpdateProviderRequest {
@@ -265,21 +266,44 @@ pub(super) async fn handle_update_provider(
Path(name): Path<String>,
Json(payload): Json<UpdateProviderRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Update or insert into database (include billing_mode)
// Prepare API key encryption if provided
let (api_key_to_store, api_key_encrypted_flag) = match &payload.api_key {
Some(key) if !key.is_empty() => {
match crypto::encrypt(key) {
Ok(encrypted) => (Some(encrypted), Some(true)),
Err(e) => {
warn!("Failed to encrypt API key for provider {}: {}", name, e);
return Json(ApiResponse::error(format!("Failed to encrypt API key: {}", e)));
}
}
}
Some(_) => {
// Empty string means clear the key
(None, Some(false))
}
None => {
// Keep existing key, we'll rely on COALESCE in SQL
(None, None)
}
};
// Update or insert into database (include billing_mode and api_key_encrypted)
let result = sqlx::query(
r#"
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold, billing_mode)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
enabled = excluded.enabled,
base_url = excluded.base_url,
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
api_key_encrypted = COALESCE(excluded.api_key_encrypted, provider_configs.api_key_encrypted),
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
@@ -290,7 +314,8 @@ pub(super) async fn handle_update_provider(
.bind(name.to_uppercase())
.bind(payload.enabled)
.bind(&payload.base_url)
.bind(&payload.api_key)
.bind(&api_key_to_store)
.bind(api_key_encrypted_flag)
.bind(payload.credit_balance)
.bind(payload.low_credit_threshold)
.bind(payload.billing_mode)

View File

@@ -1,7 +1,17 @@
use chrono::{DateTime, Duration, Utc};
use hmac::{Hmac, Mac};
use serde::{Deserialize, Serialize};
use sha2::{Sha256, digest::generic_array::GenericArray};
use std::collections::HashMap;
use std::env;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
const TOKEN_VERSION: &str = "v2";
const REFRESH_WINDOW_MINUTES: i64 = 15; // refresh if token expires within 15 minutes
#[derive(Clone, Debug)]
pub struct Session {
@@ -9,51 +19,136 @@ pub struct Session {
pub role: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub session_id: String, // unique identifier for the session (UUID)
}
#[derive(Clone)]
pub struct SessionManager {
sessions: Arc<RwLock<HashMap<String, Session>>>,
sessions: Arc<RwLock<HashMap<String, Session>>>, // key = session_id
ttl_hours: i64,
secret: Vec<u8>,
}
#[derive(Debug, Serialize, Deserialize)]
struct SessionPayload {
session_id: String,
username: String,
role: String,
iat: i64, // issued at (Unix timestamp)
exp: i64, // expiry (Unix timestamp)
version: String,
}
impl SessionManager {
pub fn new(ttl_hours: i64) -> Self {
let secret = load_session_secret();
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
ttl_hours,
secret,
}
}
/// Create a new session and return the session token.
/// Create a new session and return a signed session token.
pub async fn create_session(&self, username: String, role: String) -> String {
let token = format!("session-{}", uuid::Uuid::new_v4());
let session_id = Uuid::new_v4().to_string();
let now = Utc::now();
let expires_at = now + Duration::hours(self.ttl_hours);
let session = Session {
username,
role,
username: username.clone(),
role: role.clone(),
created_at: now,
expires_at: now + Duration::hours(self.ttl_hours),
expires_at,
session_id: session_id.clone(),
};
self.sessions.write().await.insert(token.clone(), session);
token
// Store session by session_id
self.sessions.write().await.insert(session_id.clone(), session);
// Create signed token
self.create_signed_token(&session_id, &username, &role, now.timestamp(), expires_at.timestamp())
}
/// Validate a session token and return the session if valid and not expired.
/// If the token is within the refresh window, returns a new token as well.
pub async fn validate_session(&self, token: &str) -> Option<Session> {
self.validate_session_with_refresh(token).await.map(|(session, _)| session)
}
/// Validate a session token and return (session, optional new token if refreshed).
pub async fn validate_session_with_refresh(&self, token: &str) -> Option<(Session, Option<String>)> {
// Legacy token format (UUID)
if token.starts_with("session-") {
let sessions = self.sessions.read().await;
sessions.get(token).and_then(|s| {
return sessions.get(token).and_then(|s| {
if s.expires_at > Utc::now() {
Some(s.clone())
Some((s.clone(), None))
} else {
None
}
})
});
}
// Signed token format
let payload = match verify_signed_token(token, &self.secret) {
Ok(p) => p,
Err(_) => return None,
};
// Check expiry
let now = Utc::now().timestamp();
if payload.exp <= now {
return None;
}
// Look up session by session_id
let sessions = self.sessions.read().await;
let session = match sessions.get(&payload.session_id) {
Some(s) => s.clone(),
None => return None, // session revoked or not found
};
// Ensure session username/role matches (should always match)
if session.username != payload.username || session.role != payload.role {
return None;
}
// Check if token is within refresh window (last REFRESH_WINDOW_MINUTES of validity)
let refresh_threshold = payload.exp - REFRESH_WINDOW_MINUTES * 60;
let new_token = if now >= refresh_threshold {
// Generate a new token with same session data but updated iat/exp?
// According to activity-based refresh, we should extend the session expiry.
// We'll extend from now by ttl_hours (or keep original expiry?).
// Let's extend from now by ttl_hours (sliding window).
let new_exp = Utc::now() + Duration::hours(self.ttl_hours);
// Update session expiry in store
drop(sessions); // release read lock before acquiring write lock
self.update_session_expiry(&payload.session_id, new_exp).await;
// Create new token with updated iat/exp
let new_token = self.create_signed_token(
&payload.session_id,
&payload.username,
&payload.role,
now,
new_exp.timestamp(),
);
Some(new_token)
} else {
None
};
Some((session, new_token))
}
/// Revoke (delete) a session by token.
/// Supports both legacy tokens (token is key) and signed tokens (extract session_id).
pub async fn revoke_session(&self, token: &str) {
if token.starts_with("session-") {
self.sessions.write().await.remove(token);
return;
}
// For signed token, try to extract session_id
if let Ok(payload) = verify_signed_token(token, &self.secret) {
self.sessions.write().await.remove(&payload.session_id);
}
}
/// Remove all expired sessions from the store.
@@ -61,4 +156,156 @@ impl SessionManager {
let now = Utc::now();
self.sessions.write().await.retain(|_, s| s.expires_at > now);
}
// --- Private helpers ---
fn create_signed_token(&self, session_id: &str, username: &str, role: &str, iat: i64, exp: i64) -> String {
let payload = SessionPayload {
session_id: session_id.to_string(),
username: username.to_string(),
role: role.to_string(),
iat,
exp,
version: TOKEN_VERSION.to_string(),
};
sign_token(&payload, &self.secret)
}
async fn update_session_expiry(&self, session_id: &str, new_expires_at: DateTime<Utc>) {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(session_id) {
session.expires_at = new_expires_at;
}
}
}
/// Load session secret from environment variable SESSION_SECRET (hex or base64 encoded).
/// If not set, generates a random 32-byte secret and logs a warning.
fn load_session_secret() -> Vec<u8> {
let secret_str = env::var("SESSION_SECRET").unwrap_or_else(|_| {
// Also check LLM_PROXY__SESSION_SECRET for consistency with config prefix
env::var("LLM_PROXY__SESSION_SECRET").unwrap_or_else(|_| {
// Generate a random secret (32 bytes) and encode as hex
use rand::RngCore;
let mut bytes = [0u8; 32];
rand::rng().fill_bytes(&mut bytes);
let hex_secret = hex::encode(bytes);
tracing::warn!(
"SESSION_SECRET environment variable not set. Using a randomly generated secret. \
This will invalidate all sessions on restart. Set SESSION_SECRET to a fixed hex or base64 encoded 32-byte value."
);
hex_secret
})
});
// Decode hex or base64
hex::decode(&secret_str)
.or_else(|_| URL_SAFE.decode(&secret_str))
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&secret_str))
.unwrap_or_else(|_| {
panic!("SESSION_SECRET must be hex or base64 encoded (32 bytes)");
})
}
/// Sign a session payload and return a token string in format base64_url(payload).base64_url(signature).
fn sign_token(payload: &SessionPayload, secret: &[u8]) -> String {
let json = serde_json::to_vec(payload).expect("Failed to serialize payload");
let payload_b64 = URL_SAFE.encode(&json);
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
mac.update(&json);
let signature = mac.finalize().into_bytes();
let signature_b64 = URL_SAFE.encode(signature);
format!("{}.{}", payload_b64, signature_b64)
}
/// Verify a signed token and return the decoded payload if valid.
fn verify_signed_token(token: &str, secret: &[u8]) -> Result<SessionPayload, TokenError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 2 {
return Err(TokenError::InvalidFormat);
}
let payload_b64 = parts[0];
let signature_b64 = parts[1];
let json = URL_SAFE.decode(payload_b64).map_err(|_| TokenError::InvalidFormat)?;
let signature = URL_SAFE.decode(signature_b64).map_err(|_| TokenError::InvalidFormat)?;
// Verify HMAC
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
mac.update(&json);
// Convert signature slice to GenericArray
let tag = GenericArray::from_slice(&signature);
mac.verify(tag).map_err(|_| TokenError::InvalidSignature)?;
// Deserialize payload
let payload: SessionPayload = serde_json::from_slice(&json).map_err(|_| TokenError::InvalidPayload)?;
Ok(payload)
}
#[derive(Debug)]
enum TokenError {
InvalidFormat,
InvalidSignature,
InvalidPayload,
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[test]
fn test_sign_and_verify_token() {
let secret = b"test-secret-must-be-32-bytes-long!";
let payload = SessionPayload {
session_id: "test-session".to_string(),
username: "testuser".to_string(),
role: "user".to_string(),
iat: 1000,
exp: 2000,
version: TOKEN_VERSION.to_string(),
};
let token = sign_token(&payload, secret);
let verified = verify_signed_token(&token, secret).unwrap();
assert_eq!(verified.session_id, payload.session_id);
assert_eq!(verified.username, payload.username);
assert_eq!(verified.role, payload.role);
assert_eq!(verified.iat, payload.iat);
assert_eq!(verified.exp, payload.exp);
assert_eq!(verified.version, payload.version);
}
#[test]
fn test_tampered_token() {
let secret = b"test-secret-must-be-32-bytes-long!";
let payload = SessionPayload {
session_id: "test-session".to_string(),
username: "testuser".to_string(),
role: "user".to_string(),
iat: 1000,
exp: 2000,
version: TOKEN_VERSION.to_string(),
};
let mut token = sign_token(&payload, secret);
// Tamper with payload part
let mut parts: Vec<&str> = token.split('.').collect();
let mut payload_bytes = URL_SAFE.decode(parts[0]).unwrap();
payload_bytes[0] ^= 0xFF; // flip some bits
let tampered_payload = URL_SAFE.encode(payload_bytes);
parts[0] = &tampered_payload;
token = parts.join(".");
assert!(verify_signed_token(&token, secret).is_err());
}
#[test]
fn test_load_session_secret_from_env() {
unsafe {
env::set_var("SESSION_SECRET", hex::encode([0xAA; 32]));
}
let secret = load_session_secret();
assert_eq!(secret, vec![0xAA; 32]);
unsafe {
env::remove_var("SESSION_SECRET");
}
}
}

View File

@@ -279,9 +279,10 @@ pub(super) async fn handle_system_backup(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let backup_id = format!("backup-{}", chrono::Utc::now().timestamp());
@@ -341,9 +342,10 @@ pub(super) async fn handle_update_settings(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = super::auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
Json(ApiResponse::error(
"Changing settings at runtime is not yet supported. Please update your config file and restart the server."

View File

@@ -14,9 +14,10 @@ pub(super) async fn handle_get_users(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
@@ -66,9 +67,10 @@ pub(super) async fn handle_create_user(
headers: axum::http::HeaderMap,
Json(payload): Json<CreateUserRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
@@ -147,9 +149,10 @@ pub(super) async fn handle_update_user(
Path(id): Path<i64>,
Json(payload): Json<UpdateUserRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
if let Err(e) = auth::require_admin(&state, &headers).await {
return e;
}
let (session, _) = match auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
@@ -249,8 +252,8 @@ pub(super) async fn handle_delete_user(
headers: axum::http::HeaderMap,
Path(id): Path<i64>,
) -> Json<ApiResponse<serde_json::Value>> {
let session = match auth::require_admin(&state, &headers).await {
Ok(s) => s,
let (session, _) = match auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};

View File

@@ -18,7 +18,9 @@ pub async fn init(config: &DatabaseConfig) -> Result<DbPool> {
let database_path = config.path.to_string_lossy().to_string();
info!("Connecting to database at {}", database_path);
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?.create_if_missing(true);
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?
.create_if_missing(true)
.pragma("foreign_keys", "ON");
let pool = SqlitePool::connect_with(options).await?;
@@ -29,7 +31,7 @@ pub async fn init(config: &DatabaseConfig) -> Result<DbPool> {
Ok(pool)
}
async fn run_migrations(pool: &DbPool) -> Result<()> {
pub async fn run_migrations(pool: &DbPool) -> Result<()> {
// Create clients table if it doesn't exist
sqlx::query(
r#"
@@ -88,6 +90,8 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
api_key TEXT,
credit_balance REAL DEFAULT 0.0,
low_credit_threshold REAL DEFAULT 5.0,
billing_mode TEXT,
api_key_encrypted BOOLEAN DEFAULT FALSE,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
"#,
@@ -167,6 +171,15 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
.execute(pool)
.await;
// Add billing_mode column if it doesn't exist (migration for existing DBs)
let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN billing_mode TEXT")
.execute(pool)
.await;
// Add api_key_encrypted column if it doesn't exist (migration for existing DBs)
let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN api_key_encrypted BOOLEAN DEFAULT FALSE")
.execute(pool)
.await;
// Insert default admin user if none exists (default password: admin)
let user_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users").fetch_one(pool).await?;
@@ -216,6 +229,19 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
.execute(pool)
.await?;
// Composite indexes for performance
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id)")
.execute(pool)
.await?;
// Insert default client if none exists
sqlx::query(
r#"

View File

@@ -41,23 +41,18 @@ pub use state::AppState;
pub mod test_utils {
use std::sync::Arc;
use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState};
use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState, utils::crypto, database::run_migrations};
use sqlx::sqlite::SqlitePool;
/// Create a test application state
pub async fn create_test_state() -> Arc<AppState> {
pub async fn create_test_state() -> AppState {
// Create in-memory database
let pool = SqlitePool::connect("sqlite::memory:")
.await
.expect("Failed to create test database");
// Run migrations
crate::database::init(&crate::config::DatabaseConfig {
path: std::path::PathBuf::from(":memory:"),
max_connections: 5,
})
.await
.expect("Failed to initialize test database");
// Run migrations on the pool
run_migrations(&pool).await.expect("Failed to run migrations");
let rate_limit_manager = RateLimitManager::new(
crate::rate_limiting::RateLimiterConfig::default(),
@@ -73,7 +68,7 @@ pub mod test_utils {
providers: std::collections::HashMap::new(),
};
let (dashboard_tx, _) = tokio::sync::broadcast::channel(100);
let (dashboard_tx, _) = tokio::sync::broadcast::channel::<serde_json::Value>(100);
let config = Arc::new(crate::config::AppConfig {
server: crate::config::ServerConfig {
@@ -125,20 +120,20 @@ pub mod test_utils {
ollama: vec![],
},
config_path: None,
encryption_key: "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f".to_string(),
});
Arc::new(AppState {
// Initialize encryption with the test key
crypto::init_with_key(&config.encryption_key).expect("failed to initialize crypto");
AppState::new(
config,
provider_manager,
db_pool: pool.clone(),
rate_limit_manager: Arc::new(rate_limit_manager),
client_manager,
request_logger: Arc::new(crate::logging::RequestLogger::new(pool.clone(), dashboard_tx.clone())),
model_registry: Arc::new(model_registry),
model_config_cache: crate::state::ModelConfigCache::new(pool.clone()),
dashboard_tx,
auth_tokens: vec![],
})
pool,
rate_limit_manager,
model_registry,
vec![], // auth_tokens
)
}
/// Create a test HTTP client
@@ -149,3 +144,185 @@ pub mod test_utils {
.expect("Failed to create test HTTP client")
}
}
#[cfg(test)]
mod integration_tests {
use super::test_utils::*;
use crate::{
models::{ChatCompletionRequest, ChatMessage},
server::router,
utils::crypto,
};
use axum::{
body::Body,
http::{Request, StatusCode},
};
use mockito::Server;
use serde_json::json;
use sqlx::Row;
use tower::util::ServiceExt;
#[tokio::test]
async fn test_encrypted_provider_key_integration() {
// Step 1: Setup test database and state
let state = create_test_state().await;
let pool = state.db_pool.clone();
// Step 2: Insert provider with encrypted API key
let test_api_key = "test-openai-key-12345";
let encrypted_key = crypto::encrypt(test_api_key).expect("Failed to encrypt test key");
sqlx::query(
r#"
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind("openai")
.bind("OpenAI")
.bind(true)
.bind("http://localhost:1234") // Mock server URL
.bind(&encrypted_key)
.bind(true) // api_key_encrypted flag
.bind(100.0)
.bind(5.0)
.execute(&pool)
.await
.expect("Failed to update provider URL");
// Re-initialize provider with new URL
state
.provider_manager
.initialize_provider("openai", &state.config, &pool)
.await
.expect("Failed to re-initialize provider");
// Step 4: Mock OpenAI API server
let mut server = Server::new_async().await;
let mock = server
.mock("POST", "/chat/completions")
.match_header("authorization", format!("Bearer {}", test_api_key).as_str())
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
json!({
"id": "chatcmpl-test",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-3.5-turbo",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello, world!"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
}
})
.to_string(),
)
.create_async()
.await;
// Update provider base URL to use mock server
sqlx::query("UPDATE provider_configs SET base_url = ? WHERE id = 'openai'")
.bind(&server.url())
.execute(&pool)
.await
.expect("Failed to update provider URL");
// Re-initialize provider with new URL
state
.provider_manager
.initialize_provider("openai", &state.config, &pool)
.await
.expect("Failed to re-initialize provider");
// Step 5: Create test router and make request
let app = router(state);
let request_body = ChatCompletionRequest {
model: "gpt-3.5-turbo".to_string(),
messages: vec![ChatMessage {
role: "user".to_string(),
content: crate::models::MessageContent::Text {
content: "Hello".to_string(),
},
reasoning_content: None,
tool_calls: None,
name: None,
tool_call_id: None,
}],
temperature: None,
top_p: None,
top_k: None,
n: None,
stop: None,
max_tokens: Some(100),
presence_penalty: None,
frequency_penalty: None,
stream: Some(false),
tools: None,
tool_choice: None,
};
let request = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.header("authorization", "Bearer test-token")
.body(Body::from(serde_json::to_string(&request_body).unwrap()))
.unwrap();
// Step 6: Execute request through proxy
let response = app
.oneshot(request)
.await
.expect("Failed to execute request");
let status = response.status();
println!("Response status: {}", status);
if status != StatusCode::OK {
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
println!("Response body: {}", body_str);
panic!("Response status is not OK: {}", status);
}
assert_eq!(status, StatusCode::OK);
// Verify the mock was called
mock.assert_async().await;
// Give the async logging task time to complete
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Step 7: Verify usage was logged in database
let log_row = sqlx::query("SELECT * FROM llm_requests WHERE client_id = 'client_test-tok' ORDER BY id DESC LIMIT 1")
.fetch_one(&pool)
.await
.expect("Request log not found");
assert_eq!(log_row.get::<String, _>("provider"), "openai");
assert_eq!(log_row.get::<String, _>("model"), "gpt-3.5-turbo");
assert_eq!(log_row.get::<i64, _>("prompt_tokens"), 10);
assert_eq!(log_row.get::<i64, _>("completion_tokens"), 5);
assert_eq!(log_row.get::<i64, _>("total_tokens"), 15);
assert_eq!(log_row.get::<String, _>("status"), "success");
// Verify client usage was updated
let client_row = sqlx::query("SELECT total_requests, total_tokens, total_cost FROM clients WHERE client_id = 'client_test-tok'")
.fetch_one(&pool)
.await
.expect("Client not found");
assert_eq!(client_row.get::<i64, _>("total_requests"), 1);
assert_eq!(client_row.get::<i64, _>("total_tokens"), 15);
}
}

View File

@@ -82,9 +82,9 @@ impl RequestLogger {
"#,
)
.bind(log.timestamp)
.bind(log.client_id)
.bind(&log.client_id)
.bind(&log.provider)
.bind(log.model)
.bind(&log.model)
.bind(log.prompt_tokens as i64)
.bind(log.completion_tokens as i64)
.bind(log.total_tokens as i64)
@@ -92,7 +92,7 @@ impl RequestLogger {
.bind(log.cache_write_tokens as i64)
.bind(log.cost)
.bind(log.has_images)
.bind(log.status)
.bind(&log.status)
.bind(log.error_message)
.bind(log.duration_ms as i64)
.bind(None::<String>) // request_body - optional, not stored to save disk space
@@ -100,6 +100,23 @@ impl RequestLogger {
.execute(&mut *tx)
.await?;
// Update client usage statistics
sqlx::query(
r#"
UPDATE clients SET
total_requests = total_requests + 1,
total_tokens = total_tokens + ?,
total_cost = total_cost + ?,
updated_at = CURRENT_TIMESTAMP
WHERE client_id = ?
"#,
)
.bind(log.total_tokens as i64)
.bind(log.cost)
.bind(&log.client_id)
.execute(&mut *tx)
.await?;
// Deduct from provider balance if successful.
// Providers configured with billing_mode = 'postpaid' will not have their
// credit_balance decremented. Use a conditional UPDATE so we don't need

View File

@@ -10,6 +10,7 @@ use llm_proxy::{
rate_limiting::{CircuitBreakerConfig, RateLimitManager, RateLimiterConfig},
server,
state::AppState,
utils::crypto,
};
#[tokio::main]
@@ -26,6 +27,10 @@ async fn main() -> Result<()> {
let config = AppConfig::load().await?;
info!("Configuration loaded from {:?}", config.config_path);
// Initialize encryption
crypto::init_with_key(&config.encryption_key)?;
info!("Encryption initialized");
// Initialize database connection pool
let db_pool = database::init(&config.database).await?;
info!("Database initialized at {:?}", config.database.path);

View File

@@ -7,6 +7,7 @@ use std::sync::Arc;
use crate::errors::AppError;
use crate::models::UnifiedRequest;
pub mod deepseek;
pub mod gemini;
pub mod grok;
@@ -125,17 +126,35 @@ impl ProviderManager {
db_pool: &crate::database::DbPool,
) -> Result<()> {
// Load override from database
let db_config = sqlx::query("SELECT enabled, base_url, api_key FROM provider_configs WHERE id = ?")
let db_config = sqlx::query("SELECT enabled, base_url, api_key, api_key_encrypted FROM provider_configs WHERE id = ?")
.bind(name)
.fetch_optional(db_pool)
.await?;
let (enabled, base_url, api_key) = if let Some(row) = db_config {
(
row.get::<bool, _>("enabled"),
row.get::<Option<String>, _>("base_url"),
row.get::<Option<String>, _>("api_key"),
)
let enabled = row.get::<bool, _>("enabled");
let base_url = row.get::<Option<String>, _>("base_url");
let api_key_encrypted = row.get::<bool, _>("api_key_encrypted");
let api_key = row.get::<Option<String>, _>("api_key");
// Decrypt API key if encrypted
let api_key = match (api_key, api_key_encrypted) {
(Some(key), true) => {
match crate::utils::crypto::decrypt(&key) {
Ok(decrypted) => Some(decrypted),
Err(e) => {
tracing::error!("Failed to decrypt API key for provider {}: {}", name, e);
None
}
}
}
(Some(key), false) => {
// Plaintext key - optionally encrypt and update database (lazy migration)
// For now, just use plaintext
Some(key)
}
(None, _) => None,
};
(enabled, base_url, api_key)
} else {
// No database override, use defaults from AppConfig
match name {

View File

@@ -6,12 +6,15 @@
//! 3. Global rate limiting for overall system protection
use anyhow::Result;
use governor::{Quota, RateLimiter, DefaultDirectRateLimiter};
use std::collections::HashMap;
use std::num::NonZeroU32;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use tracing::{info, warn};
type GovRateLimiter = DefaultDirectRateLimiter;
/// Rate limiter configuration
#[derive(Debug, Clone)]
pub struct RateLimiterConfig {
@@ -65,45 +68,7 @@ impl Default for CircuitBreakerConfig {
}
}
/// Simple token bucket rate limiter for a single client
#[derive(Debug)]
struct TokenBucket {
tokens: f64,
capacity: f64,
refill_rate: f64, // tokens per second
last_refill: Instant,
}
impl TokenBucket {
fn new(capacity: f64, refill_rate: f64) -> Self {
Self {
tokens: capacity,
capacity,
refill_rate,
last_refill: Instant::now(),
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
let new_tokens = elapsed * self.refill_rate;
self.tokens = (self.tokens + new_tokens).min(self.capacity);
self.last_refill = now;
}
fn try_acquire(&mut self, tokens: f64) -> bool {
self.refill();
if self.tokens >= tokens {
self.tokens -= tokens;
true
} else {
false
}
}
}
/// Circuit breaker for a provider
#[derive(Debug)]
@@ -209,8 +174,8 @@ impl ProviderCircuitBreaker {
/// Rate limiting and circuit breaking manager
#[derive(Debug)]
pub struct RateLimitManager {
client_buckets: Arc<RwLock<HashMap<String, TokenBucket>>>,
global_bucket: Arc<RwLock<TokenBucket>>,
client_buckets: Arc<RwLock<HashMap<String, GovRateLimiter>>>,
global_bucket: Arc<GovRateLimiter>,
circuit_breakers: Arc<RwLock<HashMap<String, ProviderCircuitBreaker>>>,
config: RateLimiterConfig,
circuit_config: CircuitBreakerConfig,
@@ -218,15 +183,16 @@ pub struct RateLimitManager {
impl RateLimitManager {
pub fn new(config: RateLimiterConfig, circuit_config: CircuitBreakerConfig) -> Self {
// Convert requests per minute to tokens per second
let global_refill_rate = config.global_requests_per_minute as f64 / 60.0;
// Create global rate limiter quota
let global_quota = Quota::per_minute(
NonZeroU32::new(config.global_requests_per_minute).expect("global_requests_per_minute must be positive")
)
.allow_burst(NonZeroU32::new(config.burst_size).expect("burst_size must be positive"));
let global_bucket = RateLimiter::direct(global_quota);
Self {
client_buckets: Arc::new(RwLock::new(HashMap::new())),
global_bucket: Arc::new(RwLock::new(TokenBucket::new(
config.burst_size as f64,
global_refill_rate,
))),
global_bucket: Arc::new(global_bucket),
circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
config,
circuit_config,
@@ -236,24 +202,22 @@ impl RateLimitManager {
/// Check if a client request is allowed
pub async fn check_client_request(&self, client_id: &str) -> Result<bool> {
// Check global rate limit first (1 token per request)
{
let mut global_bucket = self.global_bucket.write().await;
if !global_bucket.try_acquire(1.0) {
if self.global_bucket.check().is_err() {
warn!("Global rate limit exceeded");
return Ok(false);
}
}
// Check client-specific rate limit
let mut buckets = self.client_buckets.write().await;
let bucket = buckets.entry(client_id.to_string()).or_insert_with(|| {
TokenBucket::new(
self.config.burst_size as f64,
self.config.requests_per_minute as f64 / 60.0,
let quota = Quota::per_minute(
NonZeroU32::new(self.config.requests_per_minute).expect("requests_per_minute must be positive")
)
.allow_burst(NonZeroU32::new(self.config.burst_size).expect("burst_size must be positive"));
RateLimiter::direct(quota)
});
Ok(bucket.try_acquire(1.0))
Ok(bucket.check().is_ok())
}
/// Check if provider requests are allowed (circuit breaker)

View File

@@ -5,6 +5,11 @@ use axum::{
response::sse::{Event, Sse},
routing::{get, post},
};
use axum::http::{header, HeaderValue};
use tower_http::{
limit::RequestBodyLimitLayer,
set_header::SetResponseHeaderLayer,
};
use futures::StreamExt;
use std::sync::Arc;
@@ -23,9 +28,34 @@ use crate::{
};
pub fn router(state: AppState) -> Router {
// Security headers
let csp_header: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::CONTENT_SECURITY_POLICY,
"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws:;"
.parse()
.unwrap(),
);
let x_frame_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_FRAME_OPTIONS,
"DENY".parse().unwrap(),
);
let x_content_type_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_CONTENT_TYPE_OPTIONS,
"nosniff".parse().unwrap(),
);
let strict_transport_security: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::STRICT_TRANSPORT_SECURITY,
"max-age=31536000; includeSubDomains".parse().unwrap(),
);
Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/models", get(list_models))
.layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit
.layer(csp_header)
.layer(x_frame_options)
.layer(x_content_type_options)
.layer(strict_transport_security)
.layer(axum::middleware::from_fn_with_state(
state.clone(),
rate_limiting::middleware::rate_limit_middleware,
@@ -219,7 +249,6 @@ async fn chat_completions(
prompt_tokens,
has_images,
logger: state.request_logger.clone(),
client_manager: state.client_manager.clone(),
model_registry: state.model_registry.clone(),
model_config_cache: state.model_config_cache.clone(),
},
@@ -341,15 +370,6 @@ async fn chat_completions(
duration_ms: duration.as_millis() as u64,
});
// Update client usage (fire-and-forget, don't block response)
{
let cm = state.client_manager.clone();
let cid = client_id.clone();
tokio::spawn(async move {
let _ = cm.update_client_usage(&cid, response.total_tokens as i64, cost).await;
});
}
// Convert ProviderResponse to ChatCompletionResponse
let finish_reason = if response.tool_calls.is_some() {
"tool_calls".to_string()

171
src/utils/crypto.rs Normal file
View File

@@ -0,0 +1,171 @@
use aes_gcm::{
aead::{Aead, AeadCore, KeyInit, OsRng},
Aes256Gcm, Key, Nonce,
};
use anyhow::{anyhow, Context, Result};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use std::env;
use std::sync::OnceLock;
static RAW_KEY: OnceLock<[u8; 32]> = OnceLock::new();
/// Initialize the encryption key from a hex or base64 encoded string.
/// Must be called before any encryption/decryption operations.
/// Returns error if the key is invalid or already initialized with a different key.
pub fn init_with_key(key_str: &str) -> Result<()> {
let key_bytes = hex::decode(key_str)
.or_else(|_| BASE64.decode(key_str))
.context("Encryption key must be hex or base64 encoded")?;
if key_bytes.len() != 32 {
anyhow::bail!(
"Encryption key must be 32 bytes (256 bits), got {} bytes",
key_bytes.len()
);
}
let key_array: [u8; 32] = key_bytes.try_into().unwrap(); // safe due to length check
// Check if already initialized with same key
if let Some(existing) = RAW_KEY.get() {
if existing == &key_array {
// Same key already initialized, okay
return Ok(());
} else {
anyhow::bail!("Encryption key already initialized with a different key");
}
}
// Store raw key bytes
RAW_KEY
.set(key_array)
.map_err(|_| anyhow::anyhow!("Encryption key already initialized"))?;
Ok(())
}
/// Initialize the encryption key from the environment variable `LLM_PROXY__ENCRYPTION_KEY`.
/// Must be called before any encryption/decryption operations.
/// Panics if the environment variable is missing or invalid.
pub fn init_from_env() -> Result<()> {
let key_str =
env::var("LLM_PROXY__ENCRYPTION_KEY").context("LLM_PROXY__ENCRYPTION_KEY environment variable not set")?;
init_with_key(&key_str)
}
/// Get the encryption key bytes, panicking if not initialized.
fn get_key() -> &'static [u8; 32] {
RAW_KEY
.get()
.expect("Encryption key not initialized. Call crypto::init_with_key() or crypto::init_from_env() first.")
}
/// Encrypt a plaintext string and return a base64-encoded ciphertext (nonce || ciphertext || tag).
pub fn encrypt(plaintext: &str) -> Result<String> {
let key = Key::<Aes256Gcm>::from_slice(get_key());
let cipher = Aes256Gcm::new(key);
let nonce = Aes256Gcm::generate_nonce(&mut OsRng); // 12 bytes
let ciphertext = cipher
.encrypt(&nonce, plaintext.as_bytes())
.map_err(|e| anyhow!("Encryption failed: {}", e))?;
// Combine nonce and ciphertext (ciphertext already includes tag)
let mut combined = Vec::with_capacity(nonce.len() + ciphertext.len());
combined.extend_from_slice(&nonce);
combined.extend_from_slice(&ciphertext);
Ok(BASE64.encode(combined))
}
/// Decrypt a base64-encoded ciphertext (nonce || ciphertext || tag) to a plaintext string.
pub fn decrypt(ciphertext_b64: &str) -> Result<String> {
let key = Key::<Aes256Gcm>::from_slice(get_key());
let cipher = Aes256Gcm::new(key);
let combined = BASE64.decode(ciphertext_b64).context("Invalid base64 ciphertext")?;
if combined.len() < 12 {
anyhow::bail!("Ciphertext too short");
}
let (nonce_bytes, ciphertext_and_tag) = combined.split_at(12);
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext_bytes = cipher
.decrypt(nonce, ciphertext_and_tag)
.map_err(|e| anyhow!("Decryption failed (invalid key or corrupted ciphertext): {}", e))?;
String::from_utf8(plaintext_bytes).context("Decrypted bytes are not valid UTF-8")
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_KEY_HEX: &str = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f";
#[test]
fn test_encrypt_decrypt() {
init_with_key(TEST_KEY_HEX).unwrap();
let plaintext = "super secret api key";
let ciphertext = encrypt(plaintext).unwrap();
assert_ne!(ciphertext, plaintext);
let decrypted = decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_different_inputs_produce_different_ciphertexts() {
init_with_key(TEST_KEY_HEX).unwrap();
let plaintext = "same";
let cipher1 = encrypt(plaintext).unwrap();
let cipher2 = encrypt(plaintext).unwrap();
assert_ne!(cipher1, cipher2, "Nonce should make ciphertexts differ");
assert_eq!(decrypt(&cipher1).unwrap(), plaintext);
assert_eq!(decrypt(&cipher2).unwrap(), plaintext);
}
#[test]
fn test_invalid_key_length() {
let result = init_with_key("tooshort");
assert!(result.is_err());
}
#[test]
fn test_init_from_env() {
unsafe { std::env::set_var("LLM_PROXY__ENCRYPTION_KEY", TEST_KEY_HEX) };
let result = init_from_env();
assert!(result.is_ok());
// Ensure encryption works
let ciphertext = encrypt("test").unwrap();
let decrypted = decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, "test");
}
#[test]
fn test_missing_env_key() {
unsafe { std::env::remove_var("LLM_PROXY__ENCRYPTION_KEY") };
let result = init_from_env();
assert!(result.is_err());
}
#[test]
fn test_key_hex_and_base64() {
// Hex key works
init_with_key(TEST_KEY_HEX).unwrap();
// Base64 key (same bytes encoded as base64)
let base64_key = BASE64.encode(hex::decode(TEST_KEY_HEX).unwrap());
// Re-initialization with same key (different encoding) is allowed
let result = init_with_key(&base64_key);
assert!(result.is_ok());
// Encryption should still work
let ciphertext = encrypt("test").unwrap();
let decrypted = decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, "test");
}
#[test]
#[ignore] // conflicts with global state from other tests
fn test_already_initialized() {
init_with_key(TEST_KEY_HEX).unwrap();
let result = init_with_key(TEST_KEY_HEX);
assert!(result.is_ok()); // same key allowed
}
#[test]
#[ignore] // conflicts with global state from other tests
fn test_already_initialized_different_key() {
init_with_key(TEST_KEY_HEX).unwrap();
let different_key = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e20";
let result = init_with_key(different_key);
assert!(result.is_err());
}
}

View File

@@ -1,3 +1,4 @@
pub mod crypto;
pub mod registry;
pub mod streaming;
pub mod tokens;

View File

@@ -1,4 +1,4 @@
use crate::client::ClientManager;
use crate::errors::AppError;
use crate::logging::{RequestLog, RequestLogger};
use crate::models::ToolCall;
@@ -18,7 +18,6 @@ pub struct StreamConfig {
pub prompt_tokens: u32,
pub has_images: bool,
pub logger: Arc<RequestLogger>,
pub client_manager: Arc<ClientManager>,
pub model_registry: Arc<crate::models::registry::ModelRegistry>,
pub model_config_cache: ModelConfigCache,
}
@@ -36,7 +35,6 @@ pub struct AggregatingStream<S> {
/// Real usage data from the provider's final stream chunk (when available).
real_usage: Option<StreamUsage>,
logger: Arc<RequestLogger>,
client_manager: Arc<ClientManager>,
model_registry: Arc<crate::models::registry::ModelRegistry>,
model_config_cache: ModelConfigCache,
start_time: std::time::Instant,
@@ -60,7 +58,6 @@ where
accumulated_tool_calls: Vec::new(),
real_usage: None,
logger: config.logger,
client_manager: config.client_manager,
model_registry: config.model_registry,
model_config_cache: config.model_config_cache,
start_time: std::time::Instant::now(),
@@ -79,7 +76,6 @@ where
let provider_name = self.provider.name().to_string();
let model = self.model.clone();
let logger = self.logger.clone();
let client_manager = self.client_manager.clone();
let provider = self.provider.clone();
let estimated_prompt_tokens = self.prompt_tokens;
let has_images = self.has_images;
@@ -162,11 +158,6 @@ where
error_message: None,
duration_ms: duration.as_millis() as u64,
});
// Update client usage
let _ = client_manager
.update_client_usage(&client_id, total_tokens as i64, cost)
.await;
});
}
}
@@ -304,7 +295,6 @@ mod tests {
let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
let (dashboard_tx, _) = tokio::sync::broadcast::channel(16);
let logger = Arc::new(RequestLogger::new(pool.clone(), dashboard_tx));
let client_manager = Arc::new(ClientManager::new(pool.clone()));
let registry = Arc::new(crate::models::registry::ModelRegistry {
providers: std::collections::HashMap::new(),
});
@@ -318,7 +308,6 @@ mod tests {
prompt_tokens: 10,
has_images: false,
logger,
client_manager,
model_registry: registry,
model_config_cache: ModelConfigCache::new(pool.clone()),
},

View File

@@ -35,6 +35,14 @@ class ApiClient {
throw new Error(result.error || `HTTP error! status: ${response.status}`);
}
// Handling X-Refreshed-Token header
if (response.headers.get('X-Refreshed-Token') && window.authManager) {
window.authManager.token = response.headers.get('X-Refreshed-Token');
if (window.authManager.setToken) {
window.authManager.setToken(window.authManager.token);
}
}
return result.data;
}
@@ -87,6 +95,17 @@ class ApiClient {
const date = luxon.DateTime.fromISO(dateStr);
return date.toRelative();
}
// Helper for escaping HTML
escapeHtml(unsafe) {
if (unsafe === undefined || unsafe === null) return '';
return unsafe.toString()
.replace(/&/g, "&amp;")
.replace(/</g, "&lt;")
.replace(/>/g, "&gt;")
.replace(/"/g, "&quot;")
.replace(/'/g, "&#039;");
}
}
window.api = new ApiClient();

View File

@@ -50,6 +50,12 @@ class AuthManager {
});
}
setToken(newToken) {
if (!newToken) return;
this.token = newToken;
localStorage.setItem('dashboard_token', this.token);
}
async login(username, password) {
const errorElement = document.getElementById('login-error');
const loginBtn = document.querySelector('.login-btn');

View File

@@ -42,12 +42,15 @@ class ClientsPage {
const statusIcon = client.status === 'active' ? 'check-circle' : 'clock';
const created = luxon.DateTime.fromISO(client.created_at).toFormat('MMM dd, yyyy');
const escapedId = window.api.escapeHtml(client.id);
const escapedName = window.api.escapeHtml(client.name);
return `
<tr>
<td><span class="badge-client">${client.id}</span></td>
<td><strong>${client.name}</strong></td>
<td><span class="badge-client">${escapedId}</span></td>
<td><strong>${escapedName}</strong></td>
<td>
<code class="token-display">sk-••••${client.id.substring(client.id.length - 4)}</code>
<code class="token-display">sk-••••${escapedId.substring(escapedId.length - 4)}</code>
</td>
<td>${created}</td>
<td>${client.last_used ? window.api.formatTimeAgo(client.last_used) : 'Never'}</td>
@@ -55,16 +58,16 @@ class ClientsPage {
<td>
<span class="status-badge ${statusClass}">
<i class="fas fa-${statusIcon}"></i>
${client.status}
${window.api.escapeHtml(client.status)}
</span>
</td>
<td>
${window._userRole === 'admin' ? `
<div class="action-buttons">
<button class="btn-action" title="Edit" onclick="window.clientsPage.editClient('${client.id}')">
<button class="btn-action" title="Edit" onclick="window.clientsPage.editClient('${escapedId}')">
<i class="fas fa-edit"></i>
</button>
<button class="btn-action danger" title="Delete" onclick="window.clientsPage.deleteClient('${client.id}')">
<button class="btn-action danger" title="Delete" onclick="window.clientsPage.deleteClient('${escapedId}')">
<i class="fas fa-trash"></i>
</button>
</div>
@@ -188,10 +191,13 @@ class ClientsPage {
showTokenRevealModal(clientName, token) {
const modal = document.createElement('div');
modal.className = 'modal active';
const escapedName = window.api.escapeHtml(clientName);
const escapedToken = window.api.escapeHtml(token);
modal.innerHTML = `
<div class="modal-content">
<div class="modal-header">
<h3 class="modal-title">Client Created: ${clientName}</h3>
<h3 class="modal-title">Client Created: ${escapedName}</h3>
</div>
<div class="modal-body">
<p style="margin-bottom: 0.75rem; color: var(--yellow);">
@@ -201,7 +207,7 @@ class ClientsPage {
<div class="form-control">
<label>API Token</label>
<div style="display: flex; gap: 0.5rem;">
<input type="text" id="revealed-token" value="${token}" readonly
<input type="text" id="revealed-token" value="${escapedToken}" readonly
style="font-family: monospace; font-size: 0.85rem;">
<button class="btn btn-secondary" id="copy-token-btn" title="Copy">
<i class="fas fa-copy"></i>
@@ -248,10 +254,16 @@ class ClientsPage {
showEditClientModal(client) {
const modal = document.createElement('div');
modal.className = 'modal active';
const escapedId = window.api.escapeHtml(client.id);
const escapedName = window.api.escapeHtml(client.name);
const escapedDescription = window.api.escapeHtml(client.description);
const escapedRateLimit = window.api.escapeHtml(client.rate_limit_per_minute);
modal.innerHTML = `
<div class="modal-content">
<div class="modal-header">
<h3 class="modal-title">Edit Client: ${client.id}</h3>
<h3 class="modal-title">Edit Client: ${escapedId}</h3>
<button class="modal-close" onclick="this.closest('.modal').remove()">
<i class="fas fa-times"></i>
</button>
@@ -259,15 +271,15 @@ class ClientsPage {
<div class="modal-body">
<div class="form-control">
<label for="edit-client-name">Display Name</label>
<input type="text" id="edit-client-name" value="${client.name || ''}" placeholder="e.g. My Coding Assistant">
<input type="text" id="edit-client-name" value="${escapedName}" placeholder="e.g. My Coding Assistant">
</div>
<div class="form-control">
<label for="edit-client-description">Description</label>
<textarea id="edit-client-description" rows="3" placeholder="Optional description">${client.description || ''}</textarea>
<textarea id="edit-client-description" rows="3" placeholder="Optional description">${escapedDescription}</textarea>
</div>
<div class="form-control">
<label for="edit-client-rate-limit">Rate Limit (requests/minute)</label>
<input type="number" id="edit-client-rate-limit" min="0" value="${client.rate_limit_per_minute || ''}" placeholder="Leave empty for unlimited">
<input type="number" id="edit-client-rate-limit" min="0" value="${escapedRateLimit}" placeholder="Leave empty for unlimited">
</div>
<div class="form-control">
<label class="toggle-label">
@@ -357,12 +369,16 @@ class ClientsPage {
const lastUsed = t.last_used_at
? luxon.DateTime.fromISO(t.last_used_at).toRelative()
: 'Never';
const escapedMaskedToken = window.api.escapeHtml(t.token_masked);
const escapedClientId = window.api.escapeHtml(clientId);
const tokenId = parseInt(t.id); // Assuming ID is numeric
return `
<div style="display: flex; align-items: center; gap: 0.5rem; padding: 0.4rem 0; border-bottom: 1px solid var(--bg3);">
<code style="flex: 1; font-size: 0.8rem; color: var(--fg2);">${t.token_masked}</code>
<code style="flex: 1; font-size: 0.8rem; color: var(--fg2);">${escapedMaskedToken}</code>
<span style="font-size: 0.75rem; color: var(--fg4);" title="Last used">${lastUsed}</span>
<button class="btn-action danger" title="Revoke" style="padding: 0.2rem 0.4rem;"
onclick="window.clientsPage.revokeToken('${clientId}', ${t.id}, this)">
onclick="window.clientsPage.revokeToken('${escapedClientId}', ${tokenId}, this)">
<i class="fas fa-trash" style="font-size: 0.75rem;"></i>
</button>
</div>

View File

@@ -47,16 +47,21 @@ class ProvidersPage {
const isLowBalance = provider.credit_balance <= provider.low_credit_threshold && provider.id !== 'ollama';
const balanceColor = isLowBalance ? 'var(--red-light)' : 'var(--green-light)';
const escapedId = window.api.escapeHtml(provider.id);
const escapedName = window.api.escapeHtml(provider.name);
const escapedStatus = window.api.escapeHtml(provider.status);
const billingMode = provider.billing_mode ? provider.billing_mode.toUpperCase() : 'PREPAID';
return `
<div class="provider-card ${provider.status}">
<div class="provider-card ${escapedStatus}">
<div class="provider-card-header">
<div class="provider-info">
<h4 class="provider-name">${provider.name}</h4>
<span class="provider-id">${provider.id}</span>
<h4 class="provider-name">${escapedName}</h4>
<span class="provider-id">${escapedId}</span>
</div>
<span class="status-badge ${statusClass}">
<i class="fas fa-circle"></i>
${provider.status}
${escapedStatus}
</span>
</div>
<div class="provider-card-body">
@@ -67,12 +72,12 @@ class ProvidersPage {
</div>
<div class="meta-item" style="color: ${balanceColor}; font-weight: 700;">
<i class="fas fa-wallet"></i>
<span>Balance: ${provider.id === 'ollama' ? 'FREE' : window.api.formatCurrency(provider.credit_balance)}</span>
<span>Balance: ${escapedId === 'ollama' ? 'FREE' : window.api.formatCurrency(provider.credit_balance)}</span>
${isLowBalance ? '<i class="fas fa-exclamation-triangle" title="Low Balance"></i>' : ''}
</div>
<div class="meta-item">
<i class="fas fa-exchange-alt"></i>
<span>Billing: ${provider.billing_mode ? provider.billing_mode.toUpperCase() : 'PREPAID'}</span>
<span>Billing: ${window.api.escapeHtml(billingMode)}</span>
</div>
<div class="meta-item">
<i class="fas fa-clock"></i>
@@ -80,16 +85,16 @@ class ProvidersPage {
</div>
</div>
<div class="model-tags">
${(provider.models || []).slice(0, 5).map(m => `<span class="model-tag">${m}</span>`).join('')}
${(provider.models || []).slice(0, 5).map(m => `<span class="model-tag">${window.api.escapeHtml(m)}</span>`).join('')}
${modelCount > 5 ? `<span class="model-tag more">+${modelCount - 5} more</span>` : ''}
</div>
</div>
<div class="provider-card-footer">
<button class="btn btn-secondary btn-sm" onclick="window.providersPage.testProvider('${provider.id}')">
<button class="btn btn-secondary btn-sm" onclick="window.providersPage.testProvider('${escapedId}')">
<i class="fas fa-vial"></i> Test
</button>
${window._userRole === 'admin' ? `
<button class="btn btn-primary btn-sm" onclick="window.providersPage.configureProvider('${provider.id}')">
<button class="btn btn-primary btn-sm" onclick="window.providersPage.configureProvider('${escapedId}')">
<i class="fas fa-cog"></i> Config
</button>
` : ''}
@@ -144,10 +149,17 @@ class ProvidersPage {
const modal = document.createElement('div');
modal.className = 'modal active';
const escapedId = window.api.escapeHtml(provider.id);
const escapedName = window.api.escapeHtml(provider.name);
const escapedBaseUrl = window.api.escapeHtml(provider.base_url);
const escapedBalance = window.api.escapeHtml(provider.credit_balance);
const escapedThreshold = window.api.escapeHtml(provider.low_credit_threshold);
modal.innerHTML = `
<div class="modal-content">
<div class="modal-header">
<h3 class="modal-title">Configure ${provider.name}</h3>
<h3 class="modal-title">Configure ${escapedName}</h3>
<button class="modal-close" onclick="this.closest('.modal').remove()">
<i class="fas fa-times"></i>
</button>
@@ -161,7 +173,7 @@ class ProvidersPage {
</div>
<div class="form-control">
<label for="provider-base-url">Base URL</label>
<input type="text" id="provider-base-url" value="${provider.base_url || ''}" placeholder="Default API URL">
<input type="text" id="provider-base-url" value="${escapedBaseUrl}" placeholder="Default API URL">
</div>
<div class="form-control">
<label for="provider-api-key">API Key (Optional / Overwrite)</label>
@@ -170,11 +182,11 @@ class ProvidersPage {
<div class="grid-2">
<div class="form-control">
<label for="provider-balance">Current Credit Balance ($)</label>
<input type="number" id="provider-balance" value="${provider.credit_balance}" step="0.01">
<input type="number" id="provider-balance" value="${escapedBalance}" step="0.01">
</div>
<div class="form-control">
<label for="provider-threshold">Low Balance Alert ($)</label>
<input type="number" id="provider-threshold" value="${provider.low_credit_threshold}" step="0.50">
<input type="number" id="provider-threshold" value="${escapedThreshold}" step="0.50">
</div>
</div>
<div class="form-control">

View File

@@ -280,8 +280,6 @@
// ── Helpers ────────────────────────────────────────────────────
function escapeHtml(str) {
const div = document.createElement('div');
div.textContent = str;
return div.innerHTML;
return window.api.escapeHtml(str);
}
})();

14
timeline.mmd Normal file
View File

@@ -0,0 +1,14 @@
gantt
title LLM Proxy Project Timeline
dateFormat YYYY-MM-DD
section Frontend
Standardize Escaping (users.js) :a1, 2026-03-06, 1d
section Backend Cleanup
Remove Unused Imports :b1, 2026-03-06, 1d
section HMAC Migration
Architecture Design :c1, 2026-03-07, 1d
Backend Implementation :c2, after c1, 2d
Session Refresh Logic :c3, after c2, 1d
section Testing
Integration Test (Encrypted Keys) :d1, 2026-03-09, 2d
HMAC Verification Tests :d2, after c3, 1d