Compare commits
42 Commits
7411d3dbed
...
rust
| Author | SHA1 | Date | |
|---|---|---|---|
| 649371154f | |||
| 78fff61660 | |||
| b131094dfd | |||
| c3d81c1733 | |||
| e123f542f1 | |||
| 0d28241e39 | |||
| 754ee9cb84 | |||
| 5a9086b883 | |||
| cc5eba1957 | |||
| 3ab00fb188 | |||
| c2595f7a74 | |||
| 0526304398 | |||
| 75e2967727 | |||
| e1bc3b35eb | |||
| 0d32d953d2 | |||
| bd5ca2dd98 | |||
| 6a0aca1a6c | |||
| 4c629e17cb | |||
| fc3bc6968d | |||
| d6280abad9 | |||
| 96486b6318 | |||
| e8955fd36c | |||
| a243a3987d | |||
| 4be23629d8 | |||
| dd54c14ff8 | |||
| 633b69a07b | |||
| 975ae124d1 | |||
| 9b8483e797 | |||
| d32386df3f | |||
| 149a7c3a29 | |||
| d9cfffea62 | |||
| 90ef026c96 | |||
| 5ddf284b8f | |||
| f5677afba0 | |||
| 4ffc6452e0 | |||
| 94162a3dcc | |||
| c26925c253 | |||
| d0d64e2064 | |||
| 6a324c08c7 | |||
| 1ddb5277e9 | |||
| 1067ceaecd | |||
| fc5d3ed636 |
12
.env
12
.env
@@ -15,8 +15,14 @@ GROK_API_KEY=gk-demo-grok-key
|
||||
# Authentication tokens (comma-separated list)
|
||||
LLM_PROXY__SERVER__AUTH_TOKENS=demo-token-123456,another-token
|
||||
|
||||
# Server port (optional)
|
||||
LLM_PROXY__SERVER__PORT=8080
|
||||
|
||||
# Database path (optional)
|
||||
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
|
||||
|
||||
# Session Secret (for signed tokens)
|
||||
SESSION_SECRET=ki9khXAk9usDkasMrD2UbK4LOgrDRJz0
|
||||
|
||||
# Encryption key (required)
|
||||
LLM_PROXY__ENCRYPTION_KEY=69879f5b7913ba169982190526ae213e830b3f1f33e785ef2b68cf48c7853fcd
|
||||
|
||||
# Server port (optional)
|
||||
LLM_PROXY__SERVER__PORT=8080
|
||||
|
||||
22
.env.backup
Normal file
22
.env.backup
Normal file
@@ -0,0 +1,22 @@
|
||||
# LLM Proxy Gateway Environment Variables
|
||||
|
||||
# OpenAI
|
||||
OPENAI_API_KEY=sk-demo-openai-key
|
||||
|
||||
# Google Gemini
|
||||
GEMINI_API_KEY=AIza-demo-gemini-key
|
||||
|
||||
# DeepSeek
|
||||
DEEPSEEK_API_KEY=sk-demo-deepseek-key
|
||||
|
||||
# xAI Grok (not yet available)
|
||||
GROK_API_KEY=gk-demo-grok-key
|
||||
|
||||
# Authentication tokens (comma-separated list)
|
||||
LLM_PROXY__SERVER__AUTH_TOKENS=demo-token-123456,another-token
|
||||
|
||||
# Server port (optional)
|
||||
LLM_PROXY__SERVER__PORT=8080
|
||||
|
||||
# Database path (optional)
|
||||
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
|
||||
@@ -26,3 +26,6 @@ LLM_PROXY__SERVER__PORT=8080
|
||||
|
||||
# Database path (optional)
|
||||
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
|
||||
|
||||
# Session secret for HMAC-signed tokens (hex or base64 encoded, 32 bytes)
|
||||
SESSION_SECRET=your_session_secret_here_32_bytes
|
||||
65
CODE_REVIEW_PLAN.md
Normal file
65
CODE_REVIEW_PLAN.md
Normal file
@@ -0,0 +1,65 @@
|
||||
# LLM Proxy Code Review Plan
|
||||
|
||||
## Overview
|
||||
The **LLM Proxy** project is a Rust-based middleware designed to provide a unified interface for multiple Large Language Models (LLMs). Based on the repository structure, the project aims to implement a high-performance proxy server (`src/`) that handles request routing, usage tracking, and billing logic. A static dashboard (`static/`) provides a management interface for monitoring consumption and managing API keys. The architecture leverages Rust's async capabilities for efficient request handling and SQLite for persistent state management.
|
||||
|
||||
## Review Phases
|
||||
|
||||
### Phase 1: Backend Architecture & Rust Logic (@code-reviewer)
|
||||
- **Focus on:**
|
||||
- **Core Proxy Logic:** Efficiency of the request/response pipeline and streaming support.
|
||||
- **State Management:** Thread-safety and shared state patterns using `Arc` and `Mutex`/`RwLock`.
|
||||
- **Error Handling:** Use of idiomatic Rust error types and propagation.
|
||||
- **Async Performance:** Proper use of `tokio` or similar runtimes to avoid blocking the executor.
|
||||
- **Rust Idioms:** Adherence to Clippy suggestions and standard Rust naming conventions.
|
||||
|
||||
### Phase 2: Security & Authentication Audit (@security-auditor)
|
||||
- **Focus on:**
|
||||
- **API Key Management:** Secure storage, masking in logs, and rotation mechanisms.
|
||||
- **JWT Handling:** Validation logic, signature verification, and expiration checks.
|
||||
- **Input Validation:** Sanitization of prompts and configuration parameters to prevent injection.
|
||||
- **Dependency Audit:** Scanning for known vulnerabilities in the `Cargo.lock` using `cargo-audit`.
|
||||
|
||||
### Phase 3: Database & Data Integrity Review (@database-optimizer)
|
||||
- **Focus on:**
|
||||
- **Schema Design:** Efficiency of the SQLite schema for usage tracking and billing.
|
||||
- **Migration Strategy:** Robustness of the migration scripts to prevent data loss.
|
||||
- **Usage Tracking:** Accuracy of token counting and concurrency handling during increments.
|
||||
- **Query Optimization:** Identifying potential bottlenecks in reporting queries.
|
||||
|
||||
### Phase 4: Frontend & Dashboard Review (@frontend-developer)
|
||||
- **Focus on:**
|
||||
- **Vanilla JS Patterns:** Review of Web Components and modular JS in `static/js`.
|
||||
- **Security:** Protection against XSS in the dashboard and secure handling of local storage.
|
||||
- **UI/UX Consistency:** Ensuring the management interface is intuitive and responsive.
|
||||
- **API Integration:** Robustness of the frontend's communication with the Rust backend.
|
||||
|
||||
### Phase 5: Infrastructure & Deployment Review (@devops-engineer)
|
||||
- **Focus on:**
|
||||
- **Dockerfile Optimization:** Multi-stage builds to minimize image size and attack surface.
|
||||
- **Resource Limits:** Configuration of CPU/Memory limits for the proxy container.
|
||||
- **Deployment Docs:** Clarity of the setup process and environment variable documentation.
|
||||
|
||||
## Timeline (Gantt)
|
||||
|
||||
```mermaid
|
||||
gantt
|
||||
title LLM Proxy Code Review Timeline (March 2026)
|
||||
dateFormat YYYY-MM-DD
|
||||
section Backend & Security
|
||||
Architecture & Rust Logic (Phase 1) :active, p1, 2026-03-06, 1d
|
||||
Security & Auth Audit (Phase 2) :p2, 2026-03-07, 1d
|
||||
section Data & Frontend
|
||||
Database & Integrity (Phase 3) :p3, 2026-03-07, 1d
|
||||
Frontend & Dashboard (Phase 4) :p4, 2026-03-08, 1d
|
||||
section DevOps
|
||||
Infra & Deployment (Phase 5) :p5, 2026-03-08, 1d
|
||||
Final Review & Sign-off :2026-03-08, 4h
|
||||
```
|
||||
|
||||
## Success Criteria
|
||||
- **Security:** Zero high-priority vulnerabilities identified; all API keys masked in logs.
|
||||
- **Performance:** Proxy overhead is minimal (<10ms latency addition); queries are indexed.
|
||||
- **Maintainability:** Code passes all linting (`cargo clippy`) and formatting (`cargo fmt`) checks.
|
||||
- **Documentation:** README and deployment guides are up-to-date and accurate.
|
||||
- **Reliability:** Usage tracking matches actual API consumption with 99.9% accuracy.
|
||||
201
Cargo.lock
generated
201
Cargo.lock
generated
@@ -8,6 +8,41 @@ version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa"
|
||||
|
||||
[[package]]
|
||||
name = "aead"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0"
|
||||
dependencies = [
|
||||
"crypto-common",
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aes"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cipher",
|
||||
"cpufeatures",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aes-gcm"
|
||||
version = "0.10.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1"
|
||||
dependencies = [
|
||||
"aead",
|
||||
"aes",
|
||||
"cipher",
|
||||
"ctr",
|
||||
"ghash",
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.7.8"
|
||||
@@ -541,9 +576,33 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
"rand_core 0.6.4",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ctr"
|
||||
version = "0.9.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835"
|
||||
dependencies = [
|
||||
"cipher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dashmap"
|
||||
version = "6.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
"hashbrown 0.14.5",
|
||||
"lock_api",
|
||||
"once_cell",
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.10.0"
|
||||
@@ -895,6 +954,37 @@ dependencies = [
|
||||
"wasip3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ghash"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1"
|
||||
dependencies = [
|
||||
"opaque-debug",
|
||||
"polyval",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "governor"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0746aa765db78b521451ef74221663b57ba595bf83f75d0ce23cc09447c8139f"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"dashmap",
|
||||
"futures-sink",
|
||||
"futures-timer",
|
||||
"futures-util",
|
||||
"no-std-compat",
|
||||
"nonzero_ext",
|
||||
"parking_lot",
|
||||
"portable-atomic",
|
||||
"quanta",
|
||||
"rand 0.8.5",
|
||||
"smallvec",
|
||||
"spinning_top",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.4.13"
|
||||
@@ -923,6 +1013,12 @@ dependencies = [
|
||||
"ahash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.14.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.15.5"
|
||||
@@ -1431,6 +1527,7 @@ checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77"
|
||||
name = "llm-proxy"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"aes-gcm",
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"async-stream",
|
||||
@@ -1443,8 +1540,10 @@ dependencies = [
|
||||
"config",
|
||||
"dotenvy",
|
||||
"futures",
|
||||
"governor",
|
||||
"headers",
|
||||
"hex",
|
||||
"hmac",
|
||||
"image",
|
||||
"insta",
|
||||
"mime",
|
||||
@@ -1454,6 +1553,7 @@ dependencies = [
|
||||
"reqwest-eventsource",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"sqlx",
|
||||
"tempfile",
|
||||
"thiserror 1.0.69",
|
||||
@@ -1598,6 +1698,12 @@ dependencies = [
|
||||
"pxfm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "no-std-compat"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
|
||||
|
||||
[[package]]
|
||||
name = "nom"
|
||||
version = "7.1.3"
|
||||
@@ -1608,6 +1714,12 @@ dependencies = [
|
||||
"minimal-lexical",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nonzero_ext"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
|
||||
|
||||
[[package]]
|
||||
name = "nu-ansi-term"
|
||||
version = "0.50.3"
|
||||
@@ -1669,6 +1781,12 @@ version = "1.21.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
||||
|
||||
[[package]]
|
||||
name = "opaque-debug"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
|
||||
|
||||
[[package]]
|
||||
name = "ordered-multimap"
|
||||
version = "0.4.3"
|
||||
@@ -1824,6 +1942,24 @@ dependencies = [
|
||||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "polyval"
|
||||
version = "0.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"opaque-debug",
|
||||
"universal-hash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "portable-atomic"
|
||||
version = "1.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
|
||||
|
||||
[[package]]
|
||||
name = "potential_utf"
|
||||
version = "0.1.4"
|
||||
@@ -1897,6 +2033,21 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quanta"
|
||||
version = "0.12.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"raw-cpuid",
|
||||
"wasi",
|
||||
"web-sys",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quick-error"
|
||||
version = "2.0.1"
|
||||
@@ -2032,6 +2183,15 @@ dependencies = [
|
||||
"getrandom 0.3.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "raw-cpuid"
|
||||
version = "11.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.5.18"
|
||||
@@ -2453,6 +2613,15 @@ dependencies = [
|
||||
"lock_api",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spinning_top"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spki"
|
||||
version = "0.7.3"
|
||||
@@ -3158,6 +3327,16 @@ version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
|
||||
|
||||
[[package]]
|
||||
name = "universal-hash"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea"
|
||||
dependencies = [
|
||||
"crypto-common",
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "untrusted"
|
||||
version = "0.9.0"
|
||||
@@ -3411,6 +3590,28 @@ dependencies = [
|
||||
"wasite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||
dependencies = [
|
||||
"winapi-i686-pc-windows-gnu",
|
||||
"winapi-x86_64-pc-windows-gnu",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-i686-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||
|
||||
[[package]]
|
||||
name = "windows-core"
|
||||
version = "0.62.2"
|
||||
|
||||
@@ -13,7 +13,8 @@ repository = ""
|
||||
axum = { version = "0.8", features = ["macros", "ws"] }
|
||||
tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "net", "time", "signal", "fs"] }
|
||||
tower = "0.5"
|
||||
tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip", "fs"] }
|
||||
tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip", "fs", "set-header", "limit"] }
|
||||
governor = "0.7"
|
||||
|
||||
# ========== HTTP Clients ==========
|
||||
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
|
||||
@@ -46,6 +47,9 @@ mime = "0.3"
|
||||
anyhow = "1.0"
|
||||
thiserror = "1.0"
|
||||
bcrypt = "0.15"
|
||||
aes-gcm = "0.10"
|
||||
hmac = "0.12"
|
||||
sha2 = "0.10"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
uuid = { version = "1.0", features = ["v4", "serde"] }
|
||||
futures = "0.3"
|
||||
|
||||
480
DATABASE_REVIEW.md
Normal file
480
DATABASE_REVIEW.md
Normal file
@@ -0,0 +1,480 @@
|
||||
# Database Review Report for LLM-Proxy Repository
|
||||
|
||||
**Review Date:** 2025-03-06
|
||||
**Reviewer:** Database Optimization Expert
|
||||
**Repository:** llm-proxy
|
||||
**Focus Areas:** Schema Design, Query Optimization, Migration Strategy, Data Integrity, Usage Tracking Accuracy
|
||||
|
||||
## Executive Summary
|
||||
|
||||
The llm-proxy database implementation demonstrates solid foundation with appropriate table structures and clear separation of concerns. However, several areas require improvement to ensure scalability, data consistency, and performance as usage grows. Key findings include:
|
||||
|
||||
1. **Schema Design**: Generally normalized but missing foreign key enforcement and some critical indexes.
|
||||
2. **Query Optimization**: Well-optimized for most queries but missing composite indexes for common filtering patterns.
|
||||
3. **Migration Strategy**: Ad-hoc migration approach that may cause issues with schema evolution.
|
||||
4. **Data Integrity**: Potential race conditions in usage tracking and missing transaction boundaries.
|
||||
5. **Usage Tracking**: Generally accurate but risk of inconsistent state between related tables.
|
||||
|
||||
This report provides detailed analysis and actionable recommendations for each area.
|
||||
|
||||
## 1. Schema Design Review
|
||||
|
||||
### Tables Overview
|
||||
|
||||
The database consists of 6 main tables:
|
||||
|
||||
1. **clients**: Client management with usage aggregates
|
||||
2. **llm_requests**: Request logging with token counts and costs
|
||||
3. **provider_configs**: Provider configuration and credit balances
|
||||
4. **model_configs**: Model-specific configuration and cost overrides
|
||||
5. **users**: Dashboard user authentication
|
||||
6. **client_tokens**: API token storage for client authentication
|
||||
|
||||
### Normalization Assessment
|
||||
|
||||
**Strengths:**
|
||||
- Tables follow 3rd Normal Form (3NF) with appropriate separation
|
||||
- Foreign key relationships properly defined
|
||||
- No obvious data duplication across tables
|
||||
|
||||
**Areas for Improvement:**
|
||||
- **Denormalized aggregates**: `clients.total_requests`, `total_tokens`, `total_cost` are derived from `llm_requests`. This introduces risk of inconsistency.
|
||||
- **Provider credit balance**: Stored in `provider_configs` but also updated based on `llm_requests`. No audit trail for balance changes.
|
||||
|
||||
### Data Type Analysis
|
||||
|
||||
**Appropriate Choices:**
|
||||
- INTEGER for token counts (cast from u32 to i64)
|
||||
- REAL for monetary values
|
||||
- DATETIME for timestamps using SQLite's CURRENT_TIMESTAMP
|
||||
- TEXT for identifiers with appropriate length
|
||||
|
||||
**Potential Issues:**
|
||||
- `llm_requests.request_body` and `response_body` defined as TEXT but always set to NULL - consider removing or making optional columns.
|
||||
- `provider_configs.billing_mode` added via migration but default value not consistently applied to existing rows.
|
||||
|
||||
### Constraints and Foreign Keys
|
||||
|
||||
**Current Constraints:**
|
||||
- Primary keys defined for all tables
|
||||
- UNIQUE constraints on `clients.client_id`, `users.username`, `client_tokens.token`
|
||||
- Foreign key definitions present but **not enforced** (SQLite default)
|
||||
|
||||
**Missing Constraints:**
|
||||
- NOT NULL constraints missing on several columns where nullability not intended
|
||||
- CHECK constraints for positive values (`credit_balance >= 0`)
|
||||
- Foreign key enforcement not enabled
|
||||
|
||||
## 2. Query Optimization Analysis
|
||||
|
||||
### Indexing Strategy
|
||||
|
||||
**Existing Indexes:**
|
||||
- `idx_clients_client_id` - Essential for client lookups
|
||||
- `idx_clients_created_at` - Useful for chronological listing
|
||||
- `idx_llm_requests_timestamp` - Critical for time-based queries
|
||||
- `idx_llm_requests_client_id` - Supports client-specific queries
|
||||
- `idx_llm_requests_provider` - Good for provider breakdowns
|
||||
- `idx_llm_requests_status` - Low cardinality but acceptable
|
||||
- `idx_client_tokens_token` UNIQUE - Essential for authentication
|
||||
- `idx_client_tokens_client_id` - Supports token management
|
||||
|
||||
**Missing Critical Indexes:**
|
||||
1. `model_configs.provider_id` - Foreign key column used in JOINs
|
||||
2. `llm_requests(client_id, timestamp)` - Composite index for client time-series queries
|
||||
3. `llm_requests(provider, timestamp)` - For provider performance analysis
|
||||
4. `llm_requests(status, timestamp)` - For error trend analysis
|
||||
|
||||
### N+1 Query Detection
|
||||
|
||||
**Well-Optimized Areas:**
|
||||
- Model configuration caching prevents repeated database hits
|
||||
- Provider configs loaded in batch for dashboard display
|
||||
- Client listing uses single efficient query
|
||||
|
||||
**Potential N+1 Patterns:**
|
||||
- In `server/mod.rs` list_models function, cache lookup per model but this is in-memory
|
||||
- No significant database N+1 issues identified
|
||||
|
||||
### Inefficient Query Patterns
|
||||
|
||||
**Query 1: Time-series aggregation with strftime()**
|
||||
```sql
|
||||
SELECT strftime('%Y-%m-%d', timestamp) as date, ...
|
||||
FROM llm_requests
|
||||
WHERE 1=1 {}
|
||||
GROUP BY date, client_id, provider, model
|
||||
ORDER BY date DESC
|
||||
LIMIT 200
|
||||
```
|
||||
**Issue:** Function on indexed column prevents index utilization for the WHERE clause when filtering by timestamp range.
|
||||
|
||||
**Recommendation:** Store computed date column or use range queries on timestamp directly.
|
||||
|
||||
**Query 2: Today's stats using strftime()**
|
||||
```sql
|
||||
WHERE strftime('%Y-%m-%d', timestamp) = ?
|
||||
```
|
||||
**Issue:** Non-sargable query prevents index usage.
|
||||
|
||||
**Recommendation:** Use range query:
|
||||
```sql
|
||||
WHERE timestamp >= date(?) AND timestamp < date(?, '+1 day')
|
||||
```
|
||||
|
||||
### Recommended Index Additions
|
||||
|
||||
```sql
|
||||
-- Composite indexes for common query patterns
|
||||
CREATE INDEX idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp);
|
||||
CREATE INDEX idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp);
|
||||
CREATE INDEX idx_llm_requests_status_timestamp ON llm_requests(status, timestamp);
|
||||
|
||||
-- Foreign key index
|
||||
CREATE INDEX idx_model_configs_provider_id ON model_configs(provider_id);
|
||||
|
||||
-- Optional: Covering index for client usage queries
|
||||
CREATE INDEX idx_clients_usage ON clients(client_id, total_requests, total_tokens, total_cost);
|
||||
```
|
||||
|
||||
## 3. Migration Strategy Assessment
|
||||
|
||||
### Current Approach
|
||||
|
||||
The migration system uses a hybrid approach:
|
||||
|
||||
1. **Schema synchronization**: `CREATE TABLE IF NOT EXISTS` on startup
|
||||
2. **Ad-hoc migrations**: `ALTER TABLE` statements with error suppression
|
||||
3. **Single migration file**: `migrations/001-add-billing-mode.sql` with transaction wrapper
|
||||
|
||||
**Pros:**
|
||||
- Simple to understand and maintain
|
||||
- Automatic schema creation for new deployments
|
||||
- Error suppression prevents crashes on column existence
|
||||
|
||||
**Cons:**
|
||||
- No version tracking of applied migrations
|
||||
- Potential for inconsistent schema across deployments
|
||||
- `ALTER TABLE` error suppression hides genuine schema issues
|
||||
- No rollback capability
|
||||
|
||||
### Risks and Limitations
|
||||
|
||||
1. **Schema Drift**: Different instances may have different schemas if migrations are applied out of order
|
||||
2. **Data Loss Risk**: No backup/verification before schema changes
|
||||
3. **Production Issues**: Error suppression could mask migration failures until runtime
|
||||
|
||||
### Recommendations
|
||||
|
||||
1. **Implement Proper Migration Tooling**: Use `sqlx migrate` or similar versioned migration system
|
||||
2. **Add Migration Version Table**: Track applied migrations and checksum verification
|
||||
3. **Separate Migration Scripts**: One file per migration with up/down directions
|
||||
4. **Pre-deployment Validation**: Schema checks in CI/CD pipeline
|
||||
5. **Backup Strategy**: Automatic backups before migration execution
|
||||
|
||||
## 4. Data Integrity Evaluation
|
||||
|
||||
### Foreign Key Enforcement
|
||||
|
||||
**Critical Issue:** Foreign key constraints are defined but **not enforced** in SQLite.
|
||||
|
||||
**Impact:** Orphaned records, inconsistent referential integrity.
|
||||
|
||||
**Solution:** Enable foreign key support in connection string:
|
||||
```rust
|
||||
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?
|
||||
.create_if_missing(true)
|
||||
.pragma("foreign_keys", "ON");
|
||||
```
|
||||
|
||||
### Transaction Usage
|
||||
|
||||
**Good Patterns:**
|
||||
- Request logging uses transactions for insert + provider balance update
|
||||
- Atomic UPDATE for client usage statistics
|
||||
|
||||
**Problematic Areas:**
|
||||
|
||||
1. **Split Transactions**: Client usage update and request logging are in separate transactions
|
||||
- In `logging/mod.rs`: `insert_log` transaction includes provider balance update
|
||||
- In `utils/streaming.rs`: Client usage updated separately after logging
|
||||
- **Risk**: Partial updates if one transaction fails
|
||||
|
||||
2. **No Transaction for Client Creation**: Client and token creation not atomic
|
||||
|
||||
**Recommendations:**
|
||||
- Wrap client usage update within the same transaction as request logging
|
||||
- Use transaction for client + token creation
|
||||
- Consider using savepoints for complex operations
|
||||
|
||||
### Race Conditions and Consistency
|
||||
|
||||
**Potential Race Conditions:**
|
||||
1. **Provider credit balance**: Concurrent requests may cause lost updates
|
||||
- Current: `UPDATE provider_configs SET credit_balance = credit_balance - ?`
|
||||
- SQLite provides serializable isolation, but negative balances not prevented
|
||||
|
||||
2. **Client usage aggregates**: Concurrent updates to `total_requests`, `total_tokens`, `total_cost`
|
||||
- Similar UPDATE pattern, generally safe but consider idempotency
|
||||
|
||||
**Recommendations:**
|
||||
- Add check constraint: `CHECK (credit_balance >= 0)`
|
||||
- Implement idempotent request logging with unique request IDs
|
||||
- Consider optimistic concurrency control for critical balances
|
||||
|
||||
## 5. Usage Tracking Accuracy
|
||||
|
||||
### Token Counting Methodology
|
||||
|
||||
**Current Approach:**
|
||||
- Prompt tokens: Estimated using provider-specific estimators
|
||||
- Completion tokens: Estimated or from provider real usage data
|
||||
- Cache tokens: Separately tracked for cache-aware pricing
|
||||
|
||||
**Strengths:**
|
||||
- Fallback to estimation when provider doesn't report usage
|
||||
- Cache token differentiation for accurate pricing
|
||||
|
||||
**Weaknesses:**
|
||||
- Estimation may differ from actual provider counts
|
||||
- No validation of provider-reported token counts
|
||||
|
||||
### Cost Calculation
|
||||
|
||||
**Well Implemented:**
|
||||
- Model-specific cost overrides via `model_configs`
|
||||
- Cache-aware pricing when supported by registry
|
||||
- Provider fallback calculations
|
||||
|
||||
**Potential Issues:**
|
||||
- Floating-point precision for monetary calculations
|
||||
- No rounding strategy for fractional cents
|
||||
|
||||
### Update Consistency
|
||||
|
||||
**Inconsistency Risk:** Client aggregates updated separately from request logging.
|
||||
|
||||
**Example Flow:**
|
||||
1. Request log inserted and provider balance updated (transaction)
|
||||
2. Client usage updated (separate operation)
|
||||
3. If step 2 fails, client stats undercount usage
|
||||
|
||||
**Solution:** Include client update in the same transaction:
|
||||
```rust
|
||||
// In insert_log function, add:
|
||||
UPDATE clients
|
||||
SET total_requests = total_requests + 1,
|
||||
total_tokens = total_tokens + ?,
|
||||
total_cost = total_cost + ?
|
||||
WHERE client_id = ?;
|
||||
```
|
||||
|
||||
### Financial Accuracy
|
||||
|
||||
**Good Practices:**
|
||||
- Token-level granularity for cost calculation
|
||||
- Separation of prompt/completion/cache pricing
|
||||
- Database persistence for audit trail
|
||||
|
||||
**Recommendations:**
|
||||
1. **Audit Trail**: Add `balance_transactions` table for provider credit changes
|
||||
2. **Rounding Policy**: Define rounding strategy (e.g., to 6 decimal places)
|
||||
3. **Validation**: Periodic reconciliation of aggregates vs. detail records
|
||||
|
||||
## 6. Performance Recommendations
|
||||
|
||||
### Schema Improvements
|
||||
|
||||
1. **Partitioning Strategy**: For high-volume `llm_requests`, consider:
|
||||
- Monthly partitioning by timestamp
|
||||
- Archive old data to separate tables
|
||||
|
||||
2. **Data Retention Policy**: Implement automatic cleanup of old request logs
|
||||
```sql
|
||||
DELETE FROM llm_requests WHERE timestamp < date('now', '-90 days');
|
||||
```
|
||||
|
||||
3. **Column Optimization**: Remove unused `request_body`, `response_body` columns or implement compression
|
||||
|
||||
### Query Optimizations
|
||||
|
||||
1. **Avoid Functions on Indexed Columns**: Rewrite date queries as range queries
|
||||
2. **Batch Updates**: Consider batch updates for client usage instead of per-request
|
||||
3. **Read Replicas**: For dashboard queries, consider separate read connection
|
||||
|
||||
### Connection Pooling
|
||||
|
||||
**Current:** SQLx connection pool with default settings
|
||||
|
||||
**Recommendations:**
|
||||
- Configure pool size based on expected concurrency
|
||||
- Implement connection health checks
|
||||
- Monitor pool utilization metrics
|
||||
|
||||
### Monitoring Setup
|
||||
|
||||
**Essential Metrics:**
|
||||
- Query execution times (slow query logging)
|
||||
- Index usage statistics
|
||||
- Table growth trends
|
||||
- Connection pool utilization
|
||||
|
||||
**Implementation:**
|
||||
- Add `sqlx::metrics` integration
|
||||
- Regular `ANALYZE` execution for query planner
|
||||
- Dashboard for database health monitoring
|
||||
|
||||
## 7. Security Considerations
|
||||
|
||||
### Data Protection
|
||||
|
||||
**Sensitive Data:**
|
||||
- `provider_configs.api_key` - Should be encrypted at rest
|
||||
- `users.password_hash` - Already hashed with bcrypt
|
||||
- `client_tokens.token` - Plain text storage
|
||||
|
||||
**Recommendations:**
|
||||
- Encrypt API keys using libsodium or similar
|
||||
- Implement token hashing (similar to password hashing)
|
||||
- Regular security audits of authentication flows
|
||||
|
||||
### SQL Injection Prevention
|
||||
|
||||
**Good Practices:**
|
||||
- Use sqlx query builder with parameter binding
|
||||
- No raw SQL concatenation observed in code review
|
||||
|
||||
**Verification Needed:** Ensure all dynamic SQL uses parameterized queries
|
||||
|
||||
### Access Controls
|
||||
|
||||
**Database Level:**
|
||||
- SQLite lacks built-in user management
|
||||
- Consider file system permissions for database file
|
||||
- Application-level authentication is primary control
|
||||
|
||||
## 8. Summary of Critical Issues
|
||||
|
||||
**Priority 1 (Critical):**
|
||||
1. Foreign key constraints not enabled
|
||||
2. Split transactions risking data inconsistency
|
||||
3. Missing composite indexes for common queries
|
||||
|
||||
**Priority 2 (High):**
|
||||
1. No proper migration versioning system
|
||||
2. Potential race conditions in balance updates
|
||||
3. Non-sargable date queries impacting performance
|
||||
|
||||
**Priority 3 (Medium):**
|
||||
1. Denormalized aggregates without consistency guarantees
|
||||
2. No data retention policy for request logs
|
||||
3. Missing check constraints for data validation
|
||||
|
||||
## 9. Recommended Action Plan
|
||||
|
||||
### Phase 1: Immediate Fixes (1-2 weeks)
|
||||
1. Enable foreign key constraints in database connection
|
||||
2. Add composite indexes for common query patterns
|
||||
3. Fix transaction boundaries for client usage updates
|
||||
4. Rewrite non-sargable date queries
|
||||
|
||||
### Phase 2: Short-term Improvements (3-4 weeks)
|
||||
1. Implement proper migration system with version tracking
|
||||
2. Add check constraints for data validation
|
||||
3. Implement connection pooling configuration
|
||||
4. Create database monitoring dashboard
|
||||
|
||||
### Phase 3: Long-term Enhancements (2-3 months)
|
||||
1. Implement data retention and archiving strategy
|
||||
2. Add audit trail for provider balance changes
|
||||
3. Consider partitioning for high-volume tables
|
||||
4. Implement encryption for sensitive data
|
||||
|
||||
### Phase 4: Ongoing Maintenance
|
||||
1. Regular index maintenance and query plan analysis
|
||||
2. Periodic reconciliation of aggregate vs. detail data
|
||||
3. Security audits and dependency updates
|
||||
4. Performance benchmarking and optimization
|
||||
|
||||
---
|
||||
|
||||
## Appendices
|
||||
|
||||
### A. Sample Migration Implementation
|
||||
|
||||
```sql
|
||||
-- migrations/002-enable-foreign-keys.sql
|
||||
PRAGMA foreign_keys = ON;
|
||||
|
||||
-- migrations/003-add-composite-indexes.sql
|
||||
CREATE INDEX idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp);
|
||||
CREATE INDEX idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp);
|
||||
CREATE INDEX idx_model_configs_provider_id ON model_configs(provider_id);
|
||||
```
|
||||
|
||||
### B. Transaction Fix Example
|
||||
|
||||
```rust
|
||||
async fn insert_log(pool: &SqlitePool, log: RequestLog) -> Result<(), sqlx::Error> {
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
// Insert or ignore client
|
||||
sqlx::query("INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')")
|
||||
.bind(&log.client_id)
|
||||
.bind(&log.client_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// Insert request log
|
||||
sqlx::query("INSERT INTO llm_requests ...")
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// Update provider balance
|
||||
if log.cost > 0.0 {
|
||||
sqlx::query("UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')")
|
||||
.bind(log.cost)
|
||||
.bind(&log.provider)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Update client aggregates within same transaction
|
||||
sqlx::query("UPDATE clients SET total_requests = total_requests + 1, total_tokens = total_tokens + ?, total_cost = total_cost + ? WHERE client_id = ?")
|
||||
.bind(log.total_tokens as i64)
|
||||
.bind(log.cost)
|
||||
.bind(&log.client_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
tx.commit().await?;
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
### C. Monitoring Query Examples
|
||||
|
||||
```sql
|
||||
-- Identify unused indexes
|
||||
SELECT * FROM sqlite_master
|
||||
WHERE type = 'index'
|
||||
AND name NOT IN (
|
||||
SELECT DISTINCT name
|
||||
FROM sqlite_stat1
|
||||
WHERE tbl = 'llm_requests'
|
||||
);
|
||||
|
||||
-- Table size analysis
|
||||
SELECT name, (pgsize * page_count) / 1024 / 1024 as size_mb
|
||||
FROM dbstat
|
||||
WHERE name = 'llm_requests';
|
||||
|
||||
-- Query performance analysis (requires EXPLAIN QUERY PLAN)
|
||||
EXPLAIN QUERY PLAN
|
||||
SELECT * FROM llm_requests
|
||||
WHERE client_id = ? AND timestamp >= ?;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
*This report provides a comprehensive analysis of the current database implementation and actionable recommendations for improvement. Regular review and iteration will ensure the database continues to meet performance, consistency, and scalability requirements as the application grows.*
|
||||
99
PLAN.md
Normal file
99
PLAN.md
Normal file
@@ -0,0 +1,99 @@
|
||||
# Project Plan: LLM Proxy Enhancements & Security Upgrade
|
||||
|
||||
This document outlines the roadmap for standardizing frontend security, cleaning up the codebase, upgrading session management to HMAC-signed tokens, and extending integration testing.
|
||||
|
||||
## Phase 1: Frontend Security Standardization
|
||||
**Primary Agent:** `frontend-developer`
|
||||
|
||||
- [x] Audit `static/js/pages/users.js` for manual HTML string concatenation.
|
||||
- [x] Replace custom escaping or unescaped injections with `window.api.escapeHtml`.
|
||||
- [x] Verify user list and user detail rendering for XSS vulnerabilities.
|
||||
|
||||
## Phase 2: Codebase Cleanup
|
||||
**Primary Agent:** `backend-developer`
|
||||
|
||||
- [x] Identify and remove unused imports in `src/config/mod.rs`.
|
||||
- [x] Identify and remove unused imports in `src/providers/mod.rs`.
|
||||
- [x] Run `cargo clippy` and `cargo fmt` to ensure adherence to standards.
|
||||
|
||||
## Phase 3: HMAC Architectural Upgrade
|
||||
**Primary Agents:** `fullstack-developer`, `security-auditor`, `backend-developer`
|
||||
|
||||
### 3.1 Design (Security Auditor)
|
||||
- [x] Define Token Structure: `base64(payload).signature`.
|
||||
- Payload: `{ "session_id": "...", "username": "...", "role": "...", "exp": ... }`
|
||||
- [x] Select HMAC algorithm (HMAC-SHA256).
|
||||
- [x] Define environment variable for secret key: `SESSION_SECRET`.
|
||||
|
||||
### 3.2 Implementation (Backend Developer)
|
||||
- [x] Refactor `src/dashboard/sessions.rs`:
|
||||
- Integrate `hmac` and `sha2` crates (or similar).
|
||||
- Update `create_session` to return signed tokens.
|
||||
- Update `validate_session` to verify signature before checking store.
|
||||
- [x] Implement activity-based session refresh:
|
||||
- If session is valid and >50% through its TTL, extend `expires_at` and issue new signed token.
|
||||
|
||||
### 3.3 Integration (Fullstack Developer)
|
||||
- [x] Update dashboard API handlers to handle new token format.
|
||||
- [x] Update frontend session storage/retrieval if necessary.
|
||||
|
||||
## Phase 4: Extended Integration Testing
|
||||
**Primary Agent:** `qa-automation`
|
||||
|
||||
- [ ] Setup test environment with encrypted key storage enabled.
|
||||
- [ ] Implement end-to-end flow:
|
||||
1. Store encrypted provider key via API.
|
||||
2. Authenticate through Proxy.
|
||||
3. Make proxied LLM request (verifying decryption and usage).
|
||||
- [ ] Validate HMAC token expiration and refresh logic in automated tests.
|
||||
|
||||
## Phase 5: Code Quality & Refactoring
|
||||
**Primary Agent:** `fullstack-developer`
|
||||
|
||||
- [x] Refactor dashboard monolith into modular sub-modules (`auth.rs`, `usage.rs`, etc.).
|
||||
- [x] Standardize error handling and remove `unwrap()` in production paths.
|
||||
- [x] Implement system health metrics and backup functionality.
|
||||
|
||||
---
|
||||
|
||||
# Phase 6: Cache Cost & Provider Audit (ACTIVE)
|
||||
**Primary Agents:** `frontend-developer`, `backend-developer`, `database-optimizer`, `lab-assistant`
|
||||
|
||||
## 6.1 Dashboard UI Updates (@frontend-developer)
|
||||
- [ ] **Update Models Page Modal:** Add input fields for `Cache Read Cost` and `Cache Write Cost` in `static/js/pages/models.js`.
|
||||
- [ ] **API Integration:** Ensure `window.api.put` includes these new cost fields in the request body.
|
||||
- [ ] **Verify Costs Page:** Confirm `static/js/pages/costs.js` displays these rates correctly in the pricing table.
|
||||
|
||||
## 6.2 Provider Audit & Stream Fixes (@backend-developer)
|
||||
- [ ] **Standard DeepSeek Fix:** Modify `src/providers/deepseek.rs` to stop stripping `stream_options` for `deepseek-chat`.
|
||||
- [ ] **Grok Audit:** Verify if Grok correctly returns usage in streaming; it uses `build_openai_body` and doesn't seem to strip it.
|
||||
- [ ] **Gemini Audit:** Confirm Gemini returns `usage_metadata` reliably in the final chunk.
|
||||
- [ ] **Anthropic Audit:** Check if Anthropic streaming requires `include_usage` or similar flags.
|
||||
|
||||
## 6.3 Database & Migration Validation (@database-optimizer)
|
||||
- [ ] **Test Migrations:** Run the server to ensure `ALTER TABLE` logic in `src/database/mod.rs` applies the new columns correctly.
|
||||
- [ ] **Schema Verification:** Verify `model_configs` has `cache_read_cost_per_m` and `cache_write_cost_per_m` columns.
|
||||
|
||||
## 6.4 Token Estimation Refinement (@lab-assistant)
|
||||
- [ ] **Analyze Heuristic:** Review `chars / 4` in `src/utils/tokens.rs`.
|
||||
- [ ] **Background Precise Recount:** Propose a mechanism for a precise token count (using Tiktoken) after the response is finalized.
|
||||
|
||||
## Critical Path
|
||||
Migration Validation → UI Fields → Provider Stream Usage Reporting.
|
||||
|
||||
```mermaid
|
||||
gantt
|
||||
title Phase 6 Timeline
|
||||
dateFormat YYYY-MM-DD
|
||||
section Frontend
|
||||
Models Page UI :2026-03-06, 1d
|
||||
Costs Table Update:after Models Page UI, 1d
|
||||
section Backend
|
||||
DeepSeek Fix :2026-03-06, 1d
|
||||
Provider Audit (Grok/Gemini):after DeepSeek Fix, 2d
|
||||
section Database
|
||||
Migration Test :2026-03-06, 1d
|
||||
section Optimization
|
||||
Token Heuristic Review :2026-03-06, 1d
|
||||
```
|
||||
|
||||
108
README.md
108
README.md
@@ -26,6 +26,17 @@ A unified, high-performance LLM proxy gateway built in Rust. It provides a singl
|
||||
- **Rate Limiting:** Per-client and global rate limits.
|
||||
- **Cache-Aware Costing:** Tracks cache hit/miss tokens for accurate billing.
|
||||
|
||||
## Security
|
||||
|
||||
LLM Proxy is designed with security in mind:
|
||||
|
||||
- **HMAC Session Tokens:** Management dashboard sessions are secured using HMAC-SHA256 signed tokens.
|
||||
- **Encrypted Provider Keys:** Sensitive LLM provider API keys are stored encrypted (AES-256-GCM) in the database.
|
||||
- **Session Refresh:** Activity-based session extension prevents session hijacking while maintaining user convenience.
|
||||
- **XSS Prevention:** Standardized frontend escaping using `window.api.escapeHtml`.
|
||||
|
||||
**Note:** You must define a `SESSION_SECRET` in your `.env` file for secure session signing.
|
||||
|
||||
## Tech Stack
|
||||
|
||||
- **Runtime:** Rust with Tokio.
|
||||
@@ -39,6 +50,7 @@ A unified, high-performance LLM proxy gateway built in Rust. It provides a singl
|
||||
|
||||
- Rust (1.80+)
|
||||
- SQLite3
|
||||
- Docker (optional, for containerized deployment)
|
||||
|
||||
### Quick Start
|
||||
|
||||
@@ -53,10 +65,9 @@ A unified, high-performance LLM proxy gateway built in Rust. It provides a singl
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Edit .env and add your API keys:
|
||||
# SESSION_SECRET=... (Generate a strong random secret)
|
||||
# OPENAI_API_KEY=sk-...
|
||||
# GEMINI_API_KEY=AIza...
|
||||
# DEEPSEEK_API_KEY=sk-...
|
||||
# GROK_API_KEY=gk-... (optional)
|
||||
```
|
||||
|
||||
3. Run the proxy:
|
||||
@@ -66,50 +77,31 @@ A unified, high-performance LLM proxy gateway built in Rust. It provides a singl
|
||||
|
||||
The server starts on `http://localhost:8080` by default.
|
||||
|
||||
### Configuration
|
||||
### Deployment (Docker)
|
||||
|
||||
Edit `config.toml` to customize providers, models, and settings:
|
||||
A multi-stage `Dockerfile` is provided for efficient deployment:
|
||||
|
||||
```toml
|
||||
[server]
|
||||
port = 8080
|
||||
host = "0.0.0.0"
|
||||
```bash
|
||||
# Build the container
|
||||
docker build -t llm-proxy .
|
||||
|
||||
[database]
|
||||
path = "./data/llm_proxy.db"
|
||||
|
||||
[providers.openai]
|
||||
enabled = true
|
||||
default_model = "gpt-4o"
|
||||
|
||||
[providers.gemini]
|
||||
enabled = true
|
||||
default_model = "gemini-2.0-flash"
|
||||
|
||||
[providers.deepseek]
|
||||
enabled = true
|
||||
default_model = "deepseek-reasoner"
|
||||
|
||||
[providers.grok]
|
||||
enabled = false
|
||||
default_model = "grok-beta"
|
||||
|
||||
[providers.ollama]
|
||||
enabled = false
|
||||
base_url = "http://localhost:11434/v1"
|
||||
# Run the container
|
||||
docker run -p 8080:8080 \
|
||||
-e SESSION_SECRET=your-secure-secret \
|
||||
-v ./data:/app/data \
|
||||
llm-proxy
|
||||
```
|
||||
|
||||
## Management Dashboard
|
||||
|
||||
Access the dashboard at `http://localhost:8080`:
|
||||
Access the dashboard at `http://localhost:8080`. The dashboard architecture has been refactored into modular sub-components for better maintainability:
|
||||
|
||||
- **Overview:** Real-time request counters, system health, provider status.
|
||||
- **Analytics:** Time-series charts, filterable by date, client, provider, and model.
|
||||
- **Costs:** Budget tracking, cost breakdown by provider/client/model, projections.
|
||||
- **Clients:** Create, revoke, and rotate API tokens; per-client usage stats.
|
||||
- **Providers:** Enable/disable providers, test connections, configure API keys.
|
||||
- **Monitoring:** Live request stream via WebSocket, response times, error rates.
|
||||
- **Users:** Admin/user management with role-based access control.
|
||||
- **Auth (`/api/auth`):** Login, session management, and password changes.
|
||||
- **Usage (`/api/usage`):** Summary stats, time-series analytics, and provider breakdown.
|
||||
- **Clients (`/api/clients`):** API key management and per-client usage tracking.
|
||||
- **Providers (`/api/providers`):** Provider configuration, status monitoring, and connection testing.
|
||||
- **System (`/api/system`):** Health metrics, live logs, database backups, and global settings.
|
||||
- **Monitoring:** Live request stream via WebSocket.
|
||||
|
||||
### Default Credentials
|
||||
|
||||
@@ -137,46 +129,6 @@ response = client.chat.completions.create(
|
||||
)
|
||||
```
|
||||
|
||||
### Open WebUI
|
||||
```
|
||||
API Base URL: http://your-server:8080/v1
|
||||
API Key: YOUR_CLIENT_API_KEY
|
||||
```
|
||||
|
||||
### cURL
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer YOUR_CLIENT_API_KEY" \
|
||||
-d '{
|
||||
"model": "gpt-4o",
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
## Model Discovery
|
||||
|
||||
The proxy exposes `/v1/models` for OpenAI-compatible client model discovery:
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/v1/models \
|
||||
-H "Authorization: Bearer YOUR_CLIENT_API_KEY"
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Streaming Issues
|
||||
If clients timeout or show "TransferEncodingError", ensure:
|
||||
1. Proxy buffering is disabled in nginx: `proxy_buffering off;`
|
||||
2. Chunked transfer is enabled: `chunked_transfer_encoding on;`
|
||||
3. Timeouts are sufficient: `proxy_read_timeout 7200s;`
|
||||
|
||||
### Provider Errors
|
||||
- Check API keys are set in `.env`
|
||||
- Test provider in dashboard (Settings → Providers → Test)
|
||||
- Review logs: `journalctl -u llm-proxy -f`
|
||||
|
||||
## License
|
||||
|
||||
MIT OR Apache-2.0
|
||||
|
||||
58
RUST_BACKEND_REVIEW.md
Normal file
58
RUST_BACKEND_REVIEW.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# LLM Proxy Rust Backend Code Review Report
|
||||
|
||||
## Executive Summary
|
||||
This code review examines the `llm-proxy` Rust backend, focusing on architectural soundness, performance characteristics, and adherence to Rust idioms. The codebase demonstrates solid engineering with well-structured modular design, comprehensive error handling, and thoughtful async patterns. However, several areas require attention for production readiness, particularly around thread safety, memory efficiency, and error recovery.
|
||||
|
||||
## 1. Core Proxy Logic Review
|
||||
### Strengths
|
||||
- Clean provider abstraction (`Provider` trait).
|
||||
- Streaming support with `AggregatingStream` for token counting.
|
||||
- Model mapping and caching system.
|
||||
|
||||
### Issues Found
|
||||
- **Provider Manager Thread Safety Risk:** O(n) lookups using `Vec` with `RwLock`. Use `DashMap` instead.
|
||||
- **Streaming Memory Inefficiency:** Accumulates complete response content in memory.
|
||||
- **Model Registry Cache Invalidation:** No strategy when config changes via dashboard.
|
||||
|
||||
## 2. State Management Review
|
||||
- **Token Bucket Algorithm Flaw:** Custom implementation lacks thread-safe refill sync. Use `governor` crate.
|
||||
- **Broadcast Channel Unbounded Growth Risk:** Fixed-size (100) may drop messages.
|
||||
- **Database Connection Pool Contention:** SQLite connections shared without differentiation.
|
||||
|
||||
## 3. Error Handling Review
|
||||
- **Error Recovery Missing:** No circuit breaker or retry logic for provider calls.
|
||||
- **Stream Error Logging Gap:** Stream errors swallowed without logging partial usage.
|
||||
|
||||
## 4. Rust Idioms and Performance
|
||||
- **Unnecessary String Cloning:** Frequent cloning in authentication hot paths.
|
||||
- **JSON Parsing Inefficiency:** Multiple passes with `serde_json::Value`. Use typed structs.
|
||||
- **Missing `#[derive(Copy)]`:** For small enums like `CircuitState`.
|
||||
|
||||
## 5. Async Performance
|
||||
- **Blocking Calls:** Token estimation may block async runtime.
|
||||
- **Missing Connection Timeouts:** Only overall timeout, no separate read/write timeouts.
|
||||
- **Unbounded Task Spawn:** For client usage updates under load.
|
||||
|
||||
## 6. Security Considerations
|
||||
- **Token Leakage:** Redact tokens in `Debug` and `Display` impls.
|
||||
- **No Request Size Limits:** Vulnerable to memory exhaustion.
|
||||
|
||||
## 7. Testability
|
||||
- **Mocking Difficulty:** Tight coupling to concrete provider implementations.
|
||||
- **Missing Integration Tests:** No E2E tests for streaming.
|
||||
|
||||
## 8. Summary of Critical Actions
|
||||
### High Priority
|
||||
1. Replace custom token bucket with `governor` crate.
|
||||
2. Fix provider lookup O(n) scaling issue.
|
||||
3. Implement proper error recovery with retries.
|
||||
4. Add request size limits and timeout configurations.
|
||||
|
||||
### Medium Priority
|
||||
1. Reduce string cloning in hot paths.
|
||||
2. Implement cache invalidation for model configs.
|
||||
3. Add connection pooling separation.
|
||||
4. Improve streaming memory efficiency.
|
||||
|
||||
---
|
||||
*Review conducted by: Senior Principal Engineer (Code Reviewer)*
|
||||
58
SECURITY_AUDIT.md
Normal file
58
SECURITY_AUDIT.md
Normal file
@@ -0,0 +1,58 @@
|
||||
# LLM Proxy Security Audit Report
|
||||
|
||||
## Executive Summary
|
||||
A comprehensive security audit of the `llm-proxy` repository was conducted. The audit identified **1 critical vulnerability**, **3 high-risk issues**, **4 medium-risk issues**, and **3 low-risk issues**. The most severe findings include Cross-Site Scripting (XSS) in the dashboard interface and insecure storage of provider API keys in the database.
|
||||
|
||||
## Detailed Findings
|
||||
|
||||
### Critical Risk Vulnerabilities
|
||||
#### **CRITICAL-01: Cross-Site Scripting (XSS) in Dashboard Interface**
|
||||
- **Location**: `static/js/pages/clients.js` (multiple locations).
|
||||
- **Description**: User-controlled data (e.g., `client.id`) inserted directly into HTML or `onclick` handlers without escaping.
|
||||
- **Impact**: Arbitrary JavaScript execution in admin context, potentially stealing session tokens.
|
||||
|
||||
#### **CRITICAL-02: Insecure API Key Storage in Database**
|
||||
- **Location**: `src/database/mod.rs`, `src/providers/mod.rs`, `src/dashboard/providers.rs`.
|
||||
- **Description**: Provider API keys are stored in **plaintext** in the SQLite database.
|
||||
- **Impact**: Compromised database file exposes all provider API keys.
|
||||
|
||||
### High Risk Vulnerabilities
|
||||
#### **HIGH-01: Missing Input Validation and Size Limits**
|
||||
- **Location**: `src/server/mod.rs`, `src/models/mod.rs`.
|
||||
- **Impact**: Denial of Service via large payloads.
|
||||
|
||||
#### **HIGH-02: Sensitive Data Logging Without Encryption**
|
||||
- **Location**: `src/database/mod.rs`, `src/logging/mod.rs`.
|
||||
- **Description**: Full request and response bodies stored in `llm_requests` table without encryption or redaction.
|
||||
|
||||
#### **HIGH-03: Weak Default Credentials and Password Policy**
|
||||
- **Description**: Default admin password is 'admin' with only 4-character minimum password length.
|
||||
|
||||
### Medium Risk Vulnerabilities
|
||||
#### **MEDIUM-01: Missing CSRF Protection**
|
||||
- No CSRF tokens or SameSite cookie attributes for state-changing dashboard endpoints.
|
||||
|
||||
#### **MEDIUM-02: Insecure Session Management**
|
||||
- Session tokens stored in localStorage without HttpOnly flag.
|
||||
- Tokens use simple `session-{uuid}` format.
|
||||
|
||||
#### **MEDIUM-03: Error Information Leakage**
|
||||
- Internal error details exposed to clients in some cases.
|
||||
|
||||
#### **MEDIUM-04: Outdated Dependencies**
|
||||
- Outdated versions of `chrono`, `tokio`, and `reqwest`.
|
||||
|
||||
### Low Risk Vulnerabilities
|
||||
- Missing security headers (CSP, HSTS, X-Frame-Options).
|
||||
- Insufficient rate limiting on dashboard authentication.
|
||||
- No database encryption at rest.
|
||||
|
||||
## Recommendations
|
||||
### Immediate Actions
|
||||
1. **Fix XSS Vulnerabilities:** Implement proper HTML escaping for all user-controlled data.
|
||||
2. **Secure API Key Storage:** Encrypt API keys in database using a library like `ring`.
|
||||
3. **Implement Input Validation:** Add maximum payload size limits (e.g., 10MB).
|
||||
4. **Improve Data Protection:** Add option to disable request/response body logging.
|
||||
|
||||
---
|
||||
*Report generated by Security Auditor Agent on March 6, 2026*
|
||||
BIN
data/backups/llm_proxy.db.20260303T204435Z
Normal file
BIN
data/backups/llm_proxy.db.20260303T204435Z
Normal file
Binary file not shown.
BIN
data/backups/llm_proxy.db.20260303T205057Z
Normal file
BIN
data/backups/llm_proxy.db.20260303T205057Z
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
13
migrations/002-add-indexes.sql
Normal file
13
migrations/002-add-indexes.sql
Normal file
@@ -0,0 +1,13 @@
|
||||
-- Migration: add composite indexes for query performance
|
||||
-- Adds three composite indexes:
|
||||
-- 1. idx_llm_requests_client_timestamp on llm_requests(client_id, timestamp)
|
||||
-- 2. idx_llm_requests_provider_timestamp on llm_requests(provider, timestamp)
|
||||
-- 3. idx_model_configs_provider_id on model_configs(provider_id)
|
||||
|
||||
BEGIN TRANSACTION;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp);
|
||||
CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp);
|
||||
CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id);
|
||||
|
||||
COMMIT;
|
||||
11
server.log
Normal file
11
server.log
Normal file
@@ -0,0 +1,11 @@
|
||||
[2m2026-03-06T20:07:36.737914Z[0m [32m INFO[0m Starting LLM Proxy Gateway v0.1.0
|
||||
[2m2026-03-06T20:07:36.738903Z[0m [32m INFO[0m Configuration loaded from Some("/home/newkirk/Documents/projects/web_projects/llm-proxy/config.toml")
|
||||
[2m2026-03-06T20:07:36.738945Z[0m [32m INFO[0m Encryption initialized
|
||||
[2m2026-03-06T20:07:36.739124Z[0m [32m INFO[0m Connecting to database at ./data/llm_proxy.db
|
||||
[2m2026-03-06T20:07:36.753254Z[0m [32m INFO[0m Database migrations completed
|
||||
[2m2026-03-06T20:07:36.753294Z[0m [32m INFO[0m Database initialized at "./data/llm_proxy.db"
|
||||
[2m2026-03-06T20:07:36.755187Z[0m [32m INFO[0m Fetching model registry from https://models.dev/api.json
|
||||
[2m2026-03-06T20:07:37.000853Z[0m [32m INFO[0m Successfully loaded model registry
|
||||
[2m2026-03-06T20:07:37.001382Z[0m [32m INFO[0m Model config cache initialized
|
||||
[2m2026-03-06T20:07:37.001702Z[0m [33m WARN[0m SESSION_SECRET environment variable not set. Using a randomly generated secret. This will invalidate all sessions on restart. Set SESSION_SECRET to a fixed hex or base64 encoded 32-byte value.
|
||||
[2m2026-03-06T20:07:37.002898Z[0m [32m INFO[0m Server listening on http://0.0.0.0:8082
|
||||
1
server.pid
Normal file
1
server.pid
Normal file
@@ -0,0 +1 @@
|
||||
945904
|
||||
@@ -1,33 +1,40 @@
|
||||
use axum::{extract::FromRequestParts, http::request::Parts};
|
||||
use axum_extra::TypedHeader;
|
||||
use axum_extra::headers::Authorization;
|
||||
use headers::authorization::Bearer;
|
||||
|
||||
use crate::errors::AppError;
|
||||
|
||||
pub struct AuthenticatedClient {
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuthInfo {
|
||||
pub token: String,
|
||||
pub client_id: String,
|
||||
}
|
||||
|
||||
pub struct AuthenticatedClient {
|
||||
pub info: AuthInfo,
|
||||
}
|
||||
|
||||
impl<S> FromRequestParts<S> for AuthenticatedClient
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = AppError;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
// Extract bearer token from Authorization header
|
||||
let TypedHeader(Authorization(bearer)) = TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
|
||||
.await
|
||||
.map_err(|_| AppError::AuthError("Missing or invalid bearer token".to_string()))?;
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
// Retrieve AuthInfo from request extensions, where it was placed by rate_limit_middleware
|
||||
let info = parts
|
||||
.extensions
|
||||
.get::<AuthInfo>()
|
||||
.cloned()
|
||||
.ok_or_else(|| AppError::AuthError("Authentication info not found in request".to_string()))?;
|
||||
|
||||
let token = bearer.token().to_string();
|
||||
Ok(AuthenticatedClient { info })
|
||||
}
|
||||
}
|
||||
|
||||
// Derive client_id from the token prefix
|
||||
let client_id = format!("client_{}", &token[..8.min(token.len())]);
|
||||
impl std::ops::Deref for AuthenticatedClient {
|
||||
type Target = AuthInfo;
|
||||
|
||||
Ok(AuthenticatedClient { token, client_id })
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.info
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use anyhow::Result;
|
||||
use base64::{Engine as _};
|
||||
use config::{Config, File, FileFormat};
|
||||
use hex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
@@ -96,6 +98,7 @@ pub struct AppConfig {
|
||||
pub model_mapping: ModelMappingConfig,
|
||||
pub pricing: PricingConfig,
|
||||
pub config_path: Option<PathBuf>,
|
||||
pub encryption_key: String,
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
@@ -136,7 +139,8 @@ impl AppConfig {
|
||||
.set_default("providers.grok.enabled", true)?
|
||||
.set_default("providers.ollama.base_url", "http://localhost:11434/v1")?
|
||||
.set_default("providers.ollama.enabled", false)?
|
||||
.set_default("providers.ollama.models", Vec::<String>::new())?;
|
||||
.set_default("providers.ollama.models", Vec::<String>::new())?
|
||||
.set_default("encryption_key", "")?;
|
||||
|
||||
// Load from config file if exists
|
||||
// Priority: explicit path arg > LLM_PROXY__CONFIG_PATH env var > ./config.toml
|
||||
@@ -167,6 +171,19 @@ impl AppConfig {
|
||||
let server: ServerConfig = config.get("server")?;
|
||||
let database: DatabaseConfig = config.get("database")?;
|
||||
let providers: ProviderConfig = config.get("providers")?;
|
||||
let encryption_key: String = config.get("encryption_key")?;
|
||||
|
||||
// Validate encryption key length (must be 32 bytes after hex or base64 decoding)
|
||||
if encryption_key.is_empty() {
|
||||
anyhow::bail!("Encryption key is required (LLM_PROXY__ENCRYPTION_KEY environment variable)");
|
||||
}
|
||||
// Try hex decode first, then base64
|
||||
let key_bytes = hex::decode(&encryption_key)
|
||||
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&encryption_key))
|
||||
.map_err(|e| anyhow::anyhow!("Encryption key must be hex or base64 encoded: {}", e))?;
|
||||
if key_bytes.len() != 32 {
|
||||
anyhow::bail!("Encryption key must be 32 bytes (256 bits), got {} bytes", key_bytes.len());
|
||||
}
|
||||
|
||||
// For now, use empty model mapping and pricing (will be populated later)
|
||||
let model_mapping = ModelMappingConfig { patterns: vec![] };
|
||||
@@ -185,6 +202,7 @@ impl AppConfig {
|
||||
model_mapping,
|
||||
pricing,
|
||||
config_path: Some(config_path),
|
||||
encryption_key,
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use axum::{extract::State, response::Json};
|
||||
use axum::{extract::State, http::{HeaderMap, HeaderValue}, response::{Json, IntoResponse}};
|
||||
use bcrypt;
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
@@ -64,14 +64,14 @@ pub(super) async fn handle_login(
|
||||
pub(super) async fn handle_auth_status(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
) -> impl IntoResponse {
|
||||
let token = headers
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
if let Some(token) = token
|
||||
&& let Some(session) = state.session_manager.validate_session(token).await
|
||||
&& let Some((session, new_token)) = state.session_manager.validate_session_with_refresh(token).await
|
||||
{
|
||||
// Look up display_name from DB
|
||||
let display_name = sqlx::query_scalar::<_, Option<String>>(
|
||||
@@ -85,17 +85,23 @@ pub(super) async fn handle_auth_status(
|
||||
.flatten()
|
||||
.unwrap_or_else(|| session.username.clone());
|
||||
|
||||
return Json(ApiResponse::success(serde_json::json!({
|
||||
let mut headers = HeaderMap::new();
|
||||
if let Some(refreshed_token) = new_token {
|
||||
if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) {
|
||||
headers.insert("X-Refreshed-Token", header_value);
|
||||
}
|
||||
}
|
||||
return (headers, Json(ApiResponse::success(serde_json::json!({
|
||||
"authenticated": true,
|
||||
"user": {
|
||||
"username": session.username,
|
||||
"name": display_name,
|
||||
"role": session.role
|
||||
}
|
||||
})));
|
||||
}))));
|
||||
}
|
||||
|
||||
Json(ApiResponse::error("Not authenticated".to_string()))
|
||||
(HeaderMap::new(), Json(ApiResponse::error("Not authenticated".to_string())))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -108,7 +114,7 @@ pub(super) async fn handle_change_password(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(payload): Json<ChangePasswordRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
) -> impl IntoResponse {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Extract the authenticated user from the session token
|
||||
@@ -117,14 +123,24 @@ pub(super) async fn handle_change_password(
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
let session = match token {
|
||||
Some(t) => state.session_manager.validate_session(t).await,
|
||||
None => None,
|
||||
let (session, new_token) = match token {
|
||||
Some(t) => match state.session_manager.validate_session_with_refresh(t).await {
|
||||
Some((session, new_token)) => (Some(session), new_token),
|
||||
None => (None, None),
|
||||
},
|
||||
None => (None, None),
|
||||
};
|
||||
|
||||
let mut response_headers = HeaderMap::new();
|
||||
if let Some(refreshed_token) = new_token {
|
||||
if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) {
|
||||
response_headers.insert("X-Refreshed-Token", header_value);
|
||||
}
|
||||
}
|
||||
|
||||
let username = match session {
|
||||
Some(s) => s.username,
|
||||
None => return Json(ApiResponse::error("Not authenticated".to_string())),
|
||||
None => return (response_headers, Json(ApiResponse::error("Not authenticated".to_string()))),
|
||||
};
|
||||
|
||||
let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = ?")
|
||||
@@ -138,7 +154,7 @@ pub(super) async fn handle_change_password(
|
||||
if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) {
|
||||
let new_hash = match bcrypt::hash(&payload.new_password, 12) {
|
||||
Ok(h) => h,
|
||||
Err(_) => return Json(ApiResponse::error("Failed to hash new password".to_string())),
|
||||
Err(_) => return (response_headers, Json(ApiResponse::error("Failed to hash new password".to_string()))),
|
||||
};
|
||||
|
||||
let update_result = sqlx::query(
|
||||
@@ -150,16 +166,16 @@ pub(super) async fn handle_change_password(
|
||||
.await;
|
||||
|
||||
match update_result {
|
||||
Ok(_) => Json(ApiResponse::success(
|
||||
Ok(_) => (response_headers, Json(ApiResponse::success(
|
||||
serde_json::json!({ "message": "Password updated successfully" }),
|
||||
)),
|
||||
Err(e) => Json(ApiResponse::error(format!("Failed to update database: {}", e))),
|
||||
))),
|
||||
Err(e) => (response_headers, Json(ApiResponse::error(format!("Failed to update database: {}", e)))),
|
||||
}
|
||||
} else {
|
||||
Json(ApiResponse::error("Current password incorrect".to_string()))
|
||||
(response_headers, Json(ApiResponse::error("Current password incorrect".to_string())))
|
||||
}
|
||||
}
|
||||
Err(e) => Json(ApiResponse::error(format!("User not found: {}", e))),
|
||||
Err(e) => (response_headers, Json(ApiResponse::error(format!("User not found: {}", e)))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,19 +196,19 @@ pub(super) async fn handle_logout(
|
||||
}
|
||||
|
||||
/// Helper: Extract and validate a session from the Authorization header.
|
||||
/// Returns the Session if valid, or an error response.
|
||||
/// Returns the Session and optional new token if refreshed, or an error response.
|
||||
pub(super) async fn extract_session(
|
||||
state: &DashboardState,
|
||||
headers: &axum::http::HeaderMap,
|
||||
) -> Result<super::sessions::Session, Json<ApiResponse<serde_json::Value>>> {
|
||||
) -> Result<(super::sessions::Session, Option<String>), Json<ApiResponse<serde_json::Value>>> {
|
||||
let token = headers
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
match token {
|
||||
Some(t) => match state.session_manager.validate_session(t).await {
|
||||
Some(session) => Ok(session),
|
||||
Some(t) => match state.session_manager.validate_session_with_refresh(t).await {
|
||||
Some((session, new_token)) => Ok((session, new_token)),
|
||||
None => Err(Json(ApiResponse::error("Session expired or invalid".to_string()))),
|
||||
},
|
||||
None => Err(Json(ApiResponse::error("Not authenticated".to_string()))),
|
||||
@@ -200,13 +216,14 @@ pub(super) async fn extract_session(
|
||||
}
|
||||
|
||||
/// Helper: Extract session and require admin role.
|
||||
/// Returns session and optional new token if refreshed.
|
||||
pub(super) async fn require_admin(
|
||||
state: &DashboardState,
|
||||
headers: &axum::http::HeaderMap,
|
||||
) -> Result<super::sessions::Session, Json<ApiResponse<serde_json::Value>>> {
|
||||
let session = extract_session(state, headers).await?;
|
||||
) -> Result<(super::sessions::Session, Option<String>), Json<ApiResponse<serde_json::Value>>> {
|
||||
let (session, new_token) = extract_session(state, headers).await?;
|
||||
if session.role != "admin" {
|
||||
return Err(Json(ApiResponse::error("Admin access required".to_string())));
|
||||
}
|
||||
Ok(session)
|
||||
Ok((session, new_token))
|
||||
}
|
||||
|
||||
@@ -33,7 +33,15 @@ pub(super) struct UpdateClientPayload {
|
||||
pub(super) rate_limit_per_minute: Option<i64>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_clients(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
pub(super) async fn handle_get_clients(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
@@ -88,9 +96,10 @@ pub(super) async fn handle_create_client(
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(payload): Json<CreateClientRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
@@ -198,9 +207,10 @@ pub(super) async fn handle_update_client(
|
||||
Path(id): Path<String>,
|
||||
Json(payload): Json<UpdateClientPayload>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
@@ -294,9 +304,10 @@ pub(super) async fn handle_delete_client(
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
@@ -318,8 +329,14 @@ pub(super) async fn handle_delete_client(
|
||||
|
||||
pub(super) async fn handle_client_usage(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Get per-model breakdown for this client
|
||||
@@ -378,8 +395,14 @@ pub(super) async fn handle_client_usage(
|
||||
|
||||
pub(super) async fn handle_get_client_tokens(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
@@ -437,9 +460,10 @@ pub(super) async fn handle_create_client_token(
|
||||
Path(id): Path<String>,
|
||||
Json(payload): Json<CreateTokenRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
@@ -485,9 +509,10 @@ pub(super) async fn handle_delete_client_token(
|
||||
headers: axum::http::HeaderMap,
|
||||
Path((client_id, token_id)): Path<(String, i64)>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
|
||||
@@ -11,10 +11,18 @@ mod users;
|
||||
mod websocket;
|
||||
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
Router,
|
||||
routing::{delete, get, post, put},
|
||||
};
|
||||
use axum::http::{header, HeaderValue};
|
||||
use serde::Serialize;
|
||||
use tower_http::{
|
||||
limit::RequestBodyLimitLayer,
|
||||
set_header::SetResponseHeaderLayer,
|
||||
};
|
||||
|
||||
use crate::state::AppState;
|
||||
use sessions::SessionManager;
|
||||
@@ -52,6 +60,18 @@ impl<T> ApiResponse<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Rate limiting middleware for dashboard routes
|
||||
async fn dashboard_rate_limit_middleware(
|
||||
State(_dashboard_state): State<DashboardState>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, crate::errors::AppError> {
|
||||
// Bypass rate limiting for dashboard routes to prevent "Failed to load statistics"
|
||||
// when the UI makes many concurrent requests on load.
|
||||
// Dashboard endpoints are already secured via auth::require_admin.
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
|
||||
// Dashboard routes
|
||||
pub fn router(state: AppState) -> Router {
|
||||
let session_manager = SessionManager::new(24); // 24-hour session TTL
|
||||
@@ -60,6 +80,26 @@ pub fn router(state: AppState) -> Router {
|
||||
session_manager,
|
||||
};
|
||||
|
||||
// Security headers
|
||||
let csp_header: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
||||
header::CONTENT_SECURITY_POLICY,
|
||||
"default-src 'self'; script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; style-src 'self' 'unsafe-inline' https://cdnjs.cloudflare.com https://fonts.googleapis.com; font-src 'self' https://cdnjs.cloudflare.com https://fonts.gstatic.com; img-src 'self' data:; connect-src 'self' ws:;"
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
let x_frame_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
||||
header::X_FRAME_OPTIONS,
|
||||
"DENY".parse().unwrap(),
|
||||
);
|
||||
let x_content_type_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
||||
header::X_CONTENT_TYPE_OPTIONS,
|
||||
"nosniff".parse().unwrap(),
|
||||
);
|
||||
let strict_transport_security: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
||||
header::STRICT_TRANSPORT_SECURITY,
|
||||
"max-age=31536000; includeSubDomains".parse().unwrap(),
|
||||
);
|
||||
|
||||
Router::new()
|
||||
// Static file serving
|
||||
.fallback_service(tower_http::services::ServeDir::new("static"))
|
||||
@@ -119,5 +159,16 @@ pub fn router(state: AppState) -> Router {
|
||||
"/api/system/settings",
|
||||
get(system::handle_get_settings).post(system::handle_update_settings),
|
||||
)
|
||||
// Security layers
|
||||
.layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit
|
||||
.layer(csp_header)
|
||||
.layer(x_frame_options)
|
||||
.layer(x_content_type_options)
|
||||
.layer(strict_transport_security)
|
||||
// Rate limiting middleware
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
dashboard_state.clone(),
|
||||
dashboard_rate_limit_middleware,
|
||||
))
|
||||
.with_state(dashboard_state)
|
||||
}
|
||||
|
||||
@@ -43,42 +43,17 @@ pub(super) struct ModelListParams {
|
||||
|
||||
pub(super) async fn handle_get_models(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(params): Query<ModelListParams>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let registry = &state.app_state.model_registry;
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// If used_only, fetch the set of models that appear in llm_requests
|
||||
let used_models: Option<std::collections::HashSet<String>> =
|
||||
if params.used_only.unwrap_or(false) {
|
||||
match sqlx::query_scalar::<_, String>(
|
||||
"SELECT DISTINCT model FROM llm_requests",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
{
|
||||
Ok(models) => Some(models.into_iter().collect()),
|
||||
Err(_) => Some(std::collections::HashSet::new()),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Build filter from query params
|
||||
let filter = ModelFilter {
|
||||
provider: params.provider,
|
||||
search: params.search,
|
||||
modality: params.modality,
|
||||
tool_call: params.tool_call,
|
||||
reasoning: params.reasoning,
|
||||
has_cost: params.has_cost,
|
||||
};
|
||||
let sort_by = params.sort_by.unwrap_or_default();
|
||||
let sort_order = params.sort_order.unwrap_or_default();
|
||||
|
||||
// Get filtered and sorted model entries
|
||||
let entries = registry.list_models(&filter, &sort_by, &sort_order);
|
||||
|
||||
// Load overrides from database
|
||||
let db_models_result =
|
||||
sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
|
||||
@@ -95,16 +70,89 @@ pub(super) async fn handle_get_models(
|
||||
|
||||
let mut models_json = Vec::new();
|
||||
|
||||
if params.used_only.unwrap_or(false) {
|
||||
// EXACT USED MODELS LOGIC
|
||||
let used_pairs_result = sqlx::query(
|
||||
"SELECT DISTINCT provider, model FROM llm_requests",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
if let Ok(rows) = used_pairs_result {
|
||||
for row in rows {
|
||||
let provider: String = row.get("provider");
|
||||
let m_key: String = row.get("model");
|
||||
|
||||
let provider_name = match provider.as_str() {
|
||||
"openai" => "OpenAI",
|
||||
"gemini" => "Google Gemini",
|
||||
"deepseek" => "DeepSeek",
|
||||
"grok" => "xAI Grok",
|
||||
"ollama" => "Ollama",
|
||||
_ => provider.as_str(),
|
||||
}.to_string();
|
||||
|
||||
let m_meta = registry.find_model(&m_key);
|
||||
|
||||
let mut enabled = true;
|
||||
let mut prompt_cost = m_meta.and_then(|m| m.cost.as_ref().map(|c| c.input)).unwrap_or(0.0);
|
||||
let mut completion_cost = m_meta.and_then(|m| m.cost.as_ref().map(|c| c.output)).unwrap_or(0.0);
|
||||
let cache_read_cost = m_meta.and_then(|m| m.cost.as_ref().and_then(|c| c.cache_read));
|
||||
let cache_write_cost = m_meta.and_then(|m| m.cost.as_ref().and_then(|c| c.cache_write));
|
||||
let mut mapping = None::<String>;
|
||||
|
||||
if let Some(db_row) = db_models.get(&m_key) {
|
||||
enabled = db_row.get("enabled");
|
||||
if let Some(p) = db_row.get::<Option<f64>, _>("prompt_cost_per_m") {
|
||||
prompt_cost = p;
|
||||
}
|
||||
if let Some(c) = db_row.get::<Option<f64>, _>("completion_cost_per_m") {
|
||||
completion_cost = c;
|
||||
}
|
||||
mapping = db_row.get("mapping");
|
||||
}
|
||||
|
||||
models_json.push(serde_json::json!({
|
||||
"id": m_key,
|
||||
"provider": provider,
|
||||
"provider_name": provider_name,
|
||||
"name": m_meta.map(|m| m.name.clone()).unwrap_or_else(|| m_key.clone()),
|
||||
"enabled": enabled,
|
||||
"prompt_cost": prompt_cost,
|
||||
"completion_cost": completion_cost,
|
||||
"cache_read_cost": cache_read_cost,
|
||||
"cache_write_cost": cache_write_cost,
|
||||
"mapping": mapping,
|
||||
"context_limit": m_meta.and_then(|m| m.limit.as_ref().map(|l| l.context)).unwrap_or(0),
|
||||
"output_limit": m_meta.and_then(|m| m.limit.as_ref().map(|l| l.output)).unwrap_or(0),
|
||||
"modalities": m_meta.and_then(|m| m.modalities.as_ref().map(|mo| serde_json::json!({
|
||||
"input": mo.input,
|
||||
"output": mo.output,
|
||||
}))),
|
||||
"tool_call": m_meta.and_then(|m| m.tool_call),
|
||||
"reasoning": m_meta.and_then(|m| m.reasoning),
|
||||
}));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// REGISTRY LISTING LOGIC
|
||||
// Build filter from query params
|
||||
let filter = ModelFilter {
|
||||
provider: params.provider,
|
||||
search: params.search,
|
||||
modality: params.modality,
|
||||
tool_call: params.tool_call,
|
||||
reasoning: params.reasoning,
|
||||
has_cost: params.has_cost,
|
||||
};
|
||||
let sort_by = params.sort_by.unwrap_or_default();
|
||||
let sort_order = params.sort_order.unwrap_or_default();
|
||||
|
||||
// Get filtered and sorted model entries
|
||||
let entries = registry.list_models(&filter, &sort_by, &sort_order);
|
||||
|
||||
for entry in &entries {
|
||||
let m_key = entry.model_key;
|
||||
|
||||
// Skip models not in the used set (when used_only is active)
|
||||
if let Some(ref used) = used_models {
|
||||
if !used.contains(m_key) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let m_meta = entry.metadata;
|
||||
|
||||
let mut enabled = true;
|
||||
@@ -146,6 +194,7 @@ pub(super) async fn handle_get_models(
|
||||
"reasoning": m_meta.reasoning,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(models_json)))
|
||||
}
|
||||
@@ -156,9 +205,10 @@ pub(super) async fn handle_update_model(
|
||||
Path(id): Path<String>,
|
||||
Json(payload): Json<UpdateModelRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ use std::collections::HashMap;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
use crate::utils::crypto;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct UpdateProviderRequest {
|
||||
@@ -20,7 +21,15 @@ pub(super) struct UpdateProviderRequest {
|
||||
pub(super) billing_mode: Option<String>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_providers(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
pub(super) async fn handle_get_providers(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let registry = &state.app_state.model_registry;
|
||||
let config = &state.app_state.config;
|
||||
let pool = &state.app_state.db_pool;
|
||||
@@ -153,8 +162,14 @@ pub(super) async fn handle_get_providers(State(state): State<DashboardState>) ->
|
||||
|
||||
pub(super) async fn handle_get_provider(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(name): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let registry = &state.app_state.model_registry;
|
||||
let config = &state.app_state.config;
|
||||
let pool = &state.app_state.db_pool;
|
||||
@@ -265,21 +280,44 @@ pub(super) async fn handle_update_provider(
|
||||
Path(name): Path<String>,
|
||||
Json(payload): Json<UpdateProviderRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Update or insert into database (include billing_mode)
|
||||
// Prepare API key encryption if provided
|
||||
let (api_key_to_store, api_key_encrypted_flag) = match &payload.api_key {
|
||||
Some(key) if !key.is_empty() => {
|
||||
match crypto::encrypt(key) {
|
||||
Ok(encrypted) => (Some(encrypted), Some(true)),
|
||||
Err(e) => {
|
||||
warn!("Failed to encrypt API key for provider {}: {}", name, e);
|
||||
return Json(ApiResponse::error(format!("Failed to encrypt API key: {}", e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(_) => {
|
||||
// Empty string means clear the key
|
||||
(None, Some(false))
|
||||
}
|
||||
None => {
|
||||
// Keep existing key, we'll rely on COALESCE in SQL
|
||||
(None, None)
|
||||
}
|
||||
};
|
||||
|
||||
// Update or insert into database (include billing_mode and api_key_encrypted)
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold, billing_mode)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
enabled = excluded.enabled,
|
||||
base_url = excluded.base_url,
|
||||
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
|
||||
api_key_encrypted = COALESCE(excluded.api_key_encrypted, provider_configs.api_key_encrypted),
|
||||
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
|
||||
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
|
||||
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
|
||||
@@ -290,7 +328,8 @@ pub(super) async fn handle_update_provider(
|
||||
.bind(name.to_uppercase())
|
||||
.bind(payload.enabled)
|
||||
.bind(&payload.base_url)
|
||||
.bind(&payload.api_key)
|
||||
.bind(&api_key_to_store)
|
||||
.bind(api_key_encrypted_flag)
|
||||
.bind(payload.credit_balance)
|
||||
.bind(payload.low_credit_threshold)
|
||||
.bind(payload.billing_mode)
|
||||
@@ -326,8 +365,14 @@ pub(super) async fn handle_update_provider(
|
||||
|
||||
pub(super) async fn handle_test_provider(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(name): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let provider = match state.app_state.provider_manager.get_provider(&name).await {
|
||||
|
||||
@@ -1,7 +1,17 @@
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use hmac::{Hmac, Mac};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Sha256, digest::generic_array::GenericArray};
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
|
||||
|
||||
const TOKEN_VERSION: &str = "v2";
|
||||
const REFRESH_WINDOW_MINUTES: i64 = 15; // refresh if token expires within 15 minutes
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Session {
|
||||
@@ -9,51 +19,136 @@ pub struct Session {
|
||||
pub role: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
pub session_id: String, // unique identifier for the session (UUID)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SessionManager {
|
||||
sessions: Arc<RwLock<HashMap<String, Session>>>,
|
||||
sessions: Arc<RwLock<HashMap<String, Session>>>, // key = session_id
|
||||
ttl_hours: i64,
|
||||
secret: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct SessionPayload {
|
||||
session_id: String,
|
||||
username: String,
|
||||
role: String,
|
||||
iat: i64, // issued at (Unix timestamp)
|
||||
exp: i64, // expiry (Unix timestamp)
|
||||
version: String,
|
||||
}
|
||||
|
||||
impl SessionManager {
|
||||
pub fn new(ttl_hours: i64) -> Self {
|
||||
let secret = load_session_secret();
|
||||
Self {
|
||||
sessions: Arc::new(RwLock::new(HashMap::new())),
|
||||
ttl_hours,
|
||||
secret,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new session and return the session token.
|
||||
/// Create a new session and return a signed session token.
|
||||
pub async fn create_session(&self, username: String, role: String) -> String {
|
||||
let token = format!("session-{}", uuid::Uuid::new_v4());
|
||||
let session_id = Uuid::new_v4().to_string();
|
||||
let now = Utc::now();
|
||||
let expires_at = now + Duration::hours(self.ttl_hours);
|
||||
let session = Session {
|
||||
username,
|
||||
role,
|
||||
username: username.clone(),
|
||||
role: role.clone(),
|
||||
created_at: now,
|
||||
expires_at: now + Duration::hours(self.ttl_hours),
|
||||
expires_at,
|
||||
session_id: session_id.clone(),
|
||||
};
|
||||
self.sessions.write().await.insert(token.clone(), session);
|
||||
token
|
||||
// Store session by session_id
|
||||
self.sessions.write().await.insert(session_id.clone(), session);
|
||||
// Create signed token
|
||||
self.create_signed_token(&session_id, &username, &role, now.timestamp(), expires_at.timestamp())
|
||||
}
|
||||
|
||||
/// Validate a session token and return the session if valid and not expired.
|
||||
/// If the token is within the refresh window, returns a new token as well.
|
||||
pub async fn validate_session(&self, token: &str) -> Option<Session> {
|
||||
self.validate_session_with_refresh(token).await.map(|(session, _)| session)
|
||||
}
|
||||
|
||||
/// Validate a session token and return (session, optional new token if refreshed).
|
||||
pub async fn validate_session_with_refresh(&self, token: &str) -> Option<(Session, Option<String>)> {
|
||||
// Legacy token format (UUID)
|
||||
if token.starts_with("session-") {
|
||||
let sessions = self.sessions.read().await;
|
||||
sessions.get(token).and_then(|s| {
|
||||
return sessions.get(token).and_then(|s| {
|
||||
if s.expires_at > Utc::now() {
|
||||
Some(s.clone())
|
||||
Some((s.clone(), None))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
// Signed token format
|
||||
let payload = match verify_signed_token(token, &self.secret) {
|
||||
Ok(p) => p,
|
||||
Err(_) => return None,
|
||||
};
|
||||
|
||||
// Check expiry
|
||||
let now = Utc::now().timestamp();
|
||||
if payload.exp <= now {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Look up session by session_id
|
||||
let sessions = self.sessions.read().await;
|
||||
let session = match sessions.get(&payload.session_id) {
|
||||
Some(s) => s.clone(),
|
||||
None => return None, // session revoked or not found
|
||||
};
|
||||
|
||||
// Ensure session username/role matches (should always match)
|
||||
if session.username != payload.username || session.role != payload.role {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if token is within refresh window (last REFRESH_WINDOW_MINUTES of validity)
|
||||
let refresh_threshold = payload.exp - REFRESH_WINDOW_MINUTES * 60;
|
||||
let new_token = if now >= refresh_threshold {
|
||||
// Generate a new token with same session data but updated iat/exp?
|
||||
// According to activity-based refresh, we should extend the session expiry.
|
||||
// We'll extend from now by ttl_hours (or keep original expiry?).
|
||||
// Let's extend from now by ttl_hours (sliding window).
|
||||
let new_exp = Utc::now() + Duration::hours(self.ttl_hours);
|
||||
// Update session expiry in store
|
||||
drop(sessions); // release read lock before acquiring write lock
|
||||
self.update_session_expiry(&payload.session_id, new_exp).await;
|
||||
// Create new token with updated iat/exp
|
||||
let new_token = self.create_signed_token(
|
||||
&payload.session_id,
|
||||
&payload.username,
|
||||
&payload.role,
|
||||
now,
|
||||
new_exp.timestamp(),
|
||||
);
|
||||
Some(new_token)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Some((session, new_token))
|
||||
}
|
||||
|
||||
/// Revoke (delete) a session by token.
|
||||
/// Supports both legacy tokens (token is key) and signed tokens (extract session_id).
|
||||
pub async fn revoke_session(&self, token: &str) {
|
||||
if token.starts_with("session-") {
|
||||
self.sessions.write().await.remove(token);
|
||||
return;
|
||||
}
|
||||
// For signed token, try to extract session_id
|
||||
if let Ok(payload) = verify_signed_token(token, &self.secret) {
|
||||
self.sessions.write().await.remove(&payload.session_id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove all expired sessions from the store.
|
||||
@@ -61,4 +156,156 @@ impl SessionManager {
|
||||
let now = Utc::now();
|
||||
self.sessions.write().await.retain(|_, s| s.expires_at > now);
|
||||
}
|
||||
|
||||
// --- Private helpers ---
|
||||
|
||||
fn create_signed_token(&self, session_id: &str, username: &str, role: &str, iat: i64, exp: i64) -> String {
|
||||
let payload = SessionPayload {
|
||||
session_id: session_id.to_string(),
|
||||
username: username.to_string(),
|
||||
role: role.to_string(),
|
||||
iat,
|
||||
exp,
|
||||
version: TOKEN_VERSION.to_string(),
|
||||
};
|
||||
sign_token(&payload, &self.secret)
|
||||
}
|
||||
|
||||
async fn update_session_expiry(&self, session_id: &str, new_expires_at: DateTime<Utc>) {
|
||||
let mut sessions = self.sessions.write().await;
|
||||
if let Some(session) = sessions.get_mut(session_id) {
|
||||
session.expires_at = new_expires_at;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Load session secret from environment variable SESSION_SECRET (hex or base64 encoded).
|
||||
/// If not set, generates a random 32-byte secret and logs a warning.
|
||||
fn load_session_secret() -> Vec<u8> {
|
||||
let secret_str = env::var("SESSION_SECRET").unwrap_or_else(|_| {
|
||||
// Also check LLM_PROXY__SESSION_SECRET for consistency with config prefix
|
||||
env::var("LLM_PROXY__SESSION_SECRET").unwrap_or_else(|_| {
|
||||
// Generate a random secret (32 bytes) and encode as hex
|
||||
use rand::RngCore;
|
||||
let mut bytes = [0u8; 32];
|
||||
rand::rng().fill_bytes(&mut bytes);
|
||||
let hex_secret = hex::encode(bytes);
|
||||
tracing::warn!(
|
||||
"SESSION_SECRET environment variable not set. Using a randomly generated secret. \
|
||||
This will invalidate all sessions on restart. Set SESSION_SECRET to a fixed hex or base64 encoded 32-byte value."
|
||||
);
|
||||
hex_secret
|
||||
})
|
||||
});
|
||||
|
||||
// Decode hex or base64
|
||||
hex::decode(&secret_str)
|
||||
.or_else(|_| URL_SAFE.decode(&secret_str))
|
||||
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&secret_str))
|
||||
.unwrap_or_else(|_| {
|
||||
panic!("SESSION_SECRET must be hex or base64 encoded (32 bytes)");
|
||||
})
|
||||
}
|
||||
|
||||
/// Sign a session payload and return a token string in format base64_url(payload).base64_url(signature).
|
||||
fn sign_token(payload: &SessionPayload, secret: &[u8]) -> String {
|
||||
let json = serde_json::to_vec(payload).expect("Failed to serialize payload");
|
||||
let payload_b64 = URL_SAFE.encode(&json);
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
|
||||
mac.update(&json);
|
||||
let signature = mac.finalize().into_bytes();
|
||||
let signature_b64 = URL_SAFE.encode(signature);
|
||||
format!("{}.{}", payload_b64, signature_b64)
|
||||
}
|
||||
|
||||
/// Verify a signed token and return the decoded payload if valid.
|
||||
fn verify_signed_token(token: &str, secret: &[u8]) -> Result<SessionPayload, TokenError> {
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err(TokenError::InvalidFormat);
|
||||
}
|
||||
let payload_b64 = parts[0];
|
||||
let signature_b64 = parts[1];
|
||||
|
||||
let json = URL_SAFE.decode(payload_b64).map_err(|_| TokenError::InvalidFormat)?;
|
||||
let signature = URL_SAFE.decode(signature_b64).map_err(|_| TokenError::InvalidFormat)?;
|
||||
|
||||
// Verify HMAC
|
||||
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
|
||||
mac.update(&json);
|
||||
// Convert signature slice to GenericArray
|
||||
let tag = GenericArray::from_slice(&signature);
|
||||
mac.verify(tag).map_err(|_| TokenError::InvalidSignature)?;
|
||||
|
||||
// Deserialize payload
|
||||
let payload: SessionPayload = serde_json::from_slice(&json).map_err(|_| TokenError::InvalidPayload)?;
|
||||
Ok(payload)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum TokenError {
|
||||
InvalidFormat,
|
||||
InvalidSignature,
|
||||
InvalidPayload,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::env;
|
||||
|
||||
#[test]
|
||||
fn test_sign_and_verify_token() {
|
||||
let secret = b"test-secret-must-be-32-bytes-long!";
|
||||
let payload = SessionPayload {
|
||||
session_id: "test-session".to_string(),
|
||||
username: "testuser".to_string(),
|
||||
role: "user".to_string(),
|
||||
iat: 1000,
|
||||
exp: 2000,
|
||||
version: TOKEN_VERSION.to_string(),
|
||||
};
|
||||
let token = sign_token(&payload, secret);
|
||||
let verified = verify_signed_token(&token, secret).unwrap();
|
||||
assert_eq!(verified.session_id, payload.session_id);
|
||||
assert_eq!(verified.username, payload.username);
|
||||
assert_eq!(verified.role, payload.role);
|
||||
assert_eq!(verified.iat, payload.iat);
|
||||
assert_eq!(verified.exp, payload.exp);
|
||||
assert_eq!(verified.version, payload.version);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tampered_token() {
|
||||
let secret = b"test-secret-must-be-32-bytes-long!";
|
||||
let payload = SessionPayload {
|
||||
session_id: "test-session".to_string(),
|
||||
username: "testuser".to_string(),
|
||||
role: "user".to_string(),
|
||||
iat: 1000,
|
||||
exp: 2000,
|
||||
version: TOKEN_VERSION.to_string(),
|
||||
};
|
||||
let mut token = sign_token(&payload, secret);
|
||||
// Tamper with payload part
|
||||
let mut parts: Vec<&str> = token.split('.').collect();
|
||||
let mut payload_bytes = URL_SAFE.decode(parts[0]).unwrap();
|
||||
payload_bytes[0] ^= 0xFF; // flip some bits
|
||||
let tampered_payload = URL_SAFE.encode(payload_bytes);
|
||||
parts[0] = &tampered_payload;
|
||||
token = parts.join(".");
|
||||
assert!(verify_signed_token(&token, secret).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_session_secret_from_env() {
|
||||
unsafe {
|
||||
env::set_var("SESSION_SECRET", hex::encode([0xAA; 32]));
|
||||
}
|
||||
let secret = load_session_secret();
|
||||
assert_eq!(secret, vec![0xAA; 32]);
|
||||
unsafe {
|
||||
env::remove_var("SESSION_SECRET");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,15 @@ fn read_proc_file(path: &str) -> Option<String> {
|
||||
std::fs::read_to_string(path).ok()
|
||||
}
|
||||
|
||||
pub(super) async fn handle_system_health(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
pub(super) async fn handle_system_health(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let mut components = HashMap::new();
|
||||
components.insert("api_server".to_string(), "online".to_string());
|
||||
components.insert("database".to_string(), "online".to_string());
|
||||
@@ -67,7 +75,13 @@ pub(super) async fn handle_system_health(State(state): State<DashboardState>) ->
|
||||
/// Real system metrics from /proc (Linux only).
|
||||
pub(super) async fn handle_system_metrics(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
// --- CPU usage (aggregate across all cores) ---
|
||||
// /proc/stat first line: cpu user nice system idle iowait irq softirq steal guest guest_nice
|
||||
let cpu_percent = read_proc_file("/proc/stat")
|
||||
@@ -220,7 +234,15 @@ pub(super) async fn handle_system_metrics(
|
||||
})))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_system_logs(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
pub(super) async fn handle_system_logs(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
@@ -233,7 +255,10 @@ pub(super) async fn handle_system_logs(State(state): State<DashboardState>) -> J
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
reasoning_tokens,
|
||||
total_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
cost,
|
||||
status,
|
||||
error_message,
|
||||
@@ -257,6 +282,11 @@ pub(super) async fn handle_system_logs(State(state): State<DashboardState>) -> J
|
||||
"client_id": row.get::<String, _>("client_id"),
|
||||
"provider": row.get::<String, _>("provider"),
|
||||
"model": row.get::<String, _>("model"),
|
||||
"prompt_tokens": row.get::<i64, _>("prompt_tokens"),
|
||||
"completion_tokens": row.get::<i64, _>("completion_tokens"),
|
||||
"reasoning_tokens": row.get::<i64, _>("reasoning_tokens"),
|
||||
"cache_read_tokens": row.get::<i64, _>("cache_read_tokens"),
|
||||
"cache_write_tokens": row.get::<i64, _>("cache_write_tokens"),
|
||||
"tokens": row.get::<i64, _>("total_tokens"),
|
||||
"cost": row.get::<f64, _>("cost"),
|
||||
"status": row.get::<String, _>("status"),
|
||||
@@ -279,9 +309,10 @@ pub(super) async fn handle_system_backup(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
let backup_id = format!("backup-{}", chrono::Utc::now().timestamp());
|
||||
@@ -317,7 +348,15 @@ pub(super) async fn handle_system_backup(
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_settings(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
pub(super) async fn handle_get_settings(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let registry = &state.app_state.model_registry;
|
||||
let provider_count = registry.providers.len();
|
||||
let model_count: usize = registry.providers.values().map(|p| p.models.len()).sum();
|
||||
@@ -341,9 +380,10 @@ pub(super) async fn handle_update_settings(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
Json(ApiResponse::error(
|
||||
"Changing settings at runtime is not yet supported. Please update your config file and restart the server."
|
||||
|
||||
@@ -71,8 +71,14 @@ impl UsagePeriodFilter {
|
||||
|
||||
pub(super) async fn handle_usage_summary(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(filter): Query<UsagePeriodFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
let (period_clause, period_binds) = filter.to_sql();
|
||||
|
||||
@@ -133,7 +139,9 @@ pub(super) async fn handle_usage_summary(
|
||||
)
|
||||
.fetch_one(pool);
|
||||
|
||||
match tokio::join!(total_stats, today_stats, error_stats, avg_response) {
|
||||
let results = tokio::join!(total_stats, today_stats, error_stats, avg_response);
|
||||
|
||||
match results {
|
||||
(Ok(t), Ok(d), Ok(e), Ok(a)) => {
|
||||
let total_requests: i64 = t.get("total_requests");
|
||||
let total_tokens: i64 = t.get("total_tokens");
|
||||
@@ -168,14 +176,26 @@ pub(super) async fn handle_usage_summary(
|
||||
"total_cache_write_tokens": total_cache_write,
|
||||
})))
|
||||
}
|
||||
_ => Json(ApiResponse::error("Failed to fetch usage statistics".to_string())),
|
||||
(t_res, d_res, e_res, a_res) => {
|
||||
if let Err(e) = t_res { warn!("Total stats query failed: {}", e); }
|
||||
if let Err(e) = d_res { warn!("Today stats query failed: {}", e); }
|
||||
if let Err(e) = e_res { warn!("Error stats query failed: {}", e); }
|
||||
if let Err(e) = a_res { warn!("Avg response query failed: {}", e); }
|
||||
Json(ApiResponse::error("Failed to fetch usage statistics".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_time_series(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(filter): Query<UsagePeriodFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
let (period_clause, period_binds) = filter.to_sql();
|
||||
let granularity = filter.granularity();
|
||||
@@ -248,8 +268,14 @@ pub(super) async fn handle_time_series(
|
||||
|
||||
pub(super) async fn handle_clients_usage(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(filter): Query<UsagePeriodFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
let (period_clause, period_binds) = filter.to_sql();
|
||||
|
||||
@@ -308,8 +334,14 @@ pub(super) async fn handle_clients_usage(
|
||||
|
||||
pub(super) async fn handle_providers_usage(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(filter): Query<UsagePeriodFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
let (period_clause, period_binds) = filter.to_sql();
|
||||
|
||||
@@ -370,8 +402,14 @@ pub(super) async fn handle_providers_usage(
|
||||
|
||||
pub(super) async fn handle_detailed_usage(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(filter): Query<UsagePeriodFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
let (period_clause, period_binds) = filter.to_sql();
|
||||
|
||||
@@ -433,8 +471,14 @@ pub(super) async fn handle_detailed_usage(
|
||||
|
||||
pub(super) async fn handle_analytics_breakdown(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Query(filter): Query<UsagePeriodFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
let (period_clause, period_binds) = filter.to_sql();
|
||||
|
||||
|
||||
@@ -14,9 +14,10 @@ pub(super) async fn handle_get_users(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (_session, _) = match auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
@@ -66,9 +67,10 @@ pub(super) async fn handle_create_user(
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(payload): Json<CreateUserRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (_session, _) = match auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
@@ -147,9 +149,10 @@ pub(super) async fn handle_update_user(
|
||||
Path(id): Path<i64>,
|
||||
Json(payload): Json<UpdateUserRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
let (_session, _) = match auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
@@ -249,8 +252,8 @@ pub(super) async fn handle_delete_user(
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<i64>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let session = match auth::require_admin(&state, &headers).await {
|
||||
Ok(s) => s,
|
||||
let (session, _) = match auth::require_admin(&state, &headers).await {
|
||||
Ok((session, new_token)) => (session, new_token),
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
|
||||
@@ -18,7 +18,9 @@ pub async fn init(config: &DatabaseConfig) -> Result<DbPool> {
|
||||
let database_path = config.path.to_string_lossy().to_string();
|
||||
info!("Connecting to database at {}", database_path);
|
||||
|
||||
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?.create_if_missing(true);
|
||||
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?
|
||||
.create_if_missing(true)
|
||||
.pragma("foreign_keys", "ON");
|
||||
|
||||
let pool = SqlitePool::connect_with(options).await?;
|
||||
|
||||
@@ -29,7 +31,7 @@ pub async fn init(config: &DatabaseConfig) -> Result<DbPool> {
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
async fn run_migrations(pool: &DbPool) -> Result<()> {
|
||||
pub async fn run_migrations(pool: &DbPool) -> Result<()> {
|
||||
// Create clients table if it doesn't exist
|
||||
sqlx::query(
|
||||
r#"
|
||||
@@ -62,6 +64,7 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
|
||||
model TEXT,
|
||||
prompt_tokens INTEGER,
|
||||
completion_tokens INTEGER,
|
||||
reasoning_tokens INTEGER DEFAULT 0,
|
||||
total_tokens INTEGER,
|
||||
cost REAL,
|
||||
has_images BOOLEAN DEFAULT FALSE,
|
||||
@@ -88,6 +91,8 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
|
||||
api_key TEXT,
|
||||
credit_balance REAL DEFAULT 0.0,
|
||||
low_credit_threshold REAL DEFAULT 5.0,
|
||||
billing_mode TEXT,
|
||||
api_key_encrypted BOOLEAN DEFAULT FALSE,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"#,
|
||||
@@ -105,6 +110,8 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
|
||||
enabled BOOLEAN DEFAULT TRUE,
|
||||
prompt_cost_per_m REAL,
|
||||
completion_cost_per_m REAL,
|
||||
cache_read_cost_per_m REAL,
|
||||
cache_write_cost_per_m REAL,
|
||||
mapping TEXT,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE
|
||||
@@ -166,6 +173,26 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
|
||||
let _ = sqlx::query("ALTER TABLE llm_requests ADD COLUMN cache_write_tokens INTEGER DEFAULT 0")
|
||||
.execute(pool)
|
||||
.await;
|
||||
let _ = sqlx::query("ALTER TABLE llm_requests ADD COLUMN reasoning_tokens INTEGER DEFAULT 0")
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
// Add billing_mode column if it doesn't exist (migration for existing DBs)
|
||||
let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN billing_mode TEXT")
|
||||
.execute(pool)
|
||||
.await;
|
||||
// Add api_key_encrypted column if it doesn't exist (migration for existing DBs)
|
||||
let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN api_key_encrypted BOOLEAN DEFAULT FALSE")
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
// Add manual cache cost columns to model_configs if they don't exist
|
||||
let _ = sqlx::query("ALTER TABLE model_configs ADD COLUMN cache_read_cost_per_m REAL")
|
||||
.execute(pool)
|
||||
.await;
|
||||
let _ = sqlx::query("ALTER TABLE model_configs ADD COLUMN cache_write_cost_per_m REAL")
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
// Insert default admin user if none exists (default password: admin)
|
||||
let user_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users").fetch_one(pool).await?;
|
||||
@@ -216,6 +243,19 @@ async fn run_migrations(pool: &DbPool) -> Result<()> {
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Composite indexes for performance
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Insert default client if none exists
|
||||
sqlx::query(
|
||||
r#"
|
||||
|
||||
217
src/lib.rs
217
src/lib.rs
@@ -41,23 +41,18 @@ pub use state::AppState;
|
||||
pub mod test_utils {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState};
|
||||
use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState, utils::crypto, database::run_migrations};
|
||||
use sqlx::sqlite::SqlitePool;
|
||||
|
||||
/// Create a test application state
|
||||
pub async fn create_test_state() -> Arc<AppState> {
|
||||
pub async fn create_test_state() -> AppState {
|
||||
// Create in-memory database
|
||||
let pool = SqlitePool::connect("sqlite::memory:")
|
||||
.await
|
||||
.expect("Failed to create test database");
|
||||
|
||||
// Run migrations
|
||||
crate::database::init(&crate::config::DatabaseConfig {
|
||||
path: std::path::PathBuf::from(":memory:"),
|
||||
max_connections: 5,
|
||||
})
|
||||
.await
|
||||
.expect("Failed to initialize test database");
|
||||
// Run migrations on the pool
|
||||
run_migrations(&pool).await.expect("Failed to run migrations");
|
||||
|
||||
let rate_limit_manager = RateLimitManager::new(
|
||||
crate::rate_limiting::RateLimiterConfig::default(),
|
||||
@@ -73,7 +68,7 @@ pub mod test_utils {
|
||||
providers: std::collections::HashMap::new(),
|
||||
};
|
||||
|
||||
let (dashboard_tx, _) = tokio::sync::broadcast::channel(100);
|
||||
let (dashboard_tx, _) = tokio::sync::broadcast::channel::<serde_json::Value>(100);
|
||||
|
||||
let config = Arc::new(crate::config::AppConfig {
|
||||
server: crate::config::ServerConfig {
|
||||
@@ -125,20 +120,20 @@ pub mod test_utils {
|
||||
ollama: vec![],
|
||||
},
|
||||
config_path: None,
|
||||
encryption_key: "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f".to_string(),
|
||||
});
|
||||
|
||||
Arc::new(AppState {
|
||||
// Initialize encryption with the test key
|
||||
crypto::init_with_key(&config.encryption_key).expect("failed to initialize crypto");
|
||||
|
||||
AppState::new(
|
||||
config,
|
||||
provider_manager,
|
||||
db_pool: pool.clone(),
|
||||
rate_limit_manager: Arc::new(rate_limit_manager),
|
||||
client_manager,
|
||||
request_logger: Arc::new(crate::logging::RequestLogger::new(pool.clone(), dashboard_tx.clone())),
|
||||
model_registry: Arc::new(model_registry),
|
||||
model_config_cache: crate::state::ModelConfigCache::new(pool.clone()),
|
||||
dashboard_tx,
|
||||
auth_tokens: vec![],
|
||||
})
|
||||
pool,
|
||||
rate_limit_manager,
|
||||
model_registry,
|
||||
vec![], // auth_tokens
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a test HTTP client
|
||||
@@ -149,3 +144,185 @@ pub mod test_utils {
|
||||
.expect("Failed to create test HTTP client")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod integration_tests {
|
||||
use super::test_utils::*;
|
||||
use crate::{
|
||||
models::{ChatCompletionRequest, ChatMessage},
|
||||
server::router,
|
||||
utils::crypto,
|
||||
};
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{Request, StatusCode},
|
||||
};
|
||||
use mockito::Server;
|
||||
use serde_json::json;
|
||||
use sqlx::Row;
|
||||
use tower::util::ServiceExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_encrypted_provider_key_integration() {
|
||||
// Step 1: Setup test database and state
|
||||
let state = create_test_state().await;
|
||||
let pool = state.db_pool.clone();
|
||||
|
||||
// Step 2: Insert provider with encrypted API key
|
||||
let test_api_key = "test-openai-key-12345";
|
||||
let encrypted_key = crypto::encrypt(test_api_key).expect("Failed to encrypt test key");
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind("openai")
|
||||
.bind("OpenAI")
|
||||
.bind(true)
|
||||
.bind("http://localhost:1234") // Mock server URL
|
||||
.bind(&encrypted_key)
|
||||
.bind(true) // api_key_encrypted flag
|
||||
.bind(100.0)
|
||||
.bind(5.0)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.expect("Failed to update provider URL");
|
||||
|
||||
// Re-initialize provider with new URL
|
||||
state
|
||||
.provider_manager
|
||||
.initialize_provider("openai", &state.config, &pool)
|
||||
.await
|
||||
.expect("Failed to re-initialize provider");
|
||||
|
||||
// Step 4: Mock OpenAI API server
|
||||
let mut server = Server::new_async().await;
|
||||
let mock = server
|
||||
.mock("POST", "/chat/completions")
|
||||
.match_header("authorization", format!("Bearer {}", test_api_key).as_str())
|
||||
.with_status(200)
|
||||
.with_header("content-type", "application/json")
|
||||
.with_body(
|
||||
json!({
|
||||
"id": "chatcmpl-test",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-3.5-turbo",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello, world!"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 15
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
)
|
||||
.create_async()
|
||||
.await;
|
||||
|
||||
// Update provider base URL to use mock server
|
||||
sqlx::query("UPDATE provider_configs SET base_url = ? WHERE id = 'openai'")
|
||||
.bind(&server.url())
|
||||
.execute(&pool)
|
||||
.await
|
||||
.expect("Failed to update provider URL");
|
||||
|
||||
// Re-initialize provider with new URL
|
||||
state
|
||||
.provider_manager
|
||||
.initialize_provider("openai", &state.config, &pool)
|
||||
.await
|
||||
.expect("Failed to re-initialize provider");
|
||||
|
||||
// Step 5: Create test router and make request
|
||||
let app = router(state);
|
||||
|
||||
let request_body = ChatCompletionRequest {
|
||||
model: "gpt-3.5-turbo".to_string(),
|
||||
messages: vec![ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: crate::models::MessageContent::Text {
|
||||
content: "Hello".to_string(),
|
||||
},
|
||||
reasoning_content: None,
|
||||
tool_calls: None,
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
n: None,
|
||||
stop: None,
|
||||
max_tokens: Some(100),
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
stream: Some(false),
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
};
|
||||
|
||||
let request = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/v1/chat/completions")
|
||||
.header("content-type", "application/json")
|
||||
.header("authorization", "Bearer test-token")
|
||||
.body(Body::from(serde_json::to_string(&request_body).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
// Step 6: Execute request through proxy
|
||||
let response = app
|
||||
.oneshot(request)
|
||||
.await
|
||||
.expect("Failed to execute request");
|
||||
|
||||
let status = response.status();
|
||||
println!("Response status: {}", status);
|
||||
|
||||
if status != StatusCode::OK {
|
||||
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
|
||||
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
|
||||
println!("Response body: {}", body_str);
|
||||
panic!("Response status is not OK: {}", status);
|
||||
}
|
||||
|
||||
assert_eq!(status, StatusCode::OK);
|
||||
|
||||
// Verify the mock was called
|
||||
mock.assert_async().await;
|
||||
|
||||
// Give the async logging task time to complete
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
// Step 7: Verify usage was logged in database
|
||||
let log_row = sqlx::query("SELECT * FROM llm_requests WHERE client_id = 'client_test-tok' ORDER BY id DESC LIMIT 1")
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.expect("Request log not found");
|
||||
|
||||
assert_eq!(log_row.get::<String, _>("provider"), "openai");
|
||||
assert_eq!(log_row.get::<String, _>("model"), "gpt-3.5-turbo");
|
||||
assert_eq!(log_row.get::<i64, _>("prompt_tokens"), 10);
|
||||
assert_eq!(log_row.get::<i64, _>("completion_tokens"), 5);
|
||||
assert_eq!(log_row.get::<i64, _>("total_tokens"), 15);
|
||||
assert_eq!(log_row.get::<String, _>("status"), "success");
|
||||
|
||||
// Verify client usage was updated
|
||||
let client_row = sqlx::query("SELECT total_requests, total_tokens, total_cost FROM clients WHERE client_id = 'client_test-tok'")
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.expect("Client not found");
|
||||
|
||||
assert_eq!(client_row.get::<i64, _>("total_requests"), 1);
|
||||
assert_eq!(client_row.get::<i64, _>("total_tokens"), 15);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ pub struct RequestLog {
|
||||
pub model: String,
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub reasoning_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
pub cache_read_tokens: u32,
|
||||
pub cache_write_tokens: u32,
|
||||
@@ -77,22 +78,23 @@ impl RequestLogger {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO llm_requests
|
||||
(timestamp, client_id, provider, model, prompt_tokens, completion_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status, error_message, duration_ms, request_body, response_body)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
(timestamp, client_id, provider, model, prompt_tokens, completion_tokens, reasoning_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status, error_message, duration_ms, request_body, response_body)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(log.timestamp)
|
||||
.bind(log.client_id)
|
||||
.bind(&log.client_id)
|
||||
.bind(&log.provider)
|
||||
.bind(log.model)
|
||||
.bind(&log.model)
|
||||
.bind(log.prompt_tokens as i64)
|
||||
.bind(log.completion_tokens as i64)
|
||||
.bind(log.reasoning_tokens as i64)
|
||||
.bind(log.total_tokens as i64)
|
||||
.bind(log.cache_read_tokens as i64)
|
||||
.bind(log.cache_write_tokens as i64)
|
||||
.bind(log.cost)
|
||||
.bind(log.has_images)
|
||||
.bind(log.status)
|
||||
.bind(&log.status)
|
||||
.bind(log.error_message)
|
||||
.bind(log.duration_ms as i64)
|
||||
.bind(None::<String>) // request_body - optional, not stored to save disk space
|
||||
@@ -100,6 +102,23 @@ impl RequestLogger {
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// Update client usage statistics
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE clients SET
|
||||
total_requests = total_requests + 1,
|
||||
total_tokens = total_tokens + ?,
|
||||
total_cost = total_cost + ?,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE client_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(log.total_tokens as i64)
|
||||
.bind(log.cost)
|
||||
.bind(&log.client_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// Deduct from provider balance if successful.
|
||||
// Providers configured with billing_mode = 'postpaid' will not have their
|
||||
// credit_balance decremented. Use a conditional UPDATE so we don't need
|
||||
|
||||
@@ -10,6 +10,7 @@ use llm_proxy::{
|
||||
rate_limiting::{CircuitBreakerConfig, RateLimitManager, RateLimiterConfig},
|
||||
server,
|
||||
state::AppState,
|
||||
utils::crypto,
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
@@ -26,6 +27,10 @@ async fn main() -> Result<()> {
|
||||
let config = AppConfig::load().await?;
|
||||
info!("Configuration loaded from {:?}", config.config_path);
|
||||
|
||||
// Initialize encryption
|
||||
crypto::init_with_key(&config.encryption_key)?;
|
||||
info!("Encryption initialized");
|
||||
|
||||
// Initialize database connection pool
|
||||
let db_pool = database::init(&config.database).await?;
|
||||
info!("Database initialized at {:?}", config.database.path);
|
||||
|
||||
@@ -38,7 +38,7 @@ pub struct ChatMessage {
|
||||
pub role: String, // "system", "user", "assistant", "tool"
|
||||
#[serde(flatten)]
|
||||
pub content: MessageContent,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(alias = "reasoning", alias = "thought", skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
@@ -165,6 +165,8 @@ pub struct Usage {
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cache_read_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cache_write_tokens: Option<u32>,
|
||||
@@ -179,6 +181,8 @@ pub struct ChatCompletionStreamResponse {
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatStreamChoice>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -192,7 +196,7 @@ pub struct ChatStreamChoice {
|
||||
pub struct ChatStreamDelta {
|
||||
pub role: Option<String>,
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(alias = "reasoning", alias = "thought", skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
||||
|
||||
@@ -58,7 +58,37 @@ impl super::Provider for DeepSeekProvider {
|
||||
|
||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, false);
|
||||
let mut body = helpers::build_openai_body(&request, messages_json, false);
|
||||
|
||||
// Sanitize and fix for deepseek-reasoner (R1)
|
||||
if request.model == "deepseek-reasoner" {
|
||||
if let Some(obj) = body.as_object_mut() {
|
||||
// Remove unsupported parameters
|
||||
obj.remove("temperature");
|
||||
obj.remove("top_p");
|
||||
obj.remove("presence_penalty");
|
||||
obj.remove("frequency_penalty");
|
||||
obj.remove("logit_bias");
|
||||
obj.remove("logprobs");
|
||||
obj.remove("top_logprobs");
|
||||
|
||||
// ENSURE: EVERY assistant message must have reasoning_content and valid content
|
||||
if let Some(messages) = obj.get_mut("messages").and_then(|m| m.as_array_mut()) {
|
||||
for m in messages {
|
||||
if m["role"].as_str() == Some("assistant") {
|
||||
// DeepSeek R1 requires reasoning_content for consistency in history.
|
||||
if m.get("reasoning_content").is_none() || m["reasoning_content"].is_null() {
|
||||
m["reasoning_content"] = serde_json::json!(" ");
|
||||
}
|
||||
// DeepSeek R1 often requires content to be a string, not null/array
|
||||
if m.get("content").is_none() || m["content"].is_null() || m["content"].is_array() {
|
||||
m["content"] = serde_json::json!("");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
@@ -73,6 +103,7 @@ impl super::Provider for DeepSeekProvider {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
tracing::error!("DeepSeek API error ({}): {}", status, error_text);
|
||||
tracing::error!("Offending DeepSeek Request Body: {}", serde_json::to_string(&body).unwrap_or_default());
|
||||
return Err(AppError::ProviderError(format!("DeepSeek API error ({}): {}", status, error_text)));
|
||||
}
|
||||
|
||||
@@ -97,7 +128,9 @@ impl super::Provider for DeepSeekProvider {
|
||||
cache_write_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
helpers::calculate_cost_with_registry(
|
||||
if let Some(metadata) = registry.find_model(model) {
|
||||
if metadata.cost.is_some() {
|
||||
return helpers::calculate_cost_with_registry(
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
@@ -105,9 +138,26 @@ impl super::Provider for DeepSeekProvider {
|
||||
cache_write_tokens,
|
||||
registry,
|
||||
&self.pricing,
|
||||
0.14,
|
||||
0.28,
|
||||
)
|
||||
0.42,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Custom DeepSeek fallback that correctly handles cache hits
|
||||
let (prompt_rate, completion_rate) = self
|
||||
.pricing
|
||||
.iter()
|
||||
.find(|p| model.contains(&p.model))
|
||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||
.unwrap_or((0.28, 0.42)); // Default to DeepSeek's current API pricing
|
||||
|
||||
let cache_hit_rate = prompt_rate / 10.0;
|
||||
let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens);
|
||||
|
||||
(non_cached_prompt as f64 * prompt_rate / 1_000_000.0)
|
||||
+ (cache_read_tokens as f64 * cache_hit_rate / 1_000_000.0)
|
||||
+ (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
||||
}
|
||||
|
||||
async fn chat_completion_stream(
|
||||
@@ -118,9 +168,37 @@ impl super::Provider for DeepSeekProvider {
|
||||
let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?;
|
||||
let mut body = helpers::build_openai_body(&request, messages_json, true);
|
||||
|
||||
// Standard OpenAI cleanup
|
||||
// Sanitize and fix for deepseek-reasoner (R1)
|
||||
if request.model == "deepseek-reasoner" {
|
||||
if let Some(obj) = body.as_object_mut() {
|
||||
obj.remove("stream_options");
|
||||
// Keep stream_options if present (DeepSeek supports include_usage)
|
||||
|
||||
// Remove unsupported parameters
|
||||
obj.remove("temperature");
|
||||
|
||||
obj.remove("top_p");
|
||||
obj.remove("presence_penalty");
|
||||
obj.remove("frequency_penalty");
|
||||
obj.remove("logit_bias");
|
||||
obj.remove("logprobs");
|
||||
obj.remove("top_logprobs");
|
||||
|
||||
// ENSURE: EVERY assistant message must have reasoning_content and valid content
|
||||
if let Some(messages) = obj.get_mut("messages").and_then(|m| m.as_array_mut()) {
|
||||
for m in messages {
|
||||
if m["role"].as_str() == Some("assistant") {
|
||||
// DeepSeek R1 requires reasoning_content for consistency in history.
|
||||
if m.get("reasoning_content").is_none() || m["reasoning_content"].is_null() {
|
||||
m["reasoning_content"] = serde_json::json!(" ");
|
||||
}
|
||||
// DeepSeek R1 often requires content to be a string, not null/array
|
||||
if m.get("content").is_none() || m["content"].is_null() || m["content"].is_array() {
|
||||
m["content"] = serde_json::json!("");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let url = format!("{}/chat/completions", self.config.base_url);
|
||||
@@ -166,12 +244,18 @@ impl super::Provider for DeepSeekProvider {
|
||||
match probe_resp {
|
||||
Ok(r) if !r.status().is_success() => {
|
||||
let status = r.status();
|
||||
let body = r.text().await.unwrap_or_default();
|
||||
tracing::error!("DeepSeek Stream Error Probe ({}): {}", status, body);
|
||||
Err(AppError::ProviderError(format!("DeepSeek API error ({}): {}", status, body)))?;
|
||||
let error_body = r.text().await.unwrap_or_default();
|
||||
tracing::error!("DeepSeek Stream Error Probe ({}): {}", status, error_body);
|
||||
// Log the offending request body at ERROR level so it shows up in standard logs
|
||||
tracing::error!("Offending DeepSeek Request Body: {}", serde_json::to_string(&probe_body).unwrap_or_default());
|
||||
Err(AppError::ProviderError(format!("DeepSeek API error ({}): {}", status, error_body)))?;
|
||||
}
|
||||
_ => {
|
||||
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
||||
Ok(_) => {
|
||||
Err(AppError::ProviderError(format!("Stream error (probe returned 200): {}", e)))?;
|
||||
}
|
||||
Err(probe_err) => {
|
||||
tracing::error!("DeepSeek Stream Error Probe failed: {}", probe_err);
|
||||
Err(AppError::ProviderError(format!("Stream error (probe failed: {}): {}", probe_err, e)))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,6 +222,16 @@ impl GeminiProvider {
|
||||
let mut contents: Vec<GeminiContent> = Vec::new();
|
||||
let mut system_parts = Vec::new();
|
||||
|
||||
// PRE-PASS: Build tool_id -> function_name mapping for tool responses
|
||||
let mut tool_id_to_name = std::collections::HashMap::new();
|
||||
for msg in &messages {
|
||||
if let Some(tool_calls) = &msg.tool_calls {
|
||||
for tc in tool_calls {
|
||||
tool_id_to_name.insert(tc.id.clone(), tc.function.name.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for msg in messages {
|
||||
if msg.role == "system" {
|
||||
for part in msg.content {
|
||||
@@ -261,7 +271,14 @@ impl GeminiProvider {
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let name = msg.name.clone().or_else(|| msg.tool_call_id.clone()).unwrap_or_else(|| "unknown_function".to_string());
|
||||
// RESOLVE: Use msg.name if present, otherwise look up by tool_call_id
|
||||
let name = msg.name.clone()
|
||||
.or_else(|| {
|
||||
msg.tool_call_id.as_ref()
|
||||
.and_then(|id| tool_id_to_name.get(id).cloned())
|
||||
})
|
||||
.or_else(|| msg.tool_call_id.clone())
|
||||
.unwrap_or_else(|| "unknown_function".to_string());
|
||||
|
||||
// Gemini API requires 'response' to be a JSON object (google.protobuf.Struct).
|
||||
// If it is an array or primitive, wrap it in an object.
|
||||
@@ -322,10 +339,13 @@ impl GeminiProvider {
|
||||
let args = serde_json::from_str::<Value>(&tc.function.arguments)
|
||||
.unwrap_or_else(|_| serde_json::json!({}));
|
||||
|
||||
// RESTORE: Use tc.id as thought_signature.
|
||||
// Gemini 3 models require this field for any function call in the history.
|
||||
// We include it regardless of format to ensure the model has context.
|
||||
let thought_signature = Some(tc.id.clone());
|
||||
// RESTORE: Only use tc.id as thought_signature if it's NOT a synthetic ID.
|
||||
// Synthetic IDs (starting with 'call_') cause 400 errors as they are not valid Base64 for the TYPE_BYTES field.
|
||||
let thought_signature = if tc.id.starts_with("call_") {
|
||||
None
|
||||
} else {
|
||||
Some(tc.id.clone())
|
||||
};
|
||||
|
||||
parts.push(GeminiPart {
|
||||
text: None,
|
||||
@@ -702,6 +722,10 @@ impl super::Provider for GeminiProvider {
|
||||
let reasoning_content = candidate
|
||||
.and_then(|c| c.content.parts.iter().find_map(|p| p.thought.clone()));
|
||||
|
||||
let reasoning_tokens = reasoning_content.as_ref()
|
||||
.map(|r| crate::utils::tokens::estimate_completion_tokens(r, &model))
|
||||
.unwrap_or(0);
|
||||
|
||||
// Extract function calls → OpenAI tool_calls
|
||||
let tool_calls = candidate.and_then(|c| Self::extract_tool_calls(&c.content.parts));
|
||||
|
||||
@@ -732,6 +756,7 @@ impl super::Provider for GeminiProvider {
|
||||
tool_calls,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
reasoning_tokens,
|
||||
total_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens: 0, // Gemini doesn't report cache writes separately
|
||||
@@ -752,7 +777,9 @@ impl super::Provider for GeminiProvider {
|
||||
cache_write_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
super::helpers::calculate_cost_with_registry(
|
||||
if let Some(metadata) = registry.find_model(model) {
|
||||
if metadata.cost.is_some() {
|
||||
return super::helpers::calculate_cost_with_registry(
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
@@ -762,7 +789,24 @@ impl super::Provider for GeminiProvider {
|
||||
&self.pricing,
|
||||
0.075,
|
||||
0.30,
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Custom Gemini fallback that correctly handles cache hits (25% of input cost)
|
||||
let (prompt_rate, completion_rate) = self
|
||||
.pricing
|
||||
.iter()
|
||||
.find(|p| model.contains(&p.model))
|
||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||
.unwrap_or((0.075, 0.30)); // Default to Gemini 1.5 Flash current API pricing
|
||||
|
||||
let cache_hit_rate = prompt_rate * 0.25;
|
||||
let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens);
|
||||
|
||||
(non_cached_prompt as f64 * prompt_rate / 1_000_000.0)
|
||||
+ (cache_read_tokens as f64 * cache_hit_rate / 1_000_000.0)
|
||||
+ (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
||||
}
|
||||
|
||||
async fn chat_completion_stream(
|
||||
@@ -863,6 +907,7 @@ impl super::Provider for GeminiProvider {
|
||||
super::StreamUsage {
|
||||
prompt_tokens: u.prompt_token_count,
|
||||
completion_tokens: u.candidates_token_count,
|
||||
reasoning_tokens: 0,
|
||||
total_tokens: u.total_token_count,
|
||||
cache_read_tokens: u.cached_content_token_count,
|
||||
cache_write_tokens: 0,
|
||||
|
||||
@@ -29,7 +29,14 @@ pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result<Vec<
|
||||
"content": text_content
|
||||
});
|
||||
if let Some(tool_call_id) = &m.tool_call_id {
|
||||
msg["tool_call_id"] = serde_json::json!(tool_call_id);
|
||||
// OpenAI and others have a 40-char limit for tool_call_id.
|
||||
// Gemini signatures (56 chars) must be shortened for compatibility.
|
||||
let id = if tool_call_id.len() > 40 {
|
||||
&tool_call_id[..40]
|
||||
} else {
|
||||
tool_call_id
|
||||
};
|
||||
msg["tool_call_id"] = serde_json::json!(id);
|
||||
}
|
||||
if let Some(name) = &m.name {
|
||||
msg["name"] = serde_json::json!(name);
|
||||
@@ -65,14 +72,23 @@ pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result<Vec<
|
||||
msg["reasoning_content"] = serde_json::json!(reasoning);
|
||||
}
|
||||
|
||||
// For assistant messages with tool_calls, content can be null
|
||||
// For assistant messages with tool_calls, content can be empty string
|
||||
if let Some(tool_calls) = &m.tool_calls {
|
||||
// Sanitize tool call IDs for OpenAI compatibility (max 40 chars)
|
||||
let sanitized_calls: Vec<_> = tool_calls.iter().map(|tc| {
|
||||
let mut sanitized = tc.clone();
|
||||
if sanitized.id.len() > 40 {
|
||||
sanitized.id = sanitized.id[..40].to_string();
|
||||
}
|
||||
sanitized
|
||||
}).collect();
|
||||
|
||||
if parts.is_empty() {
|
||||
msg["content"] = serde_json::Value::Null;
|
||||
msg["content"] = serde_json::json!("");
|
||||
} else {
|
||||
msg["content"] = serde_json::json!(parts);
|
||||
}
|
||||
msg["tool_calls"] = serde_json::json!(tool_calls);
|
||||
msg["tool_calls"] = serde_json::json!(sanitized_calls);
|
||||
} else {
|
||||
msg["content"] = serde_json::json!(parts);
|
||||
}
|
||||
@@ -114,7 +130,13 @@ pub async fn messages_to_openai_json_text_only(
|
||||
"content": text_content
|
||||
});
|
||||
if let Some(tool_call_id) = &m.tool_call_id {
|
||||
msg["tool_call_id"] = serde_json::json!(tool_call_id);
|
||||
// OpenAI and others have a 40-char limit for tool_call_id.
|
||||
let id = if tool_call_id.len() > 40 {
|
||||
&tool_call_id[..40]
|
||||
} else {
|
||||
tool_call_id
|
||||
};
|
||||
msg["tool_call_id"] = serde_json::json!(id);
|
||||
}
|
||||
if let Some(name) = &m.name {
|
||||
msg["name"] = serde_json::json!(name);
|
||||
@@ -143,14 +165,23 @@ pub async fn messages_to_openai_json_text_only(
|
||||
msg["reasoning_content"] = serde_json::json!(reasoning);
|
||||
}
|
||||
|
||||
// For assistant messages with tool_calls, content can be null
|
||||
// For assistant messages with tool_calls, content can be empty string
|
||||
if let Some(tool_calls) = &m.tool_calls {
|
||||
// Sanitize tool call IDs for OpenAI compatibility (max 40 chars)
|
||||
let sanitized_calls: Vec<_> = tool_calls.iter().map(|tc| {
|
||||
let mut sanitized = tc.clone();
|
||||
if sanitized.id.len() > 40 {
|
||||
sanitized.id = sanitized.id[..40].to_string();
|
||||
}
|
||||
sanitized
|
||||
}).collect();
|
||||
|
||||
if parts.is_empty() {
|
||||
msg["content"] = serde_json::Value::Null;
|
||||
msg["content"] = serde_json::json!("");
|
||||
} else {
|
||||
msg["content"] = serde_json::json!(parts);
|
||||
}
|
||||
msg["tool_calls"] = serde_json::json!(tool_calls);
|
||||
msg["tool_calls"] = serde_json::json!(sanitized_calls);
|
||||
} else {
|
||||
msg["content"] = serde_json::json!(parts);
|
||||
}
|
||||
@@ -223,6 +254,11 @@ pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<Provide
|
||||
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
|
||||
// Extract reasoning tokens
|
||||
let reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"]
|
||||
.as_u64()
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
// Extract cache tokens — try OpenAI/Grok format first, then DeepSeek format
|
||||
let cache_read_tokens = usage["prompt_tokens_details"]["cached_tokens"]
|
||||
.as_u64()
|
||||
@@ -230,9 +266,9 @@ pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<Provide
|
||||
.or_else(|| usage["prompt_cache_hit_tokens"].as_u64())
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
// DeepSeek reports cache_write as prompt_cache_miss_tokens (tokens written to cache for future use).
|
||||
// OpenAI doesn't report cache_write in this location, but may in the future.
|
||||
let cache_write_tokens = usage["prompt_cache_miss_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
// DeepSeek reports prompt_cache_miss_tokens which are just regular non-cached tokens.
|
||||
// They do not incur a separate cache_write fee, so we don't map them here to avoid double-charging.
|
||||
let cache_write_tokens = 0;
|
||||
|
||||
Ok(ProviderResponse {
|
||||
content,
|
||||
@@ -240,6 +276,7 @@ pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<Provide
|
||||
tool_calls,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
reasoning_tokens,
|
||||
total_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
@@ -264,18 +301,21 @@ pub fn parse_openai_stream_chunk(
|
||||
let completion_tokens = u["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let total_tokens = u["total_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
|
||||
let reasoning_tokens = u["completion_tokens_details"]["reasoning_tokens"]
|
||||
.as_u64()
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
let cache_read_tokens = u["prompt_tokens_details"]["cached_tokens"]
|
||||
.as_u64()
|
||||
.or_else(|| u["prompt_cache_hit_tokens"].as_u64())
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
let cache_write_tokens = u["prompt_cache_miss_tokens"]
|
||||
.as_u64()
|
||||
.unwrap_or(0) as u32;
|
||||
let cache_write_tokens = 0;
|
||||
|
||||
Some(StreamUsage {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
reasoning_tokens,
|
||||
total_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
|
||||
@@ -7,6 +7,7 @@ use std::sync::Arc;
|
||||
use crate::errors::AppError;
|
||||
use crate::models::UnifiedRequest;
|
||||
|
||||
|
||||
pub mod deepseek;
|
||||
pub mod gemini;
|
||||
pub mod grok;
|
||||
@@ -41,6 +42,16 @@ pub trait Provider: Send + Sync {
|
||||
request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError>;
|
||||
|
||||
/// Process a streaming chat request using provider-specific "responses" style endpoint
|
||||
/// Default implementation falls back to `chat_completion_stream` for providers
|
||||
/// that do not implement a dedicated responses endpoint.
|
||||
async fn chat_responses_stream(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
self.chat_completion_stream(request).await
|
||||
}
|
||||
|
||||
/// Estimate token count for a request (for cost calculation)
|
||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32>;
|
||||
|
||||
@@ -64,6 +75,7 @@ pub struct ProviderResponse {
|
||||
pub tool_calls: Option<Vec<crate::models::ToolCall>>,
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub reasoning_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
pub cache_read_tokens: u32,
|
||||
pub cache_write_tokens: u32,
|
||||
@@ -75,6 +87,7 @@ pub struct ProviderResponse {
|
||||
pub struct StreamUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub reasoning_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
pub cache_read_tokens: u32,
|
||||
pub cache_write_tokens: u32,
|
||||
@@ -125,17 +138,35 @@ impl ProviderManager {
|
||||
db_pool: &crate::database::DbPool,
|
||||
) -> Result<()> {
|
||||
// Load override from database
|
||||
let db_config = sqlx::query("SELECT enabled, base_url, api_key FROM provider_configs WHERE id = ?")
|
||||
let db_config = sqlx::query("SELECT enabled, base_url, api_key, api_key_encrypted FROM provider_configs WHERE id = ?")
|
||||
.bind(name)
|
||||
.fetch_optional(db_pool)
|
||||
.await?;
|
||||
|
||||
let (enabled, base_url, api_key) = if let Some(row) = db_config {
|
||||
(
|
||||
row.get::<bool, _>("enabled"),
|
||||
row.get::<Option<String>, _>("base_url"),
|
||||
row.get::<Option<String>, _>("api_key"),
|
||||
)
|
||||
let enabled = row.get::<bool, _>("enabled");
|
||||
let base_url = row.get::<Option<String>, _>("base_url");
|
||||
let api_key_encrypted = row.get::<bool, _>("api_key_encrypted");
|
||||
let api_key = row.get::<Option<String>, _>("api_key");
|
||||
// Decrypt API key if encrypted
|
||||
let api_key = match (api_key, api_key_encrypted) {
|
||||
(Some(key), true) => {
|
||||
match crate::utils::crypto::decrypt(&key) {
|
||||
Ok(decrypted) => Some(decrypted),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to decrypt API key for provider {}: {}", name, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
(Some(key), false) => {
|
||||
// Plaintext key - optionally encrypt and update database (lazy migration)
|
||||
// For now, just use plaintext
|
||||
Some(key)
|
||||
}
|
||||
(None, _) => None,
|
||||
};
|
||||
(enabled, base_url, api_key)
|
||||
} else {
|
||||
// No database override, use defaults from AppConfig
|
||||
match name {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use futures::StreamExt;
|
||||
|
||||
use super::helpers;
|
||||
use super::{ProviderResponse, ProviderStreamChunk};
|
||||
@@ -25,7 +26,7 @@ impl OpenAIProvider {
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||
.pool_max_idle_per_host(4)
|
||||
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||
.tcp_keepalive(std::time::Duration::from_secs(15))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
@@ -44,7 +45,13 @@ impl super::Provider for OpenAIProvider {
|
||||
}
|
||||
|
||||
fn supports_model(&self, model: &str) -> bool {
|
||||
model.starts_with("gpt-") || model.starts_with("o1-") || model.starts_with("o3-") || model.starts_with("o4-")
|
||||
model.starts_with("gpt-") ||
|
||||
model.starts_with("o1-") ||
|
||||
model.starts_with("o2-") ||
|
||||
model.starts_with("o3-") ||
|
||||
model.starts_with("o4-") ||
|
||||
model.starts_with("o5-") ||
|
||||
model.contains("gpt-5")
|
||||
}
|
||||
|
||||
fn supports_multimodal(&self) -> bool {
|
||||
@@ -52,8 +59,22 @@ impl super::Provider for OpenAIProvider {
|
||||
}
|
||||
|
||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
// Allow proactive routing to Responses API based on heuristic
|
||||
let model_lc = request.model.to_lowercase();
|
||||
if model_lc.contains("gpt-5") || model_lc.contains("codex") {
|
||||
return self.chat_responses(request).await;
|
||||
}
|
||||
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, false);
|
||||
let mut body = helpers::build_openai_body(&request, messages_json, false);
|
||||
|
||||
// Transition: Newer OpenAI models (o1, o3, gpt-5) require max_completion_tokens
|
||||
// instead of the legacy max_tokens parameter.
|
||||
if request.model.starts_with("o1-") || request.model.starts_with("o3-") || request.model.contains("gpt-5") {
|
||||
if let Some(max_tokens) = body.as_object_mut().and_then(|obj| obj.remove("max_tokens")) {
|
||||
body["max_completion_tokens"] = max_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
@@ -65,108 +86,17 @@ impl super::Provider for OpenAIProvider {
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
|
||||
// Read error body to diagnose. If the model requires the Responses
|
||||
// API (v1/responses), retry against that endpoint.
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
if error_text.to_lowercase().contains("v1/responses") || error_text.to_lowercase().contains("only supported in v1/responses") {
|
||||
// Build a simple `input` string by concatenating message parts.
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let mut inputs: Vec<String> = Vec::new();
|
||||
for m in &messages_json {
|
||||
let role = m["role"].as_str().unwrap_or("");
|
||||
let parts = m.get("content").and_then(|c| c.as_array()).cloned().unwrap_or_default();
|
||||
let mut text_parts = Vec::new();
|
||||
for p in parts {
|
||||
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
|
||||
text_parts.push(t.to_string());
|
||||
}
|
||||
}
|
||||
inputs.push(format!("{}: {}", role, text_parts.join("")));
|
||||
}
|
||||
let input_text = inputs.join("\n");
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(format!("{}/responses", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&serde_json::json!({ "model": request.model, "input": input_text }))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let err = resp.text().await.unwrap_or_default();
|
||||
return Err(AppError::ProviderError(format!("OpenAI Responses API error: {}", err)));
|
||||
return self.chat_responses(request).await;
|
||||
}
|
||||
|
||||
let resp_json: serde_json::Value = resp.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
// Try to normalize: if it's chat-style, use existing parser
|
||||
if resp_json.get("choices").is_some() {
|
||||
return helpers::parse_openai_response(&resp_json, request.model);
|
||||
}
|
||||
|
||||
// Responses API: try to extract text from `output` or `candidates`
|
||||
// output -> [{"content": [{"type":..., "text": "..."}, ...]}]
|
||||
let mut content_text = String::new();
|
||||
if let Some(output) = resp_json.get("output").and_then(|o| o.as_array()) {
|
||||
if let Some(first) = output.get(0) {
|
||||
if let Some(contents) = first.get("content").and_then(|c| c.as_array()) {
|
||||
for item in contents {
|
||||
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
|
||||
if !content_text.is_empty() {
|
||||
content_text.push_str("\n");
|
||||
}
|
||||
content_text.push_str(text);
|
||||
} else if let Some(parts) = item.get("parts").and_then(|p| p.as_array()) {
|
||||
for p in parts {
|
||||
if let Some(t) = p.as_str() {
|
||||
if !content_text.is_empty() { content_text.push_str("\n"); }
|
||||
content_text.push_str(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: check `candidates` -> candidate.content.parts.text
|
||||
if content_text.is_empty() {
|
||||
if let Some(cands) = resp_json.get("candidates").and_then(|c| c.as_array()) {
|
||||
if let Some(c0) = cands.get(0) {
|
||||
if let Some(content) = c0.get("content") {
|
||||
if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) {
|
||||
for p in parts {
|
||||
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
|
||||
if !content_text.is_empty() { content_text.push_str("\n"); }
|
||||
content_text.push_str(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract simple usage if present
|
||||
let prompt_tokens = resp_json.get("usage").and_then(|u| u.get("prompt_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
||||
let completion_tokens = resp_json.get("usage").and_then(|u| u.get("completion_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
||||
let total_tokens = resp_json.get("usage").and_then(|u| u.get("total_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
||||
|
||||
return Ok(ProviderResponse {
|
||||
content: content_text,
|
||||
reasoning_content: None,
|
||||
tool_calls: None,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cache_read_tokens: 0,
|
||||
cache_write_tokens: 0,
|
||||
model: request.model,
|
||||
});
|
||||
}
|
||||
|
||||
return Err(AppError::ProviderError(format!("OpenAI API error: {}", error_text)));
|
||||
tracing::error!("OpenAI API error ({}): {}", status, error_text);
|
||||
return Err(AppError::ProviderError(format!("OpenAI API error ({}): {}", status, error_text)));
|
||||
}
|
||||
|
||||
let resp_json: serde_json::Value = response
|
||||
@@ -178,27 +108,80 @@ impl super::Provider for OpenAIProvider {
|
||||
}
|
||||
|
||||
async fn chat_responses(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
// Build a simple `input` string by concatenating message parts.
|
||||
// Build a structured input for the Responses API.
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let mut inputs: Vec<String> = Vec::new();
|
||||
let mut input_parts = Vec::new();
|
||||
for m in &messages_json {
|
||||
let role = m["role"].as_str().unwrap_or("");
|
||||
let parts = m.get("content").and_then(|c| c.as_array()).cloned().unwrap_or_default();
|
||||
let mut text_parts = Vec::new();
|
||||
for p in parts {
|
||||
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
|
||||
text_parts.push(t.to_string());
|
||||
let mut role = m["role"].as_str().unwrap_or("user").to_string();
|
||||
// Newer models (gpt-5, o1) prefer "developer" over "system"
|
||||
if role == "system" {
|
||||
role = "developer".to_string();
|
||||
}
|
||||
|
||||
let mut content = m.get("content").cloned().unwrap_or(serde_json::json!([]));
|
||||
|
||||
// Map content types based on role for Responses API
|
||||
if let Some(content_array) = content.as_array_mut() {
|
||||
for part in content_array {
|
||||
if let Some(part_obj) = part.as_object_mut() {
|
||||
if let Some(t) = part_obj.get("type").and_then(|v| v.as_str()) {
|
||||
match t {
|
||||
"text" => {
|
||||
let new_type = if role == "assistant" { "output_text" } else { "input_text" };
|
||||
part_obj.insert("type".to_string(), serde_json::json!(new_type));
|
||||
}
|
||||
"image_url" => {
|
||||
// Assistant typically doesn't have image_url in history this way, but for safety:
|
||||
let new_type = if role == "assistant" { "output_image" } else { "input_image" };
|
||||
part_obj.insert("type".to_string(), serde_json::json!(new_type));
|
||||
if let Some(img_url) = part_obj.remove("image_url") {
|
||||
part_obj.insert("image".to_string(), img_url);
|
||||
}
|
||||
}
|
||||
inputs.push(format!("{}: {}", role, text_parts.join("")));
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if let Some(text) = content.as_str() {
|
||||
let new_type = if role == "assistant" { "output_text" } else { "input_text" };
|
||||
content = serde_json::json!([{ "type": new_type, "text": text }]);
|
||||
}
|
||||
|
||||
input_parts.push(serde_json::json!({
|
||||
"role": role,
|
||||
"content": content
|
||||
}));
|
||||
}
|
||||
|
||||
let mut body = serde_json::json!({
|
||||
"model": request.model,
|
||||
"input": input_parts,
|
||||
});
|
||||
|
||||
// Add standard parameters
|
||||
if let Some(temp) = request.temperature {
|
||||
body["temperature"] = serde_json::json!(temp);
|
||||
}
|
||||
|
||||
// Newer models (gpt-5, o1) in Responses API use max_output_tokens
|
||||
if let Some(max_tokens) = request.max_tokens {
|
||||
if request.model.contains("gpt-5") || request.model.starts_with("o1-") || request.model.starts_with("o3-") {
|
||||
body["max_output_tokens"] = serde_json::json!(max_tokens);
|
||||
} else {
|
||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tools) = &request.tools {
|
||||
body["tools"] = serde_json::json!(tools);
|
||||
}
|
||||
let input_text = inputs.join("\n");
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(format!("{}/responses", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&serde_json::json!({ "model": request.model, "input": input_text }))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
@@ -210,11 +193,16 @@ impl super::Provider for OpenAIProvider {
|
||||
|
||||
let resp_json: serde_json::Value = resp.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
// Try to normalize: if it's chat-style, use existing parser
|
||||
if resp_json.get("choices").is_some() {
|
||||
return helpers::parse_openai_response(&resp_json, request.model);
|
||||
}
|
||||
|
||||
// Normalize Responses API output into ProviderResponse
|
||||
let mut content_text = String::new();
|
||||
if let Some(output) = resp_json.get("output").and_then(|o| o.as_array()) {
|
||||
if let Some(first) = output.get(0) {
|
||||
if let Some(contents) = first.get("content").and_then(|c| c.as_array()) {
|
||||
for out in output {
|
||||
if let Some(contents) = out.get("content").and_then(|c| c.as_array()) {
|
||||
for item in contents {
|
||||
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
|
||||
if !content_text.is_empty() { content_text.push_str("\n"); }
|
||||
@@ -259,6 +247,7 @@ impl super::Provider for OpenAIProvider {
|
||||
tool_calls: None,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
reasoning_tokens: 0,
|
||||
total_tokens,
|
||||
cache_read_tokens: 0,
|
||||
cache_write_tokens: 0,
|
||||
@@ -296,47 +285,280 @@ impl super::Provider for OpenAIProvider {
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, true);
|
||||
// Allow proactive routing to Responses API based on heuristic
|
||||
let model_lc = request.model.to_lowercase();
|
||||
if model_lc.contains("gpt-5") || model_lc.contains("codex") {
|
||||
return self.chat_responses_stream(request).await;
|
||||
}
|
||||
|
||||
// Try to create an EventSource for streaming; if creation fails or
|
||||
// the stream errors, fall back to a single synchronous request and
|
||||
// emit its result as a single chunk.
|
||||
let es_result = reqwest_eventsource::EventSource::new(
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let mut body = helpers::build_openai_body(&request, messages_json, true);
|
||||
|
||||
// Standard OpenAI cleanup
|
||||
if let Some(obj) = body.as_object_mut() {
|
||||
// stream_options.include_usage is supported by OpenAI for token usage in streaming
|
||||
// Transition: Newer OpenAI models (o1, o3, gpt-5) require max_completion_tokens
|
||||
if request.model.starts_with("o1-") || request.model.starts_with("o3-") || request.model.contains("gpt-5") {
|
||||
if let Some(max_tokens) = obj.remove("max_tokens") {
|
||||
obj.insert("max_completion_tokens".to_string(), max_tokens);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let url = format!("{}/chat/completions", self.config.base_url);
|
||||
let api_key = self.api_key.clone();
|
||||
let probe_client = self.client.clone();
|
||||
let probe_body = body.clone();
|
||||
let model = request.model.clone();
|
||||
|
||||
let es = reqwest_eventsource::EventSource::new(
|
||||
self.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body),
|
||||
);
|
||||
)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
if es_result.is_err() {
|
||||
// Fallback to non-streaming request which itself may retry to
|
||||
// Responses API if necessary (handled in chat_completion).
|
||||
let resp = self.chat_completion(request.clone()).await?;
|
||||
let single_stream = async_stream::try_stream! {
|
||||
let chunk = ProviderStreamChunk {
|
||||
content: resp.content,
|
||||
reasoning_content: resp.reasoning_content,
|
||||
finish_reason: Some("stop".to_string()),
|
||||
tool_calls: None,
|
||||
model: resp.model.clone(),
|
||||
usage: Some(super::StreamUsage {
|
||||
prompt_tokens: resp.prompt_tokens,
|
||||
completion_tokens: resp.completion_tokens,
|
||||
total_tokens: resp.total_tokens,
|
||||
cache_read_tokens: resp.cache_read_tokens,
|
||||
cache_write_tokens: resp.cache_write_tokens,
|
||||
}),
|
||||
};
|
||||
|
||||
yield chunk;
|
||||
};
|
||||
|
||||
return Ok(Box::pin(single_stream));
|
||||
let stream = async_stream::try_stream! {
|
||||
let mut es = es;
|
||||
while let Some(event) = es.next().await {
|
||||
match event {
|
||||
Ok(reqwest_eventsource::Event::Message(msg)) => {
|
||||
if msg.data == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
|
||||
let es = es_result.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
let chunk: serde_json::Value = serde_json::from_str(&msg.data)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||
|
||||
Ok(helpers::create_openai_stream(es, request.model, None))
|
||||
if let Some(p_chunk) = helpers::parse_openai_stream_chunk(&chunk, &model, None) {
|
||||
yield p_chunk?;
|
||||
}
|
||||
}
|
||||
Ok(_) => continue,
|
||||
Err(e) => {
|
||||
// Attempt to probe for the actual error body
|
||||
let probe_resp = probe_client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.json(&probe_body)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match probe_resp {
|
||||
Ok(r) if !r.status().is_success() => {
|
||||
let status = r.status();
|
||||
let error_body = r.text().await.unwrap_or_default();
|
||||
tracing::error!("OpenAI Stream Error Probe ({}): {}", status, error_body);
|
||||
tracing::debug!("Offending OpenAI Request Body: {}", serde_json::to_string(&probe_body).unwrap_or_default());
|
||||
Err(AppError::ProviderError(format!("OpenAI API error ({}): {}", status, error_body)))?;
|
||||
}
|
||||
Ok(_) => {
|
||||
// Probe returned success? This is unexpected if the original stream failed.
|
||||
Err(AppError::ProviderError(format!("Stream error (probe returned 200): {}", e)))?;
|
||||
}
|
||||
Err(probe_err) => {
|
||||
// Probe itself failed
|
||||
tracing::error!("OpenAI Stream Error Probe failed: {}", probe_err);
|
||||
Err(AppError::ProviderError(format!("Stream error (probe failed: {}): {}", probe_err, e)))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
async fn chat_responses_stream(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
// Build a structured input for the Responses API.
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let mut input_parts = Vec::new();
|
||||
for m in &messages_json {
|
||||
let mut role = m["role"].as_str().unwrap_or("user").to_string();
|
||||
// Newer models (gpt-5, o1) prefer "developer" over "system"
|
||||
if role == "system" {
|
||||
role = "developer".to_string();
|
||||
}
|
||||
|
||||
let mut content = m.get("content").cloned().unwrap_or(serde_json::json!([]));
|
||||
|
||||
// Map content types based on role for Responses API
|
||||
if let Some(content_array) = content.as_array_mut() {
|
||||
for part in content_array {
|
||||
if let Some(part_obj) = part.as_object_mut() {
|
||||
if let Some(t) = part_obj.get("type").and_then(|v| v.as_str()) {
|
||||
match t {
|
||||
"text" => {
|
||||
let new_type = if role == "assistant" { "output_text" } else { "input_text" };
|
||||
part_obj.insert("type".to_string(), serde_json::json!(new_type));
|
||||
}
|
||||
"image_url" => {
|
||||
// Assistant typically doesn't have image_url in history this way, but for safety:
|
||||
let new_type = if role == "assistant" { "output_image" } else { "input_image" };
|
||||
part_obj.insert("type".to_string(), serde_json::json!(new_type));
|
||||
if let Some(img_url) = part_obj.remove("image_url") {
|
||||
part_obj.insert("image".to_string(), img_url);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if let Some(text) = content.as_str() {
|
||||
let new_type = if role == "assistant" { "output_text" } else { "input_text" };
|
||||
content = serde_json::json!([{ "type": new_type, "text": text }]);
|
||||
}
|
||||
|
||||
input_parts.push(serde_json::json!({
|
||||
"role": role,
|
||||
"content": content
|
||||
}));
|
||||
}
|
||||
|
||||
let mut body = serde_json::json!({
|
||||
"model": request.model,
|
||||
"input": input_parts,
|
||||
"stream": true,
|
||||
});
|
||||
|
||||
// Add standard parameters
|
||||
if let Some(temp) = request.temperature {
|
||||
body["temperature"] = serde_json::json!(temp);
|
||||
}
|
||||
|
||||
// Newer models (gpt-5, o1) in Responses API use max_output_tokens
|
||||
if let Some(max_tokens) = request.max_tokens {
|
||||
if request.model.contains("gpt-5") || request.model.starts_with("o1-") || request.model.starts_with("o3-") {
|
||||
body["max_output_tokens"] = serde_json::json!(max_tokens);
|
||||
} else {
|
||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||
}
|
||||
}
|
||||
|
||||
let url = format!("{}/responses", self.config.base_url);
|
||||
let api_key = self.api_key.clone();
|
||||
let model = request.model.clone();
|
||||
let probe_client = self.client.clone();
|
||||
let probe_body = body.clone();
|
||||
|
||||
let es = reqwest_eventsource::EventSource::new(
|
||||
self.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.header("Accept", "text/event-stream")
|
||||
.json(&body),
|
||||
)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource for Responses API: {}", e)))?;
|
||||
|
||||
let stream = async_stream::try_stream! {
|
||||
let mut es = es;
|
||||
while let Some(event) = es.next().await {
|
||||
match event {
|
||||
Ok(reqwest_eventsource::Event::Message(msg)) => {
|
||||
if msg.data == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
|
||||
let chunk: serde_json::Value = serde_json::from_str(&msg.data)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse Responses stream chunk: {}", e)))?;
|
||||
|
||||
// Try standard OpenAI parsing first (choices/usage)
|
||||
if let Some(p_chunk) = helpers::parse_openai_stream_chunk(&chunk, &model, None) {
|
||||
yield p_chunk?;
|
||||
} else {
|
||||
// Responses API specific parsing for streaming
|
||||
let mut content = String::new();
|
||||
let mut finish_reason = None;
|
||||
|
||||
let event_type = chunk.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
match event_type {
|
||||
"response.output_text.delta" => {
|
||||
if let Some(delta) = chunk.get("delta").and_then(|v| v.as_str()) {
|
||||
content.push_str(delta);
|
||||
}
|
||||
}
|
||||
"response.output_text.done" => {
|
||||
if let Some(text) = chunk.get("text").and_then(|v| v.as_str()) {
|
||||
// Some implementations send the full text at the end
|
||||
// We usually prefer deltas, but if we haven't seen them, this is the fallback.
|
||||
// However, if we're already yielding deltas, we might not want this.
|
||||
// For now, let's just use it as a signal that we're done.
|
||||
finish_reason = Some("stop".to_string());
|
||||
}
|
||||
}
|
||||
"response.done" => {
|
||||
finish_reason = Some("stop".to_string());
|
||||
}
|
||||
_ => {
|
||||
// Fallback to older nested structure if present
|
||||
if let Some(output) = chunk.get("output").and_then(|o| o.as_array()) {
|
||||
for out in output {
|
||||
if let Some(contents) = out.get("content").and_then(|c| c.as_array()) {
|
||||
for item in contents {
|
||||
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
|
||||
content.push_str(text);
|
||||
} else if let Some(delta) = item.get("delta").and_then(|d| d.get("text")).and_then(|t| t.as_str()) {
|
||||
content.push_str(delta);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !content.is_empty() || finish_reason.is_some() {
|
||||
yield ProviderStreamChunk {
|
||||
content,
|
||||
reasoning_content: None,
|
||||
finish_reason,
|
||||
tool_calls: None,
|
||||
model: model.clone(),
|
||||
usage: None,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(_) => continue,
|
||||
Err(e) => {
|
||||
// Attempt to probe for the actual error body
|
||||
let probe_resp = probe_client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.header("Accept", "application/json") // Ask for JSON during probe
|
||||
.json(&probe_body)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match probe_resp {
|
||||
Ok(r) => {
|
||||
let status = r.status();
|
||||
let body = r.text().await.unwrap_or_default();
|
||||
if status.is_success() {
|
||||
tracing::warn!("Responses stream ended prematurely but probe returned 200 OK. Body: {}", body);
|
||||
Err(AppError::ProviderError(format!("Responses stream ended (server sent 200 OK with body: {})", body)))?;
|
||||
} else {
|
||||
tracing::error!("OpenAI Responses Stream Error Probe ({}): {}", status, body);
|
||||
Err(AppError::ProviderError(format!("OpenAI Responses API error ({}): {}", status, body)))?;
|
||||
}
|
||||
}
|
||||
Err(probe_err) => {
|
||||
tracing::error!("OpenAI Responses Stream Error Probe failed: {}", probe_err);
|
||||
Err(AppError::ProviderError(format!("Responses stream error (probe failed: {}): {}", probe_err, e)))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,12 +6,15 @@
|
||||
//! 3. Global rate limiting for overall system protection
|
||||
|
||||
use anyhow::Result;
|
||||
use governor::{Quota, RateLimiter, DefaultDirectRateLimiter};
|
||||
use std::collections::HashMap;
|
||||
use std::num::NonZeroU32;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{info, warn};
|
||||
|
||||
type GovRateLimiter = DefaultDirectRateLimiter;
|
||||
|
||||
/// Rate limiter configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RateLimiterConfig {
|
||||
@@ -65,45 +68,7 @@ impl Default for CircuitBreakerConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple token bucket rate limiter for a single client
|
||||
#[derive(Debug)]
|
||||
struct TokenBucket {
|
||||
tokens: f64,
|
||||
capacity: f64,
|
||||
refill_rate: f64, // tokens per second
|
||||
last_refill: Instant,
|
||||
}
|
||||
|
||||
impl TokenBucket {
|
||||
fn new(capacity: f64, refill_rate: f64) -> Self {
|
||||
Self {
|
||||
tokens: capacity,
|
||||
capacity,
|
||||
refill_rate,
|
||||
last_refill: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
fn refill(&mut self) {
|
||||
let now = Instant::now();
|
||||
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
|
||||
let new_tokens = elapsed * self.refill_rate;
|
||||
|
||||
self.tokens = (self.tokens + new_tokens).min(self.capacity);
|
||||
self.last_refill = now;
|
||||
}
|
||||
|
||||
fn try_acquire(&mut self, tokens: f64) -> bool {
|
||||
self.refill();
|
||||
|
||||
if self.tokens >= tokens {
|
||||
self.tokens -= tokens;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Circuit breaker for a provider
|
||||
#[derive(Debug)]
|
||||
@@ -209,8 +174,8 @@ impl ProviderCircuitBreaker {
|
||||
/// Rate limiting and circuit breaking manager
|
||||
#[derive(Debug)]
|
||||
pub struct RateLimitManager {
|
||||
client_buckets: Arc<RwLock<HashMap<String, TokenBucket>>>,
|
||||
global_bucket: Arc<RwLock<TokenBucket>>,
|
||||
client_buckets: Arc<RwLock<HashMap<String, GovRateLimiter>>>,
|
||||
global_bucket: Arc<GovRateLimiter>,
|
||||
circuit_breakers: Arc<RwLock<HashMap<String, ProviderCircuitBreaker>>>,
|
||||
config: RateLimiterConfig,
|
||||
circuit_config: CircuitBreakerConfig,
|
||||
@@ -218,15 +183,18 @@ pub struct RateLimitManager {
|
||||
|
||||
impl RateLimitManager {
|
||||
pub fn new(config: RateLimiterConfig, circuit_config: CircuitBreakerConfig) -> Self {
|
||||
// Convert requests per minute to tokens per second
|
||||
let global_refill_rate = config.global_requests_per_minute as f64 / 60.0;
|
||||
// Create global rate limiter quota
|
||||
// Use a much larger burst size for the global bucket to handle concurrent dashboard load
|
||||
let global_burst = config.global_requests_per_minute / 6; // e.g., 100 for 600 req/min
|
||||
let global_quota = Quota::per_minute(
|
||||
NonZeroU32::new(config.global_requests_per_minute).expect("global_requests_per_minute must be positive")
|
||||
)
|
||||
.allow_burst(NonZeroU32::new(global_burst).expect("global_burst must be positive"));
|
||||
let global_bucket = RateLimiter::direct(global_quota);
|
||||
|
||||
Self {
|
||||
client_buckets: Arc::new(RwLock::new(HashMap::new())),
|
||||
global_bucket: Arc::new(RwLock::new(TokenBucket::new(
|
||||
config.burst_size as f64,
|
||||
global_refill_rate,
|
||||
))),
|
||||
global_bucket: Arc::new(global_bucket),
|
||||
circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
|
||||
config,
|
||||
circuit_config,
|
||||
@@ -236,24 +204,22 @@ impl RateLimitManager {
|
||||
/// Check if a client request is allowed
|
||||
pub async fn check_client_request(&self, client_id: &str) -> Result<bool> {
|
||||
// Check global rate limit first (1 token per request)
|
||||
{
|
||||
let mut global_bucket = self.global_bucket.write().await;
|
||||
if !global_bucket.try_acquire(1.0) {
|
||||
if self.global_bucket.check().is_err() {
|
||||
warn!("Global rate limit exceeded");
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
|
||||
// Check client-specific rate limit
|
||||
let mut buckets = self.client_buckets.write().await;
|
||||
let bucket = buckets.entry(client_id.to_string()).or_insert_with(|| {
|
||||
TokenBucket::new(
|
||||
self.config.burst_size as f64,
|
||||
self.config.requests_per_minute as f64 / 60.0,
|
||||
let quota = Quota::per_minute(
|
||||
NonZeroU32::new(self.config.requests_per_minute).expect("requests_per_minute must be positive")
|
||||
)
|
||||
.allow_burst(NonZeroU32::new(self.config.burst_size).expect("burst_size must be positive"));
|
||||
RateLimiter::direct(quota)
|
||||
});
|
||||
|
||||
Ok(bucket.try_acquire(1.0))
|
||||
Ok(bucket.check().is_ok())
|
||||
}
|
||||
|
||||
/// Check if provider requests are allowed (circuit breaker)
|
||||
@@ -299,6 +265,7 @@ pub mod middleware {
|
||||
use super::*;
|
||||
use crate::errors::AppError;
|
||||
use crate::state::AppState;
|
||||
use crate::auth::AuthInfo;
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
middleware::Next,
|
||||
@@ -309,20 +276,24 @@ pub mod middleware {
|
||||
/// Rate limiting middleware
|
||||
pub async fn rate_limit_middleware(
|
||||
State(state): State<AppState>,
|
||||
request: Request,
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, AppError> {
|
||||
// Extract token synchronously from headers (avoids holding &Request across await)
|
||||
let token = extract_bearer_token(&request);
|
||||
|
||||
// Resolve client_id: DB token lookup, then prefix fallback
|
||||
let client_id = resolve_client_id(token, &state).await;
|
||||
// Resolve client_id and populate AuthInfo: DB token lookup, then prefix fallback
|
||||
let auth_info = resolve_auth_info(token, &state).await;
|
||||
let client_id = auth_info.client_id.clone();
|
||||
|
||||
// Check rate limits
|
||||
if !state.rate_limit_manager.check_client_request(&client_id).await? {
|
||||
return Err(AppError::RateLimitError("Rate limit exceeded".to_string()));
|
||||
}
|
||||
|
||||
// Store AuthInfo in request extensions for extractors and downstream handlers
|
||||
request.extensions_mut().insert(auth_info);
|
||||
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
|
||||
@@ -334,26 +305,39 @@ pub mod middleware {
|
||||
.map(|t| t.to_string())
|
||||
}
|
||||
|
||||
/// Resolve client ID: try DB token first, then fall back to token-prefix derivation
|
||||
async fn resolve_client_id(token: Option<String>, state: &AppState) -> String {
|
||||
/// Resolve auth info: try DB token first, then fall back to token-prefix derivation
|
||||
async fn resolve_auth_info(token: Option<String>, state: &AppState) -> AuthInfo {
|
||||
if let Some(token) = token {
|
||||
// Try DB token lookup first
|
||||
if let Ok(Some(cid)) = sqlx::query_scalar::<_, String>(
|
||||
"SELECT client_id FROM client_tokens WHERE token = ? AND is_active = TRUE",
|
||||
match sqlx::query_scalar::<_, String>(
|
||||
"UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = TRUE RETURNING client_id",
|
||||
)
|
||||
.bind(&token)
|
||||
.fetch_optional(&state.db_pool)
|
||||
.await
|
||||
{
|
||||
return cid;
|
||||
Ok(Some(cid)) => {
|
||||
return AuthInfo {
|
||||
token,
|
||||
client_id: cid,
|
||||
};
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("DB error during token lookup: {}", e);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Fallback to token-prefix derivation (env tokens / permissive mode)
|
||||
return format!("client_{}", &token[..8.min(token.len())]);
|
||||
let client_id = format!("client_{}", &token[..8.min(token.len())]);
|
||||
return AuthInfo { token, client_id };
|
||||
}
|
||||
|
||||
// No token — anonymous
|
||||
"anonymous".to_string()
|
||||
AuthInfo {
|
||||
token: String::new(),
|
||||
client_id: "anonymous".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Circuit breaker middleware for provider requests
|
||||
|
||||
@@ -5,9 +5,13 @@ use axum::{
|
||||
response::sse::{Event, Sse},
|
||||
routing::{get, post},
|
||||
};
|
||||
use axum::http::{header, HeaderValue};
|
||||
use tower_http::{
|
||||
limit::RequestBodyLimitLayer,
|
||||
set_header::SetResponseHeaderLayer,
|
||||
};
|
||||
|
||||
use futures::StreamExt;
|
||||
use sqlx;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
use tracing::{info, warn};
|
||||
@@ -24,9 +28,34 @@ use crate::{
|
||||
};
|
||||
|
||||
pub fn router(state: AppState) -> Router {
|
||||
// Security headers
|
||||
let csp_header: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
||||
header::CONTENT_SECURITY_POLICY,
|
||||
"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws:;"
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
let x_frame_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
||||
header::X_FRAME_OPTIONS,
|
||||
"DENY".parse().unwrap(),
|
||||
);
|
||||
let x_content_type_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
||||
header::X_CONTENT_TYPE_OPTIONS,
|
||||
"nosniff".parse().unwrap(),
|
||||
);
|
||||
let strict_transport_security: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
||||
header::STRICT_TRANSPORT_SECURITY,
|
||||
"max-age=31536000; includeSubDomains".parse().unwrap(),
|
||||
);
|
||||
|
||||
Router::new()
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/v1/models", get(list_models))
|
||||
.layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit
|
||||
.layer(csp_header)
|
||||
.layer(x_frame_options)
|
||||
.layer(x_content_type_options)
|
||||
.layer(strict_transport_security)
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
rate_limiting::middleware::rate_limit_middleware,
|
||||
@@ -108,8 +137,23 @@ async fn get_model_cost(
|
||||
// Check in-memory cache for cost overrides (no SQLite hit)
|
||||
if let Some(cached) = state.model_config_cache.get(model).await {
|
||||
if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) {
|
||||
// Manual overrides don't have cache-specific rates, so use simple formula
|
||||
return (prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0);
|
||||
// Manual overrides logic: if cache rates are provided, use cache-aware formula.
|
||||
// Formula: (non_cached_prompt * input_rate) + (cache_read * read_rate) + (cache_write * write_rate) + (completion * output_rate)
|
||||
let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens);
|
||||
let mut total = (non_cached_prompt as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0);
|
||||
|
||||
if let Some(cr) = cached.cache_read_cost_per_m {
|
||||
total += cache_read_tokens as f64 * cr / 1_000_000.0;
|
||||
} else {
|
||||
// No manual cache_read rate — charge cached tokens at full input rate (backwards compatibility)
|
||||
total += cache_read_tokens as f64 * p / 1_000_000.0;
|
||||
}
|
||||
|
||||
if let Some(cw) = cached.cache_write_cost_per_m {
|
||||
total += cache_write_tokens as f64 * cw / 1_000_000.0;
|
||||
}
|
||||
|
||||
return total;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,32 +166,16 @@ async fn chat_completions(
|
||||
auth: AuthenticatedClient,
|
||||
Json(mut request): Json<ChatCompletionRequest>,
|
||||
) -> Result<axum::response::Response, AppError> {
|
||||
// Resolve client_id: try DB token first, then env tokens, then permissive fallback
|
||||
let db_client_id: Option<String> = sqlx::query_scalar::<_, String>(
|
||||
"SELECT client_id FROM client_tokens WHERE token = ? AND is_active = TRUE",
|
||||
)
|
||||
.bind(&auth.token)
|
||||
.fetch_optional(&state.db_pool)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let client_id = if let Some(cid) = db_client_id {
|
||||
// Update last_used_at in background (fire-and-forget)
|
||||
let pool = state.db_pool.clone();
|
||||
let client_id = auth.client_id.clone();
|
||||
let token = auth.token.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = sqlx::query("UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ?")
|
||||
.bind(&token)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
});
|
||||
cid
|
||||
} else if state.auth_tokens.is_empty() || state.auth_tokens.contains(&auth.token) {
|
||||
// Env token match or permissive mode (no env tokens configured)
|
||||
auth.client_id.clone()
|
||||
} else {
|
||||
|
||||
// Verify token if env tokens are configured
|
||||
if !state.auth_tokens.is_empty() && !state.auth_tokens.contains(&token) {
|
||||
// If not in env tokens, check if it was a DB token (client_id wouldn't be client_XXXX prefix)
|
||||
if client_id.starts_with("client_") {
|
||||
return Err(AppError::AuthError("Invalid authentication token".to_string()));
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
let model = request.model.clone();
|
||||
@@ -213,7 +241,15 @@ async fn chat_completions(
|
||||
let prompt_tokens = crate::utils::tokens::estimate_request_tokens(&model, &unified_request);
|
||||
|
||||
// Handle streaming response
|
||||
let stream_result = provider.chat_completion_stream(unified_request).await;
|
||||
// Allow provider-specific routing for streaming too
|
||||
let use_responses = provider.name() == "openai"
|
||||
&& crate::utils::registry::model_prefers_responses(&state.model_registry, &unified_request.model);
|
||||
|
||||
let stream_result = if use_responses {
|
||||
provider.chat_responses_stream(unified_request).await
|
||||
} else {
|
||||
provider.chat_completion_stream(unified_request).await
|
||||
};
|
||||
|
||||
match stream_result {
|
||||
Ok(stream) => {
|
||||
@@ -236,7 +272,6 @@ async fn chat_completions(
|
||||
prompt_tokens,
|
||||
has_images,
|
||||
logger: state.request_logger.clone(),
|
||||
client_manager: state.client_manager.clone(),
|
||||
model_registry: state.model_registry.clone(),
|
||||
model_config_cache: state.model_config_cache.clone(),
|
||||
},
|
||||
@@ -277,6 +312,14 @@ async fn chat_completions(
|
||||
},
|
||||
finish_reason: chunk.finish_reason,
|
||||
}],
|
||||
usage: chunk.usage.as_ref().map(|u| crate::models::Usage {
|
||||
prompt_tokens: u.prompt_tokens,
|
||||
completion_tokens: u.completion_tokens,
|
||||
total_tokens: u.total_tokens,
|
||||
reasoning_tokens: if u.reasoning_tokens > 0 { Some(u.reasoning_tokens) } else { None },
|
||||
cache_read_tokens: if u.cache_read_tokens > 0 { Some(u.cache_read_tokens) } else { None },
|
||||
cache_write_tokens: if u.cache_write_tokens > 0 { Some(u.cache_write_tokens) } else { None },
|
||||
}),
|
||||
};
|
||||
|
||||
// Use axum's Event directly, wrap in Ok
|
||||
@@ -348,6 +391,7 @@ async fn chat_completions(
|
||||
model: response.model.clone(),
|
||||
prompt_tokens: response.prompt_tokens,
|
||||
completion_tokens: response.completion_tokens,
|
||||
reasoning_tokens: response.reasoning_tokens,
|
||||
total_tokens: response.total_tokens,
|
||||
cache_read_tokens: response.cache_read_tokens,
|
||||
cache_write_tokens: response.cache_write_tokens,
|
||||
@@ -358,15 +402,6 @@ async fn chat_completions(
|
||||
duration_ms: duration.as_millis() as u64,
|
||||
});
|
||||
|
||||
// Update client usage (fire-and-forget, don't block response)
|
||||
{
|
||||
let cm = state.client_manager.clone();
|
||||
let cid = client_id.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = cm.update_client_usage(&cid, response.total_tokens as i64, cost).await;
|
||||
});
|
||||
}
|
||||
|
||||
// Convert ProviderResponse to ChatCompletionResponse
|
||||
let finish_reason = if response.tool_calls.is_some() {
|
||||
"tool_calls".to_string()
|
||||
@@ -397,6 +432,7 @@ async fn chat_completions(
|
||||
prompt_tokens: response.prompt_tokens,
|
||||
completion_tokens: response.completion_tokens,
|
||||
total_tokens: response.total_tokens,
|
||||
reasoning_tokens: if response.reasoning_tokens > 0 { Some(response.reasoning_tokens) } else { None },
|
||||
cache_read_tokens: if response.cache_read_tokens > 0 { Some(response.cache_read_tokens) } else { None },
|
||||
cache_write_tokens: if response.cache_write_tokens > 0 { Some(response.cache_write_tokens) } else { None },
|
||||
}),
|
||||
@@ -426,6 +462,7 @@ async fn chat_completions(
|
||||
model: model.clone(),
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
reasoning_tokens: 0,
|
||||
total_tokens: 0,
|
||||
cache_read_tokens: 0,
|
||||
cache_write_tokens: 0,
|
||||
|
||||
@@ -15,6 +15,8 @@ pub struct CachedModelConfig {
|
||||
pub mapping: Option<String>,
|
||||
pub prompt_cost_per_m: Option<f64>,
|
||||
pub completion_cost_per_m: Option<f64>,
|
||||
pub cache_read_cost_per_m: Option<f64>,
|
||||
pub cache_write_cost_per_m: Option<f64>,
|
||||
}
|
||||
|
||||
/// In-memory cache for model_configs table.
|
||||
@@ -35,15 +37,15 @@ impl ModelConfigCache {
|
||||
|
||||
/// Load all model configs from the database into cache
|
||||
pub async fn refresh(&self) {
|
||||
match sqlx::query_as::<_, (String, bool, Option<String>, Option<f64>, Option<f64>)>(
|
||||
"SELECT id, enabled, mapping, prompt_cost_per_m, completion_cost_per_m FROM model_configs",
|
||||
match sqlx::query_as::<_, (String, bool, Option<String>, Option<f64>, Option<f64>, Option<f64>, Option<f64>)>(
|
||||
"SELECT id, enabled, mapping, prompt_cost_per_m, completion_cost_per_m, cache_read_cost_per_m, cache_write_cost_per_m FROM model_configs",
|
||||
)
|
||||
.fetch_all(&self.db_pool)
|
||||
.await
|
||||
{
|
||||
Ok(rows) => {
|
||||
let mut map = HashMap::with_capacity(rows.len());
|
||||
for (id, enabled, mapping, prompt_cost, completion_cost) in rows {
|
||||
for (id, enabled, mapping, prompt_cost, completion_cost, cache_read_cost, cache_write_cost) in rows {
|
||||
map.insert(
|
||||
id,
|
||||
CachedModelConfig {
|
||||
@@ -51,6 +53,8 @@ impl ModelConfigCache {
|
||||
mapping,
|
||||
prompt_cost_per_m: prompt_cost,
|
||||
completion_cost_per_m: completion_cost,
|
||||
cache_read_cost_per_m: cache_read_cost,
|
||||
cache_write_cost_per_m: cache_write_cost,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
171
src/utils/crypto.rs
Normal file
171
src/utils/crypto.rs
Normal file
@@ -0,0 +1,171 @@
|
||||
use aes_gcm::{
|
||||
aead::{Aead, AeadCore, KeyInit, OsRng},
|
||||
Aes256Gcm, Key, Nonce,
|
||||
};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
|
||||
use std::env;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
static RAW_KEY: OnceLock<[u8; 32]> = OnceLock::new();
|
||||
|
||||
/// Initialize the encryption key from a hex or base64 encoded string.
|
||||
/// Must be called before any encryption/decryption operations.
|
||||
/// Returns error if the key is invalid or already initialized with a different key.
|
||||
pub fn init_with_key(key_str: &str) -> Result<()> {
|
||||
let key_bytes = hex::decode(key_str)
|
||||
.or_else(|_| BASE64.decode(key_str))
|
||||
.context("Encryption key must be hex or base64 encoded")?;
|
||||
if key_bytes.len() != 32 {
|
||||
anyhow::bail!(
|
||||
"Encryption key must be 32 bytes (256 bits), got {} bytes",
|
||||
key_bytes.len()
|
||||
);
|
||||
}
|
||||
let key_array: [u8; 32] = key_bytes.try_into().unwrap(); // safe due to length check
|
||||
// Check if already initialized with same key
|
||||
if let Some(existing) = RAW_KEY.get() {
|
||||
if existing == &key_array {
|
||||
// Same key already initialized, okay
|
||||
return Ok(());
|
||||
} else {
|
||||
anyhow::bail!("Encryption key already initialized with a different key");
|
||||
}
|
||||
}
|
||||
// Store raw key bytes
|
||||
RAW_KEY
|
||||
.set(key_array)
|
||||
.map_err(|_| anyhow::anyhow!("Encryption key already initialized"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Initialize the encryption key from the environment variable `LLM_PROXY__ENCRYPTION_KEY`.
|
||||
/// Must be called before any encryption/decryption operations.
|
||||
/// Panics if the environment variable is missing or invalid.
|
||||
pub fn init_from_env() -> Result<()> {
|
||||
let key_str =
|
||||
env::var("LLM_PROXY__ENCRYPTION_KEY").context("LLM_PROXY__ENCRYPTION_KEY environment variable not set")?;
|
||||
init_with_key(&key_str)
|
||||
}
|
||||
|
||||
/// Get the encryption key bytes, panicking if not initialized.
|
||||
fn get_key() -> &'static [u8; 32] {
|
||||
RAW_KEY
|
||||
.get()
|
||||
.expect("Encryption key not initialized. Call crypto::init_with_key() or crypto::init_from_env() first.")
|
||||
}
|
||||
|
||||
/// Encrypt a plaintext string and return a base64-encoded ciphertext (nonce || ciphertext || tag).
|
||||
pub fn encrypt(plaintext: &str) -> Result<String> {
|
||||
let key = Key::<Aes256Gcm>::from_slice(get_key());
|
||||
let cipher = Aes256Gcm::new(key);
|
||||
let nonce = Aes256Gcm::generate_nonce(&mut OsRng); // 12 bytes
|
||||
let ciphertext = cipher
|
||||
.encrypt(&nonce, plaintext.as_bytes())
|
||||
.map_err(|e| anyhow!("Encryption failed: {}", e))?;
|
||||
// Combine nonce and ciphertext (ciphertext already includes tag)
|
||||
let mut combined = Vec::with_capacity(nonce.len() + ciphertext.len());
|
||||
combined.extend_from_slice(&nonce);
|
||||
combined.extend_from_slice(&ciphertext);
|
||||
Ok(BASE64.encode(combined))
|
||||
}
|
||||
|
||||
/// Decrypt a base64-encoded ciphertext (nonce || ciphertext || tag) to a plaintext string.
|
||||
pub fn decrypt(ciphertext_b64: &str) -> Result<String> {
|
||||
let key = Key::<Aes256Gcm>::from_slice(get_key());
|
||||
let cipher = Aes256Gcm::new(key);
|
||||
let combined = BASE64.decode(ciphertext_b64).context("Invalid base64 ciphertext")?;
|
||||
if combined.len() < 12 {
|
||||
anyhow::bail!("Ciphertext too short");
|
||||
}
|
||||
let (nonce_bytes, ciphertext_and_tag) = combined.split_at(12);
|
||||
let nonce = Nonce::from_slice(nonce_bytes);
|
||||
let plaintext_bytes = cipher
|
||||
.decrypt(nonce, ciphertext_and_tag)
|
||||
.map_err(|e| anyhow!("Decryption failed (invalid key or corrupted ciphertext): {}", e))?;
|
||||
String::from_utf8(plaintext_bytes).context("Decrypted bytes are not valid UTF-8")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const TEST_KEY_HEX: &str = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f";
|
||||
|
||||
#[test]
|
||||
fn test_encrypt_decrypt() {
|
||||
init_with_key(TEST_KEY_HEX).unwrap();
|
||||
let plaintext = "super secret api key";
|
||||
let ciphertext = encrypt(plaintext).unwrap();
|
||||
assert_ne!(ciphertext, plaintext);
|
||||
let decrypted = decrypt(&ciphertext).unwrap();
|
||||
assert_eq!(decrypted, plaintext);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_different_inputs_produce_different_ciphertexts() {
|
||||
init_with_key(TEST_KEY_HEX).unwrap();
|
||||
let plaintext = "same";
|
||||
let cipher1 = encrypt(plaintext).unwrap();
|
||||
let cipher2 = encrypt(plaintext).unwrap();
|
||||
assert_ne!(cipher1, cipher2, "Nonce should make ciphertexts differ");
|
||||
assert_eq!(decrypt(&cipher1).unwrap(), plaintext);
|
||||
assert_eq!(decrypt(&cipher2).unwrap(), plaintext);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_key_length() {
|
||||
let result = init_with_key("tooshort");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_init_from_env() {
|
||||
unsafe { std::env::set_var("LLM_PROXY__ENCRYPTION_KEY", TEST_KEY_HEX) };
|
||||
let result = init_from_env();
|
||||
assert!(result.is_ok());
|
||||
// Ensure encryption works
|
||||
let ciphertext = encrypt("test").unwrap();
|
||||
let decrypted = decrypt(&ciphertext).unwrap();
|
||||
assert_eq!(decrypted, "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_missing_env_key() {
|
||||
unsafe { std::env::remove_var("LLM_PROXY__ENCRYPTION_KEY") };
|
||||
let result = init_from_env();
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_key_hex_and_base64() {
|
||||
// Hex key works
|
||||
init_with_key(TEST_KEY_HEX).unwrap();
|
||||
// Base64 key (same bytes encoded as base64)
|
||||
let base64_key = BASE64.encode(hex::decode(TEST_KEY_HEX).unwrap());
|
||||
// Re-initialization with same key (different encoding) is allowed
|
||||
let result = init_with_key(&base64_key);
|
||||
assert!(result.is_ok());
|
||||
// Encryption should still work
|
||||
let ciphertext = encrypt("test").unwrap();
|
||||
let decrypted = decrypt(&ciphertext).unwrap();
|
||||
assert_eq!(decrypted, "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // conflicts with global state from other tests
|
||||
fn test_already_initialized() {
|
||||
init_with_key(TEST_KEY_HEX).unwrap();
|
||||
let result = init_with_key(TEST_KEY_HEX);
|
||||
assert!(result.is_ok()); // same key allowed
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore] // conflicts with global state from other tests
|
||||
fn test_already_initialized_different_key() {
|
||||
init_with_key(TEST_KEY_HEX).unwrap();
|
||||
let different_key = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e20";
|
||||
let result = init_with_key(different_key);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod crypto;
|
||||
pub mod registry;
|
||||
pub mod streaming;
|
||||
pub mod tokens;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::client::ClientManager;
|
||||
|
||||
use crate::errors::AppError;
|
||||
use crate::logging::{RequestLog, RequestLogger};
|
||||
use crate::models::ToolCall;
|
||||
@@ -18,7 +18,6 @@ pub struct StreamConfig {
|
||||
pub prompt_tokens: u32,
|
||||
pub has_images: bool,
|
||||
pub logger: Arc<RequestLogger>,
|
||||
pub client_manager: Arc<ClientManager>,
|
||||
pub model_registry: Arc<crate::models::registry::ModelRegistry>,
|
||||
pub model_config_cache: ModelConfigCache,
|
||||
}
|
||||
@@ -36,7 +35,6 @@ pub struct AggregatingStream<S> {
|
||||
/// Real usage data from the provider's final stream chunk (when available).
|
||||
real_usage: Option<StreamUsage>,
|
||||
logger: Arc<RequestLogger>,
|
||||
client_manager: Arc<ClientManager>,
|
||||
model_registry: Arc<crate::models::registry::ModelRegistry>,
|
||||
model_config_cache: ModelConfigCache,
|
||||
start_time: std::time::Instant,
|
||||
@@ -60,7 +58,6 @@ where
|
||||
accumulated_tool_calls: Vec::new(),
|
||||
real_usage: None,
|
||||
logger: config.logger,
|
||||
client_manager: config.client_manager,
|
||||
model_registry: config.model_registry,
|
||||
model_config_cache: config.model_config_cache,
|
||||
start_time: std::time::Instant::now(),
|
||||
@@ -79,7 +76,6 @@ where
|
||||
let provider_name = self.provider.name().to_string();
|
||||
let model = self.model.clone();
|
||||
let logger = self.logger.clone();
|
||||
let client_manager = self.client_manager.clone();
|
||||
let provider = self.provider.clone();
|
||||
let estimated_prompt_tokens = self.prompt_tokens;
|
||||
let has_images = self.has_images;
|
||||
@@ -100,11 +96,12 @@ where
|
||||
// Spawn a background task to log the completion
|
||||
tokio::spawn(async move {
|
||||
// Use real usage from the provider when available, otherwise fall back to estimates
|
||||
let (prompt_tokens, completion_tokens, total_tokens, cache_read_tokens, cache_write_tokens) =
|
||||
let (prompt_tokens, completion_tokens, reasoning_tokens, total_tokens, cache_read_tokens, cache_write_tokens) =
|
||||
if let Some(usage) = &real_usage {
|
||||
(
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens,
|
||||
usage.reasoning_tokens,
|
||||
usage.total_tokens,
|
||||
usage.cache_read_tokens,
|
||||
usage.cache_write_tokens,
|
||||
@@ -113,6 +110,7 @@ where
|
||||
(
|
||||
estimated_prompt_tokens,
|
||||
estimated_completion,
|
||||
estimated_reasoning_tokens,
|
||||
estimated_prompt_tokens + estimated_completion,
|
||||
0u32,
|
||||
0u32,
|
||||
@@ -122,8 +120,22 @@ where
|
||||
// Check in-memory cache for cost overrides (no SQLite hit)
|
||||
let cost = if let Some(cached) = config_cache.get(&model).await {
|
||||
if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) {
|
||||
// Cost override doesn't have cache-aware pricing, use simple formula
|
||||
(prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0)
|
||||
// Manual overrides logic: if cache rates are provided, use cache-aware formula.
|
||||
let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens);
|
||||
let mut total = (non_cached_prompt as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0);
|
||||
|
||||
if let Some(cr) = cached.cache_read_cost_per_m {
|
||||
total += cache_read_tokens as f64 * cr / 1_000_000.0;
|
||||
} else {
|
||||
// Charge cached tokens at full input rate if no specific rate provided
|
||||
total += cache_read_tokens as f64 * p / 1_000_000.0;
|
||||
}
|
||||
|
||||
if let Some(cw) = cached.cache_write_cost_per_m {
|
||||
total += cache_write_tokens as f64 * cw / 1_000_000.0;
|
||||
}
|
||||
|
||||
total
|
||||
} else {
|
||||
provider.calculate_cost(
|
||||
&model,
|
||||
@@ -153,6 +165,7 @@ where
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
reasoning_tokens,
|
||||
total_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
@@ -162,11 +175,6 @@ where
|
||||
error_message: None,
|
||||
duration_ms: duration.as_millis() as u64,
|
||||
});
|
||||
|
||||
// Update client usage
|
||||
let _ = client_manager
|
||||
.update_client_usage(&client_id, total_tokens as i64, cost)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -304,7 +312,6 @@ mod tests {
|
||||
let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
|
||||
let (dashboard_tx, _) = tokio::sync::broadcast::channel(16);
|
||||
let logger = Arc::new(RequestLogger::new(pool.clone(), dashboard_tx));
|
||||
let client_manager = Arc::new(ClientManager::new(pool.clone()));
|
||||
let registry = Arc::new(crate::models::registry::ModelRegistry {
|
||||
providers: std::collections::HashMap::new(),
|
||||
});
|
||||
@@ -318,7 +325,6 @@ mod tests {
|
||||
prompt_tokens: 10,
|
||||
has_images: false,
|
||||
logger,
|
||||
client_manager,
|
||||
model_registry: registry,
|
||||
model_config_cache: ModelConfigCache::new(pool.clone()),
|
||||
},
|
||||
|
||||
@@ -61,9 +61,9 @@
|
||||
--text-white: var(--fg0);
|
||||
|
||||
/* Borders */
|
||||
--border-color: var(--bg2);
|
||||
--border-radius: 8px;
|
||||
--border-radius-sm: 4px;
|
||||
--border-color: var(--bg3);
|
||||
--border-radius: 0px;
|
||||
--border-radius-sm: 0px;
|
||||
|
||||
/* Spacing System */
|
||||
--spacing-xs: 0.25rem;
|
||||
@@ -72,15 +72,15 @@
|
||||
--spacing-lg: 1.5rem;
|
||||
--spacing-xl: 2rem;
|
||||
|
||||
/* Shadows */
|
||||
--shadow-sm: 0 1px 2px 0 rgba(0, 0, 0, 0.2);
|
||||
--shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.3);
|
||||
--shadow-md: 0 10px 15px -3px rgba(0, 0, 0, 0.4);
|
||||
--shadow-lg: 0 20px 25px -5px rgba(0, 0, 0, 0.5);
|
||||
/* Shadows - Retro Block Style */
|
||||
--shadow-sm: 2px 2px 0px rgba(0, 0, 0, 0.4);
|
||||
--shadow: 4px 4px 0px rgba(0, 0, 0, 0.5);
|
||||
--shadow-md: 6px 6px 0px rgba(0, 0, 0, 0.6);
|
||||
--shadow-lg: 8px 8px 0px rgba(0, 0, 0, 0.7);
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: 'Inter', -apple-system, sans-serif;
|
||||
font-family: 'JetBrains Mono', 'Fira Code', 'Courier New', monospace;
|
||||
background-color: var(--bg-primary);
|
||||
color: var(--text-primary);
|
||||
line-height: 1.6;
|
||||
@@ -105,12 +105,12 @@ body {
|
||||
|
||||
.login-card {
|
||||
background: var(--bg1);
|
||||
border-radius: 24px;
|
||||
border-radius: var(--border-radius);
|
||||
padding: 4rem 2.5rem 3rem;
|
||||
width: 100%;
|
||||
max-width: 440px;
|
||||
box-shadow: var(--shadow-lg);
|
||||
border: 1px solid var(--bg2);
|
||||
border: 2px solid var(--bg3);
|
||||
text-align: center;
|
||||
animation: slideUp 0.6s cubic-bezier(0.34, 1.56, 0.64, 1);
|
||||
position: relative;
|
||||
@@ -191,7 +191,7 @@ body {
|
||||
color: var(--fg3);
|
||||
pointer-events: none;
|
||||
transition: all 0.25s ease;
|
||||
background: var(--bg1);
|
||||
background: transparent;
|
||||
padding: 0 0.375rem;
|
||||
z-index: 2;
|
||||
font-weight: 500;
|
||||
@@ -202,30 +202,32 @@ body {
|
||||
|
||||
.form-group input:focus ~ label,
|
||||
.form-group input:not(:placeholder-shown) ~ label {
|
||||
top: -0.625rem;
|
||||
top: 0;
|
||||
left: 0.875rem;
|
||||
font-size: 0.7rem;
|
||||
font-size: 0.75rem;
|
||||
color: var(--orange);
|
||||
font-weight: 600;
|
||||
transform: translateY(0);
|
||||
transform: translateY(-50%);
|
||||
background: linear-gradient(180deg, var(--bg1) 50%, var(--bg0) 50%);
|
||||
}
|
||||
|
||||
.form-group input {
|
||||
padding: 1rem 1.25rem;
|
||||
background: var(--bg0);
|
||||
border: 2px solid var(--bg3);
|
||||
border-radius: 12px;
|
||||
border-radius: var(--border-radius);
|
||||
font-family: inherit;
|
||||
font-size: 1rem;
|
||||
color: var(--fg1);
|
||||
transition: all 0.3s;
|
||||
transition: all 0.2s;
|
||||
width: 100%;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.form-group input:focus {
|
||||
border-color: var(--orange);
|
||||
box-shadow: 0 0 0 4px rgba(214, 93, 14, 0.2);
|
||||
outline: none;
|
||||
box-shadow: 4px 4px 0px rgba(214, 93, 14, 0.4);
|
||||
}
|
||||
|
||||
.login-btn {
|
||||
@@ -732,11 +734,11 @@ body {
|
||||
.stat-change.positive { color: var(--green-light); }
|
||||
.stat-change.negative { color: var(--red-light); }
|
||||
|
||||
/* Generic Cards */
|
||||
/* Cards */
|
||||
.card {
|
||||
background: var(--bg1);
|
||||
border-radius: var(--border-radius);
|
||||
border: 1px solid var(--bg2);
|
||||
border: 1px solid var(--bg3);
|
||||
box-shadow: var(--shadow-sm);
|
||||
margin-bottom: 1.5rem;
|
||||
display: flex;
|
||||
@@ -749,6 +751,15 @@ body {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
flex-wrap: wrap;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.card-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.card-title {
|
||||
@@ -817,25 +828,26 @@ body {
|
||||
/* Badges */
|
||||
.status-badge {
|
||||
padding: 0.25rem 0.75rem;
|
||||
border-radius: 9999px;
|
||||
border-radius: var(--border-radius);
|
||||
font-size: 0.7rem;
|
||||
font-weight: 700;
|
||||
text-transform: uppercase;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.375rem;
|
||||
border: 1px solid transparent;
|
||||
}
|
||||
|
||||
.status-badge.online, .status-badge.success { background: rgba(184, 187, 38, 0.2); color: var(--green-light); }
|
||||
.status-badge.offline, .status-badge.danger { background: rgba(251, 73, 52, 0.2); color: var(--red-light); }
|
||||
.status-badge.warning { background: rgba(250, 189, 47, 0.2); color: var(--yellow-light); }
|
||||
.status-badge.online, .status-badge.success { background: rgba(184, 187, 38, 0.2); color: var(--green-light); border-color: rgba(184, 187, 38, 0.4); }
|
||||
.status-badge.offline, .status-badge.danger { background: rgba(251, 73, 52, 0.2); color: var(--red-light); border-color: rgba(251, 73, 52, 0.4); }
|
||||
.status-badge.warning { background: rgba(250, 189, 47, 0.2); color: var(--yellow-light); border-color: rgba(250, 189, 47, 0.4); }
|
||||
|
||||
.badge-client {
|
||||
background: var(--bg2);
|
||||
color: var(--blue-light);
|
||||
padding: 2px 8px;
|
||||
border-radius: 6px;
|
||||
font-family: monospace;
|
||||
border-radius: var(--border-radius);
|
||||
font-family: inherit;
|
||||
font-size: 0.85rem;
|
||||
border: 1px solid var(--bg3);
|
||||
}
|
||||
@@ -889,7 +901,7 @@ body {
|
||||
width: 100%;
|
||||
background: var(--bg0);
|
||||
border: 1px solid var(--bg3);
|
||||
border-radius: 8px;
|
||||
border-radius: var(--border-radius);
|
||||
padding: 0.75rem;
|
||||
font-family: inherit;
|
||||
font-size: 0.875rem;
|
||||
@@ -900,7 +912,7 @@ body {
|
||||
.form-control input:focus, .form-control textarea:focus, .form-control select:focus {
|
||||
outline: none;
|
||||
border-color: var(--orange);
|
||||
box-shadow: 0 0 0 2px rgba(214, 93, 14, 0.2);
|
||||
box-shadow: 2px 2px 0px rgba(214, 93, 14, 0.4);
|
||||
}
|
||||
|
||||
.btn {
|
||||
@@ -908,21 +920,27 @@ body {
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.625rem 1.25rem;
|
||||
border-radius: 8px;
|
||||
border-radius: var(--border-radius);
|
||||
font-weight: 600;
|
||||
font-size: 0.875rem;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
transition: all 0.1s;
|
||||
border: 1px solid transparent;
|
||||
text-transform: uppercase;
|
||||
}
|
||||
|
||||
.btn-primary { background: var(--orange); color: var(--bg0); }
|
||||
.btn:active {
|
||||
transform: translate(2px, 2px);
|
||||
box-shadow: none !important;
|
||||
}
|
||||
|
||||
.btn-primary { background: var(--orange); color: var(--bg0); box-shadow: 2px 2px 0px var(--bg4); }
|
||||
.btn-primary:hover { background: var(--orange-light); }
|
||||
|
||||
.btn-secondary { background: var(--bg2); border-color: var(--bg3); color: var(--fg1); }
|
||||
.btn-secondary { background: var(--bg2); border-color: var(--bg3); color: var(--fg1); box-shadow: 2px 2px 0px var(--bg0); }
|
||||
.btn-secondary:hover { background: var(--bg3); color: var(--fg0); }
|
||||
|
||||
.btn-danger { background: var(--red); color: var(--fg0); }
|
||||
.btn-danger { background: var(--red); color: var(--fg0); box-shadow: 2px 2px 0px var(--bg4); }
|
||||
.btn-danger:hover { background: var(--red-light); }
|
||||
|
||||
/* Small inline action buttons (edit, delete, copy) */
|
||||
@@ -981,13 +999,13 @@ body {
|
||||
|
||||
.modal-content {
|
||||
background: var(--bg1);
|
||||
border-radius: 16px;
|
||||
border-radius: var(--border-radius);
|
||||
width: 90%;
|
||||
max-width: 500px;
|
||||
box-shadow: var(--shadow-lg);
|
||||
border: 1px solid var(--bg3);
|
||||
border: 2px solid var(--bg3);
|
||||
transform: translateY(20px);
|
||||
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.modal.active .modal-content {
|
||||
|
||||
@@ -4,11 +4,11 @@
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>LLM Proxy Gateway - Admin Dashboard</title>
|
||||
<link rel="stylesheet" href="/css/dashboard.css?v=7">
|
||||
<link rel="stylesheet" href="/css/dashboard.css?v=11">
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
|
||||
<link rel="icon" href="img/logo-icon.png" type="image/png" sizes="any">
|
||||
<link rel="apple-touch-icon" href="img/logo-icon.png">
|
||||
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet">
|
||||
<link href="https://fonts.googleapis.com/css2?family=Fira+Code:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;700&display=swap" rel="stylesheet">
|
||||
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/luxon@3.4.4/build/global/luxon.min.js"></script>
|
||||
</head>
|
||||
@@ -17,12 +17,11 @@
|
||||
<div id="login-screen" class="login-container">
|
||||
<div class="login-card">
|
||||
<div class="login-header">
|
||||
<img src="img/logo-full.png" alt="LLM Proxy Logo" class="login-logo" onerror="this.style.display='none'; this.nextElementSibling.style.display='block';">
|
||||
<i class="fas fa-robot login-logo-fallback" style="display: none;"></i>
|
||||
<i class="fas fa-terminal login-logo-fallback"></i>
|
||||
<h1>LLM Proxy Gateway</h1>
|
||||
<p class="login-subtitle">Admin Dashboard</p>
|
||||
</div>
|
||||
<form id="login-form" class="login-form">
|
||||
<form id="login-form" class="login-form" onsubmit="event.preventDefault();">
|
||||
<div class="form-group">
|
||||
<input type="text" id="username" name="username" placeholder=" " required>
|
||||
<label for="username">
|
||||
|
||||
@@ -32,9 +32,28 @@ class ApiClient {
|
||||
}
|
||||
|
||||
if (!response.ok || !result.success) {
|
||||
// Handle authentication errors (session expired, server restarted, etc.)
|
||||
if (response.status === 401 ||
|
||||
result.error === 'Session expired or invalid' ||
|
||||
result.error === 'Not authenticated' ||
|
||||
result.error === 'Admin access required') {
|
||||
|
||||
if (window.authManager) {
|
||||
// Try to logout to clear local state and show login screen
|
||||
window.authManager.logout();
|
||||
}
|
||||
}
|
||||
throw new Error(result.error || `HTTP error! status: ${response.status}`);
|
||||
}
|
||||
|
||||
// Handling X-Refreshed-Token header
|
||||
if (response.headers.get('X-Refreshed-Token') && window.authManager) {
|
||||
window.authManager.token = response.headers.get('X-Refreshed-Token');
|
||||
if (window.authManager.setToken) {
|
||||
window.authManager.setToken(window.authManager.token);
|
||||
}
|
||||
}
|
||||
|
||||
return result.data;
|
||||
}
|
||||
|
||||
@@ -87,6 +106,17 @@ class ApiClient {
|
||||
const date = luxon.DateTime.fromISO(dateStr);
|
||||
return date.toRelative();
|
||||
}
|
||||
|
||||
// Helper for escaping HTML
|
||||
escapeHtml(unsafe) {
|
||||
if (unsafe === undefined || unsafe === null) return '';
|
||||
return unsafe.toString()
|
||||
.replace(/&/g, "&")
|
||||
.replace(/</g, "<")
|
||||
.replace(/>/g, ">")
|
||||
.replace(/"/g, """)
|
||||
.replace(/'/g, "'");
|
||||
}
|
||||
}
|
||||
|
||||
window.api = new ApiClient();
|
||||
|
||||
@@ -50,6 +50,12 @@ class AuthManager {
|
||||
});
|
||||
}
|
||||
|
||||
setToken(newToken) {
|
||||
if (!newToken) return;
|
||||
this.token = newToken;
|
||||
localStorage.setItem('dashboard_token', this.token);
|
||||
}
|
||||
|
||||
async login(username, password) {
|
||||
const errorElement = document.getElementById('login-error');
|
||||
const loginBtn = document.querySelector('.login-btn');
|
||||
|
||||
@@ -285,7 +285,30 @@ class Dashboard {
|
||||
<p class="card-subtitle">Manage model availability and custom pricing</p>
|
||||
</div>
|
||||
<div class="card-actions">
|
||||
<input type="text" id="model-search" placeholder="Search models..." class="form-control" style="margin-bottom: 0; padding: 4px 8px; width: 250px;">
|
||||
<select id="model-provider-filter" class="form-control" style="margin-bottom: 0; padding: 4px 8px; width: auto;">
|
||||
<option value="">All Providers</option>
|
||||
<option value="openai">OpenAI</option>
|
||||
<option value="anthropic">Anthropic / Gemini</option>
|
||||
<option value="google">Google</option>
|
||||
<option value="deepseek">DeepSeek</option>
|
||||
<option value="xai">xAI</option>
|
||||
<option value="meta">Meta</option>
|
||||
<option value="cohere">Cohere</option>
|
||||
<option value="mistral">Mistral</option>
|
||||
<option value="other">Other</option>
|
||||
</select>
|
||||
<select id="model-modality-filter" class="form-control" style="margin-bottom: 0; padding: 4px 8px; width: auto;">
|
||||
<option value="">All Modalities</option>
|
||||
<option value="text">Text</option>
|
||||
<option value="image">Vision/Image</option>
|
||||
<option value="audio">Audio</option>
|
||||
</select>
|
||||
<select id="model-capability-filter" class="form-control" style="margin-bottom: 0; padding: 4px 8px; width: auto;">
|
||||
<option value="">All Capabilities</option>
|
||||
<option value="tool_call">Tool Calling</option>
|
||||
<option value="reasoning">Reasoning</option>
|
||||
</select>
|
||||
<input type="text" id="model-search" placeholder="Search models..." class="form-control" style="margin-bottom: 0; padding: 4px 8px; width: 200px;">
|
||||
</div>
|
||||
</div>
|
||||
<div class="table-container">
|
||||
|
||||
@@ -42,12 +42,15 @@ class ClientsPage {
|
||||
const statusIcon = client.status === 'active' ? 'check-circle' : 'clock';
|
||||
const created = luxon.DateTime.fromISO(client.created_at).toFormat('MMM dd, yyyy');
|
||||
|
||||
const escapedId = window.api.escapeHtml(client.id);
|
||||
const escapedName = window.api.escapeHtml(client.name);
|
||||
|
||||
return `
|
||||
<tr>
|
||||
<td><span class="badge-client">${client.id}</span></td>
|
||||
<td><strong>${client.name}</strong></td>
|
||||
<td><span class="badge-client">${escapedId}</span></td>
|
||||
<td><strong>${escapedName}</strong></td>
|
||||
<td>
|
||||
<code class="token-display">sk-••••${client.id.substring(client.id.length - 4)}</code>
|
||||
<code class="token-display">sk-••••${escapedId.substring(escapedId.length - 4)}</code>
|
||||
</td>
|
||||
<td>${created}</td>
|
||||
<td>${client.last_used ? window.api.formatTimeAgo(client.last_used) : 'Never'}</td>
|
||||
@@ -55,16 +58,16 @@ class ClientsPage {
|
||||
<td>
|
||||
<span class="status-badge ${statusClass}">
|
||||
<i class="fas fa-${statusIcon}"></i>
|
||||
${client.status}
|
||||
${window.api.escapeHtml(client.status)}
|
||||
</span>
|
||||
</td>
|
||||
<td>
|
||||
${window._userRole === 'admin' ? `
|
||||
<div class="action-buttons">
|
||||
<button class="btn-action" title="Edit" onclick="window.clientsPage.editClient('${client.id}')">
|
||||
<button class="btn-action" title="Edit" onclick="window.clientsPage.editClient('${escapedId}')">
|
||||
<i class="fas fa-edit"></i>
|
||||
</button>
|
||||
<button class="btn-action danger" title="Delete" onclick="window.clientsPage.deleteClient('${client.id}')">
|
||||
<button class="btn-action danger" title="Delete" onclick="window.clientsPage.deleteClient('${escapedId}')">
|
||||
<i class="fas fa-trash"></i>
|
||||
</button>
|
||||
</div>
|
||||
@@ -188,10 +191,13 @@ class ClientsPage {
|
||||
showTokenRevealModal(clientName, token) {
|
||||
const modal = document.createElement('div');
|
||||
modal.className = 'modal active';
|
||||
const escapedName = window.api.escapeHtml(clientName);
|
||||
const escapedToken = window.api.escapeHtml(token);
|
||||
|
||||
modal.innerHTML = `
|
||||
<div class="modal-content">
|
||||
<div class="modal-header">
|
||||
<h3 class="modal-title">Client Created: ${clientName}</h3>
|
||||
<h3 class="modal-title">Client Created: ${escapedName}</h3>
|
||||
</div>
|
||||
<div class="modal-body">
|
||||
<p style="margin-bottom: 0.75rem; color: var(--yellow);">
|
||||
@@ -201,7 +207,7 @@ class ClientsPage {
|
||||
<div class="form-control">
|
||||
<label>API Token</label>
|
||||
<div style="display: flex; gap: 0.5rem;">
|
||||
<input type="text" id="revealed-token" value="${token}" readonly
|
||||
<input type="text" id="revealed-token" value="${escapedToken}" readonly
|
||||
style="font-family: monospace; font-size: 0.85rem;">
|
||||
<button class="btn btn-secondary" id="copy-token-btn" title="Copy">
|
||||
<i class="fas fa-copy"></i>
|
||||
@@ -248,10 +254,16 @@ class ClientsPage {
|
||||
showEditClientModal(client) {
|
||||
const modal = document.createElement('div');
|
||||
modal.className = 'modal active';
|
||||
|
||||
const escapedId = window.api.escapeHtml(client.id);
|
||||
const escapedName = window.api.escapeHtml(client.name);
|
||||
const escapedDescription = window.api.escapeHtml(client.description);
|
||||
const escapedRateLimit = window.api.escapeHtml(client.rate_limit_per_minute);
|
||||
|
||||
modal.innerHTML = `
|
||||
<div class="modal-content">
|
||||
<div class="modal-header">
|
||||
<h3 class="modal-title">Edit Client: ${client.id}</h3>
|
||||
<h3 class="modal-title">Edit Client: ${escapedId}</h3>
|
||||
<button class="modal-close" onclick="this.closest('.modal').remove()">
|
||||
<i class="fas fa-times"></i>
|
||||
</button>
|
||||
@@ -259,15 +271,15 @@ class ClientsPage {
|
||||
<div class="modal-body">
|
||||
<div class="form-control">
|
||||
<label for="edit-client-name">Display Name</label>
|
||||
<input type="text" id="edit-client-name" value="${client.name || ''}" placeholder="e.g. My Coding Assistant">
|
||||
<input type="text" id="edit-client-name" value="${escapedName}" placeholder="e.g. My Coding Assistant">
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label for="edit-client-description">Description</label>
|
||||
<textarea id="edit-client-description" rows="3" placeholder="Optional description">${client.description || ''}</textarea>
|
||||
<textarea id="edit-client-description" rows="3" placeholder="Optional description">${escapedDescription}</textarea>
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label for="edit-client-rate-limit">Rate Limit (requests/minute)</label>
|
||||
<input type="number" id="edit-client-rate-limit" min="0" value="${client.rate_limit_per_minute || ''}" placeholder="Leave empty for unlimited">
|
||||
<input type="number" id="edit-client-rate-limit" min="0" value="${escapedRateLimit}" placeholder="Leave empty for unlimited">
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label class="toggle-label">
|
||||
@@ -357,12 +369,16 @@ class ClientsPage {
|
||||
const lastUsed = t.last_used_at
|
||||
? luxon.DateTime.fromISO(t.last_used_at).toRelative()
|
||||
: 'Never';
|
||||
const escapedMaskedToken = window.api.escapeHtml(t.token_masked);
|
||||
const escapedClientId = window.api.escapeHtml(clientId);
|
||||
const tokenId = parseInt(t.id); // Assuming ID is numeric
|
||||
|
||||
return `
|
||||
<div style="display: flex; align-items: center; gap: 0.5rem; padding: 0.4rem 0; border-bottom: 1px solid var(--bg3);">
|
||||
<code style="flex: 1; font-size: 0.8rem; color: var(--fg2);">${t.token_masked}</code>
|
||||
<code style="flex: 1; font-size: 0.8rem; color: var(--fg2);">${escapedMaskedToken}</code>
|
||||
<span style="font-size: 0.75rem; color: var(--fg4);" title="Last used">${lastUsed}</span>
|
||||
<button class="btn-action danger" title="Revoke" style="padding: 0.2rem 0.4rem;"
|
||||
onclick="window.clientsPage.revokeToken('${clientId}', ${t.id}, this)">
|
||||
onclick="window.clientsPage.revokeToken('${escapedClientId}', ${tokenId}, this)">
|
||||
<i class="fas fa-trash" style="font-size: 0.75rem;"></i>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
@@ -38,6 +38,24 @@ class LogsPage {
|
||||
const statusClass = log.status === 'success' ? 'success' : 'danger';
|
||||
const timestamp = luxon.DateTime.fromISO(log.timestamp).toFormat('yyyy-MM-dd HH:mm:ss');
|
||||
|
||||
let tokenDetails = `${log.tokens} total tokens`;
|
||||
if (log.status === 'success') {
|
||||
const parts = [];
|
||||
parts.push(`${log.prompt_tokens} in`);
|
||||
|
||||
let completionStr = `${log.completion_tokens} out`;
|
||||
if (log.reasoning_tokens > 0) {
|
||||
completionStr += ` (${log.reasoning_tokens} reasoning)`;
|
||||
}
|
||||
parts.push(completionStr);
|
||||
|
||||
if (log.cache_read_tokens > 0) {
|
||||
parts.push(`${log.cache_read_tokens} cache-hit`);
|
||||
}
|
||||
|
||||
tokenDetails = parts.join(', ');
|
||||
}
|
||||
|
||||
return `
|
||||
<tr class="log-row">
|
||||
<td class="whitespace-nowrap">${timestamp}</td>
|
||||
@@ -55,7 +73,7 @@ class LogsPage {
|
||||
<td>
|
||||
<div class="log-message-container">
|
||||
<code class="log-model">${log.model}</code>
|
||||
<span class="log-tokens">${log.tokens} tokens</span>
|
||||
<span class="log-tokens" title="${log.tokens} total tokens">${tokenDetails}</span>
|
||||
<span class="log-duration">${log.duration}ms</span>
|
||||
${log.error ? `<div class="log-error-msg">${log.error}</div>` : ''}
|
||||
</div>
|
||||
|
||||
@@ -31,13 +31,58 @@ class ModelsPage {
|
||||
return;
|
||||
}
|
||||
|
||||
const searchInput = document.getElementById('model-search');
|
||||
const providerFilter = document.getElementById('model-provider-filter');
|
||||
const modalityFilter = document.getElementById('model-modality-filter');
|
||||
const capabilityFilter = document.getElementById('model-capability-filter');
|
||||
|
||||
const q = searchInput ? searchInput.value.toLowerCase() : '';
|
||||
const providerVal = providerFilter ? providerFilter.value : '';
|
||||
const modalityVal = modalityFilter ? modalityFilter.value : '';
|
||||
const capabilityVal = capabilityFilter ? capabilityFilter.value : '';
|
||||
|
||||
// Apply filters non-destructively
|
||||
let filteredModels = this.models.filter(m => {
|
||||
// Text search
|
||||
if (q && !(m.id.toLowerCase().includes(q) || m.name.toLowerCase().includes(q) || m.provider.toLowerCase().includes(q))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Provider filter
|
||||
if (providerVal) {
|
||||
if (providerVal === 'other') {
|
||||
const known = ['openai', 'anthropic', 'google', 'deepseek', 'xai', 'meta', 'cohere', 'mistral'];
|
||||
if (known.includes(m.provider.toLowerCase())) return false;
|
||||
} else if (m.provider.toLowerCase() !== providerVal) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Modality filter
|
||||
if (modalityVal) {
|
||||
const mods = m.modalities && m.modalities.input ? m.modalities.input.map(x => x.toLowerCase()) : [];
|
||||
if (!mods.includes(modalityVal)) return false;
|
||||
}
|
||||
|
||||
// Capability filter
|
||||
if (capabilityVal === 'tool_call' && !m.tool_call) return false;
|
||||
if (capabilityVal === 'reasoning' && !m.reasoning) return false;
|
||||
|
||||
return true;
|
||||
});
|
||||
|
||||
if (filteredModels.length === 0) {
|
||||
tableBody.innerHTML = '<tr><td colspan="7" class="text-center">No models match the selected filters</td></tr>';
|
||||
return;
|
||||
}
|
||||
|
||||
// Sort by provider then name
|
||||
this.models.sort((a, b) => {
|
||||
filteredModels.sort((a, b) => {
|
||||
if (a.provider !== b.provider) return a.provider.localeCompare(b.provider);
|
||||
return a.name.localeCompare(b.name);
|
||||
});
|
||||
|
||||
tableBody.innerHTML = this.models.map(model => {
|
||||
tableBody.innerHTML = filteredModels.map(model => {
|
||||
const statusClass = model.enabled ? 'success' : 'secondary';
|
||||
const statusIcon = model.enabled ? 'check-circle' : 'ban';
|
||||
|
||||
@@ -99,6 +144,14 @@ class ModelsPage {
|
||||
<input type="number" id="model-completion-cost" value="${model.completion_cost}" step="0.01">
|
||||
</div>
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label for="model-cache-read-cost">Cache Read Cost (per 1M tokens)</label>
|
||||
<input type="number" id="model-cache-read-cost" value="${model.cache_read_cost || 0}" step="0.01">
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label for="model-cache-write-cost">Cache Write Cost (per 1M tokens)</label>
|
||||
<input type="number" id="model-cache-write-cost" value="${model.cache_write_cost || 0}" step="0.01">
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label for="model-mapping">Internal Mapping (Optional)</label>
|
||||
<input type="text" id="model-mapping" value="${model.mapping || ''}" placeholder="e.g. gpt-4o-2024-05-13">
|
||||
@@ -118,6 +171,8 @@ class ModelsPage {
|
||||
const enabled = modal.querySelector('#model-enabled').checked;
|
||||
const promptCost = parseFloat(modal.querySelector('#model-prompt-cost').value);
|
||||
const completionCost = parseFloat(modal.querySelector('#model-completion-cost').value);
|
||||
const cacheReadCost = parseFloat(modal.querySelector('#model-cache-read-cost').value);
|
||||
const cacheWriteCost = parseFloat(modal.querySelector('#model-cache-write-cost').value);
|
||||
const mapping = modal.querySelector('#model-mapping').value;
|
||||
|
||||
try {
|
||||
@@ -125,6 +180,8 @@ class ModelsPage {
|
||||
enabled,
|
||||
prompt_cost: promptCost,
|
||||
completion_cost: completionCost,
|
||||
cache_read_cost: isNaN(cacheReadCost) ? null : cacheReadCost,
|
||||
cache_write_cost: isNaN(cacheWriteCost) ? null : cacheWriteCost,
|
||||
mapping: mapping || null
|
||||
});
|
||||
|
||||
@@ -138,27 +195,18 @@ class ModelsPage {
|
||||
}
|
||||
|
||||
setupEventListeners() {
|
||||
const searchInput = document.getElementById('model-search');
|
||||
if (searchInput) {
|
||||
searchInput.oninput = (e) => this.filterModels(e.target.value);
|
||||
}
|
||||
const attachFilter = (id) => {
|
||||
const el = document.getElementById(id);
|
||||
if (el) {
|
||||
el.addEventListener('input', () => this.renderModelsTable());
|
||||
el.addEventListener('change', () => this.renderModelsTable());
|
||||
}
|
||||
};
|
||||
|
||||
filterModels(query) {
|
||||
if (!query) {
|
||||
this.renderModelsTable();
|
||||
return;
|
||||
}
|
||||
|
||||
const q = query.toLowerCase();
|
||||
const originalModels = this.models;
|
||||
this.models = this.models.filter(m =>
|
||||
m.id.toLowerCase().includes(q) ||
|
||||
m.name.toLowerCase().includes(q) ||
|
||||
m.provider.toLowerCase().includes(q)
|
||||
);
|
||||
this.renderModelsTable();
|
||||
this.models = originalModels;
|
||||
attachFilter('model-search');
|
||||
attachFilter('model-provider-filter');
|
||||
attachFilter('model-modality-filter');
|
||||
attachFilter('model-capability-filter');
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -47,16 +47,21 @@ class ProvidersPage {
|
||||
const isLowBalance = provider.credit_balance <= provider.low_credit_threshold && provider.id !== 'ollama';
|
||||
const balanceColor = isLowBalance ? 'var(--red-light)' : 'var(--green-light)';
|
||||
|
||||
const escapedId = window.api.escapeHtml(provider.id);
|
||||
const escapedName = window.api.escapeHtml(provider.name);
|
||||
const escapedStatus = window.api.escapeHtml(provider.status);
|
||||
const billingMode = provider.billing_mode ? provider.billing_mode.toUpperCase() : 'PREPAID';
|
||||
|
||||
return `
|
||||
<div class="provider-card ${provider.status}">
|
||||
<div class="provider-card ${escapedStatus}">
|
||||
<div class="provider-card-header">
|
||||
<div class="provider-info">
|
||||
<h4 class="provider-name">${provider.name}</h4>
|
||||
<span class="provider-id">${provider.id}</span>
|
||||
<h4 class="provider-name">${escapedName}</h4>
|
||||
<span class="provider-id">${escapedId}</span>
|
||||
</div>
|
||||
<span class="status-badge ${statusClass}">
|
||||
<i class="fas fa-circle"></i>
|
||||
${provider.status}
|
||||
${escapedStatus}
|
||||
</span>
|
||||
</div>
|
||||
<div class="provider-card-body">
|
||||
@@ -67,12 +72,12 @@ class ProvidersPage {
|
||||
</div>
|
||||
<div class="meta-item" style="color: ${balanceColor}; font-weight: 700;">
|
||||
<i class="fas fa-wallet"></i>
|
||||
<span>Balance: ${provider.id === 'ollama' ? 'FREE' : window.api.formatCurrency(provider.credit_balance)}</span>
|
||||
<span>Balance: ${escapedId === 'ollama' ? 'FREE' : window.api.formatCurrency(provider.credit_balance)}</span>
|
||||
${isLowBalance ? '<i class="fas fa-exclamation-triangle" title="Low Balance"></i>' : ''}
|
||||
</div>
|
||||
<div class="meta-item">
|
||||
<i class="fas fa-exchange-alt"></i>
|
||||
<span>Billing: ${provider.billing_mode ? provider.billing_mode.toUpperCase() : 'PREPAID'}</span>
|
||||
<span>Billing: ${window.api.escapeHtml(billingMode)}</span>
|
||||
</div>
|
||||
<div class="meta-item">
|
||||
<i class="fas fa-clock"></i>
|
||||
@@ -80,16 +85,16 @@ class ProvidersPage {
|
||||
</div>
|
||||
</div>
|
||||
<div class="model-tags">
|
||||
${(provider.models || []).slice(0, 5).map(m => `<span class="model-tag">${m}</span>`).join('')}
|
||||
${(provider.models || []).slice(0, 5).map(m => `<span class="model-tag">${window.api.escapeHtml(m)}</span>`).join('')}
|
||||
${modelCount > 5 ? `<span class="model-tag more">+${modelCount - 5} more</span>` : ''}
|
||||
</div>
|
||||
</div>
|
||||
<div class="provider-card-footer">
|
||||
<button class="btn btn-secondary btn-sm" onclick="window.providersPage.testProvider('${provider.id}')">
|
||||
<button class="btn btn-secondary btn-sm" onclick="window.providersPage.testProvider('${escapedId}')">
|
||||
<i class="fas fa-vial"></i> Test
|
||||
</button>
|
||||
${window._userRole === 'admin' ? `
|
||||
<button class="btn btn-primary btn-sm" onclick="window.providersPage.configureProvider('${provider.id}')">
|
||||
<button class="btn btn-primary btn-sm" onclick="window.providersPage.configureProvider('${escapedId}')">
|
||||
<i class="fas fa-cog"></i> Config
|
||||
</button>
|
||||
` : ''}
|
||||
@@ -144,10 +149,17 @@ class ProvidersPage {
|
||||
|
||||
const modal = document.createElement('div');
|
||||
modal.className = 'modal active';
|
||||
|
||||
const escapedId = window.api.escapeHtml(provider.id);
|
||||
const escapedName = window.api.escapeHtml(provider.name);
|
||||
const escapedBaseUrl = window.api.escapeHtml(provider.base_url);
|
||||
const escapedBalance = window.api.escapeHtml(provider.credit_balance);
|
||||
const escapedThreshold = window.api.escapeHtml(provider.low_credit_threshold);
|
||||
|
||||
modal.innerHTML = `
|
||||
<div class="modal-content">
|
||||
<div class="modal-header">
|
||||
<h3 class="modal-title">Configure ${provider.name}</h3>
|
||||
<h3 class="modal-title">Configure ${escapedName}</h3>
|
||||
<button class="modal-close" onclick="this.closest('.modal').remove()">
|
||||
<i class="fas fa-times"></i>
|
||||
</button>
|
||||
@@ -161,7 +173,7 @@ class ProvidersPage {
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label for="provider-base-url">Base URL</label>
|
||||
<input type="text" id="provider-base-url" value="${provider.base_url || ''}" placeholder="Default API URL">
|
||||
<input type="text" id="provider-base-url" value="${escapedBaseUrl}" placeholder="Default API URL">
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label for="provider-api-key">API Key (Optional / Overwrite)</label>
|
||||
@@ -170,11 +182,11 @@ class ProvidersPage {
|
||||
<div class="grid-2">
|
||||
<div class="form-control">
|
||||
<label for="provider-balance">Current Credit Balance ($)</label>
|
||||
<input type="number" id="provider-balance" value="${provider.credit_balance}" step="0.01">
|
||||
<input type="number" id="provider-balance" value="${escapedBalance}" step="0.01">
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label for="provider-threshold">Low Balance Alert ($)</label>
|
||||
<input type="number" id="provider-threshold" value="${provider.low_credit_threshold}" step="0.50">
|
||||
<input type="number" id="provider-threshold" value="${escapedThreshold}" step="0.50">
|
||||
</div>
|
||||
</div>
|
||||
<div class="form-control">
|
||||
|
||||
@@ -279,9 +279,7 @@
|
||||
|
||||
// ── Helpers ────────────────────────────────────────────────────
|
||||
|
||||
function escapeHtml(str) {
|
||||
const div = document.createElement('div');
|
||||
div.textContent = str;
|
||||
return div.innerHTML;
|
||||
}
|
||||
function escapeHtml(str) {
|
||||
return window.api.escapeHtml(str);
|
||||
}
|
||||
})();
|
||||
|
||||
14
timeline.mmd
Normal file
14
timeline.mmd
Normal file
@@ -0,0 +1,14 @@
|
||||
gantt
|
||||
title LLM Proxy Project Timeline
|
||||
dateFormat YYYY-MM-DD
|
||||
section Frontend
|
||||
Standardize Escaping (users.js) :a1, 2026-03-06, 1d
|
||||
section Backend Cleanup
|
||||
Remove Unused Imports :b1, 2026-03-06, 1d
|
||||
section HMAC Migration
|
||||
Architecture Design :c1, 2026-03-07, 1d
|
||||
Backend Implementation :c2, after c1, 2d
|
||||
Session Refresh Logic :c3, after c2, 1d
|
||||
section Testing
|
||||
Integration Test (Encrypted Keys) :d1, 2026-03-09, 2d
|
||||
HMAC Verification Tests :d2, after c3, 1d
|
||||
Reference in New Issue
Block a user