diff --git a/.env.example b/.env.example index 689ca7b7..fed54e09 100644 --- a/.env.example +++ b/.env.example @@ -1,31 +1,28 @@ # LLM Proxy Gateway Environment Variables # Copy to .env and fill in your API keys -# OpenAI +# MANDATORY: Encryption key for sessions and stored API keys +# Must be a 32-byte hex or base64 encoded string +# Example (hex): LLM_PROXY__ENCRYPTION_KEY=0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef +LLM_PROXY__ENCRYPTION_KEY=your_secure_32_byte_key_here + +# LLM Provider API Keys (Standard Environment Variables) OPENAI_API_KEY=your_openai_api_key_here - -# Google Gemini GEMINI_API_KEY=your_gemini_api_key_here - -# DeepSeek DEEPSEEK_API_KEY=your_deepseek_api_key_here - -# xAI Grok (not yet available) GROK_API_KEY=your_grok_api_key_here -# Ollama (local server) -# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://your-ollama-host:11434/v1 +# Provider Overrides (Optional) +# LLM_PROXY__PROVIDERS__OPENAI__BASE_URL=https://api.openai.com/v1 +# LLM_PROXY__PROVIDERS__GEMINI__ENABLED=true +# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://localhost:11434/v1 # LLM_PROXY__PROVIDERS__OLLAMA__ENABLED=true # LLM_PROXY__PROVIDERS__OLLAMA__MODELS=llama3,mistral,llava -# Authentication tokens (comma-separated list) -LLM_PROXY__SERVER__AUTH_TOKENS=your_bearer_token_here,another_token - -# Server port (optional) +# Server Configuration LLM_PROXY__SERVER__PORT=8080 +LLM_PROXY__SERVER__HOST=0.0.0.0 -# Database path (optional) +# Database Configuration LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db - -# Session secret for HMAC-signed tokens (hex or base64 encoded, 32 bytes) -SESSION_SECRET=your_session_secret_here_32_bytes \ No newline at end of file +LLM_PROXY__DATABASE__MAX_CONNECTIONS=10 diff --git a/BACKEND_ARCHITECTURE.md b/BACKEND_ARCHITECTURE.md new file mode 100644 index 00000000..0ee3e0b9 --- /dev/null +++ b/BACKEND_ARCHITECTURE.md @@ -0,0 +1,61 @@ +# Backend Architecture (Go) + +The LLM Proxy backend is implemented in Go, focusing on high performance, clear concurrency patterns, and maintainability. + +## Core Technologies + +- **Runtime:** Go 1.22+ +- **Web Framework:** [Gin Gonic](https://github.com/gin-gonic/gin) - Fast and lightweight HTTP routing. +- **Database:** [sqlx](https://github.com/jmoiron/sqlx) - Lightweight wrapper for standard `database/sql`. +- **SQLite Driver:** [modernc.org/sqlite](https://modernc.org/sqlite) - CGO-free SQLite implementation for ease of cross-compilation. +- **Config:** [Viper](https://github.com/spf13/viper) - Robust configuration management supporting environment variables and files. + +## Project Structure + +```text +├── cmd/ +│ └── llm-proxy/ # Entry point (main.go) +├── internal/ +│ ├── config/ # Configuration loading and validation +│ ├── db/ # Database schema, migrations, and models +│ ├── middleware/ # Auth and logging middleware +│ ├── models/ # Unified request/response structs +│ ├── providers/ # LLM provider implementations (OpenAI, Gemini, etc.) +│ ├── server/ # HTTP server, dashboard handlers, and WebSocket hub +│ └── utils/ # Common utilities (multimodal, etc.) +└── static/ # Frontend assets (served by the backend) +``` + +## Key Components + +### 1. Provider Interface (`internal/providers/provider.go`) +Standardized interface for all LLM backends: +```go +type Provider interface { + Name() string + ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) + ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) +} +``` + +### 2. Asynchronous Logging (`internal/server/logging.go`) +Uses a buffered channel and background worker to log every request to SQLite without blocking the client response. It also broadcasts logs to the WebSocket hub for real-time dashboard updates. + +### 3. Session Management (`internal/server/sessions.go`) +Implements HMAC-SHA256 signed tokens for dashboard authentication. Sessions are stored in-memory with configurable TTL. + +### 4. WebSocket Hub (`internal/server/websocket.go`) +A centralized hub for managing WebSocket connections, allowing real-time broadcast of system events and request logs to the dashboard. + +## Concurrency Model + +Go's goroutines and channels are used extensively: +- **Streaming:** Each streaming request uses a goroutine to read and parse the provider's response, feeding chunks into a channel. +- **Logging:** A single background worker processes the `logChan` to perform database writes. +- **WebSocket:** The `Hub` runs in a dedicated goroutine, handling registration and broadcasting. + +## Security + +- **Encryption Key:** A mandatory 32-byte key is used for both session signing and encryption of sensitive data in the database. +- **Auth Middleware:** Verifies client API keys against the database before proxying requests to LLM providers. +- **Bcrypt:** Passwords for dashboard users are hashed using Bcrypt with a work factor of 12. diff --git a/Cargo.lock b/Cargo.lock deleted file mode 100644 index 936d5c1d..00000000 --- a/Cargo.lock +++ /dev/null @@ -1,4139 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 4 - -[[package]] -name = "adler2" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" - -[[package]] -name = "aead" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" -dependencies = [ - "crypto-common", - "generic-array", -] - -[[package]] -name = "aes" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" -dependencies = [ - "cfg-if", - "cipher", - "cpufeatures", -] - -[[package]] -name = "aes-gcm" -version = "0.10.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" -dependencies = [ - "aead", - "aes", - "cipher", - "ctr", - "ghash", - "subtle", -] - -[[package]] -name = "ahash" -version = "0.7.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "891477e0c6a8957309ee5c45a6368af3ae14bb510732d2684ffa19af310920f9" -dependencies = [ - "getrandom 0.2.17", - "once_cell", - "version_check", -] - -[[package]] -name = "aho-corasick" -version = "1.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" -dependencies = [ - "memchr", -] - -[[package]] -name = "allocator-api2" -version = "0.2.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" - -[[package]] -name = "android_system_properties" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" -dependencies = [ - "libc", -] - -[[package]] -name = "anstyle" -version = "1.0.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" - -[[package]] -name = "anyhow" -version = "1.0.102" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" - -[[package]] -name = "assert-json-diff" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" -dependencies = [ - "serde", - "serde_json", -] - -[[package]] -name = "assert_cmd" -version = "2.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c5bcfa8749ac45dd12cb11055aeeb6b27a3895560d60d71e3c23bf979e60514" -dependencies = [ - "anstyle", - "bstr", - "libc", - "predicates", - "predicates-core", - "predicates-tree", - "wait-timeout", -] - -[[package]] -name = "async-compression" -version = "0.4.41" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0f9ee0f6e02ffd7ad5816e9464499fba7b3effd01123b515c41d1697c43dad1" -dependencies = [ - "compression-codecs", - "compression-core", - "pin-project-lite", - "tokio", -] - -[[package]] -name = "async-stream" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" -dependencies = [ - "async-stream-impl", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-stream-impl" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "async-trait" -version = "0.1.89" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "atoi" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" -dependencies = [ - "num-traits", -] - -[[package]] -name = "atomic-waker" -version = "1.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" - -[[package]] -name = "autocfg" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" - -[[package]] -name = "axum" -version = "0.8.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" -dependencies = [ - "axum-core", - "axum-macros", - "base64 0.22.1", - "bytes", - "form_urlencoded", - "futures-util", - "http", - "http-body", - "http-body-util", - "hyper", - "hyper-util", - "itoa", - "matchit", - "memchr", - "mime", - "percent-encoding", - "pin-project-lite", - "serde_core", - "serde_json", - "serde_path_to_error", - "serde_urlencoded", - "sha1", - "sync_wrapper", - "tokio", - "tokio-tungstenite", - "tower", - "tower-layer", - "tower-service", - "tracing", -] - -[[package]] -name = "axum-core" -version = "0.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" -dependencies = [ - "bytes", - "futures-core", - "http", - "http-body", - "http-body-util", - "mime", - "pin-project-lite", - "sync_wrapper", - "tower-layer", - "tower-service", - "tracing", -] - -[[package]] -name = "axum-extra" -version = "0.12.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fef252edff26ddba56bbcdf2ee3307b8129acb86f5749b68990c168a6fcc9c76" -dependencies = [ - "axum", - "axum-core", - "bytes", - "futures-core", - "futures-util", - "headers", - "http", - "http-body", - "http-body-util", - "mime", - "pin-project-lite", - "tower-layer", - "tower-service", - "tracing", -] - -[[package]] -name = "axum-macros" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "base64" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" - -[[package]] -name = "base64" -version = "0.21.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" - -[[package]] -name = "base64" -version = "0.22.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" - -[[package]] -name = "base64ct" -version = "1.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" - -[[package]] -name = "bcrypt" -version = "0.15.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e65938ed058ef47d92cf8b346cc76ef48984572ade631927e9937b5ffc7662c7" -dependencies = [ - "base64 0.22.1", - "blowfish", - "getrandom 0.2.17", - "subtle", - "zeroize", -] - -[[package]] -name = "bit-set" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" -dependencies = [ - "bit-vec", -] - -[[package]] -name = "bit-vec" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" - -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - -[[package]] -name = "bitflags" -version = "2.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" -dependencies = [ - "serde_core", -] - -[[package]] -name = "block-buffer" -version = "0.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" -dependencies = [ - "generic-array", -] - -[[package]] -name = "blowfish" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e412e2cd0f2b2d93e02543ceae7917b3c70331573df19ee046bcbc35e45e87d7" -dependencies = [ - "byteorder", - "cipher", -] - -[[package]] -name = "bstr" -version = "1.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63044e1ae8e69f3b5a92c736ca6269b8d12fa7efe39bf34ddb06d102cf0e2cab" -dependencies = [ - "memchr", - "regex-automata", - "serde", -] - -[[package]] -name = "bumpalo" -version = "3.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" - -[[package]] -name = "bytemuck" -version = "1.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" - -[[package]] -name = "byteorder" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" - -[[package]] -name = "byteorder-lite" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" - -[[package]] -name = "bytes" -version = "1.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" - -[[package]] -name = "cc" -version = "1.2.56" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" -dependencies = [ - "find-msvc-tools", - "shlex", -] - -[[package]] -name = "cfg-if" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" - -[[package]] -name = "cfg_aliases" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" - -[[package]] -name = "chrono" -version = "0.4.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" -dependencies = [ - "iana-time-zone", - "js-sys", - "num-traits", - "serde", - "wasm-bindgen", - "windows-link", -] - -[[package]] -name = "cipher" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" -dependencies = [ - "crypto-common", - "inout", -] - -[[package]] -name = "colored" -version = "3.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "faf9468729b8cbcea668e36183cb69d317348c2e08e994829fb56ebfdfbaac34" -dependencies = [ - "windows-sys 0.61.2", -] - -[[package]] -name = "compression-codecs" -version = "0.4.37" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb7b51a7d9c967fc26773061ba86150f19c50c0d65c887cb1fbe295fd16619b7" -dependencies = [ - "compression-core", - "flate2", - "memchr", -] - -[[package]] -name = "compression-core" -version = "0.4.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75984efb6ed102a0d42db99afb6c1948f0380d1d91808d5529916e6c08b49d8d" - -[[package]] -name = "concurrent-queue" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" -dependencies = [ - "crossbeam-utils", -] - -[[package]] -name = "config" -version = "0.13.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23738e11972c7643e4ec947840fc463b6a571afcd3e735bdfce7d03c7a784aca" -dependencies = [ - "async-trait", - "json5", - "lazy_static", - "nom", - "pathdiff", - "ron", - "rust-ini", - "serde", - "serde_json", - "toml 0.5.11", - "yaml-rust", -] - -[[package]] -name = "console" -version = "0.15.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" -dependencies = [ - "encode_unicode", - "libc", - "once_cell", - "windows-sys 0.59.0", -] - -[[package]] -name = "const-oid" -version = "0.9.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" - -[[package]] -name = "core-foundation-sys" -version = "0.8.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" - -[[package]] -name = "cpufeatures" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" -dependencies = [ - "libc", -] - -[[package]] -name = "crc" -version = "3.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5eb8a2a1cd12ab0d987a5d5e825195d372001a4094a0376319d5a0ad71c1ba0d" -dependencies = [ - "crc-catalog", -] - -[[package]] -name = "crc-catalog" -version = "2.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" - -[[package]] -name = "crc32fast" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "crossbeam-queue" -version = "0.3.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" -dependencies = [ - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-utils" -version = "0.8.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" - -[[package]] -name = "crypto-common" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" -dependencies = [ - "generic-array", - "rand_core 0.6.4", - "typenum", -] - -[[package]] -name = "ctr" -version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" -dependencies = [ - "cipher", -] - -[[package]] -name = "dashmap" -version = "6.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" -dependencies = [ - "cfg-if", - "crossbeam-utils", - "hashbrown 0.14.5", - "lock_api", - "once_cell", - "parking_lot_core", -] - -[[package]] -name = "data-encoding" -version = "2.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea" - -[[package]] -name = "der" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" -dependencies = [ - "const-oid", - "pem-rfc7468", - "zeroize", -] - -[[package]] -name = "difflib" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6184e33543162437515c2e2b48714794e37845ec9851711914eec9d308f6ebe8" - -[[package]] -name = "digest" -version = "0.10.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" -dependencies = [ - "block-buffer", - "const-oid", - "crypto-common", - "subtle", -] - -[[package]] -name = "displaydoc" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "dlv-list" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0688c2a7f92e427f44895cd63841bff7b29f8d7a1648b9e7e07a4a365b2e1257" - -[[package]] -name = "dotenvy" -version = "0.15.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" - -[[package]] -name = "either" -version = "1.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" -dependencies = [ - "serde", -] - -[[package]] -name = "encode_unicode" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" - -[[package]] -name = "equivalent" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" - -[[package]] -name = "errno" -version = "0.3.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" -dependencies = [ - "libc", - "windows-sys 0.61.2", -] - -[[package]] -name = "etcetera" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" -dependencies = [ - "cfg-if", - "home", - "windows-sys 0.48.0", -] - -[[package]] -name = "event-listener" -version = "5.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13b66accf52311f30a0db42147dadea9850cb48cd070028831ae5f5d4b856ab" -dependencies = [ - "concurrent-queue", - "parking", - "pin-project-lite", -] - -[[package]] -name = "eventsource-stream" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" -dependencies = [ - "futures-core", - "nom", - "pin-project-lite", -] - -[[package]] -name = "fancy-regex" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "531e46835a22af56d1e3b66f04844bed63158bc094a628bec1d321d9b4c44bf2" -dependencies = [ - "bit-set", - "regex-automata", - "regex-syntax", -] - -[[package]] -name = "fastrand" -version = "2.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" - -[[package]] -name = "fdeflate" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c" -dependencies = [ - "simd-adler32", -] - -[[package]] -name = "find-msvc-tools" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" - -[[package]] -name = "flate2" -version = "1.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" -dependencies = [ - "crc32fast", - "miniz_oxide", -] - -[[package]] -name = "flume" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" -dependencies = [ - "futures-core", - "futures-sink", - "spin", -] - -[[package]] -name = "fnv" -version = "1.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" - -[[package]] -name = "foldhash" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" - -[[package]] -name = "form_urlencoded" -version = "1.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" -dependencies = [ - "percent-encoding", -] - -[[package]] -name = "futures" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" -dependencies = [ - "futures-channel", - "futures-core", - "futures-executor", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", -] - -[[package]] -name = "futures-channel" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" -dependencies = [ - "futures-core", - "futures-sink", -] - -[[package]] -name = "futures-core" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" - -[[package]] -name = "futures-executor" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" -dependencies = [ - "futures-core", - "futures-task", - "futures-util", -] - -[[package]] -name = "futures-intrusive" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" -dependencies = [ - "futures-core", - "lock_api", - "parking_lot", -] - -[[package]] -name = "futures-io" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" - -[[package]] -name = "futures-macro" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "futures-sink" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" - -[[package]] -name = "futures-task" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" - -[[package]] -name = "futures-timer" -version = "3.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" - -[[package]] -name = "futures-util" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" -dependencies = [ - "futures-channel", - "futures-core", - "futures-io", - "futures-macro", - "futures-sink", - "futures-task", - "memchr", - "pin-project-lite", - "slab", -] - -[[package]] -name = "generic-array" -version = "0.14.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" -dependencies = [ - "typenum", - "version_check", -] - -[[package]] -name = "getrandom" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" -dependencies = [ - "cfg-if", - "js-sys", - "libc", - "wasi", - "wasm-bindgen", -] - -[[package]] -name = "getrandom" -version = "0.3.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" -dependencies = [ - "cfg-if", - "js-sys", - "libc", - "r-efi", - "wasip2", - "wasm-bindgen", -] - -[[package]] -name = "getrandom" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "139ef39800118c7683f2fd3c98c1b23c09ae076556b435f8e9064ae108aaeeec" -dependencies = [ - "cfg-if", - "libc", - "r-efi", - "wasip2", - "wasip3", -] - -[[package]] -name = "ghash" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" -dependencies = [ - "opaque-debug", - "polyval", -] - -[[package]] -name = "governor" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0746aa765db78b521451ef74221663b57ba595bf83f75d0ce23cc09447c8139f" -dependencies = [ - "cfg-if", - "dashmap", - "futures-sink", - "futures-timer", - "futures-util", - "no-std-compat", - "nonzero_ext", - "parking_lot", - "portable-atomic", - "quanta", - "rand 0.8.5", - "smallvec", - "spinning_top", -] - -[[package]] -name = "h2" -version = "0.4.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" -dependencies = [ - "atomic-waker", - "bytes", - "fnv", - "futures-core", - "futures-sink", - "http", - "indexmap", - "slab", - "tokio", - "tokio-util", - "tracing", -] - -[[package]] -name = "hashbrown" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" -dependencies = [ - "ahash", -] - -[[package]] -name = "hashbrown" -version = "0.14.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" - -[[package]] -name = "hashbrown" -version = "0.15.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" -dependencies = [ - "allocator-api2", - "equivalent", - "foldhash", -] - -[[package]] -name = "hashbrown" -version = "0.16.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" - -[[package]] -name = "hashlink" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" -dependencies = [ - "hashbrown 0.15.5", -] - -[[package]] -name = "headers" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3314d5adb5d94bcdf56771f2e50dbbc80bb4bdf88967526706205ac9eff24eb" -dependencies = [ - "base64 0.22.1", - "bytes", - "headers-core", - "http", - "httpdate", - "mime", - "sha1", -] - -[[package]] -name = "headers-core" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" -dependencies = [ - "http", -] - -[[package]] -name = "heck" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" - -[[package]] -name = "hex" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" - -[[package]] -name = "hkdf" -version = "0.12.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" -dependencies = [ - "hmac", -] - -[[package]] -name = "hmac" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" -dependencies = [ - "digest", -] - -[[package]] -name = "home" -version = "0.5.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" -dependencies = [ - "windows-sys 0.61.2", -] - -[[package]] -name = "http" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" -dependencies = [ - "bytes", - "itoa", -] - -[[package]] -name = "http-body" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" -dependencies = [ - "bytes", - "http", -] - -[[package]] -name = "http-body-util" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" -dependencies = [ - "bytes", - "futures-core", - "http", - "http-body", - "pin-project-lite", -] - -[[package]] -name = "http-range-header" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9171a2ea8a68358193d15dd5d70c1c10a2afc3e7e4c5bc92bc9f025cebd7359c" - -[[package]] -name = "httparse" -version = "1.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" - -[[package]] -name = "httpdate" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" - -[[package]] -name = "hyper" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" -dependencies = [ - "atomic-waker", - "bytes", - "futures-channel", - "futures-core", - "h2", - "http", - "http-body", - "httparse", - "httpdate", - "itoa", - "pin-project-lite", - "pin-utils", - "smallvec", - "tokio", - "want", -] - -[[package]] -name = "hyper-rustls" -version = "0.27.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" -dependencies = [ - "http", - "hyper", - "hyper-util", - "rustls", - "rustls-pki-types", - "tokio", - "tokio-rustls", - "tower-service", - "webpki-roots", -] - -[[package]] -name = "hyper-util" -version = "0.1.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" -dependencies = [ - "base64 0.22.1", - "bytes", - "futures-channel", - "futures-util", - "http", - "http-body", - "hyper", - "ipnet", - "libc", - "percent-encoding", - "pin-project-lite", - "socket2", - "tokio", - "tower-service", - "tracing", -] - -[[package]] -name = "iana-time-zone" -version = "0.1.65" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" -dependencies = [ - "android_system_properties", - "core-foundation-sys", - "iana-time-zone-haiku", - "js-sys", - "log", - "wasm-bindgen", - "windows-core", -] - -[[package]] -name = "iana-time-zone-haiku" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" -dependencies = [ - "cc", -] - -[[package]] -name = "icu_collections" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" -dependencies = [ - "displaydoc", - "potential_utf", - "yoke", - "zerofrom", - "zerovec", -] - -[[package]] -name = "icu_locale_core" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" -dependencies = [ - "displaydoc", - "litemap", - "tinystr", - "writeable", - "zerovec", -] - -[[package]] -name = "icu_normalizer" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" -dependencies = [ - "icu_collections", - "icu_normalizer_data", - "icu_properties", - "icu_provider", - "smallvec", - "zerovec", -] - -[[package]] -name = "icu_normalizer_data" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" - -[[package]] -name = "icu_properties" -version = "2.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec" -dependencies = [ - "icu_collections", - "icu_locale_core", - "icu_properties_data", - "icu_provider", - "zerotrie", - "zerovec", -] - -[[package]] -name = "icu_properties_data" -version = "2.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af" - -[[package]] -name = "icu_provider" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" -dependencies = [ - "displaydoc", - "icu_locale_core", - "writeable", - "yoke", - "zerofrom", - "zerotrie", - "zerovec", -] - -[[package]] -name = "id-arena" -version = "2.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" - -[[package]] -name = "idna" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" -dependencies = [ - "idna_adapter", - "smallvec", - "utf8_iter", -] - -[[package]] -name = "idna_adapter" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" -dependencies = [ - "icu_normalizer", - "icu_properties", -] - -[[package]] -name = "image" -version = "0.25.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6506c6c10786659413faa717ceebcb8f70731c0a60cbae39795fdf114519c1a" -dependencies = [ - "bytemuck", - "byteorder-lite", - "image-webp", - "moxcms", - "num-traits", - "png", - "zune-core", - "zune-jpeg", -] - -[[package]] -name = "image-webp" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525e9ff3e1a4be2fbea1fdf0e98686a6d98b4d8f937e1bf7402245af1909e8c3" -dependencies = [ - "byteorder-lite", - "quick-error", -] - -[[package]] -name = "indexmap" -version = "2.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" -dependencies = [ - "equivalent", - "hashbrown 0.16.1", - "serde", - "serde_core", -] - -[[package]] -name = "inout" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" -dependencies = [ - "generic-array", -] - -[[package]] -name = "insta" -version = "1.46.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e82db8c87c7f1ccecb34ce0c24399b8a73081427f3c7c50a5d597925356115e4" -dependencies = [ - "console", - "once_cell", - "similar", - "tempfile", -] - -[[package]] -name = "ipnet" -version = "2.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" - -[[package]] -name = "iri-string" -version = "0.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" -dependencies = [ - "memchr", - "serde", -] - -[[package]] -name = "itoa" -version = "1.0.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" - -[[package]] -name = "js-sys" -version = "0.3.90" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14dc6f6450b3f6d4ed5b16327f38fed626d375a886159ca555bd7822c0c3a5a6" -dependencies = [ - "once_cell", - "wasm-bindgen", -] - -[[package]] -name = "json5" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b0db21af676c1ce64250b5f40f3ce2cf27e4e47cb91ed91eb6fe9350b430c1" -dependencies = [ - "pest", - "pest_derive", - "serde", -] - -[[package]] -name = "lazy_static" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" -dependencies = [ - "spin", -] - -[[package]] -name = "leb128fmt" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" - -[[package]] -name = "libc" -version = "0.2.182" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" - -[[package]] -name = "libm" -version = "0.2.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" - -[[package]] -name = "libredox" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616" -dependencies = [ - "bitflags 2.11.0", - "libc", - "redox_syscall 0.7.2", -] - -[[package]] -name = "libsqlite3-sys" -version = "0.30.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" -dependencies = [ - "cc", - "pkg-config", - "vcpkg", -] - -[[package]] -name = "linked-hash-map" -version = "0.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" - -[[package]] -name = "linux-raw-sys" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" - -[[package]] -name = "litemap" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" - -[[package]] -name = "llm-proxy" -version = "0.1.0" -dependencies = [ - "aes-gcm", - "anyhow", - "assert_cmd", - "async-stream", - "async-trait", - "axum", - "axum-extra", - "base64 0.21.7", - "bcrypt", - "chrono", - "config", - "dotenvy", - "futures", - "governor", - "headers", - "hex", - "hmac", - "image", - "insta", - "mime", - "mockito", - "rand 0.9.2", - "reqwest", - "reqwest-eventsource", - "serde", - "serde_json", - "sha2", - "sqlx", - "tempfile", - "thiserror 1.0.69", - "tiktoken-rs", - "tokio", - "tokio-test", - "toml 0.8.23", - "tower", - "tower-http", - "tracing", - "tracing-subscriber", - "uuid", -] - -[[package]] -name = "lock_api" -version = "0.4.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" -dependencies = [ - "scopeguard", -] - -[[package]] -name = "log" -version = "0.4.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" - -[[package]] -name = "lru-slab" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" - -[[package]] -name = "matchers" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" -dependencies = [ - "regex-automata", -] - -[[package]] -name = "matchit" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" - -[[package]] -name = "md-5" -version = "0.10.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" -dependencies = [ - "cfg-if", - "digest", -] - -[[package]] -name = "memchr" -version = "2.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" - -[[package]] -name = "mime" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" - -[[package]] -name = "mime_guess" -version = "2.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" -dependencies = [ - "mime", - "unicase", -] - -[[package]] -name = "minimal-lexical" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" - -[[package]] -name = "miniz_oxide" -version = "0.8.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" -dependencies = [ - "adler2", - "simd-adler32", -] - -[[package]] -name = "mio" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" -dependencies = [ - "libc", - "wasi", - "windows-sys 0.61.2", -] - -[[package]] -name = "mockito" -version = "1.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90820618712cab19cfc46b274c6c22546a82affcb3c3bdf0f29e3db8e1bb92c0" -dependencies = [ - "assert-json-diff", - "bytes", - "colored", - "futures-core", - "http", - "http-body", - "http-body-util", - "hyper", - "hyper-util", - "log", - "pin-project-lite", - "rand 0.9.2", - "regex", - "serde_json", - "serde_urlencoded", - "similar", - "tokio", -] - -[[package]] -name = "moxcms" -version = "0.7.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac9557c559cd6fc9867e122e20d2cbefc9ca29d80d027a8e39310920ed2f0a97" -dependencies = [ - "num-traits", - "pxfm", -] - -[[package]] -name = "no-std-compat" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" - -[[package]] -name = "nom" -version = "7.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" -dependencies = [ - "memchr", - "minimal-lexical", -] - -[[package]] -name = "nonzero_ext" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" - -[[package]] -name = "nu-ansi-term" -version = "0.50.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" -dependencies = [ - "windows-sys 0.61.2", -] - -[[package]] -name = "num-bigint-dig" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e661dda6640fad38e827a6d4a310ff4763082116fe217f279885c97f511bb0b7" -dependencies = [ - "lazy_static", - "libm", - "num-integer", - "num-iter", - "num-traits", - "rand 0.8.5", - "smallvec", - "zeroize", -] - -[[package]] -name = "num-integer" -version = "0.1.46" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" -dependencies = [ - "num-traits", -] - -[[package]] -name = "num-iter" -version = "0.1.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" -dependencies = [ - "autocfg", - "num-integer", - "num-traits", -] - -[[package]] -name = "num-traits" -version = "0.2.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" -dependencies = [ - "autocfg", - "libm", -] - -[[package]] -name = "once_cell" -version = "1.21.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" - -[[package]] -name = "opaque-debug" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" - -[[package]] -name = "ordered-multimap" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccd746e37177e1711c20dd619a1620f34f5c8b569c53590a72dedd5344d8924a" -dependencies = [ - "dlv-list", - "hashbrown 0.12.3", -] - -[[package]] -name = "parking" -version = "2.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" - -[[package]] -name = "parking_lot" -version = "0.12.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall 0.5.18", - "smallvec", - "windows-link", -] - -[[package]] -name = "pathdiff" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" - -[[package]] -name = "pem-rfc7468" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" -dependencies = [ - "base64ct", -] - -[[package]] -name = "percent-encoding" -version = "2.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" - -[[package]] -name = "pest" -version = "2.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0848c601009d37dfa3430c4666e147e49cdcf1b92ecd3e63657d8a5f19da662" -dependencies = [ - "memchr", - "ucd-trie", -] - -[[package]] -name = "pest_derive" -version = "2.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11f486f1ea21e6c10ed15d5a7c77165d0ee443402f0780849d1768e7d9d6fe77" -dependencies = [ - "pest", - "pest_generator", -] - -[[package]] -name = "pest_generator" -version = "2.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8040c4647b13b210a963c1ed407c1ff4fdfa01c31d6d2a098218702e6664f94f" -dependencies = [ - "pest", - "pest_meta", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "pest_meta" -version = "2.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89815c69d36021a140146f26659a81d6c2afa33d216d736dd4be5381a7362220" -dependencies = [ - "pest", - "sha2", -] - -[[package]] -name = "pin-project-lite" -version = "0.2.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" - -[[package]] -name = "pin-utils" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" - -[[package]] -name = "pkcs1" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" -dependencies = [ - "der", - "pkcs8", - "spki", -] - -[[package]] -name = "pkcs8" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" -dependencies = [ - "der", - "spki", -] - -[[package]] -name = "pkg-config" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" - -[[package]] -name = "png" -version = "0.18.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60769b8b31b2a9f263dae2776c37b1b28ae246943cf719eb6946a1db05128a61" -dependencies = [ - "bitflags 2.11.0", - "crc32fast", - "fdeflate", - "flate2", - "miniz_oxide", -] - -[[package]] -name = "polyval" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" -dependencies = [ - "cfg-if", - "cpufeatures", - "opaque-debug", - "universal-hash", -] - -[[package]] -name = "portable-atomic" -version = "1.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" - -[[package]] -name = "potential_utf" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" -dependencies = [ - "zerovec", -] - -[[package]] -name = "ppv-lite86" -version = "0.2.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" -dependencies = [ - "zerocopy", -] - -[[package]] -name = "predicates" -version = "3.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ada8f2932f28a27ee7b70dd6c1c39ea0675c55a36879ab92f3a715eaa1e63cfe" -dependencies = [ - "anstyle", - "difflib", - "predicates-core", -] - -[[package]] -name = "predicates-core" -version = "1.0.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cad38746f3166b4031b1a0d39ad9f954dd291e7854fcc0eed52ee41a0b50d144" - -[[package]] -name = "predicates-tree" -version = "1.0.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0de1b847b39c8131db0467e9df1ff60e6d0562ab8e9a16e568ad0fdb372e2f2" -dependencies = [ - "predicates-core", - "termtree", -] - -[[package]] -name = "prettyplease" -version = "0.2.37" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" -dependencies = [ - "proc-macro2", - "syn", -] - -[[package]] -name = "proc-macro2" -version = "1.0.106" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "pxfm" -version = "0.1.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7186d3822593aa4393561d186d1393b3923e9d6163d3fbfd6e825e3e6cf3e6a8" -dependencies = [ - "num-traits", -] - -[[package]] -name = "quanta" -version = "0.12.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" -dependencies = [ - "crossbeam-utils", - "libc", - "once_cell", - "raw-cpuid", - "wasi", - "web-sys", - "winapi", -] - -[[package]] -name = "quick-error" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" - -[[package]] -name = "quinn" -version = "0.11.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" -dependencies = [ - "bytes", - "cfg_aliases", - "pin-project-lite", - "quinn-proto", - "quinn-udp", - "rustc-hash 2.1.1", - "rustls", - "socket2", - "thiserror 2.0.18", - "tokio", - "tracing", - "web-time", -] - -[[package]] -name = "quinn-proto" -version = "0.11.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" -dependencies = [ - "bytes", - "getrandom 0.3.4", - "lru-slab", - "rand 0.9.2", - "ring", - "rustc-hash 2.1.1", - "rustls", - "rustls-pki-types", - "slab", - "thiserror 2.0.18", - "tinyvec", - "tracing", - "web-time", -] - -[[package]] -name = "quinn-udp" -version = "0.5.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" -dependencies = [ - "cfg_aliases", - "libc", - "once_cell", - "socket2", - "tracing", - "windows-sys 0.60.2", -] - -[[package]] -name = "quote" -version = "1.0.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "r-efi" -version = "5.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" - -[[package]] -name = "rand" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" -dependencies = [ - "libc", - "rand_chacha 0.3.1", - "rand_core 0.6.4", -] - -[[package]] -name = "rand" -version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" -dependencies = [ - "rand_chacha 0.9.0", - "rand_core 0.9.5", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core 0.6.4", -] - -[[package]] -name = "rand_chacha" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" -dependencies = [ - "ppv-lite86", - "rand_core 0.9.5", -] - -[[package]] -name = "rand_core" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" -dependencies = [ - "getrandom 0.2.17", -] - -[[package]] -name = "rand_core" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" -dependencies = [ - "getrandom 0.3.4", -] - -[[package]] -name = "raw-cpuid" -version = "11.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" -dependencies = [ - "bitflags 2.11.0", -] - -[[package]] -name = "redox_syscall" -version = "0.5.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" -dependencies = [ - "bitflags 2.11.0", -] - -[[package]] -name = "redox_syscall" -version = "0.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d94dd2f7cd932d4dc02cc8b2b50dfd38bd079a4e5d79198b99743d7fcf9a4b4" -dependencies = [ - "bitflags 2.11.0", -] - -[[package]] -name = "regex" -version = "1.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" -dependencies = [ - "aho-corasick", - "memchr", - "regex-automata", - "regex-syntax", -] - -[[package]] -name = "regex-automata" -version = "0.4.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" -dependencies = [ - "aho-corasick", - "memchr", - "regex-syntax", -] - -[[package]] -name = "regex-syntax" -version = "0.8.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" - -[[package]] -name = "reqwest" -version = "0.12.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" -dependencies = [ - "base64 0.22.1", - "bytes", - "futures-core", - "futures-util", - "http", - "http-body", - "http-body-util", - "hyper", - "hyper-rustls", - "hyper-util", - "js-sys", - "log", - "percent-encoding", - "pin-project-lite", - "quinn", - "rustls", - "rustls-pki-types", - "serde", - "serde_json", - "serde_urlencoded", - "sync_wrapper", - "tokio", - "tokio-rustls", - "tokio-util", - "tower", - "tower-http", - "tower-service", - "url", - "wasm-bindgen", - "wasm-bindgen-futures", - "wasm-streams", - "web-sys", - "webpki-roots", -] - -[[package]] -name = "reqwest-eventsource" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde" -dependencies = [ - "eventsource-stream", - "futures-core", - "futures-timer", - "mime", - "nom", - "pin-project-lite", - "reqwest", - "thiserror 1.0.69", -] - -[[package]] -name = "ring" -version = "0.17.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" -dependencies = [ - "cc", - "cfg-if", - "getrandom 0.2.17", - "libc", - "untrusted", - "windows-sys 0.52.0", -] - -[[package]] -name = "ron" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88073939a61e5b7680558e6be56b419e208420c2adb92be54921fa6b72283f1a" -dependencies = [ - "base64 0.13.1", - "bitflags 1.3.2", - "serde", -] - -[[package]] -name = "rsa" -version = "0.9.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8573f03f5883dcaebdfcf4725caa1ecb9c15b2ef50c43a07b816e06799bb12d" -dependencies = [ - "const-oid", - "digest", - "num-bigint-dig", - "num-integer", - "num-traits", - "pkcs1", - "pkcs8", - "rand_core 0.6.4", - "signature", - "spki", - "subtle", - "zeroize", -] - -[[package]] -name = "rust-ini" -version = "0.18.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6d5f2436026b4f6e79dc829837d467cc7e9a55ee40e750d716713540715a2df" -dependencies = [ - "cfg-if", - "ordered-multimap", -] - -[[package]] -name = "rustc-hash" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" - -[[package]] -name = "rustc-hash" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" - -[[package]] -name = "rustix" -version = "1.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" -dependencies = [ - "bitflags 2.11.0", - "errno", - "libc", - "linux-raw-sys", - "windows-sys 0.61.2", -] - -[[package]] -name = "rustls" -version = "0.23.37" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" -dependencies = [ - "once_cell", - "ring", - "rustls-pki-types", - "rustls-webpki", - "subtle", - "zeroize", -] - -[[package]] -name = "rustls-pki-types" -version = "1.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" -dependencies = [ - "web-time", - "zeroize", -] - -[[package]] -name = "rustls-webpki" -version = "0.103.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" -dependencies = [ - "ring", - "rustls-pki-types", - "untrusted", -] - -[[package]] -name = "rustversion" -version = "1.0.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" - -[[package]] -name = "ryu" -version = "1.0.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" - -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - -[[package]] -name = "semver" -version = "1.0.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" - -[[package]] -name = "serde" -version = "1.0.228" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" -dependencies = [ - "serde_core", - "serde_derive", -] - -[[package]] -name = "serde_core" -version = "1.0.228" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_derive" -version = "1.0.228" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "serde_json" -version = "1.0.149" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" -dependencies = [ - "itoa", - "memchr", - "serde", - "serde_core", - "zmij", -] - -[[package]] -name = "serde_path_to_error" -version = "0.1.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" -dependencies = [ - "itoa", - "serde", - "serde_core", -] - -[[package]] -name = "serde_spanned" -version = "0.6.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3" -dependencies = [ - "serde", -] - -[[package]] -name = "serde_urlencoded" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" -dependencies = [ - "form_urlencoded", - "itoa", - "ryu", - "serde", -] - -[[package]] -name = "sha1" -version = "0.10.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - -[[package]] -name = "sha2" -version = "0.10.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - -[[package]] -name = "sharded-slab" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" -dependencies = [ - "lazy_static", -] - -[[package]] -name = "shlex" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" - -[[package]] -name = "signal-hook-registry" -version = "1.4.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" -dependencies = [ - "errno", - "libc", -] - -[[package]] -name = "signature" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" -dependencies = [ - "digest", - "rand_core 0.6.4", -] - -[[package]] -name = "simd-adler32" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" - -[[package]] -name = "similar" -version = "2.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" - -[[package]] -name = "slab" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" - -[[package]] -name = "smallvec" -version = "1.15.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" -dependencies = [ - "serde", -] - -[[package]] -name = "socket2" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86f4aa3ad99f2088c990dfa82d367e19cb29268ed67c574d10d0a4bfe71f07e0" -dependencies = [ - "libc", - "windows-sys 0.60.2", -] - -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" -dependencies = [ - "lock_api", -] - -[[package]] -name = "spinning_top" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" -dependencies = [ - "lock_api", -] - -[[package]] -name = "spki" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" -dependencies = [ - "base64ct", - "der", -] - -[[package]] -name = "sqlx" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fefb893899429669dcdd979aff487bd78f4064e5e7907e4269081e0ef7d97dc" -dependencies = [ - "sqlx-core", - "sqlx-macros", - "sqlx-mysql", - "sqlx-postgres", - "sqlx-sqlite", -] - -[[package]] -name = "sqlx-core" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee6798b1838b6a0f69c007c133b8df5866302197e404e8b6ee8ed3e3a5e68dc6" -dependencies = [ - "base64 0.22.1", - "bytes", - "chrono", - "crc", - "crossbeam-queue", - "either", - "event-listener", - "futures-core", - "futures-intrusive", - "futures-io", - "futures-util", - "hashbrown 0.15.5", - "hashlink", - "indexmap", - "log", - "memchr", - "once_cell", - "percent-encoding", - "serde", - "serde_json", - "sha2", - "smallvec", - "thiserror 2.0.18", - "tokio", - "tokio-stream", - "tracing", - "url", -] - -[[package]] -name = "sqlx-macros" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2d452988ccaacfbf5e0bdbc348fb91d7c8af5bee192173ac3636b5fb6e6715d" -dependencies = [ - "proc-macro2", - "quote", - "sqlx-core", - "sqlx-macros-core", - "syn", -] - -[[package]] -name = "sqlx-macros-core" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19a9c1841124ac5a61741f96e1d9e2ec77424bf323962dd894bdb93f37d5219b" -dependencies = [ - "dotenvy", - "either", - "heck", - "hex", - "once_cell", - "proc-macro2", - "quote", - "serde", - "serde_json", - "sha2", - "sqlx-core", - "sqlx-mysql", - "sqlx-postgres", - "sqlx-sqlite", - "syn", - "tokio", - "url", -] - -[[package]] -name = "sqlx-mysql" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa003f0038df784eb8fecbbac13affe3da23b45194bd57dba231c8f48199c526" -dependencies = [ - "atoi", - "base64 0.22.1", - "bitflags 2.11.0", - "byteorder", - "bytes", - "chrono", - "crc", - "digest", - "dotenvy", - "either", - "futures-channel", - "futures-core", - "futures-io", - "futures-util", - "generic-array", - "hex", - "hkdf", - "hmac", - "itoa", - "log", - "md-5", - "memchr", - "once_cell", - "percent-encoding", - "rand 0.8.5", - "rsa", - "serde", - "sha1", - "sha2", - "smallvec", - "sqlx-core", - "stringprep", - "thiserror 2.0.18", - "tracing", - "whoami", -] - -[[package]] -name = "sqlx-postgres" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db58fcd5a53cf07c184b154801ff91347e4c30d17a3562a635ff028ad5deda46" -dependencies = [ - "atoi", - "base64 0.22.1", - "bitflags 2.11.0", - "byteorder", - "chrono", - "crc", - "dotenvy", - "etcetera", - "futures-channel", - "futures-core", - "futures-util", - "hex", - "hkdf", - "hmac", - "home", - "itoa", - "log", - "md-5", - "memchr", - "once_cell", - "rand 0.8.5", - "serde", - "serde_json", - "sha2", - "smallvec", - "sqlx-core", - "stringprep", - "thiserror 2.0.18", - "tracing", - "whoami", -] - -[[package]] -name = "sqlx-sqlite" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2d12fe70b2c1b4401038055f90f151b78208de1f9f89a7dbfd41587a10c3eea" -dependencies = [ - "atoi", - "chrono", - "flume", - "futures-channel", - "futures-core", - "futures-executor", - "futures-intrusive", - "futures-util", - "libsqlite3-sys", - "log", - "percent-encoding", - "serde", - "serde_urlencoded", - "sqlx-core", - "thiserror 2.0.18", - "tracing", - "url", -] - -[[package]] -name = "stable_deref_trait" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" - -[[package]] -name = "stringprep" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1" -dependencies = [ - "unicode-bidi", - "unicode-normalization", - "unicode-properties", -] - -[[package]] -name = "subtle" -version = "2.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" - -[[package]] -name = "syn" -version = "2.0.117" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "sync_wrapper" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" -dependencies = [ - "futures-core", -] - -[[package]] -name = "synstructure" -version = "0.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "tempfile" -version = "3.26.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82a72c767771b47409d2345987fda8628641887d5466101319899796367354a0" -dependencies = [ - "fastrand", - "getrandom 0.4.1", - "once_cell", - "rustix", - "windows-sys 0.61.2", -] - -[[package]] -name = "termtree" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" - -[[package]] -name = "thiserror" -version = "1.0.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" -dependencies = [ - "thiserror-impl 1.0.69", -] - -[[package]] -name = "thiserror" -version = "2.0.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" -dependencies = [ - "thiserror-impl 2.0.18", -] - -[[package]] -name = "thiserror-impl" -version = "1.0.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "thiserror-impl" -version = "2.0.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "thread_local" -version = "1.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "tiktoken-rs" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a19830747d9034cd9da43a60eaa8e552dfda7712424aebf187b7a60126bae0d" -dependencies = [ - "anyhow", - "base64 0.22.1", - "bstr", - "fancy-regex", - "lazy_static", - "regex", - "rustc-hash 1.1.0", -] - -[[package]] -name = "tinystr" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" -dependencies = [ - "displaydoc", - "zerovec", -] - -[[package]] -name = "tinyvec" -version = "1.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" -dependencies = [ - "tinyvec_macros", -] - -[[package]] -name = "tinyvec_macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" - -[[package]] -name = "tokio" -version = "1.49.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72a2903cd7736441aac9df9d7688bd0ce48edccaadf181c3b90be801e81d3d86" -dependencies = [ - "bytes", - "libc", - "mio", - "parking_lot", - "pin-project-lite", - "signal-hook-registry", - "socket2", - "tokio-macros", - "windows-sys 0.61.2", -] - -[[package]] -name = "tokio-macros" -version = "2.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "tokio-rustls" -version = "0.26.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" -dependencies = [ - "rustls", - "tokio", -] - -[[package]] -name = "tokio-stream" -version = "0.1.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32da49809aab5c3bc678af03902d4ccddea2a87d028d86392a4b1560c6906c70" -dependencies = [ - "futures-core", - "pin-project-lite", - "tokio", -] - -[[package]] -name = "tokio-test" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f6d24790a10a7af737693a3e8f1d03faef7e6ca0cc99aae5066f533766de545" -dependencies = [ - "futures-core", - "tokio", - "tokio-stream", -] - -[[package]] -name = "tokio-tungstenite" -version = "0.28.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857" -dependencies = [ - "futures-util", - "log", - "tokio", - "tungstenite", -] - -[[package]] -name = "tokio-util" -version = "0.7.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" -dependencies = [ - "bytes", - "futures-core", - "futures-sink", - "pin-project-lite", - "tokio", -] - -[[package]] -name = "toml" -version = "0.5.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234" -dependencies = [ - "serde", -] - -[[package]] -name = "toml" -version = "0.8.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" -dependencies = [ - "serde", - "serde_spanned", - "toml_datetime", - "toml_edit", -] - -[[package]] -name = "toml_datetime" -version = "0.6.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" -dependencies = [ - "serde", -] - -[[package]] -name = "toml_edit" -version = "0.22.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" -dependencies = [ - "indexmap", - "serde", - "serde_spanned", - "toml_datetime", - "toml_write", - "winnow", -] - -[[package]] -name = "toml_write" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" - -[[package]] -name = "tower" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" -dependencies = [ - "futures-core", - "futures-util", - "pin-project-lite", - "sync_wrapper", - "tokio", - "tower-layer", - "tower-service", - "tracing", -] - -[[package]] -name = "tower-http" -version = "0.6.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" -dependencies = [ - "async-compression", - "bitflags 2.11.0", - "bytes", - "futures-core", - "futures-util", - "http", - "http-body", - "http-body-util", - "http-range-header", - "httpdate", - "iri-string", - "mime", - "mime_guess", - "percent-encoding", - "pin-project-lite", - "tokio", - "tokio-util", - "tower", - "tower-layer", - "tower-service", - "tracing", -] - -[[package]] -name = "tower-layer" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" - -[[package]] -name = "tower-service" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" - -[[package]] -name = "tracing" -version = "0.1.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" -dependencies = [ - "log", - "pin-project-lite", - "tracing-attributes", - "tracing-core", -] - -[[package]] -name = "tracing-attributes" -version = "0.1.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "tracing-core" -version = "0.1.36" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" -dependencies = [ - "once_cell", - "valuable", -] - -[[package]] -name = "tracing-log" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" -dependencies = [ - "log", - "once_cell", - "tracing-core", -] - -[[package]] -name = "tracing-serde" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "704b1aeb7be0d0a84fc9828cae51dab5970fee5088f83d1dd7ee6f6246fc6ff1" -dependencies = [ - "serde", - "tracing-core", -] - -[[package]] -name = "tracing-subscriber" -version = "0.3.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" -dependencies = [ - "matchers", - "nu-ansi-term", - "once_cell", - "regex-automata", - "serde", - "serde_json", - "sharded-slab", - "smallvec", - "thread_local", - "tracing", - "tracing-core", - "tracing-log", - "tracing-serde", -] - -[[package]] -name = "try-lock" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" - -[[package]] -name = "tungstenite" -version = "0.28.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442" -dependencies = [ - "bytes", - "data-encoding", - "http", - "httparse", - "log", - "rand 0.9.2", - "sha1", - "thiserror 2.0.18", - "utf-8", -] - -[[package]] -name = "typenum" -version = "1.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" - -[[package]] -name = "ucd-trie" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2896d95c02a80c6d6a5d6e953d479f5ddf2dfdb6a244441010e373ac0fb88971" - -[[package]] -name = "unicase" -version = "2.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" - -[[package]] -name = "unicode-bidi" -version = "0.3.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" - -[[package]] -name = "unicode-ident" -version = "1.0.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" - -[[package]] -name = "unicode-normalization" -version = "0.1.25" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8" -dependencies = [ - "tinyvec", -] - -[[package]] -name = "unicode-properties" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7df058c713841ad818f1dc5d3fd88063241cc61f49f5fbea4b951e8cf5a8d71d" - -[[package]] -name = "unicode-xid" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" - -[[package]] -name = "universal-hash" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" -dependencies = [ - "crypto-common", - "subtle", -] - -[[package]] -name = "untrusted" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" - -[[package]] -name = "url" -version = "2.5.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" -dependencies = [ - "form_urlencoded", - "idna", - "percent-encoding", - "serde", -] - -[[package]] -name = "utf-8" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" - -[[package]] -name = "utf8_iter" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" - -[[package]] -name = "uuid" -version = "1.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b672338555252d43fd2240c714dc444b8c6fb0a5c5335e65a07bba7742735ddb" -dependencies = [ - "getrandom 0.4.1", - "js-sys", - "serde_core", - "wasm-bindgen", -] - -[[package]] -name = "valuable" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" - -[[package]] -name = "vcpkg" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" - -[[package]] -name = "version_check" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" - -[[package]] -name = "wait-timeout" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09ac3b126d3914f9849036f826e054cbabdc8519970b8998ddaf3b5bd3c65f11" -dependencies = [ - "libc", -] - -[[package]] -name = "want" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" -dependencies = [ - "try-lock", -] - -[[package]] -name = "wasi" -version = "0.11.1+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" - -[[package]] -name = "wasip2" -version = "1.0.2+wasi-0.2.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" -dependencies = [ - "wit-bindgen", -] - -[[package]] -name = "wasip3" -version = "0.4.0+wasi-0.3.0-rc-2026-01-06" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" -dependencies = [ - "wit-bindgen", -] - -[[package]] -name = "wasite" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" - -[[package]] -name = "wasm-bindgen" -version = "0.2.113" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60722a937f594b7fde9adb894d7c092fc1bb6612897c46368d18e7a20208eff2" -dependencies = [ - "cfg-if", - "once_cell", - "rustversion", - "wasm-bindgen-macro", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-futures" -version = "0.4.63" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a89f4650b770e4521aa6573724e2aed4704372151bd0de9d16a3bbabb87441a" -dependencies = [ - "cfg-if", - "futures-util", - "js-sys", - "once_cell", - "wasm-bindgen", - "web-sys", -] - -[[package]] -name = "wasm-bindgen-macro" -version = "0.2.113" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fac8c6395094b6b91c4af293f4c79371c163f9a6f56184d2c9a85f5a95f3950" -dependencies = [ - "quote", - "wasm-bindgen-macro-support", -] - -[[package]] -name = "wasm-bindgen-macro-support" -version = "0.2.113" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab3fabce6159dc20728033842636887e4877688ae94382766e00b180abac9d60" -dependencies = [ - "bumpalo", - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-shared" -version = "0.2.113" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de0e091bdb824da87dc01d967388880d017a0a9bc4f3bdc0d86ee9f9336e3bb5" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "wasm-encoder" -version = "0.244.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" -dependencies = [ - "leb128fmt", - "wasmparser", -] - -[[package]] -name = "wasm-metadata" -version = "0.244.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" -dependencies = [ - "anyhow", - "indexmap", - "wasm-encoder", - "wasmparser", -] - -[[package]] -name = "wasm-streams" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" -dependencies = [ - "futures-util", - "js-sys", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", -] - -[[package]] -name = "wasmparser" -version = "0.244.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" -dependencies = [ - "bitflags 2.11.0", - "hashbrown 0.15.5", - "indexmap", - "semver", -] - -[[package]] -name = "web-sys" -version = "0.3.90" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "705eceb4ce901230f8625bd1d665128056ccbe4b7408faa625eec1ba80f59a97" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - -[[package]] -name = "web-time" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - -[[package]] -name = "webpki-roots" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22cfaf3c063993ff62e73cb4311efde4db1efb31ab78a3e5c457939ad5cc0bed" -dependencies = [ - "rustls-pki-types", -] - -[[package]] -name = "whoami" -version = "1.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d4a4db5077702ca3015d3d02d74974948aba2ad9e12ab7df718ee64ccd7e97d" -dependencies = [ - "libredox", - "wasite", -] - -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - -[[package]] -name = "windows-core" -version = "0.62.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" -dependencies = [ - "windows-implement", - "windows-interface", - "windows-link", - "windows-result", - "windows-strings", -] - -[[package]] -name = "windows-implement" -version = "0.60.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "windows-interface" -version = "0.59.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "windows-link" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" - -[[package]] -name = "windows-result" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" -dependencies = [ - "windows-link", -] - -[[package]] -name = "windows-strings" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" -dependencies = [ - "windows-link", -] - -[[package]] -name = "windows-sys" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" -dependencies = [ - "windows-targets 0.48.5", -] - -[[package]] -name = "windows-sys" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" -dependencies = [ - "windows-targets 0.52.6", -] - -[[package]] -name = "windows-sys" -version = "0.59.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" -dependencies = [ - "windows-targets 0.52.6", -] - -[[package]] -name = "windows-sys" -version = "0.60.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" -dependencies = [ - "windows-targets 0.53.5", -] - -[[package]] -name = "windows-sys" -version = "0.61.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" -dependencies = [ - "windows-link", -] - -[[package]] -name = "windows-targets" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" -dependencies = [ - "windows_aarch64_gnullvm 0.48.5", - "windows_aarch64_msvc 0.48.5", - "windows_i686_gnu 0.48.5", - "windows_i686_msvc 0.48.5", - "windows_x86_64_gnu 0.48.5", - "windows_x86_64_gnullvm 0.48.5", - "windows_x86_64_msvc 0.48.5", -] - -[[package]] -name = "windows-targets" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" -dependencies = [ - "windows_aarch64_gnullvm 0.52.6", - "windows_aarch64_msvc 0.52.6", - "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm 0.52.6", - "windows_i686_msvc 0.52.6", - "windows_x86_64_gnu 0.52.6", - "windows_x86_64_gnullvm 0.52.6", - "windows_x86_64_msvc 0.52.6", -] - -[[package]] -name = "windows-targets" -version = "0.53.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" -dependencies = [ - "windows-link", - "windows_aarch64_gnullvm 0.53.1", - "windows_aarch64_msvc 0.53.1", - "windows_i686_gnu 0.53.1", - "windows_i686_gnullvm 0.53.1", - "windows_i686_msvc 0.53.1", - "windows_x86_64_gnu 0.53.1", - "windows_x86_64_gnullvm 0.53.1", - "windows_x86_64_msvc 0.53.1", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" - -[[package]] -name = "windows_i686_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" - -[[package]] -name = "windows_i686_gnu" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" - -[[package]] -name = "windows_i686_gnu" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" - -[[package]] -name = "windows_i686_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" - -[[package]] -name = "windows_i686_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" - -[[package]] -name = "windows_i686_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" - -[[package]] -name = "windows_i686_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" - -[[package]] -name = "windows_i686_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" - -[[package]] -name = "winnow" -version = "0.7.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" -dependencies = [ - "memchr", -] - -[[package]] -name = "wit-bindgen" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" -dependencies = [ - "wit-bindgen-rust-macro", -] - -[[package]] -name = "wit-bindgen-core" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" -dependencies = [ - "anyhow", - "heck", - "wit-parser", -] - -[[package]] -name = "wit-bindgen-rust" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" -dependencies = [ - "anyhow", - "heck", - "indexmap", - "prettyplease", - "syn", - "wasm-metadata", - "wit-bindgen-core", - "wit-component", -] - -[[package]] -name = "wit-bindgen-rust-macro" -version = "0.51.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" -dependencies = [ - "anyhow", - "prettyplease", - "proc-macro2", - "quote", - "syn", - "wit-bindgen-core", - "wit-bindgen-rust", -] - -[[package]] -name = "wit-component" -version = "0.244.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" -dependencies = [ - "anyhow", - "bitflags 2.11.0", - "indexmap", - "log", - "serde", - "serde_derive", - "serde_json", - "wasm-encoder", - "wasm-metadata", - "wasmparser", - "wit-parser", -] - -[[package]] -name = "wit-parser" -version = "0.244.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" -dependencies = [ - "anyhow", - "id-arena", - "indexmap", - "log", - "semver", - "serde", - "serde_derive", - "serde_json", - "unicode-xid", - "wasmparser", -] - -[[package]] -name = "writeable" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" - -[[package]] -name = "yaml-rust" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56c1936c4cc7a1c9ab21a1ebb602eb942ba868cbd44a99cb7cdc5892335e1c85" -dependencies = [ - "linked-hash-map", -] - -[[package]] -name = "yoke" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" -dependencies = [ - "stable_deref_trait", - "yoke-derive", - "zerofrom", -] - -[[package]] -name = "yoke-derive" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "synstructure", -] - -[[package]] -name = "zerocopy" -version = "0.8.39" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db6d35d663eadb6c932438e763b262fe1a70987f9ae936e60158176d710cae4a" -dependencies = [ - "zerocopy-derive", -] - -[[package]] -name = "zerocopy-derive" -version = "0.8.39" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4122cd3169e94605190e77839c9a40d40ed048d305bfdc146e7df40ab0f3e517" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "zerofrom" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" -dependencies = [ - "zerofrom-derive", -] - -[[package]] -name = "zerofrom-derive" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "synstructure", -] - -[[package]] -name = "zeroize" -version = "1.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" - -[[package]] -name = "zerotrie" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" -dependencies = [ - "displaydoc", - "yoke", - "zerofrom", -] - -[[package]] -name = "zerovec" -version = "0.11.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" -dependencies = [ - "yoke", - "zerofrom", - "zerovec-derive", -] - -[[package]] -name = "zerovec-derive" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "zmij" -version = "1.0.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" - -[[package]] -name = "zune-core" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb8a0807f7c01457d0379ba880ba6322660448ddebc890ce29bb64da71fb40f9" - -[[package]] -name = "zune-jpeg" -version = "0.5.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "410e9ecef634c709e3831c2cfdb8d9c32164fae1c67496d5b68fff728eec37fe" -dependencies = [ - "zune-core", -] diff --git a/Cargo.toml b/Cargo.toml deleted file mode 100644 index f2d8f48b..00000000 --- a/Cargo.toml +++ /dev/null @@ -1,75 +0,0 @@ -[package] -name = "llm-proxy" -version = "0.1.0" -edition = "2024" -rust-version = "1.87" -description = "Unified LLM proxy gateway supporting OpenAI, Gemini, DeepSeek, and Grok with token tracking and cost calculation" -authors = ["newkirk"] -license = "MIT OR Apache-2.0" -repository = "" - -[dependencies] -# ========== Web Framework & Async Runtime ========== -axum = { version = "0.8", features = ["macros", "ws"] } -tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "net", "time", "signal", "fs"] } -tower = "0.5" -tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip", "fs", "set-header", "limit"] } -governor = "0.7" - -# ========== HTTP Clients ========== -reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] } -tiktoken-rs = "0.9" - -# ========== Database & ORM ========== -sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite", "macros", "migrate", "chrono"] } - -# ========== Authentication & Middleware ========== -axum-extra = { version = "0.12", features = ["typed-header"] } -headers = "0.4" - -# ========== Configuration Management ========== -config = "0.13" -dotenvy = "0.15" -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -toml = "0.8" - -# ========== Logging & Monitoring ========== -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } - -# ========== Multimodal & Image Processing ========== -base64 = "0.21" -image = { version = "0.25", default-features = false, features = ["jpeg", "png", "webp"] } -mime = "0.3" - -# ========== Error Handling & Utilities ========== -anyhow = "1.0" -thiserror = "1.0" -bcrypt = "0.15" -aes-gcm = "0.10" -hmac = "0.12" -sha2 = "0.10" -chrono = { version = "0.4", features = ["serde"] } -uuid = { version = "1.0", features = ["v4", "serde"] } -futures = "0.3" -async-trait = "0.1" -async-stream = "0.3" -reqwest-eventsource = "0.6" -rand = "0.9" -hex = "0.4" - -[dev-dependencies] -tokio-test = "0.4" -mockito = "1.0" -tempfile = "3.10" -assert_cmd = "2.0" -insta = "1.39" -anyhow = "1.0" - -[profile.release] -opt-level = 3 -lto = true -codegen-units = 1 -strip = true -panic = "abort" diff --git a/Dockerfile b/Dockerfile index 8c349f9a..ad1fe879 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,35 +1,34 @@ -# ── Build stage ────────────────────────────────────────────── -FROM rust:1-bookworm AS builder +# Build stage +FROM golang:1.22-alpine AS builder WORKDIR /app -# Cache dependency build -COPY Cargo.toml Cargo.lock ./ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && \ - cargo build --release && \ - rm -rf src +# Copy go mod and sum files +COPY go.mod go.sum ./ +RUN go mod download -# Build the actual binary -COPY src/ src/ -RUN touch src/main.rs && cargo build --release +# Copy the source code +COPY . . -# ── Runtime stage ──────────────────────────────────────────── -FROM debian:bookworm-slim +# Build the application +RUN CGO_ENABLED=0 GOOS=linux go build -o llm-proxy ./cmd/llm-proxy -RUN apt-get update && \ - apt-get install -y --no-install-recommends ca-certificates && \ - rm -rf /var/lib/apt/lists/* +# Final stage +FROM alpine:latest + +RUN apk --no-cache add ca-certificates tzdata WORKDIR /app -COPY --from=builder /app/target/release/llm-proxy /app/llm-proxy -COPY static/ /app/static/ +# Copy the binary from the builder stage +COPY --from=builder /app/llm-proxy . +COPY --from=builder /app/static ./static -# Default config location -VOLUME ["/app/config", "/app/data"] +# Create data directory +RUN mkdir -p /app/data +# Expose port EXPOSE 8080 -ENV RUST_LOG=info - -ENTRYPOINT ["/app/llm-proxy"] +# Run the application +CMD ["./llm-proxy"] diff --git a/README.md b/README.md index 5cff16fe..30fa4cb0 100644 --- a/README.md +++ b/README.md @@ -1,114 +1,108 @@ # LLM Proxy Gateway -A unified, high-performance LLM proxy gateway built in Rust. It provides a single OpenAI-compatible API to access multiple providers (OpenAI, Gemini, DeepSeek, Grok, Ollama) with built-in token tracking, real-time cost calculation, multi-user authentication, and a management dashboard. +A unified, high-performance LLM proxy gateway built in Go. It provides a single OpenAI-compatible API to access multiple providers (OpenAI, Gemini, DeepSeek, Grok, Ollama) with built-in token tracking, real-time cost calculation, multi-user authentication, and a management dashboard. ## Features - **Unified API:** OpenAI-compatible `/v1/chat/completions` and `/v1/models` endpoints. - **Multi-Provider Support:** - **OpenAI:** GPT-4o, GPT-4o Mini, o1, o3 reasoning models. - - **Google Gemini:** Gemini 2.0 Flash, Pro, and vision models. - - **DeepSeek:** DeepSeek Chat and Reasoner models. + - **Google Gemini:** Gemini 2.0 Flash, Pro, and vision models (with native CoT support). + - **DeepSeek:** DeepSeek Chat and Reasoner (R1) models. - **xAI Grok:** Grok-beta models. - **Ollama:** Local LLMs running on your network. - **Observability & Tracking:** - - **Real-time Costing:** Fetches live pricing and context specs from `models.dev` on startup. - - **Token Counting:** Precise estimation using `tiktoken-rs`. - - **Database Logging:** Every request logged to SQLite for historical analysis. - - **Streaming Support:** Full SSE (Server-Sent Events) with `[DONE]` termination for client compatibility. + - **Asynchronous Logging:** Non-blocking request logging to SQLite using background workers. + - **Token Counting:** Precise estimation and tracking of prompt, completion, and reasoning tokens. + - **Database Persistence:** Every request logged to SQLite for historical analysis and dashboard analytics. + - **Streaming Support:** Full SSE (Server-Sent Events) support for all providers. - **Multimodal (Vision):** Image processing (Base64 and remote URLs) across compatible providers. - **Multi-User Access Control:** - **Admin Role:** Full access to all dashboard features, user management, and system configuration. - **Viewer Role:** Read-only access to usage analytics, costs, and monitoring. - **Client API Keys:** Create and manage multiple client tokens for external integrations. - **Reliability:** - - **Circuit Breaking:** Automatically protects when providers are down. - - **Rate Limiting:** Per-client and global rate limits. - - **Cache-Aware Costing:** Tracks cache hit/miss tokens for accurate billing. + - **Circuit Breaking:** Automatically protects when providers are down (coming soon). + - **Rate Limiting:** Per-client and global rate limits (coming soon). ## Security LLM Proxy is designed with security in mind: -- **HMAC Session Tokens:** Management dashboard sessions are secured using HMAC-SHA256 signed tokens. -- **Encrypted Provider Keys:** Sensitive LLM provider API keys are stored encrypted (AES-256-GCM) in the database. -- **Session Refresh:** Activity-based session extension prevents session hijacking while maintaining user convenience. -- **XSS Prevention:** Standardized frontend escaping using `window.api.escapeHtml`. +- **Signed Session Tokens:** Management dashboard sessions are secured using HMAC-SHA256 signed tokens. +- **Encrypted Storage:** Support for encrypted provider API keys in the database. +- **Auth Middleware:** Secure client authentication via database-backed API keys. -**Note:** You must define a `SESSION_SECRET` in your `.env` file for secure session signing. +**Note:** You must define an `LLM_PROXY__ENCRYPTION_KEY` in your `.env` file for secure session signing and encryption. ## Tech Stack -- **Runtime:** Rust with Tokio. -- **Web Framework:** Axum. -- **Database:** SQLx with SQLite. -- **Frontend:** Vanilla JS/CSS with Chart.js for visualizations. +- **Runtime:** Go 1.22+ +- **Web Framework:** Gin Gonic +- **Database:** sqlx with SQLite (CGO-free via `modernc.org/sqlite`) +- **Frontend:** Vanilla JS/CSS with Chart.js for visualizations ## Getting Started ### Prerequisites -- Rust (1.80+) -- SQLite3 +- Go (1.22+) +- SQLite3 (optional, driver is built-in) - Docker (optional, for containerized deployment) ### Quick Start 1. Clone and build: ```bash - git clone ssh://git.dustin.coffee:2222/hobokenchicken/llm-proxy.git + git clone cd llm-proxy - cargo build --release + go build -o llm-proxy ./cmd/llm-proxy ``` 2. Configure environment: ```bash cp .env.example .env - # Edit .env and add your API keys: - # SESSION_SECRET=... (Generate a strong random secret) + # Edit .env and add your configuration: + # LLM_PROXY__ENCRYPTION_KEY=... (32-byte hex or base64 string) # OPENAI_API_KEY=sk-... # GEMINI_API_KEY=AIza... ``` 3. Run the proxy: ```bash - cargo run --release + ./llm-proxy ``` -The server starts on `http://localhost:8080` by default. +The server starts on `http://0.0.0.0:8080` by default. ### Deployment (Docker) -A multi-stage `Dockerfile` is provided for efficient deployment: - ```bash # Build the container docker build -t llm-proxy . # Run the container docker run -p 8080:8080 \ - -e SESSION_SECRET=your-secure-secret \ + -e LLM_PROXY__ENCRYPTION_KEY=your-secure-key \ -v ./data:/app/data \ llm-proxy ``` ## Management Dashboard -Access the dashboard at `http://localhost:8080`. The dashboard architecture has been refactored into modular sub-components for better maintainability: +Access the dashboard at `http://localhost:8080`. -- **Auth (`/api/auth`):** Login, session management, and password changes. -- **Usage (`/api/usage`):** Summary stats, time-series analytics, and provider breakdown. -- **Clients (`/api/clients`):** API key management and per-client usage tracking. -- **Providers (`/api/providers`):** Provider configuration, status monitoring, and connection testing. -- **System (`/api/system`):** Health metrics, live logs, database backups, and global settings. +- **Auth:** Login, session management, and status tracking. +- **Usage:** Summary stats, time-series analytics, and provider breakdown. +- **Clients:** API key management and per-client usage tracking. +- **Providers:** Provider configuration and status monitoring. +- **Users:** Admin-only user management for dashboard access. - **Monitoring:** Live request stream via WebSocket. ### Default Credentials - **Username:** `admin` -- **Password:** `admin123` - -Change the admin password in the dashboard after first login! +- **Password:** `admin` (You will be prompted to change this or should change it manually in the dashboard) ## API Usage @@ -131,4 +125,4 @@ response = client.chat.completions.create( ## License -MIT OR Apache-2.0 +MIT diff --git a/RUST_BACKEND_REVIEW.md b/RUST_BACKEND_REVIEW.md deleted file mode 100644 index 544e7ac4..00000000 --- a/RUST_BACKEND_REVIEW.md +++ /dev/null @@ -1,58 +0,0 @@ -# LLM Proxy Rust Backend Code Review Report - -## Executive Summary -This code review examines the `llm-proxy` Rust backend, focusing on architectural soundness, performance characteristics, and adherence to Rust idioms. The codebase demonstrates solid engineering with well-structured modular design, comprehensive error handling, and thoughtful async patterns. However, several areas require attention for production readiness, particularly around thread safety, memory efficiency, and error recovery. - -## 1. Core Proxy Logic Review -### Strengths -- Clean provider abstraction (`Provider` trait). -- Streaming support with `AggregatingStream` for token counting. -- Model mapping and caching system. - -### Issues Found -- **Provider Manager Thread Safety Risk:** O(n) lookups using `Vec` with `RwLock`. Use `DashMap` instead. -- **Streaming Memory Inefficiency:** Accumulates complete response content in memory. -- **Model Registry Cache Invalidation:** No strategy when config changes via dashboard. - -## 2. State Management Review -- **Token Bucket Algorithm Flaw:** Custom implementation lacks thread-safe refill sync. Use `governor` crate. -- **Broadcast Channel Unbounded Growth Risk:** Fixed-size (100) may drop messages. -- **Database Connection Pool Contention:** SQLite connections shared without differentiation. - -## 3. Error Handling Review -- **Error Recovery Missing:** No circuit breaker or retry logic for provider calls. -- **Stream Error Logging Gap:** Stream errors swallowed without logging partial usage. - -## 4. Rust Idioms and Performance -- **Unnecessary String Cloning:** Frequent cloning in authentication hot paths. -- **JSON Parsing Inefficiency:** Multiple passes with `serde_json::Value`. Use typed structs. -- **Missing `#[derive(Copy)]`:** For small enums like `CircuitState`. - -## 5. Async Performance -- **Blocking Calls:** Token estimation may block async runtime. -- **Missing Connection Timeouts:** Only overall timeout, no separate read/write timeouts. -- **Unbounded Task Spawn:** For client usage updates under load. - -## 6. Security Considerations -- **Token Leakage:** Redact tokens in `Debug` and `Display` impls. -- **No Request Size Limits:** Vulnerable to memory exhaustion. - -## 7. Testability -- **Mocking Difficulty:** Tight coupling to concrete provider implementations. -- **Missing Integration Tests:** No E2E tests for streaming. - -## 8. Summary of Critical Actions -### High Priority -1. Replace custom token bucket with `governor` crate. -2. Fix provider lookup O(n) scaling issue. -3. Implement proper error recovery with retries. -4. Add request size limits and timeout configurations. - -### Medium Priority -1. Reduce string cloning in hot paths. -2. Implement cache invalidation for model configs. -3. Add connection pooling separation. -4. Improve streaming memory efficiency. - ---- -*Review conducted by: Senior Principal Engineer (Code Reviewer)* diff --git a/TODO.md b/TODO.md new file mode 100644 index 00000000..5d608255 --- /dev/null +++ b/TODO.md @@ -0,0 +1,48 @@ +# Migration TODO List + +## Completed Tasks +- [x] Initial Go project setup +- [x] Database schema & migrations +- [x] Configuration loader (Viper) +- [x] Auth Middleware +- [x] Basic Provider implementations (OpenAI, Gemini, DeepSeek, Grok) +- [x] Streaming Support (SSE & Gemini custom streaming) +- [x] Move Rust files to `rust_backup` +- [x] Enhanced `helpers.go` for Multimodal & Tool Calling (OpenAI compatible) +- [x] Enhanced `server.go` for robust request conversion +- [x] Dashboard Management APIs (Clients, Tokens, Users, Providers) +- [x] Dashboard Analytics & Usage Summary +- [x] WebSocket for real-time dashboard updates + +## Feature Parity Checklist (High Priority) + +### OpenAI Provider +- [x] Tool Calling +- [x] Multimodal (Images) support +- [ ] Reasoning Content (CoT) support for `o1`, `o3` (need to ensure it's parsed in responses) +- [ ] Support for `/v1/responses` API (required for some gpt-5/o1 models) + +### Gemini Provider +- [x] Tool Calling (mapping to Gemini format) +- [x] Multimodal (Images) support +- [x] Reasoning/Thought support +- [x] Handle Tool Response role in unified format + +### DeepSeek Provider +- [x] Reasoning Content (CoT) support +- [x] Parameter sanitization for `deepseek-reasoner` +- [x] Tool Calling support + +### Grok Provider +- [x] Tool Calling support +- [x] Multimodal support + +## Infrastructure & Middleware +- [ ] Implement Request Logging to SQLite (asynchronous) +- [ ] Implement Rate Limiting (`golang.org/x/time/rate`) +- [ ] Implement Circuit Breaker (`github.com/sony/gobreaker`) +- [ ] Implement Model Cost Calculation logic + +## Verification +- [ ] Unit tests for feature-specific mapping (CoT, Tools, Images) +- [ ] Integration tests with live LLM APIs diff --git a/clippy.toml b/clippy.toml deleted file mode 100644 index ba34bf46..00000000 --- a/clippy.toml +++ /dev/null @@ -1 +0,0 @@ -too-many-arguments-threshold = 8 diff --git a/cmd/llm-proxy/main.go b/cmd/llm-proxy/main.go new file mode 100644 index 00000000..ace43654 --- /dev/null +++ b/cmd/llm-proxy/main.go @@ -0,0 +1,39 @@ +package main + +import ( + "log" + + "llm-proxy/internal/config" + "llm-proxy/internal/db" + "llm-proxy/internal/server" + + "github.com/joho/godotenv" +) + +func main() { + // Load environment variables + if err := godotenv.Load(); err != nil { + log.Println("No .env file found") + } + + // Load configuration + cfg, err := config.Load() + if err != nil { + log.Fatalf("Failed to load configuration: %v", err) + } + + // Initialize database + database, err := db.Init(cfg.Database.Path) + if err != nil { + log.Fatalf("Failed to initialize database: %v", err) + } + + // Initialize server + s := server.NewServer(cfg, database) + + // Run server + log.Printf("Starting LLM Proxy on %s:%d", cfg.Server.Host, cfg.Server.Port) + if err := s.Run(); err != nil { + log.Fatalf("Server failed: %v", err) + } +} diff --git a/data/llm_proxy.dbLLM_PROXY__ENCRYPTION_KEY=69879f5b7913ba169982190526ae213e830b3f1f33e785ef2b68cf48c7853fcd b/data/llm_proxy.dbLLM_PROXY__ENCRYPTION_KEY=69879f5b7913ba169982190526ae213e830b3f1f33e785ef2b68cf48c7853fcd deleted file mode 100644 index f9b058c5..00000000 Binary files a/data/llm_proxy.dbLLM_PROXY__ENCRYPTION_KEY=69879f5b7913ba169982190526ae213e830b3f1f33e785ef2b68cf48c7853fcd and /dev/null differ diff --git a/deployment.md b/deployment.md index e8a57711..976dcaf1 100644 --- a/deployment.md +++ b/deployment.md @@ -1,322 +1,52 @@ -# LLM Proxy Gateway - Deployment Guide +# Deployment Guide (Go) -## Overview -A unified LLM proxy gateway supporting OpenAI, Google Gemini, DeepSeek, and xAI Grok with token tracking, cost calculation, and admin dashboard. +This guide covers deploying the Go-based LLM Proxy Gateway. -## System Requirements -- **CPU**: 2 cores minimum -- **RAM**: 512MB minimum (1GB recommended) -- **Storage**: 10GB minimum -- **OS**: Linux (tested on Arch Linux, Ubuntu, Debian) -- **Runtime**: Rust 1.70+ with Cargo +## Environment Setup -## Deployment Options - -### Option 1: Docker (Recommended) -```dockerfile -FROM rust:1.70-alpine as builder -WORKDIR /app -COPY . . -RUN cargo build --release - -FROM alpine:latest -RUN apk add --no-cache libgcc -COPY --from=builder /app/target/release/llm-proxy /usr/local/bin/ -COPY --from=builder /app/static /app/static -WORKDIR /app -EXPOSE 8080 -CMD ["llm-proxy"] -``` - -### Option 2: Systemd Service (Bare Metal/LXC) -```ini -# /etc/systemd/system/llm-proxy.service -[Unit] -Description=LLM Proxy Gateway -After=network.target - -[Service] -Type=simple -User=llmproxy -Group=llmproxy -WorkingDirectory=/opt/llm-proxy -ExecStart=/opt/llm-proxy/llm-proxy -Restart=always -RestartSec=10 -Environment="RUST_LOG=info" -Environment="LLM_PROXY__SERVER__PORT=8080" -Environment="LLM_PROXY__SERVER__AUTH_TOKENS=sk-test-123,sk-test-456" - -[Install] -WantedBy=multi-user.target -``` - -### Option 3: LXC Container (Proxmox) -1. Create Alpine Linux LXC container -2. Install Rust: `apk add rust cargo` -3. Copy application files -4. Build: `cargo build --release` -5. Run: `./target/release/llm-proxy` - -## Configuration - -### Environment Variables -```bash -# Required API Keys -OPENAI_API_KEY=sk-... -GEMINI_API_KEY=AIza... -DEEPSEEK_API_KEY=sk-... -GROK_API_KEY=gk-... # Optional - -# Server Configuration (with LLM_PROXY__ prefix) -LLM_PROXY__SERVER__PORT=8080 -LLM_PROXY__SERVER__HOST=0.0.0.0 -LLM_PROXY__SERVER__AUTH_TOKENS=sk-test-123,sk-test-456 - -# Database Configuration -LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db -LLM_PROXY__DATABASE__MAX_CONNECTIONS=10 - -# Provider Configuration -LLM_PROXY__PROVIDERS__OPENAI__ENABLED=true -LLM_PROXY__PROVIDERS__GEMINI__ENABLED=true -LLM_PROXY__PROVIDERS__DEEPSEEK__ENABLED=true -LLM_PROXY__PROVIDERS__GROK__ENABLED=false -``` - -### Configuration File (config.toml) -Create `config.toml` in the application directory: -```toml -[server] -port = 8080 -host = "0.0.0.0" -auth_tokens = ["sk-test-123", "sk-test-456"] - -[database] -path = "./data/llm_proxy.db" -max_connections = 10 - -[providers.openai] -enabled = true -base_url = "https://api.openai.com/v1" -default_model = "gpt-4o" - -[providers.gemini] -enabled = true -base_url = "https://generativelanguage.googleapis.com/v1" -default_model = "gemini-2.0-flash" - -[providers.deepseek] -enabled = true -base_url = "https://api.deepseek.com" -default_model = "deepseek-reasoner" - -[providers.grok] -enabled = false -base_url = "https://api.x.ai/v1" -default_model = "grok-beta" -``` - -## Nginx Reverse Proxy Configuration - -**Important for SSE/Streaming:** Disable buffering and configure timeouts for proper SSE support. - -```nginx -server { - listen 80; - server_name llm-proxy.yourdomain.com; - - location / { - proxy_pass http://localhost:8080; - proxy_http_version 1.1; - - # SSE/Streaming support - proxy_buffering off; - chunked_transfer_encoding on; - proxy_set_header Connection ''; - - # Timeouts for long-running streams - proxy_connect_timeout 7200s; - proxy_read_timeout 7200s; - proxy_send_timeout 7200s; - - # Disable gzip for streaming - gzip off; - - # Headers - proxy_set_header Host $host; - proxy_set_header X-Real-IP $remote_addr; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header X-Forwarded-Proto $scheme; - } - - # SSL configuration (recommended) - listen 443 ssl http2; - ssl_certificate /etc/letsencrypt/live/llm-proxy.yourdomain.com/fullchain.pem; - ssl_certificate_key /etc/letsencrypt/live/llm-proxy.yourdomain.com/privkey.pem; -} -``` - -### NGINX Proxy Manager - -If using NGINX Proxy Manager, add this to **Advanced Settings**: - -```nginx -proxy_buffering off; -proxy_http_version 1.1; -proxy_set_header Connection ''; -chunked_transfer_encoding on; -proxy_connect_timeout 7200s; -proxy_read_timeout 7200s; -proxy_send_timeout 7200s; -gzip off; -``` - -## Security Considerations - -### 1. Authentication -- Use strong Bearer tokens -- Rotate tokens regularly -- Consider implementing JWT for production - -### 2. Rate Limiting -- Implement per-client rate limiting -- Consider using `governor` crate for advanced rate limiting - -### 3. Network Security -- Run behind reverse proxy (nginx) -- Enable HTTPS -- Restrict access by IP if needed -- Use firewall rules - -### 4. Data Security -- Database encryption (SQLCipher for SQLite) -- Secure API key storage -- Regular backups - -## Monitoring & Maintenance - -### Logging -- Application logs: `RUST_LOG=info` (or `debug` for troubleshooting) -- Access logs via nginx -- Database logs for audit trail - -### Health Checks -```bash -# Health endpoint -curl http://localhost:8080/health - -# Database check -sqlite3 ./data/llm_proxy.db "SELECT COUNT(*) FROM llm_requests;" -``` - -### Backup Strategy -```bash -#!/bin/bash -# backup.sh -BACKUP_DIR="/backups/llm-proxy" -DATE=$(date +%Y%m%d_%H%M%S) - -# Backup database -sqlite3 ./data/llm_proxy.db ".backup $BACKUP_DIR/llm_proxy_$DATE.db" - -# Backup configuration -cp config.toml $BACKUP_DIR/config_$DATE.toml - -# Rotate old backups (keep 30 days) -find $BACKUP_DIR -name "*.db" -mtime +30 -delete -find $BACKUP_DIR -name "*.toml" -mtime +30 -delete -``` - -## Performance Tuning - -### Database Optimization -```sql --- Run these SQL commands periodically -VACUUM; -ANALYZE; -``` - -### Memory Management -- Monitor memory usage with `htop` or `ps aux` -- Adjust `max_connections` based on load -- Consider connection pooling for high traffic - -### Scaling -1. **Vertical Scaling**: Increase container resources -2. **Horizontal Scaling**: Deploy multiple instances behind load balancer -3. **Database**: Migrate to PostgreSQL for high-volume usage - -## Troubleshooting - -### Common Issues - -1. **Port already in use** +1. **Mandatory Configuration:** + Create a `.env` file from the example: ```bash - netstat -tulpn | grep :8080 - kill # or change port in config + cp .env.example .env ``` + Ensure `LLM_PROXY__ENCRYPTION_KEY` is set to a secure 32-byte string. -2. **Database permissions** - ```bash - chown -R llmproxy:llmproxy /opt/llm-proxy/data - chmod 600 /opt/llm-proxy/data/llm_proxy.db - ``` +2. **Data Directory:** + The proxy stores its database in `./data/llm_proxy.db` by default. Ensure this directory exists and is writable. -3. **API key errors** - - Verify environment variables are set - - Check provider status (dashboard) - - Test connectivity: `curl https://api.openai.com/v1/models` +## Binary Deployment -4. **High memory usage** - - Check for memory leaks - - Reduce `max_connections` - - Implement connection timeouts - -### Debug Mode +### 1. Build ```bash -# Run with debug logging -RUST_LOG=debug ./llm-proxy - -# Check system logs -journalctl -u llm-proxy -f +go build -o llm-proxy ./cmd/llm-proxy ``` -## Integration - -### Open-WebUI Compatibility -The proxy provides OpenAI-compatible API, so configure Open-WebUI: -``` -API Base URL: http://your-proxy-address:8080 -API Key: sk-test-123 (or your configured token) +### 2. Run +```bash +./llm-proxy ``` -### Custom Clients -```python -import openai +## Docker Deployment -client = openai.OpenAI( - base_url="http://localhost:8080/v1", - api_key="sk-test-123" -) +The project includes a multi-stage `Dockerfile` for minimal image size. -response = client.chat.completions.create( - model="gpt-4", - messages=[{"role": "user", "content": "Hello"}] -) +### 1. Build Image +```bash +docker build -t llm-proxy . ``` -## Updates & Upgrades +### 2. Run Container +```bash +docker run -d \ + --name llm-proxy \ + -p 8080:8080 \ + -v $(pwd)/data:/app/data \ + --env-file .env \ + llm-proxy +``` -1. **Backup** current configuration and database -2. **Stop** the service: `systemctl stop llm-proxy` -3. **Update** code: `git pull` or copy new binaries -4. **Migrate** database if needed (check migrations/) -5. **Restart**: `systemctl start llm-proxy` -6. **Verify**: Check logs and test endpoints +## Production Considerations -## Support -- Check logs in `/var/log/llm-proxy/` -- Monitor dashboard at `http://your-server:8080` -- Review database metrics in dashboard -- Enable debug logging for troubleshooting \ No newline at end of file +- **SSL/TLS:** It is recommended to run the proxy behind a reverse proxy like Nginx or Caddy for SSL termination. +- **Backups:** Regularly backup the `data/llm_proxy.db` file. +- **Monitoring:** Monitor the `/health` endpoint for system status. diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..55f66ee6 --- /dev/null +++ b/go.mod @@ -0,0 +1,62 @@ +module llm-proxy + +go 1.26.1 + +require ( + github.com/gin-gonic/gin v1.12.0 + github.com/go-resty/resty/v2 v2.17.2 + github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.3 + github.com/jmoiron/sqlx v1.4.0 + github.com/joho/godotenv v1.5.1 + github.com/spf13/viper v1.21.0 + golang.org/x/crypto v0.48.0 + modernc.org/sqlite v1.47.0 +) + +require ( + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.15.0 // indirect + github.com/bytedance/sonic/loader v0.5.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.12 // indirect + github.com/gin-contrib/sse v1.1.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.30.1 // indirect + github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/goccy/go-json v0.10.5 // indirect + github.com/goccy/go-yaml v1.19.2 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/quic-go/qpack v0.6.0 // indirect + github.com/quic-go/quic-go v0.59.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/sagikazarmark/locafero v0.11.0 // indirect + github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect + github.com/spf13/afero v1.15.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.3.1 // indirect + go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect + go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/arch v0.22.0 // indirect + golang.org/x/net v0.51.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/text v0.34.0 // indirect + golang.org/x/time v0.15.0 // indirect + google.golang.org/protobuf v1.36.10 // indirect + modernc.org/libc v1.70.0 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..93a0c3b3 --- /dev/null +++ b/go.sum @@ -0,0 +1,183 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= +github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= +github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= +github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw= +github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= +github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w= +github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM= +github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8= +github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w= +github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM= +github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk= +github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= +github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= +github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= +github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= +github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= +github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= +github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= +github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= +github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= +github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= +github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= +github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY= +github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= +go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE= +go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0= +go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= +go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI= +golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= +golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= +google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= +google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.32.0 h1:hjG66bI/kqIPX1b2yT6fr/jt+QedtP2fqojG2VrFuVw= +modernc.org/ccgo/v4 v4.32.0/go.mod h1:6F08EBCx5uQc38kMGl+0Nm0oWczoo1c7cgpzEry7Uc0= +modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM= +modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo= +modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw= +modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.47.0 h1:R1XyaNpoW4Et9yly+I2EeX7pBza/w+pmYee/0HJDyKk= +modernc.org/sqlite v1.47.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 00000000..aa3237b6 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,174 @@ +package config + +import ( + "encoding/base64" + "encoding/hex" + "fmt" + "os" + "strings" + + "github.com/spf13/viper" +) + +type Config struct { + Server ServerConfig `mapstructure:"server"` + Database DatabaseConfig `mapstructure:"database"` + Providers ProviderConfig `mapstructure:"providers"` + EncryptionKey string `mapstructure:"encryption_key"` + KeyBytes []byte +} + +type ServerConfig struct { + Port int `mapstructure:"port"` + Host string `mapstructure:"host"` + AuthTokens []string `mapstructure:"auth_tokens"` +} + +type DatabaseConfig struct { + Path string `mapstructure:"path"` + MaxConnections int `mapstructure:"max_connections"` +} + +type ProviderConfig struct { + OpenAI OpenAIConfig `mapstructure:"openai"` + Gemini GeminiConfig `mapstructure:"gemini"` + DeepSeek DeepSeekConfig `mapstructure:"deepseek"` + Grok GrokConfig `mapstructure:"grok"` + Ollama OllamaConfig `mapstructure:"ollama"` +} + +type OpenAIConfig struct { + APIKeyEnv string `mapstructure:"api_key_env"` + BaseURL string `mapstructure:"base_url"` + DefaultModel string `mapstructure:"default_model"` + Enabled bool `mapstructure:"enabled"` +} + +type GeminiConfig struct { + APIKeyEnv string `mapstructure:"api_key_env"` + BaseURL string `mapstructure:"base_url"` + DefaultModel string `mapstructure:"default_model"` + Enabled bool `mapstructure:"enabled"` +} + +type DeepSeekConfig struct { + APIKeyEnv string `mapstructure:"api_key_env"` + BaseURL string `mapstructure:"base_url"` + DefaultModel string `mapstructure:"default_model"` + Enabled bool `mapstructure:"enabled"` +} + +type GrokConfig struct { + APIKeyEnv string `mapstructure:"api_key_env"` + BaseURL string `mapstructure:"base_url"` + DefaultModel string `mapstructure:"default_model"` + Enabled bool `mapstructure:"enabled"` +} + +type OllamaConfig struct { + BaseURL string `mapstructure:"base_url"` + Enabled bool `mapstructure:"enabled"` + DefaultModel string `mapstructure:"default_model"` + Models []string `mapstructure:"models"` +} + +func Load() (*Config, error) { + v := viper.New() + + // Defaults + v.SetDefault("server.port", 8080) + v.SetDefault("server.host", "0.0.0.0") + v.SetDefault("server.auth_tokens", []string{}) + v.SetDefault("database.path", "./data/llm_proxy.db") + v.SetDefault("database.max_connections", 10) + + v.SetDefault("providers.openai.api_key_env", "OPENAI_API_KEY") + v.SetDefault("providers.openai.base_url", "https://api.openai.com/v1") + v.SetDefault("providers.openai.default_model", "gpt-4o") + v.SetDefault("providers.openai.enabled", true) + + v.SetDefault("providers.gemini.api_key_env", "GEMINI_API_KEY") + v.SetDefault("providers.gemini.base_url", "https://generativelanguage.googleapis.com/v1") + v.SetDefault("providers.gemini.default_model", "gemini-2.0-flash") + v.SetDefault("providers.gemini.enabled", true) + + v.SetDefault("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY") + v.SetDefault("providers.deepseek.base_url", "https://api.deepseek.com") + v.SetDefault("providers.deepseek.default_model", "deepseek-reasoner") + v.SetDefault("providers.deepseek.enabled", true) + + v.SetDefault("providers.grok.api_key_env", "GROK_API_KEY") + v.SetDefault("providers.grok.base_url", "https://api.x.ai/v1") + v.SetDefault("providers.grok.default_model", "grok-beta") + v.SetDefault("providers.grok.enabled", true) + + v.SetDefault("providers.ollama.base_url", "http://localhost:11434/v1") + v.SetDefault("providers.ollama.enabled", false) + v.SetDefault("providers.ollama.models", []string{}) + + // Environment variables + v.SetEnvPrefix("LLM_PROXY") + v.SetEnvKeyReplacer(strings.NewReplacer(".", "__")) + v.AutomaticEnv() + + // Config file + v.SetConfigName("config") + v.SetConfigType("toml") + v.AddConfigPath(".") + if envPath := os.Getenv("LLM_PROXY__CONFIG_PATH"); envPath != "" { + v.SetConfigFile(envPath) + } + + if err := v.ReadInConfig(); err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); !ok { + return nil, fmt.Errorf("failed to read config file: %w", err) + } + } + + var cfg Config + if err := v.Unmarshal(&cfg); err != nil { + return nil, fmt.Errorf("failed to unmarshal config: %w", err) + } + + // Validate encryption key + if cfg.EncryptionKey == "" { + return nil, fmt.Errorf("encryption key is required (LLM_PROXY__ENCRYPTION_KEY)") + } + + keyBytes, err := hex.DecodeString(cfg.EncryptionKey) + if err != nil { + keyBytes, err = base64.StdEncoding.DecodeString(cfg.EncryptionKey) + if err != nil { + return nil, fmt.Errorf("encryption key must be hex or base64 encoded") + } + } + + if len(keyBytes) != 32 { + return nil, fmt.Errorf("encryption key must be 32 bytes, got %d", len(keyBytes)) + } + cfg.KeyBytes = keyBytes + + return &cfg, nil +} + +func (c *Config) GetAPIKey(provider string) (string, error) { + var envVar string + switch provider { + case "openai": + envVar = c.Providers.OpenAI.APIKeyEnv + case "gemini": + envVar = c.Providers.Gemini.APIKeyEnv + case "deepseek": + envVar = c.Providers.DeepSeek.APIKeyEnv + case "grok": + envVar = c.Providers.Grok.APIKeyEnv + default: + return "", fmt.Errorf("unknown provider: %s", provider) + } + + val := os.Getenv(envVar) + if val == "" { + return "", fmt.Errorf("environment variable %s not set for %s", envVar, provider) + } + return val, nil +} diff --git a/internal/db/db.go b/internal/db/db.go new file mode 100644 index 00000000..508358e8 --- /dev/null +++ b/internal/db/db.go @@ -0,0 +1,264 @@ +package db + +import ( + "fmt" + "log" + "os" + "path/filepath" + "time" + + "github.com/jmoiron/sqlx" + _ "modernc.org/sqlite" + "golang.org/x/crypto/bcrypt" +) + +type DB struct { + *sqlx.DB +} + +func Init(path string) (*DB, error) { + // Ensure directory exists + dir := filepath.Dir(path) + if dir != "." && dir != "/" { + if err := os.MkdirAll(dir, 0755); err != nil { + return nil, fmt.Errorf("failed to create database directory: %w", err) + } + } + + // Connect to SQLite + dsn := fmt.Sprintf("file:%s?_pragma=foreign_keys(1)", path) + db, err := sqlx.Connect("sqlite", dsn) + if err != nil { + return nil, fmt.Errorf("failed to connect to database: %w", err) + } + + instance := &DB{db} + + // Run migrations + if err := instance.RunMigrations(); err != nil { + return nil, fmt.Errorf("failed to run migrations: %w", err) + } + + return instance, nil +} + +func (db *DB) RunMigrations() error { + // Tables creation + queries := []string{ + `CREATE TABLE IF NOT EXISTS clients ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + client_id TEXT UNIQUE NOT NULL, + name TEXT, + description TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + is_active BOOLEAN DEFAULT TRUE, + rate_limit_per_minute INTEGER DEFAULT 60, + total_requests INTEGER DEFAULT 0, + total_tokens INTEGER DEFAULT 0, + total_cost REAL DEFAULT 0.0 + )`, + `CREATE TABLE IF NOT EXISTS llm_requests ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, + client_id TEXT, + provider TEXT, + model TEXT, + prompt_tokens INTEGER, + completion_tokens INTEGER, + reasoning_tokens INTEGER DEFAULT 0, + total_tokens INTEGER, + cost REAL, + has_images BOOLEAN DEFAULT FALSE, + status TEXT DEFAULT 'success', + error_message TEXT, + duration_ms INTEGER, + request_body TEXT, + response_body TEXT, + cache_read_tokens INTEGER DEFAULT 0, + cache_write_tokens INTEGER DEFAULT 0, + FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE SET NULL + )`, + `CREATE TABLE IF NOT EXISTS provider_configs ( + id TEXT PRIMARY KEY, + display_name TEXT NOT NULL, + enabled BOOLEAN DEFAULT TRUE, + base_url TEXT, + api_key TEXT, + credit_balance REAL DEFAULT 0.0, + low_credit_threshold REAL DEFAULT 5.0, + billing_mode TEXT, + api_key_encrypted BOOLEAN DEFAULT FALSE, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + )`, + `CREATE TABLE IF NOT EXISTS model_configs ( + id TEXT PRIMARY KEY, + provider_id TEXT NOT NULL, + display_name TEXT, + enabled BOOLEAN DEFAULT TRUE, + prompt_cost_per_m REAL, + completion_cost_per_m REAL, + cache_read_cost_per_m REAL, + cache_write_cost_per_m REAL, + mapping TEXT, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE + )`, + `CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT UNIQUE NOT NULL, + password_hash TEXT NOT NULL, + display_name TEXT, + role TEXT DEFAULT 'admin', + must_change_password BOOLEAN DEFAULT FALSE, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + )`, + `CREATE TABLE IF NOT EXISTS client_tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + client_id TEXT NOT NULL, + token TEXT NOT NULL UNIQUE, + name TEXT DEFAULT 'default', + is_active BOOLEAN DEFAULT TRUE, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + last_used_at DATETIME, + FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE CASCADE + )`, + } + + for _, q := range queries { + if _, err := db.Exec(q); err != nil { + return fmt.Errorf("migration failed for query [%s]: %w", q, err) + } + } + + // Add indices + indices := []string{ + "CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)", + "CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)", + "CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)", + "CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)", + "CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)", + "CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_client_tokens_token ON client_tokens(token)", + "CREATE INDEX IF NOT EXISTS idx_client_tokens_client_id ON client_tokens(client_id)", + "CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp)", + "CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp)", + "CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id)", + } + + for _, idx := range indices { + if _, err := db.Exec(idx); err != nil { + return fmt.Errorf("failed to create index [%s]: %w", idx, err) + } + } + + // Default admin user + var count int + if err := db.Get(&count, "SELECT COUNT(*) FROM users"); err != nil { + return fmt.Errorf("failed to count users: %w", err) + } + + if count == 0 { + hash, err := bcrypt.GenerateFromPassword([]byte("admin"), 12) + if err != nil { + return fmt.Errorf("failed to hash default password: %w", err) + } + _, err = db.Exec("INSERT INTO users (username, password_hash, role, must_change_password) VALUES ('admin', ?, 'admin', 1)", string(hash)) + if err != nil { + return fmt.Errorf("failed to insert default admin: %w", err) + } + log.Println("Created default admin user with password 'admin' (must change on first login)") + } + + // Default client + _, err := db.Exec(`INSERT OR IGNORE INTO clients (client_id, name, description) + VALUES ('default', 'Default Client', 'Default client for anonymous requests')`) + if err != nil { + return fmt.Errorf("failed to insert default client: %w", err) + } + + return nil +} + +// Data models for DB tables + +type Client struct { + ID int `db:"id"` + ClientID string `db:"client_id"` + Name *string `db:"name"` + Description *string `db:"description"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` + IsActive bool `db:"is_active"` + RateLimitPerMinute int `db:"rate_limit_per_minute"` + TotalRequests int `db:"total_requests"` + TotalTokens int `db:"total_tokens"` + TotalCost float64 `db:"total_cost"` +} + +type LLMRequest struct { + ID int `db:"id"` + Timestamp time.Time `db:"timestamp"` + ClientID *string `db:"client_id"` + Provider *string `db:"provider"` + Model *string `db:"model"` + PromptTokens *int `db:"prompt_tokens"` + CompletionTokens *int `db:"completion_tokens"` + ReasoningTokens int `db:"reasoning_tokens"` + TotalTokens *int `db:"total_tokens"` + Cost *float64 `db:"cost"` + HasImages bool `db:"has_images"` + Status string `db:"status"` + ErrorMessage *string `db:"error_message"` + DurationMS *int `db:"duration_ms"` + RequestBody *string `db:"request_body"` + ResponseBody *string `db:"response_body"` + CacheReadTokens int `db:"cache_read_tokens"` + CacheWriteTokens int `db:"cache_write_tokens"` +} + +type ProviderConfig struct { + ID string `db:"id"` + DisplayName string `db:"display_name"` + Enabled bool `db:"enabled"` + BaseURL *string `db:"base_url"` + APIKey *string `db:"api_key"` + CreditBalance float64 `db:"credit_balance"` + LowCreditThreshold float64 `db:"low_credit_threshold"` + BillingMode *string `db:"billing_mode"` + APIKeyEncrypted bool `db:"api_key_encrypted"` + UpdatedAt time.Time `db:"updated_at"` +} + +type ModelConfig struct { + ID string `db:"id"` + ProviderID string `db:"provider_id"` + DisplayName *string `db:"display_name"` + Enabled bool `db:"enabled"` + PromptCostPerM *float64 `db:"prompt_cost_per_m"` + CompletionCostPerM *float64 `db:"completion_cost_per_m"` + CacheReadCostPerM *float64 `db:"cache_read_cost_per_m"` + CacheWriteCostPerM *float64 `db:"cache_write_cost_per_m"` + Mapping *string `db:"mapping"` + UpdatedAt time.Time `db:"updated_at"` +} + +type User struct { + ID int `db:"id"` + Username string `db:"username"` + PasswordHash string `db:"password_hash"` + DisplayName *string `db:"display_name"` + Role string `db:"role"` + MustChangePassword bool `db:"must_change_password"` + CreatedAt time.Time `db:"created_at"` +} + +type ClientToken struct { + ID int `db:"id"` + ClientID string `db:"client_id"` + Token string `db:"token"` + Name string `db:"name"` + IsActive bool `db:"is_active"` + CreatedAt time.Time `db:"created_at"` + LastUsedAt *time.Time `db:"last_used_at"` +} diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go new file mode 100644 index 00000000..706e56d6 --- /dev/null +++ b/internal/middleware/auth.go @@ -0,0 +1,52 @@ +package middleware + +import ( + "log" + "strings" + + "llm-proxy/internal/db" + "llm-proxy/internal/models" + + "github.com/gin-gonic/gin" +) + +func AuthMiddleware(database *db.DB) gin.HandlerFunc { + return func(c *gin.Context) { + authHeader := c.GetHeader("Authorization") + if authHeader == "" { + c.Next() + return + } + + token := strings.TrimPrefix(authHeader, "Bearer ") + if token == authHeader { // No "Bearer " prefix + c.Next() + return + } + + // Try to resolve client from database + var clientID string + err := database.Get(&clientID, "UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = 1 RETURNING client_id", token) + + if err == nil { + c.Set("auth", models.AuthInfo{ + Token: token, + ClientID: clientID, + }) + } else { + // Fallback to token-prefix derivation (matches Rust behavior) + prefixLen := len(token) + if prefixLen > 8 { + prefixLen = 8 + } + clientID = "client_" + token[:prefixLen] + c.Set("auth", models.AuthInfo{ + Token: token, + ClientID: clientID, + }) + log.Printf("Token not found in DB, using fallback client ID: %s", clientID) + } + + c.Next() + } +} diff --git a/internal/models/models.go b/internal/models/models.go new file mode 100644 index 00000000..921bd206 --- /dev/null +++ b/internal/models/models.go @@ -0,0 +1,216 @@ +package models + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "github.com/go-resty/resty/v2" +) + +// OpenAI-compatible Request/Response Structs + +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK *uint32 `json:"top_k,omitempty"` + N *uint32 `json:"n,omitempty"` + Stop json.RawMessage `json:"stop,omitempty"` // Can be string or array of strings + MaxTokens *uint32 `json:"max_tokens,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + Stream *bool `json:"stream,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` +} + +type ChatMessage struct { + Role string `json:"role"` // "system", "user", "assistant", "tool" + Content interface{} `json:"content"` + ReasoningContent *string `json:"reasoning_content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + Name *string `json:"name,omitempty"` + ToolCallID *string `json:"tool_call_id,omitempty"` +} + +type ContentPart struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + ImageUrl *ImageUrl `json:"image_url,omitempty"` +} + +type ImageUrl struct { + URL string `json:"url"` + Detail *string `json:"detail,omitempty"` +} + +// Tool-Calling Types + +type Tool struct { + Type string `json:"type"` + Function FunctionDef `json:"function"` +} + +type FunctionDef struct { + Name string `json:"name"` + Description *string `json:"description,omitempty"` + Parameters json.RawMessage `json:"parameters,omitempty"` +} + +type ToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Function FunctionCall `json:"function"` +} + +type FunctionCall struct { + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +type ToolCallDelta struct { + Index uint32 `json:"index"` + ID *string `json:"id,omitempty"` + Type *string `json:"type,omitempty"` + Function *FunctionCallDelta `json:"function,omitempty"` +} + +type FunctionCallDelta struct { + Name *string `json:"name,omitempty"` + Arguments *string `json:"arguments,omitempty"` +} + +// OpenAI-compatible Response Structs + +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` +} + +type ChatChoice struct { + Index uint32 `json:"index"` + Message ChatMessage `json:"message"` + FinishReason *string `json:"finish_reason,omitempty"` +} + +type Usage struct { + PromptTokens uint32 `json:"prompt_tokens"` + CompletionTokens uint32 `json:"completion_tokens"` + TotalTokens uint32 `json:"total_tokens"` + ReasoningTokens *uint32 `json:"reasoning_tokens,omitempty"` + CacheReadTokens *uint32 `json:"cache_read_tokens,omitempty"` + CacheWriteTokens *uint32 `json:"cache_write_tokens,omitempty"` +} + +// Streaming Response Structs + +type ChatCompletionStreamResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatStreamChoice `json:"choices"` + Usage *Usage `json:"usage,omitempty"` +} + +type ChatStreamChoice struct { + Index uint32 `json:"index"` + Delta ChatStreamDelta `json:"delta"` + FinishReason *string `json:"finish_reason,omitempty"` +} + +type ChatStreamDelta struct { + Role *string `json:"role,omitempty"` + Content *string `json:"content,omitempty"` + ReasoningContent *string `json:"reasoning_content,omitempty"` + ToolCalls []ToolCallDelta `json:"tool_calls,omitempty"` +} + +type StreamUsage struct { + PromptTokens uint32 `json:"prompt_tokens"` + CompletionTokens uint32 `json:"completion_tokens"` + TotalTokens uint32 `json:"total_tokens"` + ReasoningTokens uint32 `json:"reasoning_tokens"` + CacheReadTokens uint32 `json:"cache_read_tokens"` + CacheWriteTokens uint32 `json:"cache_write_tokens"` +} + +// Unified Request Format (for internal use) + +type UnifiedRequest struct { + ClientID string + Model string + Messages []UnifiedMessage + Temperature *float64 + TopP *float64 + TopK *uint32 + N *uint32 + Stop []string + MaxTokens *uint32 + PresencePenalty *float64 + FrequencyPenalty *float64 + Stream bool + HasImages bool + Tools []Tool + ToolChoice json.RawMessage +} + +type UnifiedMessage struct { + Role string + Content []UnifiedContentPart + ReasoningContent *string + ToolCalls []ToolCall + Name *string + ToolCallID *string +} + +type UnifiedContentPart struct { + Type string + Text string + Image *ImageInput +} + +type ImageInput struct { + Base64 string `json:"base64,omitempty"` + URL string `json:"url,omitempty"` + MimeType string `json:"mime_type,omitempty"` +} + +func (i *ImageInput) ToBase64() (string, string, error) { + if i.Base64 != "" { + return i.Base64, i.MimeType, nil + } + + if i.URL != "" { + client := resty.New() + resp, err := client.R().Get(i.URL) + if err != nil { + return "", "", fmt.Errorf("failed to fetch image: %w", err) + } + + if !resp.IsSuccess() { + return "", "", fmt.Errorf("failed to fetch image: HTTP %d", resp.StatusCode()) + } + + mimeType := resp.Header().Get("Content-Type") + if mimeType == "" { + mimeType = "image/jpeg" + } + + encoded := base64.StdEncoding.EncodeToString(resp.Body()) + return encoded, mimeType, nil + } + + return "", "", fmt.Errorf("empty image input") +} + +// AuthInfo for context +type AuthInfo struct { + Token string + ClientID string +} diff --git a/internal/providers/deepseek.go b/internal/providers/deepseek.go new file mode 100644 index 00000000..e145cbda --- /dev/null +++ b/internal/providers/deepseek.go @@ -0,0 +1,143 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + + "llm-proxy/internal/config" + "llm-proxy/internal/models" + "github.com/go-resty/resty/v2" +) + +type DeepSeekProvider struct { + client *resty.Client + config config.DeepSeekConfig + apiKey string +} + +func NewDeepSeekProvider(cfg config.DeepSeekConfig, apiKey string) *DeepSeekProvider { + return &DeepSeekProvider{ + client: resty.New(), + config: cfg, + apiKey: apiKey, + } +} + +func (p *DeepSeekProvider) Name() string { + return "deepseek" +} + +func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) { + messagesJSON, err := MessagesToOpenAIJSON(req.Messages) + if err != nil { + return nil, fmt.Errorf("failed to convert messages: %w", err) + } + + body := BuildOpenAIBody(req, messagesJSON, false) + + // Sanitize for deepseek-reasoner + if req.Model == "deepseek-reasoner" { + delete(body, "temperature") + delete(body, "top_p") + delete(body, "presence_penalty") + delete(body, "frequency_penalty") + + // Ensure assistant messages have content and reasoning_content + if msgs, ok := body["messages"].([]interface{}); ok { + for _, m := range msgs { + if msg, ok := m.(map[string]interface{}); ok { + if msg["role"] == "assistant" { + if msg["reasoning_content"] == nil { + msg["reasoning_content"] = " " + } + if msg["content"] == nil || msg["content"] == "" { + msg["content"] = "" + } + } + } + } + } + } + + resp, err := p.client.R(). + SetContext(ctx). + SetHeader("Authorization", "Bearer "+p.apiKey). + SetBody(body). + Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL)) + + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if !resp.IsSuccess() { + return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String()) + } + + var respJSON map[string]interface{} + if err := json.Unmarshal(resp.Body(), &respJSON); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return ParseOpenAIResponse(respJSON, req.Model) +} + +func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) { + messagesJSON, err := MessagesToOpenAIJSON(req.Messages) + if err != nil { + return nil, fmt.Errorf("failed to convert messages: %w", err) + } + + body := BuildOpenAIBody(req, messagesJSON, true) + + // Sanitize for deepseek-reasoner + if req.Model == "deepseek-reasoner" { + delete(body, "temperature") + delete(body, "top_p") + delete(body, "presence_penalty") + delete(body, "frequency_penalty") + + // Ensure assistant messages have content and reasoning_content + if msgs, ok := body["messages"].([]interface{}); ok { + for _, m := range msgs { + if msg, ok := m.(map[string]interface{}); ok { + if msg["role"] == "assistant" { + if msg["reasoning_content"] == nil { + msg["reasoning_content"] = " " + } + if msg["content"] == nil || msg["content"] == "" { + msg["content"] = "" + } + } + } + } + } + } + + resp, err := p.client.R(). + SetContext(ctx). + SetHeader("Authorization", "Bearer "+p.apiKey). + SetBody(body). + SetDoNotParseResponse(true). + Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL)) + + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if !resp.IsSuccess() { + return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String()) + } + + ch := make(chan *models.ChatCompletionStreamResponse) + + go func() { + defer close(ch) + err := StreamOpenAI(resp.RawBody(), ch) + if err != nil { + fmt.Printf("DeepSeek Stream error: %v\n", err) + } + }() + + return ch, nil +} diff --git a/internal/providers/gemini.go b/internal/providers/gemini.go new file mode 100644 index 00000000..14a4d396 --- /dev/null +++ b/internal/providers/gemini.go @@ -0,0 +1,254 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + + "llm-proxy/internal/config" + "llm-proxy/internal/models" + "github.com/go-resty/resty/v2" +) + +type GeminiProvider struct { + client *resty.Client + config config.GeminiConfig + apiKey string +} + +func NewGeminiProvider(cfg config.GeminiConfig, apiKey string) *GeminiProvider { + return &GeminiProvider{ + client: resty.New(), + config: cfg, + apiKey: apiKey, + } +} + +func (p *GeminiProvider) Name() string { + return "gemini" +} + +type GeminiRequest struct { + Contents []GeminiContent `json:"contents"` +} + +type GeminiContent struct { + Role string `json:"role,omitempty"` + Parts []GeminiPart `json:"parts"` +} + +type GeminiPart struct { + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` + FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"` + FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"` +} + +type GeminiInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` +} + +type GeminiFunctionCall struct { + Name string `json:"name"` + Args json.RawMessage `json:"args"` +} + +type GeminiFunctionResponse struct { + Name string `json:"name"` + Response json.RawMessage `json:"response"` +} + +func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) { + // Gemini mapping + var contents []GeminiContent + for _, msg := range req.Messages { + role := "user" + if msg.Role == "assistant" { + role = "model" + } else if msg.Role == "tool" { + role = "user" // Tool results are user-side in Gemini + } + + var parts []GeminiPart + + // Handle tool responses + if msg.Role == "tool" { + text := "" + if len(msg.Content) > 0 { + text = msg.Content[0].Text + } + + // Gemini expects functionResponse to be an object + name := "unknown_function" + if msg.Name != nil { + name = *msg.Name + } + + parts = append(parts, GeminiPart{ + FunctionResponse: &GeminiFunctionResponse{ + Name: name, + Response: json.RawMessage(text), + }, + }) + } else { + for _, cp := range msg.Content { + if cp.Type == "text" { + parts = append(parts, GeminiPart{Text: cp.Text}) + } else if cp.Image != nil { + base64Data, mimeType, _ := cp.Image.ToBase64() + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: mimeType, + Data: base64Data, + }, + }) + } + } + + // Handle assistant tool calls + if msg.Role == "assistant" && len(msg.ToolCalls) > 0 { + for _, tc := range msg.ToolCalls { + parts = append(parts, GeminiPart{ + FunctionCall: &GeminiFunctionCall{ + Name: tc.Function.Name, + Args: json.RawMessage(tc.Function.Arguments), + }, + }) + } + } + } + + contents = append(contents, GeminiContent{ + Role: role, + Parts: parts, + }) + } + + body := GeminiRequest{ + Contents: contents, + } + + url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey) + + resp, err := p.client.R(). + SetContext(ctx). + SetBody(body). + Post(url) + + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if !resp.IsSuccess() { + return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String()) + } + + // Parse Gemini response and convert to OpenAI format + var geminiResp struct { + Candidates []struct { + Content struct { + Parts []struct { + Text string `json:"text"` + } `json:"parts"` + } `json:"content"` + FinishReason string `json:"finishReason"` + } `json:"candidates"` + UsageMetadata struct { + PromptTokenCount uint32 `json:"promptTokenCount"` + CandidatesTokenCount uint32 `json:"candidatesTokenCount"` + TotalTokenCount uint32 `json:"totalTokenCount"` + } `json:"usageMetadata"` + } + + if err := json.Unmarshal(resp.Body(), &geminiResp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + if len(geminiResp.Candidates) == 0 { + return nil, fmt.Errorf("no candidates in Gemini response") + } + + content := "" + for _, p := range geminiResp.Candidates[0].Content.Parts { + content += p.Text + } + + openAIResp := &models.ChatCompletionResponse{ + ID: "gemini-" + req.Model, + Object: "chat.completion", + Created: 0, // Should be current timestamp + Model: req.Model, + Choices: []models.ChatChoice{ + { + Index: 0, + Message: models.ChatMessage{ + Role: "assistant", + Content: content, + }, + FinishReason: &geminiResp.Candidates[0].FinishReason, + }, + }, + Usage: &models.Usage{ + PromptTokens: geminiResp.UsageMetadata.PromptTokenCount, + CompletionTokens: geminiResp.UsageMetadata.CandidatesTokenCount, + TotalTokens: geminiResp.UsageMetadata.TotalTokenCount, + }, + } + + return openAIResp, nil +} + +func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) { + // Simplified Gemini mapping + var contents []GeminiContent + for _, msg := range req.Messages { + role := "user" + if msg.Role == "assistant" { + role = "model" + } + + var parts []GeminiPart + for _, p := range msg.Content { + parts = append(parts, GeminiPart{Text: p.Text}) + } + + contents = append(contents, GeminiContent{ + Role: role, + Parts: parts, + }) + } + + body := GeminiRequest{ + Contents: contents, + } + + // Use streamGenerateContent for streaming + url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey) + + resp, err := p.client.R(). + SetContext(ctx). + SetBody(body). + SetDoNotParseResponse(true). + Post(url) + + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if !resp.IsSuccess() { + return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String()) + } + + ch := make(chan *models.ChatCompletionStreamResponse) + + go func() { + defer close(ch) + err := StreamGemini(resp.RawBody(), ch, req.Model) + if err != nil { + fmt.Printf("Gemini Stream error: %v\n", err) + } + }() + + return ch, nil +} diff --git a/internal/providers/grok.go b/internal/providers/grok.go new file mode 100644 index 00000000..a14ed62b --- /dev/null +++ b/internal/providers/grok.go @@ -0,0 +1,95 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + + "llm-proxy/internal/config" + "llm-proxy/internal/models" + "github.com/go-resty/resty/v2" +) + +type GrokProvider struct { + client *resty.Client + config config.GrokConfig + apiKey string +} + +func NewGrokProvider(cfg config.GrokConfig, apiKey string) *GrokProvider { + return &GrokProvider{ + client: resty.New(), + config: cfg, + apiKey: apiKey, + } +} + +func (p *GrokProvider) Name() string { + return "grok" +} + +func (p *GrokProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) { + messagesJSON, err := MessagesToOpenAIJSON(req.Messages) + if err != nil { + return nil, fmt.Errorf("failed to convert messages: %w", err) + } + + body := BuildOpenAIBody(req, messagesJSON, false) + + resp, err := p.client.R(). + SetContext(ctx). + SetHeader("Authorization", "Bearer "+p.apiKey). + SetBody(body). + Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL)) + + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if !resp.IsSuccess() { + return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String()) + } + + var respJSON map[string]interface{} + if err := json.Unmarshal(resp.Body(), &respJSON); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return ParseOpenAIResponse(respJSON, req.Model) +} + +func (p *GrokProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) { + messagesJSON, err := MessagesToOpenAIJSON(req.Messages) + if err != nil { + return nil, fmt.Errorf("failed to convert messages: %w", err) + } + + body := BuildOpenAIBody(req, messagesJSON, true) + + resp, err := p.client.R(). + SetContext(ctx). + SetHeader("Authorization", "Bearer "+p.apiKey). + SetBody(body). + SetDoNotParseResponse(true). + Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL)) + + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if !resp.IsSuccess() { + return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String()) + } + + ch := make(chan *models.ChatCompletionStreamResponse) + + go func() { + defer close(ch) + err := StreamOpenAI(resp.RawBody(), ch) + if err != nil { + fmt.Printf("Grok Stream error: %v\n", err) + } + }() + + return ch, nil +} diff --git a/internal/providers/helpers.go b/internal/providers/helpers.go new file mode 100644 index 00000000..a66c1537 --- /dev/null +++ b/internal/providers/helpers.go @@ -0,0 +1,259 @@ +package providers + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "strings" + + "llm-proxy/internal/models" +) + +// MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images. +func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) { + var result []interface{} + for _, m := range messages { + if m.Role == "tool" { + text := "" + if len(m.Content) > 0 { + text = m.Content[0].Text + } + msg := map[string]interface{}{ + "role": "tool", + "content": text, + } + if m.ToolCallID != nil { + id := *m.ToolCallID + if len(id) > 40 { + id = id[:40] + } + msg["tool_call_id"] = id + } + if m.Name != nil { + msg["name"] = *m.Name + } + result = append(result, msg) + continue + } + + var parts []interface{} + for _, p := range m.Content { + if p.Type == "text" { + parts = append(parts, map[string]interface{}{ + "type": "text", + "text": p.Text, + }) + } else if p.Image != nil { + base64Data, mimeType, err := p.Image.ToBase64() + if err != nil { + return nil, fmt.Errorf("failed to convert image to base64: %w", err) + } + parts = append(parts, map[string]interface{}{ + "type": "image_url", + "image_url": map[string]interface{}{ + "url": fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data), + }, + }) + } + } + + msg := map[string]interface{}{ + "role": m.Role, + "content": parts, + } + + if m.ReasoningContent != nil { + msg["reasoning_content"] = *m.ReasoningContent + } + + if len(m.ToolCalls) > 0 { + sanitizedCalls := make([]models.ToolCall, len(m.ToolCalls)) + copy(sanitizedCalls, m.ToolCalls) + for i := range sanitizedCalls { + if len(sanitizedCalls[i].ID) > 40 { + sanitizedCalls[i].ID = sanitizedCalls[i].ID[:40] + } + } + msg["tool_calls"] = sanitizedCalls + if len(parts) == 0 { + msg["content"] = "" + } + } + + if m.Name != nil { + msg["name"] = *m.Name + } + + result = append(result, msg) + } + return result, nil +} + +func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{}, stream bool) map[string]interface{} { + body := map[string]interface{}{ + "model": request.Model, + "messages": messagesJSON, + "stream": stream, + } + + if stream { + body["stream_options"] = map[string]interface{}{ + "include_usage": true, + } + } + + if request.Temperature != nil { + body["temperature"] = *request.Temperature + } + if request.MaxTokens != nil { + body["max_tokens"] = *request.MaxTokens + } + if len(request.Tools) > 0 { + body["tools"] = request.Tools + } + if request.ToolChoice != nil { + var toolChoice interface{} + if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil { + body["tool_choice"] = toolChoice + } + } + + return body +} + +func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models.ChatCompletionResponse, error) { + data, err := json.Marshal(respJSON) + if err != nil { + return nil, err + } + + var resp models.ChatCompletionResponse + if err := json.Unmarshal(data, &resp); err != nil { + return nil, err + } + + return &resp, nil +} + +// Streaming support + +func ParseOpenAIStreamChunk(line string) (*models.ChatCompletionStreamResponse, bool, error) { + if line == "" { + return nil, false, nil + } + if !strings.HasPrefix(line, "data: ") { + return nil, false, nil + } + + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + return nil, true, nil + } + + var chunk models.ChatCompletionStreamResponse + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + return nil, false, fmt.Errorf("failed to unmarshal stream chunk: %w", err) + } + + return &chunk, false, nil +} + +func StreamOpenAI(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse) error { + defer ctx.Close() + scanner := bufio.NewScanner(ctx) + for scanner.Scan() { + line := scanner.Text() + chunk, done, err := ParseOpenAIStreamChunk(line) + if err != nil { + return err + } + if done { + break + } + if chunk != nil { + ch <- chunk + } + } + return scanner.Err() +} + +func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error { + defer ctx.Close() + + dec := json.NewDecoder(ctx) + + t, err := dec.Token() + if err != nil { + return err + } + if delim, ok := t.(json.Delim); ok && delim == '[' { + for dec.More() { + var geminiChunk struct { + Candidates []struct { + Content struct { + Parts []struct { + Text string `json:"text,omitempty"` + Thought string `json:"thought,omitempty"` + } `json:"parts"` + } `json:"content"` + FinishReason string `json:"finishReason"` + } `json:"candidates"` + UsageMetadata struct { + PromptTokenCount uint32 `json:"promptTokenCount"` + CandidatesTokenCount uint32 `json:"candidatesTokenCount"` + TotalTokenCount uint32 `json:"totalTokenCount"` + } `json:"usageMetadata"` + } + + if err := dec.Decode(&geminiChunk); err != nil { + return err + } + + if len(geminiChunk.Candidates) > 0 { + content := "" + var reasoning *string + for _, p := range geminiChunk.Candidates[0].Content.Parts { + if p.Text != "" { + content += p.Text + } + if p.Thought != "" { + if reasoning == nil { + reasoning = new(string) + } + *reasoning += p.Thought + } + } + + finishReason := strings.ToLower(geminiChunk.Candidates[0].FinishReason) + if finishReason == "stop" { + finishReason = "stop" + } + + ch <- &models.ChatCompletionStreamResponse{ + ID: "gemini-stream", + Object: "chat.completion.chunk", + Created: 0, + Model: model, + Choices: []models.ChatStreamChoice{ + { + Index: 0, + Delta: models.ChatStreamDelta{ + Content: &content, + ReasoningContent: reasoning, + }, + FinishReason: &finishReason, + }, + }, + Usage: &models.Usage{ + PromptTokens: geminiChunk.UsageMetadata.PromptTokenCount, + CompletionTokens: geminiChunk.UsageMetadata.CandidatesTokenCount, + TotalTokens: geminiChunk.UsageMetadata.TotalTokenCount, + }, + } + } + } + } + + return nil +} diff --git a/internal/providers/openai.go b/internal/providers/openai.go new file mode 100644 index 00000000..a9842bb3 --- /dev/null +++ b/internal/providers/openai.go @@ -0,0 +1,113 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "llm-proxy/internal/config" + "llm-proxy/internal/models" + "github.com/go-resty/resty/v2" +) + +type OpenAIProvider struct { + client *resty.Client + config config.OpenAIConfig + apiKey string +} + +func NewOpenAIProvider(cfg config.OpenAIConfig, apiKey string) *OpenAIProvider { + return &OpenAIProvider{ + client: resty.New(), + config: cfg, + apiKey: apiKey, + } +} + +func (p *OpenAIProvider) Name() string { + return "openai" +} + +func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) { + messagesJSON, err := MessagesToOpenAIJSON(req.Messages) + if err != nil { + return nil, fmt.Errorf("failed to convert messages: %w", err) + } + + body := BuildOpenAIBody(req, messagesJSON, false) + + // Transition: Newer models require max_completion_tokens + if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") { + if maxTokens, ok := body["max_tokens"]; ok { + delete(body, "max_tokens") + body["max_completion_tokens"] = maxTokens + } + } + + resp, err := p.client.R(). + SetContext(ctx). + SetHeader("Authorization", "Bearer "+p.apiKey). + SetBody(body). + Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL)) + + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if !resp.IsSuccess() { + return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String()) + } + + var respJSON map[string]interface{} + if err := json.Unmarshal(resp.Body(), &respJSON); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return ParseOpenAIResponse(respJSON, req.Model) +} + +func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) { + messagesJSON, err := MessagesToOpenAIJSON(req.Messages) + if err != nil { + return nil, fmt.Errorf("failed to convert messages: %w", err) + } + + body := BuildOpenAIBody(req, messagesJSON, true) + + // Transition: Newer models require max_completion_tokens + if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") { + if maxTokens, ok := body["max_tokens"]; ok { + delete(body, "max_tokens") + body["max_completion_tokens"] = maxTokens + } + } + + resp, err := p.client.R(). + SetContext(ctx). + SetHeader("Authorization", "Bearer "+p.apiKey). + SetBody(body). + SetDoNotParseResponse(true). + Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL)) + + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + + if !resp.IsSuccess() { + return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String()) + } + + ch := make(chan *models.ChatCompletionStreamResponse) + + go func() { + defer close(ch) + err := StreamOpenAI(resp.RawBody(), ch) + if err != nil { + // In a real app, you might want to send an error chunk or log it + fmt.Printf("Stream error: %v\n", err) + } + }() + + return ch, nil +} diff --git a/internal/providers/provider.go b/internal/providers/provider.go new file mode 100644 index 00000000..36738f53 --- /dev/null +++ b/internal/providers/provider.go @@ -0,0 +1,13 @@ +package providers + +import ( + "context" + + "llm-proxy/internal/models" +) + +type Provider interface { + Name() string + ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) + ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) +} diff --git a/internal/server/dashboard.go b/internal/server/dashboard.go new file mode 100644 index 00000000..c43a9beb --- /dev/null +++ b/internal/server/dashboard.go @@ -0,0 +1,675 @@ +package server + +import ( + "fmt" + "net/http" + "strings" + "time" + + "llm-proxy/internal/db" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "golang.org/x/crypto/bcrypt" +) + +type ApiResponse struct { + Success bool `json:"success"` + Data interface{} `json:"data,omitempty"` + Error string `json:"error,omitempty"` +} + +func SuccessResponse(data interface{}) ApiResponse { + return ApiResponse{Success: true, Data: data} +} + +func ErrorResponse(err string) ApiResponse { + return ApiResponse{Success: false, Error: err} +} + +func (s *Server) adminAuthMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ") + if token == "" { + c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse("Not authenticated")) + return + } + + session, _, err := s.sessions.ValidateSession(token) + if err != nil { + c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse("Session expired or invalid")) + return + } + + if session.Role != "admin" { + c.AbortWithStatusJSON(http.StatusForbidden, ErrorResponse("Admin access required")) + return + } + + c.Set("session", session) + c.Next() + } +} + +type LoginRequest struct { + Username string `json:"username" binding:"required"` + Password string `json:"password" binding:"required"` +} + +func (s *Server) handleLogin(c *gin.Context) { + var req LoginRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request")) + return + } + + var user db.User + err := s.database.Get(&user, "SELECT username, password_hash, display_name, role, must_change_password FROM users WHERE username = ?", req.Username) + if err != nil { + c.JSON(http.StatusUnauthorized, ErrorResponse("Invalid username or password")) + return + } + + if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil { + c.JSON(http.StatusUnauthorized, ErrorResponse("Invalid username or password")) + return + } + + token, err := s.sessions.CreateSession(user.Username, user.Role) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse("Failed to create session")) + return + } + + displayName := user.Username + if user.DisplayName != nil { + displayName = *user.DisplayName + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{ + "token": token, + "must_change_password": user.MustChangePassword, + "user": gin.H{ + "username": user.Username, + "name": displayName, + "role": user.Role, + }, + })) +} + +func (s *Server) handleAuthStatus(c *gin.Context) { + token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ") + session, _, err := s.sessions.ValidateSession(token) + if err != nil { + c.JSON(http.StatusUnauthorized, ErrorResponse("Not authenticated")) + return + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{ + "authenticated": true, + "user": gin.H{ + "username": session.Username, + "role": session.Role, + }, + })) +} + +func (s *Server) handleLogout(c *gin.Context) { + token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ") + s.sessions.RevokeSession(token) + c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Logged out"})) +} + +type UsagePeriodFilter struct { + Period string `form:"period"` + From string `form:"from"` + To string `form:"to"` +} + +func (f *UsagePeriodFilter) ToSQL() (string, []interface{}) { + period := f.Period + if period == "" { + period = "all" + } + + if period == "custom" { + var clauses []string + var binds []interface{} + if f.From != "" { + clauses = append(clauses, "timestamp >= ?") + binds = append(binds, f.From) + } + if f.To != "" { + clauses = append(clauses, "timestamp <= ?") + binds = append(binds, f.To) + } + if len(clauses) > 0 { + return " AND " + strings.Join(clauses, " AND "), binds + } + return "", nil + } + + now := time.Now().UTC() + var cutoff time.Time + switch period { + case "today": + cutoff = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) + case "24h": + cutoff = now.Add(-24 * time.Hour) + case "7d": + cutoff = now.Add(-7 * 24 * time.Hour) + case "30d": + cutoff = now.Add(-30 * 24 * time.Hour) + default: + return "", nil + } + + return " AND timestamp >= ?", []interface{}{cutoff.Format(time.RFC3339)} +} + +func (s *Server) handleUsageSummary(c *gin.Context) { + var filter UsagePeriodFilter + if err := c.ShouldBindQuery(&filter); err != nil { + // ignore + } + + clause, binds := filter.ToSQL() + + query := fmt.Sprintf(` + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(total_tokens), 0) as total_tokens, + COALESCE(SUM(cost), 0.0) as total_cost, + COUNT(DISTINCT client_id) as active_clients + FROM llm_requests + WHERE 1=1 %s + `, clause) + + var stats struct { + TotalRequests int `db:"total_requests"` + TotalTokens int `db:"total_tokens"` + TotalCost float64 `db:"total_cost"` + ActiveClients int `db:"active_clients"` + } + + err := s.database.Get(&stats, query, binds...) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + c.JSON(http.StatusOK, SuccessResponse(stats)) +} + +func (s *Server) handleTimeSeries(c *gin.Context) { + var filter UsagePeriodFilter + if err := c.ShouldBindQuery(&filter); err != nil { + // ignore + } + + clause, binds := filter.ToSQL() + + if clause == "" { + cutoff := time.Now().UTC().Add(-30 * 24 * time.Hour) + clause = " AND timestamp >= ?" + binds = []interface{}{cutoff.Format(time.RFC3339)} + } + + query := fmt.Sprintf(` + SELECT + strftime('%%Y-%%m-%%d', timestamp) as bucket, + COUNT(*) as requests, + COALESCE(SUM(total_tokens), 0) as tokens, + COALESCE(SUM(cost), 0.0) as cost + FROM llm_requests + WHERE 1=1 %s + GROUP BY bucket + ORDER BY bucket + `, clause) + + var rows []struct { + Bucket string `db:"bucket"` + Requests int `db:"requests"` + Tokens int `db:"tokens"` + Cost float64 `db:"cost"` + } + + err := s.database.Select(&rows, query, binds...) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + series := make([]gin.H, len(rows)) + for i, r := range rows { + series[i] = gin.H{ + "time": r.Bucket, + "requests": r.Requests, + "tokens": r.Tokens, + "cost": r.Cost, + } + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{ + "series": series, + })) +} + +func (s *Server) handleAnalyticsBreakdown(c *gin.Context) { + var filter UsagePeriodFilter + if err := c.ShouldBindQuery(&filter); err != nil { + // ignore + } + + clause, binds := filter.ToSQL() + + var models []struct { + Label string `db:"label"` + Value int `db:"value"` + } + err := s.database.Select(&models, fmt.Sprintf("SELECT model as label, COUNT(*) as value FROM llm_requests WHERE 1=1 %s GROUP BY model ORDER BY value DESC", clause), binds...) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + var clients []struct { + Label string `db:"label"` + Value int `db:"value"` + } + err = s.database.Select(&clients, fmt.Sprintf("SELECT client_id as label, COUNT(*) as value FROM llm_requests WHERE 1=1 %s GROUP BY client_id ORDER BY value DESC", clause), binds...) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{ + "models": models, + "clients": clients, + })) +} + +func (s *Server) handleGetClients(c *gin.Context) { + var clients []db.Client + err := s.database.Select(&clients, "SELECT * FROM clients ORDER BY created_at DESC") + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + c.JSON(http.StatusOK, SuccessResponse(clients)) +} + +type CreateClientRequest struct { + Name string `json:"name" binding:"required"` + ClientID *string `json:"client_id"` +} + +func (s *Server) handleCreateClient(c *gin.Context) { + var req CreateClientRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request")) + return + } + + clientID := "" + if req.ClientID != nil { + clientID = *req.ClientID + } else { + clientID = "client-" + uuid.New().String()[:8] + } + + _, err := s.database.Exec("INSERT INTO clients (client_id, name, is_active) VALUES (?, ?, 1)", clientID, req.Name) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + token := "sk-" + uuid.New().String() + uuid.New().String() + token = token[:51] + + _, err = s.database.Exec("INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, 'default')", clientID, token) + if err != nil { + // Log error + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{ + "id": clientID, + "name": req.Name, + "status": "active", + "token": token, + "created_at": time.Now(), + })) +} + +func (s *Server) handleDeleteClient(c *gin.Context) { + id := c.Param("id") + if id == "default" { + c.JSON(http.StatusBadRequest, ErrorResponse("Cannot delete default client")) + return + } + + _, err := s.database.Exec("DELETE FROM clients WHERE client_id = ?", id) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Client deleted"})) +} + +func (s *Server) handleGetClientTokens(c *gin.Context) { + id := c.Param("id") + var tokens []db.ClientToken + err := s.database.Select(&tokens, "SELECT * FROM client_tokens WHERE client_id = ? ORDER BY created_at DESC", id) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + type MaskedToken struct { + ID int `json:"id"` + TokenMasked string `json:"token_masked"` + Name string `json:"name"` + IsActive bool `json:"is_active"` + CreatedAt time.Time `json:"created_at"` + LastUsedAt *time.Time `json:"last_used_at"` + } + + masked := make([]MaskedToken, len(tokens)) + for i, t := range tokens { + maskedToken := "••••" + if len(t.Token) > 8 { + maskedToken = t.Token[:3] + "••••" + t.Token[len(t.Token)-8:] + } + masked[i] = MaskedToken{ + ID: t.ID, + TokenMasked: maskedToken, + Name: t.Name, + IsActive: t.IsActive, + CreatedAt: t.CreatedAt, + LastUsedAt: t.LastUsedAt, + } + } + + c.JSON(http.StatusOK, SuccessResponse(masked)) +} + +type CreateTokenRequest struct { + Name string `json:"name"` +} + +func (s *Server) handleCreateClientToken(c *gin.Context) { + clientID := c.Param("id") + var req CreateTokenRequest + if err := c.ShouldBindJSON(&req); err != nil { + // optional name + } + + name := "default" + if req.Name != "" { + name = req.Name + } + + token := "sk-" + uuid.New().String() + uuid.New().String() + token = token[:51] + + _, err := s.database.Exec("INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, ?)", clientID, token, name) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{ + "token": token, + "name": name, + "created_at": time.Now(), + })) +} + +func (s *Server) handleDeleteClientToken(c *gin.Context) { + tokenID := c.Param("token_id") + + _, err := s.database.Exec("DELETE FROM client_tokens WHERE id = ?", tokenID) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Token revoked"})) +} + +func (s *Server) handleGetProviders(c *gin.Context) { + var dbConfigs []db.ProviderConfig + err := s.database.Select(&dbConfigs, "SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs") + if err != nil { + // Log error + } + + dbMap := make(map[string]db.ProviderConfig) + for _, cfg := range dbConfigs { + dbMap[cfg.ID] = cfg + } + + providerIDs := []string{"openai", "gemini", "deepseek", "grok", "ollama"} + var result []gin.H + + for _, id := range providerIDs { + var name string + var enabled bool + var baseURL string + + switch id { + case "openai": + name = "OpenAI" + enabled = s.cfg.Providers.OpenAI.Enabled + baseURL = s.cfg.Providers.OpenAI.BaseURL + case "gemini": + name = "Google Gemini" + enabled = s.cfg.Providers.Gemini.Enabled + baseURL = s.cfg.Providers.Gemini.BaseURL + case "deepseek": + name = "DeepSeek" + enabled = s.cfg.Providers.DeepSeek.Enabled + baseURL = s.cfg.Providers.DeepSeek.BaseURL + case "grok": + name = "xAI Grok" + enabled = s.cfg.Providers.Grok.Enabled + baseURL = s.cfg.Providers.Grok.BaseURL + case "ollama": + name = "Ollama" + enabled = s.cfg.Providers.Ollama.Enabled + baseURL = s.cfg.Providers.Ollama.BaseURL + } + + var balance float64 + var threshold float64 = 5.0 + var billingMode string + + if dbCfg, ok := dbMap[id]; ok { + enabled = dbCfg.Enabled + if dbCfg.BaseURL != nil { + baseURL = *dbCfg.BaseURL + } + balance = dbCfg.CreditBalance + threshold = dbCfg.LowCreditThreshold + if dbCfg.BillingMode != nil { + billingMode = *dbCfg.BillingMode + } + } + + status := "disabled" + if enabled { + if _, ok := s.providers[id]; ok { + status = "online" + } else { + status = "error" + } + } + + result = append(result, gin.H{ + "id": id, + "name": name, + "enabled": enabled, + "status": status, + "base_url": baseURL, + "credit_balance": balance, + "low_credit_threshold": threshold, + "billing_mode": billingMode, + }) + } + + c.JSON(http.StatusOK, SuccessResponse(result)) +} + +type UpdateProviderRequest struct { + Enabled bool `json:"enabled"` + BaseURL *string `json:"base_url"` + APIKey *string `json:"api_key"` + CreditBalance *float64 `json:"credit_balance"` + LowCreditThreshold *float64 `json:"low_credit_threshold"` + BillingMode *string `json:"billing_mode"` +} + +func (s *Server) handleUpdateProvider(c *gin.Context) { + name := c.Param("name") + var req UpdateProviderRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request")) + return + } + + _, err := s.database.Exec(` + INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + enabled = excluded.enabled, + base_url = COALESCE(excluded.base_url, provider_configs.base_url), + api_key = COALESCE(excluded.api_key, provider_configs.api_key), + credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance), + low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold), + billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode), + updated_at = CURRENT_TIMESTAMP + `, name, strings.ToUpper(name), req.Enabled, req.BaseURL, req.APIKey, req.CreditBalance, req.LowCreditThreshold, req.BillingMode) + + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Provider updated"})) +} + +func (s *Server) handleGetModels(c *gin.Context) { + var models []db.ModelConfig + err := s.database.Select(&models, "SELECT * FROM model_configs") + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + c.JSON(http.StatusOK, SuccessResponse(models)) +} + +func (s *Server) handleGetUsers(c *gin.Context) { + var users []db.User + err := s.database.Select(&users, "SELECT id, username, display_name, role, must_change_password, created_at FROM users") + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + c.JSON(http.StatusOK, SuccessResponse(users)) +} + +type CreateUserRequest struct { + Username string `json:"username" binding:"required"` + Password string `json:"password" binding:"required"` + DisplayName *string `json:"display_name"` + Role *string `json:"role"` +} + +func (s *Server) handleCreateUser(c *gin.Context) { + var req CreateUserRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request")) + return + } + + hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), 12) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse("Failed to hash password")) + return + } + + role := "viewer" + if req.Role != nil { + role = *req.Role + } + + _, err = s.database.Exec("INSERT INTO users (username, password_hash, display_name, role, must_change_password) VALUES (?, ?, ?, ?, 1)", + req.Username, string(hash), req.DisplayName, role) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User created"})) +} + +type UpdateUserRequest struct { + DisplayName *string `json:"display_name"` + Role *string `json:"role"` + Password *string `json:"password"` + MustChangePassword *bool `json:"must_change_password"` +} + +func (s *Server) handleUpdateUser(c *gin.Context) { + id := c.Param("id") + var req UpdateUserRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request")) + return + } + + if req.DisplayName != nil { + s.database.Exec("UPDATE users SET display_name = ? WHERE id = ?", req.DisplayName, id) + } + if req.Role != nil { + s.database.Exec("UPDATE users SET role = ? WHERE id = ?", req.Role, id) + } + if req.MustChangePassword != nil { + s.database.Exec("UPDATE users SET must_change_password = ? WHERE id = ?", req.MustChangePassword, id) + } + if req.Password != nil { + hash, _ := bcrypt.GenerateFromPassword([]byte(*req.Password), 12) + s.database.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hash), id) + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User updated"})) +} + +func (s *Server) handleDeleteUser(c *gin.Context) { + id := c.Param("id") + + session, _ := c.Get("session") + if sess, ok := session.(*Session); ok { + var username string + s.database.Get(&username, "SELECT username FROM users WHERE id = ?", id) + if username == sess.Username { + c.JSON(http.StatusBadRequest, ErrorResponse("Cannot delete your own account")) + return + } + } + + _, err := s.database.Exec("DELETE FROM users WHERE id = ?", id) + if err != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error())) + return + } + + c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User deleted"})) +} + +func (s *Server) handleSystemHealth(c *gin.Context) { + c.JSON(http.StatusOK, SuccessResponse(gin.H{ + "status": "ok", + "db": "connected", + })) +} diff --git a/internal/server/logging.go b/internal/server/logging.go new file mode 100644 index 00000000..53d502aa --- /dev/null +++ b/internal/server/logging.go @@ -0,0 +1,113 @@ +package server + +import ( + "log" + "time" + + "llm-proxy/internal/db" +) + +type RequestLog struct { + Timestamp time.Time `json:"timestamp"` + ClientID string `json:"client_id"` + Provider string `json:"provider"` + Model string `json:"model"` + PromptTokens uint32 `json:"prompt_tokens"` + CompletionTokens uint32 `json:"completion_tokens"` + ReasoningTokens uint32 `json:"reasoning_tokens"` + TotalTokens uint32 `json:"total_tokens"` + CacheReadTokens uint32 `json:"cache_read_tokens"` + CacheWriteTokens uint32 `json:"cache_write_tokens"` + Cost float64 `json:"cost"` + HasImages bool `json:"has_images"` + Status string `json:"status"` + ErrorMessage string `json:"error_message,omitempty"` + DurationMS int64 `json:"duration_ms"` +} + +type RequestLogger struct { + database *db.DB + hub *Hub + logChan chan RequestLog +} + +func NewRequestLogger(database *db.DB, hub *Hub) *RequestLogger { + return &RequestLogger{ + database: database, + hub: hub, + logChan: make(chan RequestLog, 100), + } +} + +func (l *RequestLogger) Start() { + go func() { + for entry := range l.logChan { + l.processLog(entry) + } + }() +} + +func (l *RequestLogger) LogRequest(entry RequestLog) { + select { + case l.logChan <- entry: + default: + log.Println("Request log channel full, dropping log entry") + } +} + +func (l *RequestLogger) processLog(entry RequestLog) { + // Broadcast to dashboard + l.hub.broadcast <- map[string]interface{}{ + "type": "request", + "channel": "requests", + "payload": entry, + } + + // Insert into DB + tx, err := l.database.Begin() + if err != nil { + log.Printf("Failed to begin transaction for logging: %v", err) + return + } + defer tx.Rollback() + + // Ensure client exists + _, _ = tx.Exec("INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')", + entry.ClientID, entry.ClientID) + + // Insert log + _, err = tx.Exec(` + INSERT INTO llm_requests + (timestamp, client_id, provider, model, prompt_tokens, completion_tokens, reasoning_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status, error_message, duration_ms) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + `, entry.Timestamp, entry.ClientID, entry.Provider, entry.Model, + entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.TotalTokens, + entry.CacheReadTokens, entry.CacheWriteTokens, entry.Cost, entry.HasImages, + entry.Status, entry.ErrorMessage, entry.DurationMS) + + if err != nil { + log.Printf("Failed to insert request log: %v", err) + return + } + + // Update client stats + _, _ = tx.Exec(` + UPDATE clients SET + total_requests = total_requests + 1, + total_tokens = total_tokens + ?, + total_cost = total_cost + ?, + updated_at = CURRENT_TIMESTAMP + WHERE client_id = ? + `, entry.TotalTokens, entry.Cost, entry.ClientID) + + // Update provider balance + if entry.Cost > 0 { + _, _ = tx.Exec("UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')", + entry.Cost, entry.Provider) + } + + err = tx.Commit() + if err != nil { + log.Printf("Failed to commit logging transaction: %v", err) + } +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 00000000..8fa75e4e --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,326 @@ +package server + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "llm-proxy/internal/config" + "llm-proxy/internal/db" + "llm-proxy/internal/middleware" + "llm-proxy/internal/models" + "llm-proxy/internal/providers" + "llm-proxy/internal/utils" + + "github.com/gin-gonic/gin" +) + +type Server struct { + router *gin.Engine + cfg *config.Config + database *db.DB + providers map[string]providers.Provider + sessions *SessionManager + hub *Hub + logger *RequestLogger +} + +func NewServer(cfg *config.Config, database *db.DB) *Server { + router := gin.Default() + hub := NewHub() + + s := &Server{ + router: router, + cfg: cfg, + database: database, + providers: make(map[string]providers.Provider), + sessions: NewSessionManager(cfg.KeyBytes, 24*time.Hour), + hub: hub, + logger: NewRequestLogger(database, hub), + } + + // Initialize providers + if cfg.Providers.OpenAI.Enabled { + apiKey, _ := cfg.GetAPIKey("openai") + s.providers["openai"] = providers.NewOpenAIProvider(cfg.Providers.OpenAI, apiKey) + } + if cfg.Providers.Gemini.Enabled { + apiKey, _ := cfg.GetAPIKey("gemini") + s.providers["gemini"] = providers.NewGeminiProvider(cfg.Providers.Gemini, apiKey) + } + if cfg.Providers.DeepSeek.Enabled { + apiKey, _ := cfg.GetAPIKey("deepseek") + s.providers["deepseek"] = providers.NewDeepSeekProvider(cfg.Providers.DeepSeek, apiKey) + } + if cfg.Providers.Grok.Enabled { + apiKey, _ := cfg.GetAPIKey("grok") + s.providers["grok"] = providers.NewGrokProvider(cfg.Providers.Grok, apiKey) + } + + s.setupRoutes() + return s +} + +func (s *Server) setupRoutes() { + s.router.Use(middleware.AuthMiddleware(s.database)) + + // Static files + s.router.Static("/static", "./static") + s.router.StaticFile("/", "./static/index.html") + s.router.StaticFile("/favicon.ico", "./static/favicon.ico") + + // WebSocket + s.router.GET("/ws", s.handleWebSocket) + + v1 := s.router.Group("/v1") + { + v1.POST("/chat/completions", s.handleChatCompletions) + } + + // Dashboard API Group + api := s.router.Group("/api") + { + api.POST("/auth/login", s.handleLogin) + api.GET("/auth/status", s.handleAuthStatus) + api.POST("/auth/logout", s.handleLogout) + + // Protected dashboard routes (need admin session) + admin := api.Group("/") + admin.Use(s.adminAuthMiddleware()) + { + admin.GET("/usage/summary", s.handleUsageSummary) + admin.GET("/usage/time-series", s.handleTimeSeries) + admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown) + + admin.GET("/clients", s.handleGetClients) + admin.POST("/clients", s.handleCreateClient) + admin.DELETE("/clients/:id", s.handleDeleteClient) + + admin.GET("/clients/:id/tokens", s.handleGetClientTokens) + admin.POST("/clients/:id/tokens", s.handleCreateClientToken) + admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken) + + admin.GET("/providers", s.handleGetProviders) + admin.PUT("/providers/:name", s.handleUpdateProvider) + admin.GET("/models", s.handleGetModels) + + admin.GET("/users", s.handleGetUsers) + admin.POST("/users", s.handleCreateUser) + admin.PUT("/users/:id", s.handleUpdateUser) + admin.DELETE("/users/:id", s.handleDeleteUser) + + admin.GET("/system/health", s.handleSystemHealth) + } + } + + s.router.GET("/health", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + }) +} + +func (s *Server) handleChatCompletions(c *gin.Context) { + startTime := time.Now() + var req models.ChatCompletionRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Select provider based on model name + providerName := "openai" // default + if strings.Contains(req.Model, "gemini") { + providerName = "gemini" + } else if strings.Contains(req.Model, "deepseek") { + providerName = "deepseek" + } else if strings.Contains(req.Model, "grok") { + providerName = "grok" + } + + provider, ok := s.providers[providerName] + if !ok { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)}) + return + } + + // Convert ChatCompletionRequest to UnifiedRequest + unifiedReq := &models.UnifiedRequest{ + Model: req.Model, + Messages: []models.UnifiedMessage{}, + Temperature: req.Temperature, + TopP: req.TopP, + TopK: req.TopK, + N: req.N, + MaxTokens: req.MaxTokens, + PresencePenalty: req.PresencePenalty, + FrequencyPenalty: req.FrequencyPenalty, + Stream: req.Stream != nil && *req.Stream, + Tools: req.Tools, + ToolChoice: req.ToolChoice, + } + + // Handle Stop sequences + if req.Stop != nil { + var stop []string + if err := json.Unmarshal(req.Stop, &stop); err == nil { + unifiedReq.Stop = stop + } else { + var singleStop string + if err := json.Unmarshal(req.Stop, &singleStop); err == nil { + unifiedReq.Stop = []string{singleStop} + } + } + } + + // Convert messages + for _, msg := range req.Messages { + unifiedMsg := models.UnifiedMessage{ + Role: msg.Role, + Content: []models.UnifiedContentPart{}, + ReasoningContent: msg.ReasoningContent, + ToolCalls: msg.ToolCalls, + Name: msg.Name, + ToolCallID: msg.ToolCallID, + } + + // Handle multimodal content + if strContent, ok := msg.Content.(string); ok { + unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{ + Type: "text", + Text: strContent, + }) + } else if parts, ok := msg.Content.([]interface{}); ok { + for _, part := range parts { + if partMap, ok := part.(map[string]interface{}); ok { + partType, _ := partMap["type"].(string) + if partType == "text" { + text, _ := partMap["text"].(string) + unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{ + Type: "text", + Text: text, + }) + } else if partType == "image_url" { + if imgURLMap, ok := partMap["image_url"].(map[string]interface{}); ok { + url, _ := imgURLMap["url"].(string) + imageInput := &models.ImageInput{} + if strings.HasPrefix(url, "data:") { + mime, data, err := utils.ParseDataURL(url) + if err == nil { + imageInput.Base64 = data + imageInput.MimeType = mime + } + } else { + imageInput.URL = url + } + unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{ + Type: "image", + Image: imageInput, + }) + unifiedReq.HasImages = true + } + } + } + } + } + + unifiedReq.Messages = append(unifiedReq.Messages, unifiedMsg) + } + + clientID := "default" + if auth, ok := c.Get("auth"); ok { + if authInfo, ok := auth.(models.AuthInfo); ok { + unifiedReq.ClientID = authInfo.ClientID + clientID = authInfo.ClientID + } + } else { + unifiedReq.ClientID = clientID + } + + if unifiedReq.Stream { + ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq) + if err != nil { + s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + + var lastUsage *models.Usage + c.Stream(func(w io.Writer) bool { + chunk, ok := <-ch + if !ok { + fmt.Fprintf(w, "data: [DONE]\n\n") + s.logRequest(startTime, clientID, providerName, req.Model, lastUsage, nil, unifiedReq.HasImages) + return false + } + if chunk.Usage != nil { + lastUsage = chunk.Usage + } + data, err := json.Marshal(chunk) + if err != nil { + return false + } + fmt.Fprintf(w, "data: %s\n\n", data) + return true + }) + return + } + + resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq) + if err != nil { + s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage, nil, unifiedReq.HasImages) + c.JSON(http.StatusOK, resp) +} + +func (s *Server) logRequest(start time.Time, clientID, provider, model string, usage *models.Usage, err error, hasImages bool) { + entry := RequestLog{ + Timestamp: start, + ClientID: clientID, + Provider: provider, + Model: model, + Status: "success", + DurationMS: time.Since(start).Milliseconds(), + HasImages: hasImages, + } + + if err != nil { + entry.Status = "error" + entry.ErrorMessage = err.Error() + } + + if usage != nil { + entry.PromptTokens = usage.PromptTokens + entry.CompletionTokens = usage.CompletionTokens + entry.TotalTokens = usage.TotalTokens + if usage.ReasoningTokens != nil { + entry.ReasoningTokens = *usage.ReasoningTokens + } + if usage.CacheReadTokens != nil { + entry.CacheReadTokens = *usage.CacheReadTokens + } + if usage.CacheWriteTokens != nil { + entry.CacheWriteTokens = *usage.CacheWriteTokens + } + // TODO: Calculate cost properly based on pricing + entry.Cost = 0.0 + } + + s.logger.LogRequest(entry) +} + +func (s *Server) Run() error { + go s.hub.Run() + s.logger.Start() + addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port) + return s.router.Run(addr) +} diff --git a/internal/server/sessions.go b/internal/server/sessions.go new file mode 100644 index 00000000..0d4065da --- /dev/null +++ b/internal/server/sessions.go @@ -0,0 +1,151 @@ +package server + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "strings" + "sync" + "time" + + "github.com/google/uuid" +) + +type Session struct { + Username string `json:"username"` + Role string `json:"role"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` + SessionID string `json:"session_id"` +} + +type SessionManager struct { + sessions map[string]Session + mu sync.RWMutex + secret []byte + ttl time.Duration +} + +type sessionPayload struct { + SessionID string `json:"session_id"` + Username string `json:"username"` + Role string `json:"role"` + Exp int64 `json:"exp"` +} + +func NewSessionManager(secret []byte, ttl time.Duration) *SessionManager { + return &SessionManager{ + sessions: make(map[string]Session), + secret: secret, + ttl: ttl, + } +} + +func (m *SessionManager) CreateSession(username, role string) (string, error) { + sessionID := uuid.New().String() + now := time.Now() + expiresAt := now.Add(m.ttl) + + m.mu.Lock() + m.sessions[sessionID] = Session{ + Username: username, + Role: role, + CreatedAt: now, + ExpiresAt: expiresAt, + SessionID: sessionID, + } + m.mu.Unlock() + + return m.createSignedToken(sessionID, username, role, expiresAt.Unix()) +} + +func (m *SessionManager) createSignedToken(sessionID, username, role string, exp int64) (string, error) { + payload := sessionPayload{ + SessionID: sessionID, + Username: username, + Role: role, + Exp: exp, + } + + payloadJSON, err := json.Marshal(payload) + if err != nil { + return "", err + } + + payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON) + + h := hmac.New(sha256.New, m.secret) + h.Write(payloadJSON) + signature := h.Sum(nil) + signatureB64 := base64.RawURLEncoding.EncodeToString(signature) + + return fmt.Sprintf("%s.%s", payloadB64, signatureB64), nil +} + +func (m *SessionManager) ValidateSession(token string) (*Session, string, error) { + parts := strings.Split(token, ".") + if len(parts) != 2 { + return nil, "", fmt.Errorf("invalid token format") + } + + payloadB64 := parts[0] + signatureB64 := parts[1] + + payloadJSON, err := base64.RawURLEncoding.DecodeString(payloadB64) + if err != nil { + return nil, "", err + } + + signature, err := base64.RawURLEncoding.DecodeString(signatureB64) + if err != nil { + return nil, "", err + } + + h := hmac.New(sha256.New, m.secret) + h.Write(payloadJSON) + if !hmac.Equal(signature, h.Sum(nil)) { + return nil, "", fmt.Errorf("invalid signature") + } + + var payload sessionPayload + if err := json.Unmarshal(payloadJSON, &payload); err != nil { + return nil, "", err + } + + if time.Now().Unix() > payload.Exp { + return nil, "", fmt.Errorf("token expired") + } + + m.mu.RLock() + session, ok := m.sessions[payload.SessionID] + m.mu.RUnlock() + + if !ok { + return nil, "", fmt.Errorf("session not found") + } + + return &session, "", nil +} + +func (m *SessionManager) RevokeSession(token string) { + parts := strings.Split(token, ".") + if len(parts) != 2 { + return + } + + payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return + } + + var payload sessionPayload + if err := json.Unmarshal(payloadJSON, &payload); err != nil { + return + } + + m.mu.Lock() + delete(m.sessions, payload.SessionID) + m.mu.Unlock() +} diff --git a/internal/server/websocket.go b/internal/server/websocket.go new file mode 100644 index 00000000..25be7dc7 --- /dev/null +++ b/internal/server/websocket.go @@ -0,0 +1,98 @@ +package server + +import ( + "log" + "net/http" + "sync" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true // In production, refine this + }, +} + +type Hub struct { + clients map[*websocket.Conn]bool + broadcast chan interface{} + register chan *websocket.Conn + unregister chan *websocket.Conn + mu sync.Mutex +} + +func NewHub() *Hub { + return &Hub{ + clients: make(map[*websocket.Conn]bool), + broadcast: make(chan interface{}), + register: make(chan *websocket.Conn), + unregister: make(chan *websocket.Conn), + } +} + +func (h *Hub) Run() { + for { + select { + case client := <-h.register: + h.mu.Lock() + h.clients[client] = true + h.mu.Unlock() + log.Println("WebSocket client registered") + case client := <-h.unregister: + h.mu.Lock() + if _, ok := h.clients[client]; ok { + delete(h.clients, client) + client.Close() + } + h.mu.Unlock() + log.Println("WebSocket client unregistered") + case message := <-h.broadcast: + h.mu.Lock() + for client := range h.clients { + err := client.WriteJSON(message) + if err != nil { + log.Printf("WebSocket error: %v", err) + client.Close() + delete(h.clients, client) + } + } + h.mu.Unlock() + } + } +} + +func (s *Server) handleWebSocket(c *gin.Context) { + conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + log.Printf("Failed to set websocket upgrade: %v", err) + return + } + + s.hub.register <- conn + + defer func() { + s.hub.unregister <- conn + }() + + // Initial message + conn.WriteJSON(gin.H{ + "type": "connected", + "message": "Connected to LLM Proxy Dashboard", + }) + + for { + var msg map[string]interface{} + err := conn.ReadJSON(&msg) + if err != nil { + break + } + + if msg["type"] == "ping" { + conn.WriteJSON(gin.H{"type": "pong", "payload": gin.H{}}) + } + } +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go new file mode 100644 index 00000000..2482d940 --- /dev/null +++ b/internal/utils/utils.go @@ -0,0 +1,19 @@ +package utils + +import ( + "fmt" + "strings" +) + +func ParseDataURL(dataURL string) (string, string, error) { + if !strings.HasPrefix(dataURL, "data:") { + return "", "", fmt.Errorf("not a data URL") + } + + parts := strings.Split(dataURL[5:], ";base64,") + if len(parts) != 2 { + return "", "", fmt.Errorf("invalid data URL format") + } + + return parts[0], parts[1], nil +} diff --git a/rustfmt.toml b/rustfmt.toml deleted file mode 100644 index 53f860ab..00000000 --- a/rustfmt.toml +++ /dev/null @@ -1,2 +0,0 @@ -max_width = 120 -use_field_init_shorthand = true diff --git a/src/auth/mod.rs b/src/auth/mod.rs deleted file mode 100644 index e370cec1..00000000 --- a/src/auth/mod.rs +++ /dev/null @@ -1,45 +0,0 @@ -use axum::{extract::FromRequestParts, http::request::Parts}; - -use crate::errors::AppError; - -#[derive(Debug, Clone)] -pub struct AuthInfo { - pub token: String, - pub client_id: String, -} - -pub struct AuthenticatedClient { - pub info: AuthInfo, -} - -impl FromRequestParts for AuthenticatedClient -where - S: Send + Sync, -{ - type Rejection = AppError; - - async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { - // Retrieve AuthInfo from request extensions, where it was placed by rate_limit_middleware - let info = parts - .extensions - .get::() - .cloned() - .ok_or_else(|| AppError::AuthError("Authentication info not found in request".to_string()))?; - - Ok(AuthenticatedClient { info }) - } -} - -impl std::ops::Deref for AuthenticatedClient { - type Target = AuthInfo; - - fn deref(&self) -> &Self::Target { - &self.info - } -} - -pub fn validate_token(token: &str, valid_tokens: &[String]) -> bool { - // Simple validation against list of tokens - // In production, use proper token validation (JWT, database lookup, etc.) - valid_tokens.contains(&token.to_string()) -} diff --git a/src/client/mod.rs b/src/client/mod.rs deleted file mode 100644 index 51b51f2b..00000000 --- a/src/client/mod.rs +++ /dev/null @@ -1,304 +0,0 @@ -//! Client management for LLM proxy -//! -//! This module handles: -//! 1. Client registration and management -//! 2. Client usage tracking -//! 3. Client rate limit configuration - -use anyhow::Result; -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; -use sqlx::{Row, SqlitePool}; -use tracing::{info, warn}; - -/// Client information -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Client { - pub id: i64, - pub client_id: String, - pub name: String, - pub description: String, - pub created_at: DateTime, - pub updated_at: DateTime, - pub is_active: bool, - pub rate_limit_per_minute: i64, - pub total_requests: i64, - pub total_tokens: i64, - pub total_cost: f64, -} - -/// Client creation request -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CreateClientRequest { - pub client_id: String, - pub name: String, - pub description: String, - pub rate_limit_per_minute: Option, -} - -/// Client update request -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct UpdateClientRequest { - pub name: Option, - pub description: Option, - pub is_active: Option, - pub rate_limit_per_minute: Option, -} - -/// Client manager for database operations -pub struct ClientManager { - db_pool: SqlitePool, -} - -impl ClientManager { - pub fn new(db_pool: SqlitePool) -> Self { - Self { db_pool } - } - - /// Create a new client - pub async fn create_client(&self, request: CreateClientRequest) -> Result { - let rate_limit = request.rate_limit_per_minute.unwrap_or(60); - - // First insert the client - sqlx::query( - r#" - INSERT INTO clients (client_id, name, description, rate_limit_per_minute) - VALUES (?, ?, ?, ?) - "#, - ) - .bind(&request.client_id) - .bind(&request.name) - .bind(&request.description) - .bind(rate_limit) - .execute(&self.db_pool) - .await?; - - // Then fetch the created client - let client = self - .get_client(&request.client_id) - .await? - .ok_or_else(|| anyhow::anyhow!("Failed to retrieve created client"))?; - - info!("Created client: {} ({})", client.name, client.client_id); - Ok(client) - } - - /// Get a client by ID - pub async fn get_client(&self, client_id: &str) -> Result> { - let row = sqlx::query( - r#" - SELECT - id, client_id, name, description, - created_at, updated_at, is_active, - rate_limit_per_minute, total_requests, total_tokens, total_cost - FROM clients - WHERE client_id = ? - "#, - ) - .bind(client_id) - .fetch_optional(&self.db_pool) - .await?; - - if let Some(row) = row { - let client = Client { - id: row.get("id"), - client_id: row.get("client_id"), - name: row.get("name"), - description: row.get("description"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), - is_active: row.get("is_active"), - rate_limit_per_minute: row.get("rate_limit_per_minute"), - total_requests: row.get("total_requests"), - total_tokens: row.get("total_tokens"), - total_cost: row.get("total_cost"), - }; - Ok(Some(client)) - } else { - Ok(None) - } - } - - /// Update a client - pub async fn update_client(&self, client_id: &str, request: UpdateClientRequest) -> Result> { - // First, get the current client to check if it exists - let current_client = self.get_client(client_id).await?; - if current_client.is_none() { - return Ok(None); - } - - // Build update query dynamically based on provided fields - let mut query_builder = sqlx::QueryBuilder::new("UPDATE clients SET "); - let mut has_updates = false; - - if let Some(name) = &request.name { - query_builder.push("name = "); - query_builder.push_bind(name); - has_updates = true; - } - - if let Some(description) = &request.description { - if has_updates { - query_builder.push(", "); - } - query_builder.push("description = "); - query_builder.push_bind(description); - has_updates = true; - } - - if let Some(is_active) = request.is_active { - if has_updates { - query_builder.push(", "); - } - query_builder.push("is_active = "); - query_builder.push_bind(is_active); - has_updates = true; - } - - if let Some(rate_limit) = request.rate_limit_per_minute { - if has_updates { - query_builder.push(", "); - } - query_builder.push("rate_limit_per_minute = "); - query_builder.push_bind(rate_limit); - has_updates = true; - } - - // Always update the updated_at timestamp - if has_updates { - query_builder.push(", "); - } - query_builder.push("updated_at = CURRENT_TIMESTAMP"); - - if !has_updates { - // No updates to make - return self.get_client(client_id).await; - } - - query_builder.push(" WHERE client_id = "); - query_builder.push_bind(client_id); - - let query = query_builder.build(); - query.execute(&self.db_pool).await?; - - // Fetch the updated client - let updated_client = self.get_client(client_id).await?; - - if updated_client.is_some() { - info!("Updated client: {}", client_id); - } - - Ok(updated_client) - } - - /// List all clients - pub async fn list_clients(&self, limit: Option, offset: Option) -> Result> { - let limit = limit.unwrap_or(100); - let offset = offset.unwrap_or(0); - - let rows = sqlx::query( - r#" - SELECT - id, client_id, name, description, - created_at, updated_at, is_active, - rate_limit_per_minute, total_requests, total_tokens, total_cost - FROM clients - ORDER BY created_at DESC - LIMIT ? OFFSET ? - "#, - ) - .bind(limit) - .bind(offset) - .fetch_all(&self.db_pool) - .await?; - - let mut clients = Vec::new(); - for row in rows { - let client = Client { - id: row.get("id"), - client_id: row.get("client_id"), - name: row.get("name"), - description: row.get("description"), - created_at: row.get("created_at"), - updated_at: row.get("updated_at"), - is_active: row.get("is_active"), - rate_limit_per_minute: row.get("rate_limit_per_minute"), - total_requests: row.get("total_requests"), - total_tokens: row.get("total_tokens"), - total_cost: row.get("total_cost"), - }; - clients.push(client); - } - - Ok(clients) - } - - /// Delete a client - pub async fn delete_client(&self, client_id: &str) -> Result { - let result = sqlx::query("DELETE FROM clients WHERE client_id = ?") - .bind(client_id) - .execute(&self.db_pool) - .await?; - - let deleted = result.rows_affected() > 0; - - if deleted { - info!("Deleted client: {}", client_id); - } else { - warn!("Client not found for deletion: {}", client_id); - } - - Ok(deleted) - } - - /// Update client usage statistics after a request - pub async fn update_client_usage(&self, client_id: &str, tokens: i64, cost: f64) -> Result<()> { - sqlx::query( - r#" - UPDATE clients - SET - total_requests = total_requests + 1, - total_tokens = total_tokens + ?, - total_cost = total_cost + ?, - updated_at = CURRENT_TIMESTAMP - WHERE client_id = ? - "#, - ) - .bind(tokens) - .bind(cost) - .bind(client_id) - .execute(&self.db_pool) - .await?; - - Ok(()) - } - - /// Get client usage statistics - pub async fn get_client_usage(&self, client_id: &str) -> Result> { - let row = sqlx::query( - r#" - SELECT total_requests, total_tokens, total_cost - FROM clients - WHERE client_id = ? - "#, - ) - .bind(client_id) - .fetch_optional(&self.db_pool) - .await?; - - if let Some(row) = row { - let total_requests: i64 = row.get("total_requests"); - let total_tokens: i64 = row.get("total_tokens"); - let total_cost: f64 = row.get("total_cost"); - Ok(Some((total_requests, total_tokens, total_cost))) - } else { - Ok(None) - } - } - - /// Check if a client exists and is active - pub async fn validate_client(&self, client_id: &str) -> Result { - let client = self.get_client(client_id).await?; - Ok(client.map(|c| c.is_active).unwrap_or(false)) - } -} diff --git a/src/config/mod.rs b/src/config/mod.rs deleted file mode 100644 index aa0a651c..00000000 --- a/src/config/mod.rs +++ /dev/null @@ -1,260 +0,0 @@ -use anyhow::Result; -use base64::{Engine as _}; -use config::{Config, File, FileFormat}; -use hex; -use serde::{Deserialize, Serialize}; -use std::path::PathBuf; -use std::sync::Arc; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ServerConfig { - pub port: u16, - pub host: String, - #[serde(deserialize_with = "deserialize_vec_or_string")] - pub auth_tokens: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DatabaseConfig { - pub path: PathBuf, - pub max_connections: u32, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ProviderConfig { - pub openai: OpenAIConfig, - pub gemini: GeminiConfig, - pub deepseek: DeepSeekConfig, - pub grok: GrokConfig, - pub ollama: OllamaConfig, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct OpenAIConfig { - pub api_key_env: String, - pub base_url: String, - pub default_model: String, - pub enabled: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GeminiConfig { - pub api_key_env: String, - pub base_url: String, - pub default_model: String, - pub enabled: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DeepSeekConfig { - pub api_key_env: String, - pub base_url: String, - pub default_model: String, - pub enabled: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GrokConfig { - pub api_key_env: String, - pub base_url: String, - pub default_model: String, - pub enabled: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct OllamaConfig { - pub base_url: String, - pub enabled: bool, - #[serde(deserialize_with = "deserialize_vec_or_string")] - pub models: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ModelMappingConfig { - pub patterns: Vec<(String, String)>, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PricingConfig { - pub openai: Vec, - pub gemini: Vec, - pub deepseek: Vec, - pub grok: Vec, - pub ollama: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ModelPricing { - pub model: String, - pub prompt_tokens_per_million: f64, - pub completion_tokens_per_million: f64, -} - -#[derive(Debug, Clone)] -pub struct AppConfig { - pub server: ServerConfig, - pub database: DatabaseConfig, - pub providers: ProviderConfig, - pub model_mapping: ModelMappingConfig, - pub pricing: PricingConfig, - pub config_path: Option, - pub encryption_key: String, -} - -impl AppConfig { - pub async fn load() -> Result> { - Self::load_from_path(None).await - } - - /// Load configuration from a specific path (for testing) - pub async fn load_from_path(config_path: Option) -> Result> { - // Load configuration from multiple sources - let mut config_builder = Config::builder(); - - // Default configuration - config_builder = config_builder - .set_default("server.port", 8080)? - .set_default("server.host", "0.0.0.0")? - .set_default("server.auth_tokens", Vec::::new())? - .set_default("database.path", "./data/llm_proxy.db")? - .set_default("database.max_connections", 10)? - .set_default("providers.openai.api_key_env", "OPENAI_API_KEY")? - .set_default("providers.openai.base_url", "https://api.openai.com/v1")? - .set_default("providers.openai.default_model", "gpt-4o")? - .set_default("providers.openai.enabled", true)? - .set_default("providers.gemini.api_key_env", "GEMINI_API_KEY")? - .set_default( - "providers.gemini.base_url", - "https://generativelanguage.googleapis.com/v1", - )? - .set_default("providers.gemini.default_model", "gemini-2.0-flash")? - .set_default("providers.gemini.enabled", true)? - .set_default("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")? - .set_default("providers.deepseek.base_url", "https://api.deepseek.com")? - .set_default("providers.deepseek.default_model", "deepseek-reasoner")? - .set_default("providers.deepseek.enabled", true)? - .set_default("providers.grok.api_key_env", "GROK_API_KEY")? - .set_default("providers.grok.base_url", "https://api.x.ai/v1")? - .set_default("providers.grok.default_model", "grok-beta")? - .set_default("providers.grok.enabled", true)? - .set_default("providers.ollama.base_url", "http://localhost:11434/v1")? - .set_default("providers.ollama.enabled", false)? - .set_default("providers.ollama.models", Vec::::new())? - .set_default("encryption_key", "")?; - - // Load from config file if exists - // Priority: explicit path arg > LLM_PROXY__CONFIG_PATH env var > ./config.toml - let config_path = config_path - .or_else(|| std::env::var("LLM_PROXY__CONFIG_PATH").ok().map(PathBuf::from)) - .unwrap_or_else(|| { - std::env::current_dir() - .unwrap_or_else(|_| PathBuf::from(".")) - .join("config.toml") - }); - if config_path.exists() { - config_builder = config_builder.add_source(File::from(config_path.clone()).format(FileFormat::Toml)); - } - - // Load from .env file - dotenvy::dotenv().ok(); - - // Load from environment variables (with prefix "LLM_PROXY_") - config_builder = config_builder.add_source( - config::Environment::with_prefix("LLM_PROXY") - .separator("__") - .try_parsing(true), - ); - - let config = config_builder.build()?; - - // Deserialize configuration - let server: ServerConfig = config.get("server")?; - let database: DatabaseConfig = config.get("database")?; - let providers: ProviderConfig = config.get("providers")?; - let encryption_key: String = config.get("encryption_key")?; - - // Validate encryption key length (must be 32 bytes after hex or base64 decoding) - if encryption_key.is_empty() { - anyhow::bail!("Encryption key is required (LLM_PROXY__ENCRYPTION_KEY environment variable)"); - } - // Try hex decode first, then base64 - let key_bytes = hex::decode(&encryption_key) - .or_else(|_| base64::engine::general_purpose::STANDARD.decode(&encryption_key)) - .map_err(|e| anyhow::anyhow!("Encryption key must be hex or base64 encoded: {}", e))?; - if key_bytes.len() != 32 { - anyhow::bail!("Encryption key must be 32 bytes (256 bits), got {} bytes", key_bytes.len()); - } - - // For now, use empty model mapping and pricing (will be populated later) - let model_mapping = ModelMappingConfig { patterns: vec![] }; - let pricing = PricingConfig { - openai: vec![], - gemini: vec![], - deepseek: vec![], - grok: vec![], - ollama: vec![], - }; - - Ok(Arc::new(AppConfig { - server, - database, - providers, - model_mapping, - pricing, - config_path: Some(config_path), - encryption_key, - })) - } - - pub fn get_api_key(&self, provider: &str) -> Result { - let env_var = match provider { - "openai" => &self.providers.openai.api_key_env, - "gemini" => &self.providers.gemini.api_key_env, - "deepseek" => &self.providers.deepseek.api_key_env, - "grok" => &self.providers.grok.api_key_env, - _ => return Err(anyhow::anyhow!("Unknown provider: {}", provider)), - }; - - std::env::var(env_var).map_err(|_| anyhow::anyhow!("Environment variable {} not set for {}", env_var, provider)) - } -} - -/// Helper function to deserialize a Vec from either a sequence or a comma-separated string -fn deserialize_vec_or_string<'de, D>(deserializer: D) -> Result, D::Error> -where - D: serde::Deserializer<'de>, -{ - struct VecOrString; - - impl<'de> serde::de::Visitor<'de> for VecOrString { - type Value = Vec; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a sequence or a comma-separated string") - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - Ok(value - .split(',') - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()) - .collect()) - } - - fn visit_seq(self, mut seq: S) -> Result - where - S: serde::de::SeqAccess<'de>, - { - let mut vec = Vec::new(); - while let Some(element) = seq.next_element()? { - vec.push(element); - } - Ok(vec) - } - } - - deserializer.deserialize_any(VecOrString) -} diff --git a/src/dashboard/auth.rs b/src/dashboard/auth.rs deleted file mode 100644 index 10a0df84..00000000 --- a/src/dashboard/auth.rs +++ /dev/null @@ -1,229 +0,0 @@ -use axum::{extract::State, http::{HeaderMap, HeaderValue}, response::{Json, IntoResponse}}; -use bcrypt; -use serde::Deserialize; -use sqlx::Row; -use tracing::warn; - -use super::{ApiResponse, DashboardState}; - -// Authentication handlers -#[derive(Deserialize)] -pub(super) struct LoginRequest { - pub(super) username: String, - pub(super) password: String, -} - -pub(super) async fn handle_login( - State(state): State, - Json(payload): Json, -) -> Json> { - let pool = &state.app_state.db_pool; - - let user_result = sqlx::query( - "SELECT username, password_hash, display_name, role, must_change_password FROM users WHERE username = ?", - ) - .bind(&payload.username) - .fetch_optional(pool) - .await; - - match user_result { - Ok(Some(row)) => { - let hash = row.get::("password_hash"); - if bcrypt::verify(&payload.password, &hash).unwrap_or(false) { - let username = row.get::("username"); - let role = row.get::("role"); - let display_name = row - .get::, _>("display_name") - .unwrap_or_else(|| username.clone()); - let must_change_password = row.get::("must_change_password"); - let token = state - .session_manager - .create_session(username.clone(), role.clone()) - .await; - Json(ApiResponse::success(serde_json::json!({ - "token": token, - "must_change_password": must_change_password, - "user": { - "username": username, - "name": display_name, - "role": role - } - }))) - } else { - Json(ApiResponse::error("Invalid username or password".to_string())) - } - } - Ok(None) => Json(ApiResponse::error("Invalid username or password".to_string())), - Err(e) => { - warn!("Database error during login: {}", e); - Json(ApiResponse::error("Login failed due to system error".to_string())) - } - } -} - -pub(super) async fn handle_auth_status( - State(state): State, - headers: axum::http::HeaderMap, -) -> impl IntoResponse { - let token = headers - .get("Authorization") - .and_then(|v| v.to_str().ok()) - .and_then(|v| v.strip_prefix("Bearer ")); - - if let Some(token) = token - && let Some((session, new_token)) = state.session_manager.validate_session_with_refresh(token).await - { - // Look up display_name from DB - let display_name = sqlx::query_scalar::<_, Option>( - "SELECT display_name FROM users WHERE username = ?", - ) - .bind(&session.username) - .fetch_optional(&state.app_state.db_pool) - .await - .ok() - .flatten() - .flatten() - .unwrap_or_else(|| session.username.clone()); - - let mut headers = HeaderMap::new(); - if let Some(refreshed_token) = new_token { - if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) { - headers.insert("X-Refreshed-Token", header_value); - } - } - return (headers, Json(ApiResponse::success(serde_json::json!({ - "authenticated": true, - "user": { - "username": session.username, - "name": display_name, - "role": session.role - } - })))); - } - - (HeaderMap::new(), Json(ApiResponse::error("Not authenticated".to_string()))) -} - -#[derive(Deserialize)] -pub(super) struct ChangePasswordRequest { - pub(super) current_password: String, - pub(super) new_password: String, -} - -pub(super) async fn handle_change_password( - State(state): State, - headers: axum::http::HeaderMap, - Json(payload): Json, -) -> impl IntoResponse { - let pool = &state.app_state.db_pool; - - // Extract the authenticated user from the session token - let token = headers - .get("Authorization") - .and_then(|v| v.to_str().ok()) - .and_then(|v| v.strip_prefix("Bearer ")); - - let (session, new_token) = match token { - Some(t) => match state.session_manager.validate_session_with_refresh(t).await { - Some((session, new_token)) => (Some(session), new_token), - None => (None, None), - }, - None => (None, None), - }; - - let mut response_headers = HeaderMap::new(); - if let Some(refreshed_token) = new_token { - if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) { - response_headers.insert("X-Refreshed-Token", header_value); - } - } - - let username = match session { - Some(s) => s.username, - None => return (response_headers, Json(ApiResponse::error("Not authenticated".to_string()))), - }; - - let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = ?") - .bind(&username) - .fetch_one(pool) - .await; - - match user_result { - Ok(row) => { - let hash = row.get::("password_hash"); - if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) { - let new_hash = match bcrypt::hash(&payload.new_password, 12) { - Ok(h) => h, - Err(_) => return (response_headers, Json(ApiResponse::error("Failed to hash new password".to_string()))), - }; - - let update_result = sqlx::query( - "UPDATE users SET password_hash = ?, must_change_password = FALSE WHERE username = ?", - ) - .bind(new_hash) - .bind(&username) - .execute(pool) - .await; - - match update_result { - Ok(_) => (response_headers, Json(ApiResponse::success( - serde_json::json!({ "message": "Password updated successfully" }), - ))), - Err(e) => (response_headers, Json(ApiResponse::error(format!("Failed to update database: {}", e)))), - } - } else { - (response_headers, Json(ApiResponse::error("Current password incorrect".to_string()))) - } - } - Err(e) => (response_headers, Json(ApiResponse::error(format!("User not found: {}", e)))), - } -} - -pub(super) async fn handle_logout( - State(state): State, - headers: axum::http::HeaderMap, -) -> Json> { - let token = headers - .get("Authorization") - .and_then(|v| v.to_str().ok()) - .and_then(|v| v.strip_prefix("Bearer ")); - - if let Some(token) = token { - state.session_manager.revoke_session(token).await; - } - - Json(ApiResponse::success(serde_json::json!({ "message": "Logged out" }))) -} - -/// Helper: Extract and validate a session from the Authorization header. -/// Returns the Session and optional new token if refreshed, or an error response. -pub(super) async fn extract_session( - state: &DashboardState, - headers: &axum::http::HeaderMap, -) -> Result<(super::sessions::Session, Option), Json>> { - let token = headers - .get("Authorization") - .and_then(|v| v.to_str().ok()) - .and_then(|v| v.strip_prefix("Bearer ")); - - match token { - Some(t) => match state.session_manager.validate_session_with_refresh(t).await { - Some((session, new_token)) => Ok((session, new_token)), - None => Err(Json(ApiResponse::error("Session expired or invalid".to_string()))), - }, - None => Err(Json(ApiResponse::error("Not authenticated".to_string()))), - } -} - -/// Helper: Extract session and require admin role. -/// Returns session and optional new token if refreshed. -pub(super) async fn require_admin( - state: &DashboardState, - headers: &axum::http::HeaderMap, -) -> Result<(super::sessions::Session, Option), Json>> { - let (session, new_token) = extract_session(state, headers).await?; - if session.role != "admin" { - return Err(Json(ApiResponse::error("Admin access required".to_string()))); - } - Ok((session, new_token)) -} diff --git a/src/dashboard/clients.rs b/src/dashboard/clients.rs deleted file mode 100644 index ca5c1709..00000000 --- a/src/dashboard/clients.rs +++ /dev/null @@ -1,538 +0,0 @@ -use axum::{ - extract::{Path, State}, - response::Json, -}; -use chrono; -use rand::Rng; -use serde::Deserialize; -use serde_json; -use sqlx::Row; -use tracing::warn; -use uuid; - -use super::{ApiResponse, DashboardState}; - -/// Generate a random API token: sk-{48 hex chars} -fn generate_token() -> String { - let mut rng = rand::rng(); - let bytes: Vec = (0..24).map(|_| rng.random::()).collect(); - format!("sk-{}", hex::encode(bytes)) -} - -#[derive(Deserialize)] -pub(super) struct CreateClientRequest { - pub(super) name: String, - pub(super) client_id: Option, -} - -#[derive(Deserialize)] -pub(super) struct UpdateClientPayload { - pub(super) name: Option, - pub(super) description: Option, - pub(super) is_active: Option, - pub(super) rate_limit_per_minute: Option, -} - -pub(super) async fn handle_get_clients( - State(state): State, - headers: axum::http::HeaderMap, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - - let result = sqlx::query( - r#" - SELECT - client_id as id, - name, - description, - created_at, - total_requests, - total_tokens, - total_cost, - is_active, - rate_limit_per_minute - FROM clients - ORDER BY created_at DESC - "#, - ) - .fetch_all(pool) - .await; - - match result { - Ok(rows) => { - let clients: Vec = rows - .into_iter() - .map(|row| { - serde_json::json!({ - "id": row.get::("id"), - "name": row.get::, _>("name").unwrap_or_else(|| "Unnamed".to_string()), - "description": row.get::, _>("description"), - "created_at": row.get::, _>("created_at"), - "requests_count": row.get::("total_requests"), - "total_tokens": row.get::("total_tokens"), - "total_cost": row.get::("total_cost"), - "status": if row.get::("is_active") { "active" } else { "inactive" }, - "rate_limit_per_minute": row.get::, _>("rate_limit_per_minute"), - }) - }) - .collect(); - - Json(ApiResponse::success(serde_json::json!(clients))) - } - Err(e) => { - warn!("Failed to fetch clients: {}", e); - Json(ApiResponse::error("Failed to fetch clients".to_string())) - } - } -} - -pub(super) async fn handle_create_client( - State(state): State, - headers: axum::http::HeaderMap, - Json(payload): Json, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - - let client_id = payload - .client_id - .unwrap_or_else(|| format!("client-{}", &uuid::Uuid::new_v4().to_string()[..8])); - - let result = sqlx::query( - r#" - INSERT INTO clients (client_id, name, is_active) - VALUES (?, ?, TRUE) - RETURNING * - "#, - ) - .bind(&client_id) - .bind(&payload.name) - .fetch_one(pool) - .await; - - match result { - Ok(row) => { - // Auto-generate a token for the new client - let token = generate_token(); - let token_result = sqlx::query( - "INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, 'default')", - ) - .bind(&client_id) - .bind(&token) - .execute(pool) - .await; - - if let Err(e) = token_result { - warn!("Client created but failed to generate token: {}", e); - } - - Json(ApiResponse::success(serde_json::json!({ - "id": row.get::("client_id"), - "name": row.get::, _>("name"), - "created_at": row.get::, _>("created_at"), - "status": "active", - "token": token, - }))) - } - Err(e) => { - warn!("Failed to create client: {}", e); - Json(ApiResponse::error(format!("Failed to create client: {}", e))) - } - } -} - -pub(super) async fn handle_get_client( - State(state): State, - Path(id): Path, -) -> Json> { - let pool = &state.app_state.db_pool; - - let result = sqlx::query( - r#" - SELECT - c.client_id as id, - c.name, - c.description, - c.is_active, - c.rate_limit_per_minute, - c.created_at, - COALESCE(c.total_tokens, 0) as total_tokens, - COALESCE(c.total_cost, 0.0) as total_cost, - COUNT(r.id) as total_requests, - MAX(r.timestamp) as last_request - FROM clients c - LEFT JOIN llm_requests r ON c.client_id = r.client_id - WHERE c.client_id = ? - GROUP BY c.client_id - "#, - ) - .bind(&id) - .fetch_optional(pool) - .await; - - match result { - Ok(Some(row)) => Json(ApiResponse::success(serde_json::json!({ - "id": row.get::("id"), - "name": row.get::, _>("name").unwrap_or_else(|| "Unnamed".to_string()), - "description": row.get::, _>("description"), - "is_active": row.get::("is_active"), - "rate_limit_per_minute": row.get::, _>("rate_limit_per_minute"), - "created_at": row.get::, _>("created_at"), - "total_tokens": row.get::("total_tokens"), - "total_cost": row.get::("total_cost"), - "total_requests": row.get::("total_requests"), - "last_request": row.get::>, _>("last_request"), - "status": if row.get::("is_active") { "active" } else { "inactive" }, - }))), - Ok(None) => Json(ApiResponse::error(format!("Client '{}' not found", id))), - Err(e) => { - warn!("Failed to fetch client: {}", e); - Json(ApiResponse::error(format!("Failed to fetch client: {}", e))) - } - } -} - -pub(super) async fn handle_update_client( - State(state): State, - headers: axum::http::HeaderMap, - Path(id): Path, - Json(payload): Json, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - - // Build dynamic UPDATE query from provided fields - let mut sets = Vec::new(); - let mut binds: Vec = Vec::new(); - - if let Some(ref name) = payload.name { - sets.push("name = ?"); - binds.push(name.clone()); - } - if let Some(ref desc) = payload.description { - sets.push("description = ?"); - binds.push(desc.clone()); - } - if payload.is_active.is_some() { - sets.push("is_active = ?"); - } - if payload.rate_limit_per_minute.is_some() { - sets.push("rate_limit_per_minute = ?"); - } - - if sets.is_empty() { - return Json(ApiResponse::error("No fields to update".to_string())); - } - - // Always update updated_at - sets.push("updated_at = CURRENT_TIMESTAMP"); - - let sql = format!("UPDATE clients SET {} WHERE client_id = ?", sets.join(", ")); - let mut query = sqlx::query(&sql); - - // Bind in the same order as sets - for b in &binds { - query = query.bind(b); - } - if let Some(active) = payload.is_active { - query = query.bind(active); - } - if let Some(rate) = payload.rate_limit_per_minute { - query = query.bind(rate); - } - query = query.bind(&id); - - match query.execute(pool).await { - Ok(result) => { - if result.rows_affected() == 0 { - return Json(ApiResponse::error(format!("Client '{}' not found", id))); - } - // Return the updated client - let row = sqlx::query( - r#" - SELECT client_id as id, name, description, is_active, rate_limit_per_minute, - created_at, total_requests, total_tokens, total_cost - FROM clients WHERE client_id = ? - "#, - ) - .bind(&id) - .fetch_one(pool) - .await; - - match row { - Ok(row) => Json(ApiResponse::success(serde_json::json!({ - "id": row.get::("id"), - "name": row.get::, _>("name").unwrap_or_else(|| "Unnamed".to_string()), - "description": row.get::, _>("description"), - "is_active": row.get::("is_active"), - "rate_limit_per_minute": row.get::, _>("rate_limit_per_minute"), - "created_at": row.get::, _>("created_at"), - "total_requests": row.get::("total_requests"), - "total_tokens": row.get::("total_tokens"), - "total_cost": row.get::("total_cost"), - "status": if row.get::("is_active") { "active" } else { "inactive" }, - }))), - Err(e) => { - warn!("Failed to fetch updated client: {}", e); - // Update succeeded but fetch failed — still report success - Json(ApiResponse::success(serde_json::json!({ "message": "Client updated" }))) - } - } - } - Err(e) => { - warn!("Failed to update client: {}", e); - Json(ApiResponse::error(format!("Failed to update client: {}", e))) - } - } -} - -pub(super) async fn handle_delete_client( - State(state): State, - headers: axum::http::HeaderMap, - Path(id): Path, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - - // Don't allow deleting the default client - if id == "default" { - return Json(ApiResponse::error("Cannot delete default client".to_string())); - } - - let result = sqlx::query("DELETE FROM clients WHERE client_id = ?") - .bind(id) - .execute(pool) - .await; - - match result { - Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Client deleted" }))), - Err(e) => Json(ApiResponse::error(format!("Failed to delete client: {}", e))), - } -} - -pub(super) async fn handle_client_usage( - State(state): State, - headers: axum::http::HeaderMap, - Path(id): Path, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - - // Get per-model breakdown for this client - let result = sqlx::query( - r#" - SELECT - model, - provider, - COUNT(*) as request_count, - SUM(prompt_tokens) as prompt_tokens, - SUM(completion_tokens) as completion_tokens, - SUM(total_tokens) as total_tokens, - SUM(cost) as total_cost, - AVG(duration_ms) as avg_duration_ms - FROM llm_requests - WHERE client_id = ? - GROUP BY model, provider - ORDER BY total_cost DESC - "#, - ) - .bind(&id) - .fetch_all(pool) - .await; - - match result { - Ok(rows) => { - let breakdown: Vec = rows - .into_iter() - .map(|row| { - serde_json::json!({ - "model": row.get::("model"), - "provider": row.get::("provider"), - "request_count": row.get::("request_count"), - "prompt_tokens": row.get::("prompt_tokens"), - "completion_tokens": row.get::("completion_tokens"), - "total_tokens": row.get::("total_tokens"), - "total_cost": row.get::("total_cost"), - "avg_duration_ms": row.get::("avg_duration_ms"), - }) - }) - .collect(); - - Json(ApiResponse::success(serde_json::json!({ - "client_id": id, - "breakdown": breakdown, - }))) - } - Err(e) => { - warn!("Failed to fetch client usage: {}", e); - Json(ApiResponse::error(format!("Failed to fetch client usage: {}", e))) - } - } -} - -// ── Token management endpoints ────────────────────────────────────── - -pub(super) async fn handle_get_client_tokens( - State(state): State, - headers: axum::http::HeaderMap, - Path(id): Path, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - - let result = sqlx::query( - r#" - SELECT id, token, name, is_active, created_at, last_used_at - FROM client_tokens - WHERE client_id = ? - ORDER BY created_at DESC - "#, - ) - .bind(&id) - .fetch_all(pool) - .await; - - match result { - Ok(rows) => { - let tokens: Vec = rows - .into_iter() - .map(|row| { - let token: String = row.get("token"); - // Mask all but last 8 chars: sk-••••abcd1234 - let masked = if token.len() > 8 { - format!("{}••••{}", &token[..3], &token[token.len() - 8..]) - } else { - "••••".to_string() - }; - serde_json::json!({ - "id": row.get::("id"), - "token_masked": masked, - "name": row.get::, _>("name").unwrap_or_else(|| "default".to_string()), - "is_active": row.get::("is_active"), - "created_at": row.get::, _>("created_at"), - "last_used_at": row.get::>, _>("last_used_at"), - }) - }) - .collect(); - - Json(ApiResponse::success(serde_json::json!(tokens))) - } - Err(e) => { - warn!("Failed to fetch client tokens: {}", e); - Json(ApiResponse::error(format!("Failed to fetch client tokens: {}", e))) - } - } -} - -#[derive(Deserialize)] -pub(super) struct CreateTokenRequest { - pub(super) name: Option, -} - -pub(super) async fn handle_create_client_token( - State(state): State, - headers: axum::http::HeaderMap, - Path(id): Path, - Json(payload): Json, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - - // Verify client exists - let exists: Option<(i64,)> = sqlx::query_as("SELECT 1 as x FROM clients WHERE client_id = ?") - .bind(&id) - .fetch_optional(pool) - .await - .unwrap_or(None); - - if exists.is_none() { - return Json(ApiResponse::error(format!("Client '{}' not found", id))); - } - - let token = generate_token(); - let token_name = payload.name.unwrap_or_else(|| "default".to_string()); - - let result = sqlx::query( - "INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, ?) RETURNING id, created_at", - ) - .bind(&id) - .bind(&token) - .bind(&token_name) - .fetch_one(pool) - .await; - - match result { - Ok(row) => Json(ApiResponse::success(serde_json::json!({ - "id": row.get::("id"), - "token": token, - "name": token_name, - "created_at": row.get::, _>("created_at"), - }))), - Err(e) => { - warn!("Failed to create client token: {}", e); - Json(ApiResponse::error(format!("Failed to create token: {}", e))) - } - } -} - -pub(super) async fn handle_delete_client_token( - State(state): State, - headers: axum::http::HeaderMap, - Path((client_id, token_id)): Path<(String, i64)>, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - - let result = sqlx::query("DELETE FROM client_tokens WHERE id = ? AND client_id = ?") - .bind(token_id) - .bind(&client_id) - .execute(pool) - .await; - - match result { - Ok(r) => { - if r.rows_affected() == 0 { - Json(ApiResponse::error("Token not found".to_string())) - } else { - Json(ApiResponse::success(serde_json::json!({ "message": "Token revoked" }))) - } - } - Err(e) => { - warn!("Failed to delete client token: {}", e); - Json(ApiResponse::error(format!("Failed to revoke token: {}", e))) - } - } -} diff --git a/src/dashboard/mod.rs b/src/dashboard/mod.rs deleted file mode 100644 index 460e5f71..00000000 --- a/src/dashboard/mod.rs +++ /dev/null @@ -1,174 +0,0 @@ -// Dashboard module for LLM Proxy Gateway - -mod auth; -mod clients; -mod models; -mod providers; -pub mod sessions; -mod system; -mod usage; -mod users; -mod websocket; - -use axum::{ - extract::{Request, State}, - middleware::Next, - response::Response, - Router, - routing::{delete, get, post, put}, -}; -use axum::http::{header, HeaderValue}; -use serde::Serialize; -use tower_http::{ - limit::RequestBodyLimitLayer, - set_header::SetResponseHeaderLayer, -}; - -use crate::state::AppState; -use sessions::SessionManager; - -// Dashboard state -#[derive(Clone)] -struct DashboardState { - app_state: AppState, - session_manager: SessionManager, -} - -// API Response types -#[derive(Serialize)] -struct ApiResponse { - success: bool, - data: Option, - error: Option, -} - -impl ApiResponse { - fn success(data: T) -> Self { - Self { - success: true, - data: Some(data), - error: None, - } - } - - fn error(error: String) -> Self { - Self { - success: false, - data: None, - error: Some(error), - } - } -} - -/// Rate limiting middleware for dashboard routes -async fn dashboard_rate_limit_middleware( - State(_dashboard_state): State, - request: Request, - next: Next, -) -> Result { - // Bypass rate limiting for dashboard routes to prevent "Failed to load statistics" - // when the UI makes many concurrent requests on load. - // Dashboard endpoints are already secured via auth::require_admin. - Ok(next.run(request).await) -} - -// Dashboard routes -pub fn router(state: AppState) -> Router { - let session_manager = SessionManager::new(24); // 24-hour session TTL - let dashboard_state = DashboardState { - app_state: state, - session_manager, - }; - - // Security headers - let csp_header: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( - header::CONTENT_SECURITY_POLICY, - "default-src 'self'; script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; style-src 'self' 'unsafe-inline' https://cdnjs.cloudflare.com https://fonts.googleapis.com; font-src 'self' https://cdnjs.cloudflare.com https://fonts.gstatic.com; img-src 'self' data:; connect-src 'self' ws:;" - .parse() - .unwrap(), - ); - let x_frame_options: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( - header::X_FRAME_OPTIONS, - "DENY".parse().unwrap(), - ); - let x_content_type_options: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( - header::X_CONTENT_TYPE_OPTIONS, - "nosniff".parse().unwrap(), - ); - let strict_transport_security: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( - header::STRICT_TRANSPORT_SECURITY, - "max-age=31536000; includeSubDomains".parse().unwrap(), - ); - - Router::new() - // Static file serving - .fallback_service(tower_http::services::ServeDir::new("static")) - // WebSocket endpoint - .route("/ws", get(websocket::handle_websocket)) - // API endpoints - .route("/api/auth/login", post(auth::handle_login)) - .route("/api/auth/status", get(auth::handle_auth_status)) - .route("/api/auth/logout", post(auth::handle_logout)) - .route("/api/auth/change-password", post(auth::handle_change_password)) - .route( - "/api/users", - get(users::handle_get_users).post(users::handle_create_user), - ) - .route( - "/api/users/{id}", - put(users::handle_update_user).delete(users::handle_delete_user), - ) - .route("/api/usage/summary", get(usage::handle_usage_summary)) - .route("/api/usage/time-series", get(usage::handle_time_series)) - .route("/api/usage/clients", get(usage::handle_clients_usage)) - .route("/api/usage/providers", get(usage::handle_providers_usage)) - .route("/api/usage/detailed", get(usage::handle_detailed_usage)) - .route("/api/analytics/breakdown", get(usage::handle_analytics_breakdown)) - .route("/api/models", get(models::handle_get_models)) - .route("/api/models/{id}", put(models::handle_update_model)) - .route( - "/api/clients", - get(clients::handle_get_clients).post(clients::handle_create_client), - ) - .route( - "/api/clients/{id}", - get(clients::handle_get_client) - .put(clients::handle_update_client) - .delete(clients::handle_delete_client), - ) - .route("/api/clients/{id}/usage", get(clients::handle_client_usage)) - .route( - "/api/clients/{id}/tokens", - get(clients::handle_get_client_tokens).post(clients::handle_create_client_token), - ) - .route( - "/api/clients/{id}/tokens/{token_id}", - delete(clients::handle_delete_client_token), - ) - .route("/api/providers", get(providers::handle_get_providers)) - .route( - "/api/providers/{name}", - get(providers::handle_get_provider).put(providers::handle_update_provider), - ) - .route("/api/providers/{name}/test", post(providers::handle_test_provider)) - .route("/api/system/health", get(system::handle_system_health)) - .route("/api/system/metrics", get(system::handle_system_metrics)) - .route("/api/system/logs", get(system::handle_system_logs)) - .route("/api/system/backup", post(system::handle_system_backup)) - .route( - "/api/system/settings", - get(system::handle_get_settings).post(system::handle_update_settings), - ) - // Security layers - .layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit - .layer(csp_header) - .layer(x_frame_options) - .layer(x_content_type_options) - .layer(strict_transport_security) - // Rate limiting middleware - .layer(axum::middleware::from_fn_with_state( - dashboard_state.clone(), - dashboard_rate_limit_middleware, - )) - .with_state(dashboard_state) -} diff --git a/src/dashboard/models.rs b/src/dashboard/models.rs deleted file mode 100644 index ddcb5eaf..00000000 --- a/src/dashboard/models.rs +++ /dev/null @@ -1,254 +0,0 @@ -use axum::{ - extract::{Path, Query, State}, - response::Json, -}; -use serde::Deserialize; -use serde_json; -use sqlx::Row; -use std::collections::HashMap; - -use super::{ApiResponse, DashboardState}; -use crate::models::registry::{ModelFilter, ModelSortBy, SortOrder}; - -#[derive(Deserialize)] -pub(super) struct UpdateModelRequest { - pub(super) enabled: bool, - pub(super) prompt_cost: Option, - pub(super) completion_cost: Option, - pub(super) mapping: Option, -} - -/// Query parameters for `GET /api/models`. -#[derive(Debug, Deserialize, Default)] -pub(super) struct ModelListParams { - /// Filter by provider ID. - pub provider: Option, - /// Text search on model ID or name. - pub search: Option, - /// Filter by input modality (e.g. "image"). - pub modality: Option, - /// Only models that support tool calling. - pub tool_call: Option, - /// Only models that support reasoning. - pub reasoning: Option, - /// Only models that have pricing data. - pub has_cost: Option, - /// Only models that have been used in requests. - pub used_only: Option, - /// Sort field (name, id, provider, context_limit, input_cost, output_cost). - pub sort_by: Option, - /// Sort direction (asc, desc). - pub sort_order: Option, -} - -pub(super) async fn handle_get_models( - State(state): State, - headers: axum::http::HeaderMap, - Query(params): Query, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let registry = &state.app_state.model_registry; - let pool = &state.app_state.db_pool; - - // Load overrides from database - let db_models_result = - sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs") - .fetch_all(pool) - .await; - - let mut db_models = HashMap::new(); - if let Ok(rows) = db_models_result { - for row in rows { - let id: String = row.get("id"); - db_models.insert(id, row); - } - } - - let mut models_json = Vec::new(); - - if params.used_only.unwrap_or(false) { - // EXACT USED MODELS LOGIC - let used_pairs_result = sqlx::query( - "SELECT DISTINCT provider, model FROM llm_requests", - ) - .fetch_all(pool) - .await; - - if let Ok(rows) = used_pairs_result { - for row in rows { - let provider: String = row.get("provider"); - let m_key: String = row.get("model"); - - let provider_name = match provider.as_str() { - "openai" => "OpenAI", - "gemini" => "Google Gemini", - "deepseek" => "DeepSeek", - "grok" => "xAI Grok", - "ollama" => "Ollama", - _ => provider.as_str(), - }.to_string(); - - let m_meta = registry.find_model(&m_key); - - let mut enabled = true; - let mut prompt_cost = m_meta.and_then(|m| m.cost.as_ref().map(|c| c.input)).unwrap_or(0.0); - let mut completion_cost = m_meta.and_then(|m| m.cost.as_ref().map(|c| c.output)).unwrap_or(0.0); - let cache_read_cost = m_meta.and_then(|m| m.cost.as_ref().and_then(|c| c.cache_read)); - let cache_write_cost = m_meta.and_then(|m| m.cost.as_ref().and_then(|c| c.cache_write)); - let mut mapping = None::; - - if let Some(db_row) = db_models.get(&m_key) { - enabled = db_row.get("enabled"); - if let Some(p) = db_row.get::, _>("prompt_cost_per_m") { - prompt_cost = p; - } - if let Some(c) = db_row.get::, _>("completion_cost_per_m") { - completion_cost = c; - } - mapping = db_row.get("mapping"); - } - - models_json.push(serde_json::json!({ - "id": m_key, - "provider": provider, - "provider_name": provider_name, - "name": m_meta.map(|m| m.name.clone()).unwrap_or_else(|| m_key.clone()), - "enabled": enabled, - "prompt_cost": prompt_cost, - "completion_cost": completion_cost, - "cache_read_cost": cache_read_cost, - "cache_write_cost": cache_write_cost, - "mapping": mapping, - "context_limit": m_meta.and_then(|m| m.limit.as_ref().map(|l| l.context)).unwrap_or(0), - "output_limit": m_meta.and_then(|m| m.limit.as_ref().map(|l| l.output)).unwrap_or(0), - "modalities": m_meta.and_then(|m| m.modalities.as_ref().map(|mo| serde_json::json!({ - "input": mo.input, - "output": mo.output, - }))), - "tool_call": m_meta.and_then(|m| m.tool_call), - "reasoning": m_meta.and_then(|m| m.reasoning), - })); - } - } - } else { - // REGISTRY LISTING LOGIC - // Build filter from query params - let filter = ModelFilter { - provider: params.provider, - search: params.search, - modality: params.modality, - tool_call: params.tool_call, - reasoning: params.reasoning, - has_cost: params.has_cost, - }; - let sort_by = params.sort_by.unwrap_or_default(); - let sort_order = params.sort_order.unwrap_or_default(); - - // Get filtered and sorted model entries - let entries = registry.list_models(&filter, &sort_by, &sort_order); - - for entry in &entries { - let m_key = entry.model_key; - let m_meta = entry.metadata; - - let mut enabled = true; - let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0); - let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0); - let cache_read_cost = m_meta.cost.as_ref().and_then(|c| c.cache_read); - let cache_write_cost = m_meta.cost.as_ref().and_then(|c| c.cache_write); - let mut mapping = None::; - - if let Some(row) = db_models.get(m_key) { - enabled = row.get("enabled"); - if let Some(p) = row.get::, _>("prompt_cost_per_m") { - prompt_cost = p; - } - if let Some(c) = row.get::, _>("completion_cost_per_m") { - completion_cost = c; - } - mapping = row.get("mapping"); - } - - models_json.push(serde_json::json!({ - "id": m_key, - "provider": entry.provider_id, - "provider_name": entry.provider_name, - "name": m_meta.name, - "enabled": enabled, - "prompt_cost": prompt_cost, - "completion_cost": completion_cost, - "cache_read_cost": cache_read_cost, - "cache_write_cost": cache_write_cost, - "mapping": mapping, - "context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0), - "output_limit": m_meta.limit.as_ref().map(|l| l.output).unwrap_or(0), - "modalities": m_meta.modalities.as_ref().map(|m| serde_json::json!({ - "input": m.input, - "output": m.output, - })), - "tool_call": m_meta.tool_call, - "reasoning": m_meta.reasoning, - })); - } - } - - Json(ApiResponse::success(serde_json::json!(models_json))) -} - -pub(super) async fn handle_update_model( - State(state): State, - headers: axum::http::HeaderMap, - Path(id): Path, - Json(payload): Json, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - - // Find provider_id for this model in registry - let provider_id = state - .app_state - .model_registry - .providers - .iter() - .find(|(_, p)| p.models.contains_key(&id)) - .map(|(id, _)| id.clone()) - .unwrap_or_else(|| "unknown".to_string()); - - let result = sqlx::query( - r#" - INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping) - VALUES (?, ?, ?, ?, ?, ?) - ON CONFLICT(id) DO UPDATE SET - enabled = excluded.enabled, - prompt_cost_per_m = excluded.prompt_cost_per_m, - completion_cost_per_m = excluded.completion_cost_per_m, - mapping = excluded.mapping, - updated_at = CURRENT_TIMESTAMP - "#, - ) - .bind(&id) - .bind(provider_id) - .bind(payload.enabled) - .bind(payload.prompt_cost) - .bind(payload.completion_cost) - .bind(payload.mapping) - .execute(pool) - .await; - - match result { - Ok(_) => { - // Invalidate the in-memory cache so the proxy picks up the change immediately - state.app_state.model_config_cache.invalidate().await; - Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" }))) - } - Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))), - } -} diff --git a/src/dashboard/providers.rs b/src/dashboard/providers.rs deleted file mode 100644 index 7f9f204a..00000000 --- a/src/dashboard/providers.rs +++ /dev/null @@ -1,440 +0,0 @@ -use axum::{ - extract::{Path, State}, - response::Json, -}; -use serde::Deserialize; -use serde_json; -use sqlx::Row; -use std::collections::HashMap; -use tracing::warn; - -use super::{ApiResponse, DashboardState}; -use crate::utils::crypto; - -#[derive(Deserialize)] -pub(super) struct UpdateProviderRequest { - pub(super) enabled: bool, - pub(super) base_url: Option, - pub(super) api_key: Option, - pub(super) credit_balance: Option, - pub(super) low_credit_threshold: Option, - pub(super) billing_mode: Option, -} - -pub(super) async fn handle_get_providers( - State(state): State, - headers: axum::http::HeaderMap, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let registry = &state.app_state.model_registry; - let config = &state.app_state.config; - let pool = &state.app_state.db_pool; - - // Load all overrides from database (including billing_mode) - let db_configs_result = sqlx::query( - "SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs", - ) - .fetch_all(pool) - .await; - - let mut db_configs = HashMap::new(); - if let Ok(rows) = db_configs_result { - for row in rows { - let id: String = row.get("id"); - let enabled: bool = row.get("enabled"); - let base_url: Option = row.get("base_url"); - let balance: f64 = row.get("credit_balance"); - let threshold: f64 = row.get("low_credit_threshold"); - let billing_mode: Option = row.get("billing_mode"); - db_configs.insert(id, (enabled, base_url, balance, threshold, billing_mode)); - } - } - - let mut providers_json = Vec::new(); - - // Define the list of providers we support - let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"]; - - for id in provider_ids { - // Get base config - let (mut enabled, mut base_url, display_name) = match id { - "openai" => ( - config.providers.openai.enabled, - config.providers.openai.base_url.clone(), - "OpenAI", - ), - "gemini" => ( - config.providers.gemini.enabled, - config.providers.gemini.base_url.clone(), - "Google Gemini", - ), - "deepseek" => ( - config.providers.deepseek.enabled, - config.providers.deepseek.base_url.clone(), - "DeepSeek", - ), - "grok" => ( - config.providers.grok.enabled, - config.providers.grok.base_url.clone(), - "xAI Grok", - ), - "ollama" => ( - config.providers.ollama.enabled, - config.providers.ollama.base_url.clone(), - "Ollama", - ), - _ => (false, "".to_string(), "Unknown"), - }; - - let mut balance = 0.0; - let mut threshold = 5.0; - let mut billing_mode: Option = None; - - // Apply database overrides - if let Some((db_enabled, db_url, db_balance, db_threshold, db_billing)) = db_configs.get(id) { - enabled = *db_enabled; - if let Some(url) = db_url { - base_url = url.clone(); - } - balance = *db_balance; - threshold = *db_threshold; - billing_mode = db_billing.clone(); - } - - // Find models for this provider in registry - // NOTE: registry provider IDs differ from internal IDs for some providers. - let registry_key = match id { - "gemini" => "google", - "grok" => "xai", - _ => id, - }; - - let mut models = Vec::new(); - if let Some(p_info) = registry.providers.get(registry_key) { - models = p_info.models.keys().cloned().collect(); - } else if id == "ollama" { - models = config.providers.ollama.models.clone(); - } - - // Determine status - let status = if !enabled { - "disabled" - } else { - // Check if it's actually initialized in the provider manager - if state.app_state.provider_manager.get_provider(id).await.is_some() { - // Check circuit breaker - if state - .app_state - .rate_limit_manager - .check_provider_request(id) - .await - .unwrap_or(true) - { - "online" - } else { - "degraded" - } - } else { - "error" // Enabled but failed to initialize (e.g. missing API key) - } - }; - - providers_json.push(serde_json::json!({ - "id": id, - "name": display_name, - "enabled": enabled, - "status": status, - "models": models, - "base_url": base_url, - "credit_balance": balance, - "low_credit_threshold": threshold, - "billing_mode": billing_mode, - "last_used": None::, - })); - } - - Json(ApiResponse::success(serde_json::json!(providers_json))) -} - -pub(super) async fn handle_get_provider( - State(state): State, - headers: axum::http::HeaderMap, - Path(name): Path, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let registry = &state.app_state.model_registry; - let config = &state.app_state.config; - let pool = &state.app_state.db_pool; - - // Validate provider name - let (mut enabled, mut base_url, display_name) = match name.as_str() { - "openai" => ( - config.providers.openai.enabled, - config.providers.openai.base_url.clone(), - "OpenAI", - ), - "gemini" => ( - config.providers.gemini.enabled, - config.providers.gemini.base_url.clone(), - "Google Gemini", - ), - "deepseek" => ( - config.providers.deepseek.enabled, - config.providers.deepseek.base_url.clone(), - "DeepSeek", - ), - "grok" => ( - config.providers.grok.enabled, - config.providers.grok.base_url.clone(), - "xAI Grok", - ), - "ollama" => ( - config.providers.ollama.enabled, - config.providers.ollama.base_url.clone(), - "Ollama", - ), - _ => return Json(ApiResponse::error(format!("Unknown provider '{}'", name))), - }; - - let mut balance = 0.0; - let mut threshold = 5.0; - let mut billing_mode: Option = None; - - // Apply database overrides - let db_config = sqlx::query( - "SELECT enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs WHERE id = ?", - ) - .bind(&name) - .fetch_optional(pool) - .await; - - if let Ok(Some(row)) = db_config { - enabled = row.get::("enabled"); - if let Some(url) = row.get::, _>("base_url") { - base_url = url; - } - balance = row.get::("credit_balance"); - threshold = row.get::("low_credit_threshold"); - billing_mode = row.get::, _>("billing_mode"); - } - - // Find models for this provider - // NOTE: registry provider IDs differ from internal IDs for some providers. - let registry_key = match name.as_str() { - "gemini" => "google", - "grok" => "xai", - _ => name.as_str(), - }; - - let mut models = Vec::new(); - if let Some(p_info) = registry.providers.get(registry_key) { - models = p_info.models.keys().cloned().collect(); - } else if name == "ollama" { - models = config.providers.ollama.models.clone(); - } - - // Determine status - let status = if !enabled { - "disabled" - } else if state.app_state.provider_manager.get_provider(&name).await.is_some() { - if state - .app_state - .rate_limit_manager - .check_provider_request(&name) - .await - .unwrap_or(true) - { - "online" - } else { - "degraded" - } - } else { - "error" - }; - - Json(ApiResponse::success(serde_json::json!({ - "id": name, - "name": display_name, - "enabled": enabled, - "status": status, - "models": models, - "base_url": base_url, - "credit_balance": balance, - "low_credit_threshold": threshold, - "billing_mode": billing_mode, - "last_used": None::, - }))) -} - -pub(super) async fn handle_update_provider( - State(state): State, - headers: axum::http::HeaderMap, - Path(name): Path, - Json(payload): Json, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - - // Prepare API key encryption if provided - let (api_key_to_store, api_key_encrypted_flag) = match &payload.api_key { - Some(key) if !key.is_empty() => { - match crypto::encrypt(key) { - Ok(encrypted) => (Some(encrypted), Some(true)), - Err(e) => { - warn!("Failed to encrypt API key for provider {}: {}", name, e); - return Json(ApiResponse::error(format!("Failed to encrypt API key: {}", e))); - } - } - } - Some(_) => { - // Empty string means clear the key - (None, Some(false)) - } - None => { - // Keep existing key, we'll rely on COALESCE in SQL - (None, None) - } - }; - - // Update or insert into database (include billing_mode and api_key_encrypted) - let result = sqlx::query( - r#" - INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold, billing_mode) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(id) DO UPDATE SET - enabled = excluded.enabled, - base_url = excluded.base_url, - api_key = COALESCE(excluded.api_key, provider_configs.api_key), - api_key_encrypted = COALESCE(excluded.api_key_encrypted, provider_configs.api_key_encrypted), - credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance), - low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold), - billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode), - updated_at = CURRENT_TIMESTAMP - "#, - ) - .bind(&name) - .bind(name.to_uppercase()) - .bind(payload.enabled) - .bind(&payload.base_url) - .bind(&api_key_to_store) - .bind(api_key_encrypted_flag) - .bind(payload.credit_balance) - .bind(payload.low_credit_threshold) - .bind(payload.billing_mode) - .execute(pool) - .await; - - match result { - Ok(_) => { - // Re-initialize provider in manager - if let Err(e) = state - .app_state - .provider_manager - .initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool) - .await - { - warn!("Failed to re-initialize provider {}: {}", name, e); - return Json(ApiResponse::error(format!( - "Provider settings saved but initialization failed: {}", - e - ))); - } - - Json(ApiResponse::success( - serde_json::json!({ "message": "Provider updated and re-initialized" }), - )) - } - Err(e) => { - warn!("Failed to update provider config: {}", e); - Json(ApiResponse::error(format!("Failed to update provider: {}", e))) - } - } -} - -pub(super) async fn handle_test_provider( - State(state): State, - headers: axum::http::HeaderMap, - Path(name): Path, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let start = std::time::Instant::now(); - - let provider = match state.app_state.provider_manager.get_provider(&name).await { - Some(p) => p, - None => { - return Json(ApiResponse::error(format!( - "Provider '{}' not found or not enabled", - name - ))); - } - }; - - // Pick a real model for this provider from the registry - // NOTE: registry provider IDs differ from internal IDs for some providers. - let registry_key = match name.as_str() { - "gemini" => "google", - "grok" => "xai", - _ => name.as_str(), - }; - - let test_model = state - .app_state - .model_registry - .providers - .get(registry_key) - .and_then(|p| p.models.keys().next().cloned()) - .unwrap_or_else(|| name.clone()); - - let test_request = crate::models::UnifiedRequest { - client_id: "system-test".to_string(), - model: test_model, - messages: vec![crate::models::UnifiedMessage { - role: "user".to_string(), - content: vec![crate::models::ContentPart::Text { text: "Hi".to_string() }], - reasoning_content: None, - tool_calls: None, - name: None, - tool_call_id: None, - }], - temperature: None, - top_p: None, - top_k: None, - n: None, - stop: None, - max_tokens: Some(5), - presence_penalty: None, - frequency_penalty: None, - stream: false, - has_images: false, - tools: None, - tool_choice: None, - }; - - match provider.chat_completion(test_request).await { - Ok(_) => { - let latency = start.elapsed().as_millis(); - Json(ApiResponse::success(serde_json::json!({ - "success": true, - "latency": latency, - "message": "Connection test successful" - }))) - } - Err(e) => Json(ApiResponse::error(format!("Provider test failed: {}", e))), - } -} diff --git a/src/dashboard/sessions.rs b/src/dashboard/sessions.rs deleted file mode 100644 index 0e011efc..00000000 --- a/src/dashboard/sessions.rs +++ /dev/null @@ -1,311 +0,0 @@ -use chrono::{DateTime, Duration, Utc}; -use hmac::{Hmac, Mac}; -use serde::{Deserialize, Serialize}; -use sha2::{Sha256, digest::generic_array::GenericArray}; -use std::collections::HashMap; -use std::env; -use std::sync::Arc; -use tokio::sync::RwLock; -use uuid::Uuid; - -use base64::{engine::general_purpose::URL_SAFE, Engine as _}; - -const TOKEN_VERSION: &str = "v2"; -const REFRESH_WINDOW_MINUTES: i64 = 15; // refresh if token expires within 15 minutes - -#[derive(Clone, Debug)] -pub struct Session { - pub username: String, - pub role: String, - pub created_at: DateTime, - pub expires_at: DateTime, - pub session_id: String, // unique identifier for the session (UUID) -} - -#[derive(Clone)] -pub struct SessionManager { - sessions: Arc>>, // key = session_id - ttl_hours: i64, - secret: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -struct SessionPayload { - session_id: String, - username: String, - role: String, - iat: i64, // issued at (Unix timestamp) - exp: i64, // expiry (Unix timestamp) - version: String, -} - -impl SessionManager { - pub fn new(ttl_hours: i64) -> Self { - let secret = load_session_secret(); - Self { - sessions: Arc::new(RwLock::new(HashMap::new())), - ttl_hours, - secret, - } - } - - /// Create a new session and return a signed session token. - pub async fn create_session(&self, username: String, role: String) -> String { - let session_id = Uuid::new_v4().to_string(); - let now = Utc::now(); - let expires_at = now + Duration::hours(self.ttl_hours); - let session = Session { - username: username.clone(), - role: role.clone(), - created_at: now, - expires_at, - session_id: session_id.clone(), - }; - // Store session by session_id - self.sessions.write().await.insert(session_id.clone(), session); - // Create signed token - self.create_signed_token(&session_id, &username, &role, now.timestamp(), expires_at.timestamp()) - } - - /// Validate a session token and return the session if valid and not expired. - /// If the token is within the refresh window, returns a new token as well. - pub async fn validate_session(&self, token: &str) -> Option { - self.validate_session_with_refresh(token).await.map(|(session, _)| session) - } - - /// Validate a session token and return (session, optional new token if refreshed). - pub async fn validate_session_with_refresh(&self, token: &str) -> Option<(Session, Option)> { - // Legacy token format (UUID) - if token.starts_with("session-") { - let sessions = self.sessions.read().await; - return sessions.get(token).and_then(|s| { - if s.expires_at > Utc::now() { - Some((s.clone(), None)) - } else { - None - } - }); - } - - // Signed token format - let payload = match verify_signed_token(token, &self.secret) { - Ok(p) => p, - Err(_) => return None, - }; - - // Check expiry - let now = Utc::now().timestamp(); - if payload.exp <= now { - return None; - } - - // Look up session by session_id - let sessions = self.sessions.read().await; - let session = match sessions.get(&payload.session_id) { - Some(s) => s.clone(), - None => return None, // session revoked or not found - }; - - // Ensure session username/role matches (should always match) - if session.username != payload.username || session.role != payload.role { - return None; - } - - // Check if token is within refresh window (last REFRESH_WINDOW_MINUTES of validity) - let refresh_threshold = payload.exp - REFRESH_WINDOW_MINUTES * 60; - let new_token = if now >= refresh_threshold { - // Generate a new token with same session data but updated iat/exp? - // According to activity-based refresh, we should extend the session expiry. - // We'll extend from now by ttl_hours (or keep original expiry?). - // Let's extend from now by ttl_hours (sliding window). - let new_exp = Utc::now() + Duration::hours(self.ttl_hours); - // Update session expiry in store - drop(sessions); // release read lock before acquiring write lock - self.update_session_expiry(&payload.session_id, new_exp).await; - // Create new token with updated iat/exp - let new_token = self.create_signed_token( - &payload.session_id, - &payload.username, - &payload.role, - now, - new_exp.timestamp(), - ); - Some(new_token) - } else { - None - }; - - Some((session, new_token)) - } - - /// Revoke (delete) a session by token. - /// Supports both legacy tokens (token is key) and signed tokens (extract session_id). - pub async fn revoke_session(&self, token: &str) { - if token.starts_with("session-") { - self.sessions.write().await.remove(token); - return; - } - // For signed token, try to extract session_id - if let Ok(payload) = verify_signed_token(token, &self.secret) { - self.sessions.write().await.remove(&payload.session_id); - } - } - - /// Remove all expired sessions from the store. - pub async fn cleanup_expired(&self) { - let now = Utc::now(); - self.sessions.write().await.retain(|_, s| s.expires_at > now); - } - - // --- Private helpers --- - - fn create_signed_token(&self, session_id: &str, username: &str, role: &str, iat: i64, exp: i64) -> String { - let payload = SessionPayload { - session_id: session_id.to_string(), - username: username.to_string(), - role: role.to_string(), - iat, - exp, - version: TOKEN_VERSION.to_string(), - }; - sign_token(&payload, &self.secret) - } - - async fn update_session_expiry(&self, session_id: &str, new_expires_at: DateTime) { - let mut sessions = self.sessions.write().await; - if let Some(session) = sessions.get_mut(session_id) { - session.expires_at = new_expires_at; - } - } -} - -/// Load session secret from environment variable SESSION_SECRET (hex or base64 encoded). -/// If not set, generates a random 32-byte secret and logs a warning. -fn load_session_secret() -> Vec { - let secret_str = env::var("SESSION_SECRET").unwrap_or_else(|_| { - // Also check LLM_PROXY__SESSION_SECRET for consistency with config prefix - env::var("LLM_PROXY__SESSION_SECRET").unwrap_or_else(|_| { - // Generate a random secret (32 bytes) and encode as hex - use rand::RngCore; - let mut bytes = [0u8; 32]; - rand::rng().fill_bytes(&mut bytes); - let hex_secret = hex::encode(bytes); - tracing::warn!( - "SESSION_SECRET environment variable not set. Using a randomly generated secret. \ - This will invalidate all sessions on restart. Set SESSION_SECRET to a fixed hex or base64 encoded 32-byte value." - ); - hex_secret - }) - }); - - // Decode hex or base64 - hex::decode(&secret_str) - .or_else(|_| URL_SAFE.decode(&secret_str)) - .or_else(|_| base64::engine::general_purpose::STANDARD.decode(&secret_str)) - .unwrap_or_else(|_| { - panic!("SESSION_SECRET must be hex or base64 encoded (32 bytes)"); - }) -} - -/// Sign a session payload and return a token string in format base64_url(payload).base64_url(signature). -fn sign_token(payload: &SessionPayload, secret: &[u8]) -> String { - let json = serde_json::to_vec(payload).expect("Failed to serialize payload"); - let payload_b64 = URL_SAFE.encode(&json); - let mut mac = Hmac::::new_from_slice(secret).expect("HMAC can take key of any size"); - mac.update(&json); - let signature = mac.finalize().into_bytes(); - let signature_b64 = URL_SAFE.encode(signature); - format!("{}.{}", payload_b64, signature_b64) -} - -/// Verify a signed token and return the decoded payload if valid. -fn verify_signed_token(token: &str, secret: &[u8]) -> Result { - let parts: Vec<&str> = token.split('.').collect(); - if parts.len() != 2 { - return Err(TokenError::InvalidFormat); - } - let payload_b64 = parts[0]; - let signature_b64 = parts[1]; - - let json = URL_SAFE.decode(payload_b64).map_err(|_| TokenError::InvalidFormat)?; - let signature = URL_SAFE.decode(signature_b64).map_err(|_| TokenError::InvalidFormat)?; - - // Verify HMAC - let mut mac = Hmac::::new_from_slice(secret).expect("HMAC can take key of any size"); - mac.update(&json); - // Convert signature slice to GenericArray - let tag = GenericArray::from_slice(&signature); - mac.verify(tag).map_err(|_| TokenError::InvalidSignature)?; - - // Deserialize payload - let payload: SessionPayload = serde_json::from_slice(&json).map_err(|_| TokenError::InvalidPayload)?; - Ok(payload) -} - -#[derive(Debug)] -enum TokenError { - InvalidFormat, - InvalidSignature, - InvalidPayload, -} - -#[cfg(test)] -mod tests { - use super::*; - use std::env; - - #[test] - fn test_sign_and_verify_token() { - let secret = b"test-secret-must-be-32-bytes-long!"; - let payload = SessionPayload { - session_id: "test-session".to_string(), - username: "testuser".to_string(), - role: "user".to_string(), - iat: 1000, - exp: 2000, - version: TOKEN_VERSION.to_string(), - }; - let token = sign_token(&payload, secret); - let verified = verify_signed_token(&token, secret).unwrap(); - assert_eq!(verified.session_id, payload.session_id); - assert_eq!(verified.username, payload.username); - assert_eq!(verified.role, payload.role); - assert_eq!(verified.iat, payload.iat); - assert_eq!(verified.exp, payload.exp); - assert_eq!(verified.version, payload.version); - } - - #[test] - fn test_tampered_token() { - let secret = b"test-secret-must-be-32-bytes-long!"; - let payload = SessionPayload { - session_id: "test-session".to_string(), - username: "testuser".to_string(), - role: "user".to_string(), - iat: 1000, - exp: 2000, - version: TOKEN_VERSION.to_string(), - }; - let mut token = sign_token(&payload, secret); - // Tamper with payload part - let mut parts: Vec<&str> = token.split('.').collect(); - let mut payload_bytes = URL_SAFE.decode(parts[0]).unwrap(); - payload_bytes[0] ^= 0xFF; // flip some bits - let tampered_payload = URL_SAFE.encode(payload_bytes); - parts[0] = &tampered_payload; - token = parts.join("."); - assert!(verify_signed_token(&token, secret).is_err()); - } - - #[test] - fn test_load_session_secret_from_env() { - unsafe { - env::set_var("SESSION_SECRET", hex::encode([0xAA; 32])); - } - let secret = load_session_secret(); - assert_eq!(secret, vec![0xAA; 32]); - unsafe { - env::remove_var("SESSION_SECRET"); - } - } -} \ No newline at end of file diff --git a/src/dashboard/system.rs b/src/dashboard/system.rs deleted file mode 100644 index 2cfdc503..00000000 --- a/src/dashboard/system.rs +++ /dev/null @@ -1,405 +0,0 @@ -use axum::{extract::State, response::Json}; -use chrono; -use serde_json; -use sqlx::Row; -use std::collections::HashMap; -use tracing::warn; - -use super::{ApiResponse, DashboardState}; - -/// Read a value from /proc files, returning None on any failure. -fn read_proc_file(path: &str) -> Option { - std::fs::read_to_string(path).ok() -} - -pub(super) async fn handle_system_health( - State(state): State, - headers: axum::http::HeaderMap, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let mut components = HashMap::new(); - components.insert("api_server".to_string(), "online".to_string()); - components.insert("database".to_string(), "online".to_string()); - - // Check provider health via circuit breakers - let provider_ids: Vec = state - .app_state - .provider_manager - .get_all_providers() - .await - .iter() - .map(|p| p.name().to_string()) - .collect(); - - for p_id in provider_ids { - if state - .app_state - .rate_limit_manager - .check_provider_request(&p_id) - .await - .unwrap_or(true) - { - components.insert(p_id, "online".to_string()); - } else { - components.insert(p_id, "degraded".to_string()); - } - } - - // Read real memory usage from /proc/self/status - let memory_mb = read_proc_file("/proc/self/status") - .and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string())) - .and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::().ok())) - .map(|kb| kb / 1024.0) - .unwrap_or(0.0); - - // Get real database pool stats - let db_pool_size = state.app_state.db_pool.size(); - let db_pool_idle = state.app_state.db_pool.num_idle(); - - Json(ApiResponse::success(serde_json::json!({ - "status": "healthy", - "timestamp": chrono::Utc::now().to_rfc3339(), - "components": components, - "metrics": { - "memory_usage_mb": (memory_mb * 10.0).round() / 10.0, - "db_connections_active": db_pool_size - db_pool_idle as u32, - "db_connections_idle": db_pool_idle, - } - }))) -} - -/// Real system metrics from /proc (Linux only). -pub(super) async fn handle_system_metrics( - State(state): State, - headers: axum::http::HeaderMap, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - // --- CPU usage (aggregate across all cores) --- - // /proc/stat first line: cpu user nice system idle iowait irq softirq steal guest guest_nice - let cpu_percent = read_proc_file("/proc/stat") - .and_then(|s| { - let line = s.lines().find(|l| l.starts_with("cpu "))?.to_string(); - let fields: Vec = line - .split_whitespace() - .skip(1) - .filter_map(|v| v.parse().ok()) - .collect(); - if fields.len() >= 4 { - let idle = fields[3]; - let total: u64 = fields.iter().sum(); - if total > 0 { - Some(((total - idle) as f64 / total as f64 * 100.0 * 10.0).round() / 10.0) - } else { - None - } - } else { - None - } - }) - .unwrap_or(0.0); - - // --- Memory (system-wide from /proc/meminfo) --- - let meminfo = read_proc_file("/proc/meminfo").unwrap_or_default(); - let parse_meminfo = |key: &str| -> u64 { - meminfo - .lines() - .find(|l| l.starts_with(key)) - .and_then(|l| l.split_whitespace().nth(1)) - .and_then(|v| v.parse::().ok()) - .unwrap_or(0) - }; - let mem_total_kb = parse_meminfo("MemTotal:"); - let mem_available_kb = parse_meminfo("MemAvailable:"); - let mem_used_kb = mem_total_kb.saturating_sub(mem_available_kb); - let mem_total_mb = mem_total_kb as f64 / 1024.0; - let mem_used_mb = mem_used_kb as f64 / 1024.0; - let mem_percent = if mem_total_kb > 0 { - (mem_used_kb as f64 / mem_total_kb as f64 * 100.0 * 10.0).round() / 10.0 - } else { - 0.0 - }; - - // --- Process-specific memory (VmRSS) --- - let process_rss_mb = read_proc_file("/proc/self/status") - .and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string())) - .and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::().ok())) - .map(|kb| (kb / 1024.0 * 10.0).round() / 10.0) - .unwrap_or(0.0); - - // --- Disk usage of the data directory --- - let (disk_total_gb, disk_used_gb, disk_percent) = { - // statvfs via libc would be ideal; use df as a simple fallback - std::process::Command::new("df") - .args(["-BM", "--output=size,used,pcent", "."]) - .output() - .ok() - .and_then(|o| { - let out = String::from_utf8_lossy(&o.stdout); - let line = out.lines().nth(1)?.to_string(); - let parts: Vec<&str> = line.split_whitespace().collect(); - if parts.len() >= 3 { - let total = parts[0].trim_end_matches('M').parse::().unwrap_or(0.0) / 1024.0; - let used = parts[1].trim_end_matches('M').parse::().unwrap_or(0.0) / 1024.0; - let pct = parts[2].trim_end_matches('%').parse::().unwrap_or(0.0); - Some(((total * 10.0).round() / 10.0, (used * 10.0).round() / 10.0, pct)) - } else { - None - } - }) - .unwrap_or((0.0, 0.0, 0.0)) - }; - - // --- Uptime --- - let uptime_seconds = read_proc_file("/proc/uptime") - .and_then(|s| s.split_whitespace().next().and_then(|v| v.parse::().ok())) - .unwrap_or(0.0) as u64; - - // --- Load average --- - let (load_1, load_5, load_15) = read_proc_file("/proc/loadavg") - .and_then(|s| { - let parts: Vec<&str> = s.split_whitespace().collect(); - if parts.len() >= 3 { - Some(( - parts[0].parse::().unwrap_or(0.0), - parts[1].parse::().unwrap_or(0.0), - parts[2].parse::().unwrap_or(0.0), - )) - } else { - None - } - }) - .unwrap_or((0.0, 0.0, 0.0)); - - // --- Network (from /proc/net/dev, aggregate non-lo interfaces) --- - let (net_rx_bytes, net_tx_bytes) = read_proc_file("/proc/net/dev") - .map(|s| { - s.lines() - .skip(2) // skip header lines - .filter(|l| !l.trim().starts_with("lo:")) - .fold((0u64, 0u64), |(rx, tx), line| { - let parts: Vec<&str> = line.split_whitespace().collect(); - if parts.len() >= 10 { - let r = parts[1].parse::().unwrap_or(0); - let t = parts[9].parse::().unwrap_or(0); - (rx + r, tx + t) - } else { - (rx, tx) - } - }) - }) - .unwrap_or((0, 0)); - - // --- Database pool --- - let db_pool_size = state.app_state.db_pool.size(); - let db_pool_idle = state.app_state.db_pool.num_idle(); - - // --- Active WebSocket listeners --- - let ws_listeners = state.app_state.dashboard_tx.receiver_count(); - - Json(ApiResponse::success(serde_json::json!({ - "cpu": { - "usage_percent": cpu_percent, - "load_average": [load_1, load_5, load_15], - }, - "memory": { - "total_mb": (mem_total_mb * 10.0).round() / 10.0, - "used_mb": (mem_used_mb * 10.0).round() / 10.0, - "usage_percent": mem_percent, - "process_rss_mb": process_rss_mb, - }, - "disk": { - "total_gb": disk_total_gb, - "used_gb": disk_used_gb, - "usage_percent": disk_percent, - }, - "network": { - "rx_bytes": net_rx_bytes, - "tx_bytes": net_tx_bytes, - }, - "uptime_seconds": uptime_seconds, - "connections": { - "db_active": db_pool_size - db_pool_idle as u32, - "db_idle": db_pool_idle, - "websocket_listeners": ws_listeners, - }, - "timestamp": chrono::Utc::now().to_rfc3339(), - }))) -} - -pub(super) async fn handle_system_logs( - State(state): State, - headers: axum::http::HeaderMap, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - - let result = sqlx::query( - r#" - SELECT - id, - timestamp, - client_id, - provider, - model, - prompt_tokens, - completion_tokens, - reasoning_tokens, - total_tokens, - cache_read_tokens, - cache_write_tokens, - cost, - status, - error_message, - duration_ms - FROM llm_requests - ORDER BY timestamp DESC - LIMIT 100 - "#, - ) - .fetch_all(pool) - .await; - - match result { - Ok(rows) => { - let logs: Vec = rows - .into_iter() - .map(|row| { - serde_json::json!({ - "id": row.get::("id"), - "timestamp": row.get::, _>("timestamp"), - "client_id": row.get::("client_id"), - "provider": row.get::("provider"), - "model": row.get::("model"), - "prompt_tokens": row.get::("prompt_tokens"), - "completion_tokens": row.get::("completion_tokens"), - "reasoning_tokens": row.get::("reasoning_tokens"), - "cache_read_tokens": row.get::("cache_read_tokens"), - "cache_write_tokens": row.get::("cache_write_tokens"), - "tokens": row.get::("total_tokens"), - "cost": row.get::("cost"), - "status": row.get::("status"), - "error": row.get::, _>("error_message"), - "duration": row.get::("duration_ms"), - }) - }) - .collect(); - - Json(ApiResponse::success(serde_json::json!(logs))) - } - Err(e) => { - warn!("Failed to fetch system logs: {}", e); - Json(ApiResponse::error("Failed to fetch system logs".to_string())) - } - } -} - -pub(super) async fn handle_system_backup( - State(state): State, - headers: axum::http::HeaderMap, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - let backup_id = format!("backup-{}", chrono::Utc::now().timestamp()); - let backup_path = format!("data/{}.db", backup_id); - - // Ensure the data directory exists - if let Err(e) = std::fs::create_dir_all("data") { - return Json(ApiResponse::error(format!("Failed to create backup directory: {}", e))); - } - - // Use SQLite VACUUM INTO for a consistent backup - let result = sqlx::query(&format!("VACUUM INTO '{}'", backup_path)) - .execute(pool) - .await; - - match result { - Ok(_) => { - // Get backup file size - let size_bytes = std::fs::metadata(&backup_path).map(|m| m.len()).unwrap_or(0); - - Json(ApiResponse::success(serde_json::json!({ - "success": true, - "message": "Backup completed successfully", - "backup_id": backup_id, - "backup_path": backup_path, - "size_bytes": size_bytes, - }))) - } - Err(e) => { - warn!("Database backup failed: {}", e); - Json(ApiResponse::error(format!("Backup failed: {}", e))) - } - } -} - -pub(super) async fn handle_get_settings( - State(state): State, - headers: axum::http::HeaderMap, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let registry = &state.app_state.model_registry; - let provider_count = registry.providers.len(); - let model_count: usize = registry.providers.values().map(|p| p.models.len()).sum(); - - Json(ApiResponse::success(serde_json::json!({ - "server": { - "auth_tokens": state.app_state.auth_tokens.iter().map(|t| mask_token(t)).collect::>(), - "version": env!("CARGO_PKG_VERSION"), - }, - "registry": { - "provider_count": provider_count, - "model_count": model_count, - }, - "database": { - "type": "SQLite", - } - }))) -} - -pub(super) async fn handle_update_settings( - State(state): State, - headers: axum::http::HeaderMap, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - Json(ApiResponse::error( - "Changing settings at runtime is not yet supported. Please update your config file and restart the server." - .to_string(), - )) -} - -// Helper functions -fn mask_token(token: &str) -> String { - if token.len() <= 8 { - return "*****".to_string(); - } - - let masked_len = token.len().min(12); - let visible_len = 4; - let mask_len = masked_len - visible_len; - - format!("{}{}", "*".repeat(mask_len), &token[token.len() - visible_len..]) -} diff --git a/src/dashboard/usage.rs b/src/dashboard/usage.rs deleted file mode 100644 index de7a532c..00000000 --- a/src/dashboard/usage.rs +++ /dev/null @@ -1,526 +0,0 @@ -use axum::{ - extract::{Query, State}, - response::Json, -}; -use chrono; -use serde::Deserialize; -use serde_json; -use sqlx::Row; -use tracing::warn; - -use super::{ApiResponse, DashboardState}; - -/// Query parameters for time-based filtering on usage endpoints. -#[derive(Debug, Deserialize, Default)] -pub(super) struct UsagePeriodFilter { - /// Preset period: "today", "24h", "7d", "30d", "all" (default: "all") - pub period: Option, - /// Custom range start (ISO 8601, e.g. "2025-06-01T00:00:00Z") - pub from: Option, - /// Custom range end (ISO 8601) - pub to: Option, -} - -impl UsagePeriodFilter { - /// Returns `(sql_fragment, bind_values)` for a WHERE clause. - /// The fragment is either empty (no filter) or " AND timestamp >= ? [AND timestamp <= ?]". - fn to_sql(&self) -> (String, Vec) { - let period = self.period.as_deref().unwrap_or("all"); - - if period == "custom" { - let mut clause = String::new(); - let mut binds = Vec::new(); - if let Some(ref from) = self.from { - clause.push_str(" AND timestamp >= ?"); - binds.push(from.clone()); - } - if let Some(ref to) = self.to { - clause.push_str(" AND timestamp <= ?"); - binds.push(to.clone()); - } - return (clause, binds); - } - - let now = chrono::Utc::now(); - let cutoff = match period { - "today" => { - // Start of today (UTC) - let today = now.format("%Y-%m-%dT00:00:00Z").to_string(); - Some(today) - } - "24h" => Some((now - chrono::Duration::hours(24)).to_rfc3339()), - "7d" => Some((now - chrono::Duration::days(7)).to_rfc3339()), - "30d" => Some((now - chrono::Duration::days(30)).to_rfc3339()), - _ => None, // "all" or unrecognized - }; - - match cutoff { - Some(ts) => (" AND timestamp >= ?".to_string(), vec![ts]), - None => (String::new(), vec![]), - } - } - - /// Determine the time-series granularity label for grouping. - fn granularity(&self) -> &'static str { - match self.period.as_deref().unwrap_or("all") { - "today" | "24h" => "hour", - _ => "day", - } - } -} - -pub(super) async fn handle_usage_summary( - State(state): State, - headers: axum::http::HeaderMap, - Query(filter): Query, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - let (period_clause, period_binds) = filter.to_sql(); - - // Total stats (filtered by period) - let period_sql = format!( - r#" - SELECT - COUNT(*) as total_requests, - COALESCE(SUM(total_tokens), 0) as total_tokens, - COALESCE(SUM(cost), 0.0) as total_cost, - COUNT(DISTINCT client_id) as active_clients, - COALESCE(SUM(cache_read_tokens), 0) as total_cache_read, - COALESCE(SUM(cache_write_tokens), 0) as total_cache_write - FROM llm_requests - WHERE 1=1 {} - "#, - period_clause - ); - let mut q = sqlx::query(&period_sql); - for b in &period_binds { - q = q.bind(b); - } - let total_stats = q.fetch_one(pool); - - // Today's stats - let today = chrono::Utc::now().format("%Y-%m-%d").to_string(); - let today_stats = sqlx::query( - r#" - SELECT - COUNT(*) as today_requests, - COALESCE(SUM(total_tokens), 0) as today_tokens, - COALESCE(SUM(cost), 0.0) as today_cost - FROM llm_requests - WHERE strftime('%Y-%m-%d', timestamp) = ? - "#, - ) - .bind(today) - .fetch_one(pool); - - // Error stats - let error_stats = sqlx::query( - r#" - SELECT - COUNT(*) as total, - SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as errors - FROM llm_requests - "#, - ) - .fetch_one(pool); - - // Average response time - let avg_response = sqlx::query( - r#" - SELECT COALESCE(AVG(duration_ms), 0.0) as avg_duration - FROM llm_requests - WHERE status = 'success' - "#, - ) - .fetch_one(pool); - - let results = tokio::join!(total_stats, today_stats, error_stats, avg_response); - - match results { - (Ok(t), Ok(d), Ok(e), Ok(a)) => { - let total_requests: i64 = t.get("total_requests"); - let total_tokens: i64 = t.get("total_tokens"); - let total_cost: f64 = t.get("total_cost"); - let active_clients: i64 = t.get("active_clients"); - let total_cache_read: i64 = t.get("total_cache_read"); - let total_cache_write: i64 = t.get("total_cache_write"); - - let today_requests: i64 = d.get("today_requests"); - let today_cost: f64 = d.get("today_cost"); - - let total_count: i64 = e.get("total"); - let error_count: i64 = e.get("errors"); - let error_rate = if total_count > 0 { - (error_count as f64 / total_count as f64) * 100.0 - } else { - 0.0 - }; - - let avg_response_time: f64 = a.get("avg_duration"); - - Json(ApiResponse::success(serde_json::json!({ - "total_requests": total_requests, - "total_tokens": total_tokens, - "total_cost": total_cost, - "active_clients": active_clients, - "today_requests": today_requests, - "today_cost": today_cost, - "error_rate": error_rate, - "avg_response_time": avg_response_time, - "total_cache_read_tokens": total_cache_read, - "total_cache_write_tokens": total_cache_write, - }))) - } - (t_res, d_res, e_res, a_res) => { - if let Err(e) = t_res { warn!("Total stats query failed: {}", e); } - if let Err(e) = d_res { warn!("Today stats query failed: {}", e); } - if let Err(e) = e_res { warn!("Error stats query failed: {}", e); } - if let Err(e) = a_res { warn!("Avg response query failed: {}", e); } - Json(ApiResponse::error("Failed to fetch usage statistics".to_string())) - } - } -} - -pub(super) async fn handle_time_series( - State(state): State, - headers: axum::http::HeaderMap, - Query(filter): Query, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - let (period_clause, period_binds) = filter.to_sql(); - let granularity = filter.granularity(); - - // Determine the strftime format and default lookback - let (strftime_fmt, _label_key, default_lookback) = match granularity { - "hour" => ("%H:00", "hour", chrono::Duration::hours(24)), - _ => ("%Y-%m-%d", "day", chrono::Duration::days(30)), - }; - - // If no period filter, apply a sensible default lookback - let (clause, binds) = if period_clause.is_empty() { - let cutoff = (chrono::Utc::now() - default_lookback).to_rfc3339(); - (" AND timestamp >= ?".to_string(), vec![cutoff]) - } else { - (period_clause, period_binds) - }; - - let sql = format!( - r#" - SELECT - strftime('{strftime_fmt}', timestamp) as bucket, - COUNT(*) as requests, - COALESCE(SUM(total_tokens), 0) as tokens, - COALESCE(SUM(cost), 0.0) as cost - FROM llm_requests - WHERE 1=1 {clause} - GROUP BY bucket - ORDER BY bucket - "#, - ); - - let mut q = sqlx::query(&sql); - for b in &binds { - q = q.bind(b); - } - - let result = q.fetch_all(pool).await; - - match result { - Ok(rows) => { - let mut series = Vec::new(); - - for row in rows { - let bucket: String = row.get("bucket"); - let requests: i64 = row.get("requests"); - let tokens: i64 = row.get("tokens"); - let cost: f64 = row.get("cost"); - - series.push(serde_json::json!({ - "time": bucket, - "requests": requests, - "tokens": tokens, - "cost": cost, - })); - } - - Json(ApiResponse::success(serde_json::json!({ - "series": series, - "period": filter.period.as_deref().unwrap_or("all"), - "granularity": granularity, - }))) - } - Err(e) => { - warn!("Failed to fetch time series data: {}", e); - Json(ApiResponse::error("Failed to fetch time series data".to_string())) - } - } -} - -pub(super) async fn handle_clients_usage( - State(state): State, - headers: axum::http::HeaderMap, - Query(filter): Query, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - let (period_clause, period_binds) = filter.to_sql(); - - let sql = format!( - r#" - SELECT - client_id, - COUNT(*) as requests, - COALESCE(SUM(total_tokens), 0) as tokens, - COALESCE(SUM(cost), 0.0) as cost, - MAX(timestamp) as last_request - FROM llm_requests - WHERE 1=1 {} - GROUP BY client_id - ORDER BY requests DESC - "#, - period_clause - ); - - let mut q = sqlx::query(&sql); - for b in &period_binds { - q = q.bind(b); - } - - let result = q.fetch_all(pool).await; - - match result { - Ok(rows) => { - let mut client_usage = Vec::new(); - - for row in rows { - let client_id: String = row.get("client_id"); - let requests: i64 = row.get("requests"); - let tokens: i64 = row.get("tokens"); - let cost: f64 = row.get("cost"); - let last_request: Option> = row.get("last_request"); - - client_usage.push(serde_json::json!({ - "client_id": client_id, - "client_name": client_id, - "requests": requests, - "tokens": tokens, - "cost": cost, - "last_request": last_request, - })); - } - - Json(ApiResponse::success(serde_json::json!(client_usage))) - } - Err(e) => { - warn!("Failed to fetch client usage data: {}", e); - Json(ApiResponse::error("Failed to fetch client usage data".to_string())) - } - } -} - -pub(super) async fn handle_providers_usage( - State(state): State, - headers: axum::http::HeaderMap, - Query(filter): Query, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - let (period_clause, period_binds) = filter.to_sql(); - - let sql = format!( - r#" - SELECT - provider, - COUNT(*) as requests, - COALESCE(SUM(total_tokens), 0) as tokens, - COALESCE(SUM(cost), 0.0) as cost, - COALESCE(SUM(cache_read_tokens), 0) as cache_read, - COALESCE(SUM(cache_write_tokens), 0) as cache_write - FROM llm_requests - WHERE 1=1 {} - GROUP BY provider - ORDER BY requests DESC - "#, - period_clause - ); - - let mut q = sqlx::query(&sql); - for b in &period_binds { - q = q.bind(b); - } - - let result = q.fetch_all(pool).await; - - match result { - Ok(rows) => { - let mut provider_usage = Vec::new(); - - for row in rows { - let provider: String = row.get("provider"); - let requests: i64 = row.get("requests"); - let tokens: i64 = row.get("tokens"); - let cost: f64 = row.get("cost"); - let cache_read: i64 = row.get("cache_read"); - let cache_write: i64 = row.get("cache_write"); - - provider_usage.push(serde_json::json!({ - "provider": provider, - "requests": requests, - "tokens": tokens, - "cost": cost, - "cache_read_tokens": cache_read, - "cache_write_tokens": cache_write, - })); - } - - Json(ApiResponse::success(serde_json::json!(provider_usage))) - } - Err(e) => { - warn!("Failed to fetch provider usage data: {}", e); - Json(ApiResponse::error("Failed to fetch provider usage data".to_string())) - } - } -} - -pub(super) async fn handle_detailed_usage( - State(state): State, - headers: axum::http::HeaderMap, - Query(filter): Query, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - let (period_clause, period_binds) = filter.to_sql(); - - let sql = format!( - r#" - SELECT - strftime('%Y-%m-%d', timestamp) as date, - client_id, - provider, - model, - COUNT(*) as requests, - COALESCE(SUM(total_tokens), 0) as tokens, - COALESCE(SUM(cost), 0.0) as cost, - COALESCE(SUM(cache_read_tokens), 0) as cache_read, - COALESCE(SUM(cache_write_tokens), 0) as cache_write - FROM llm_requests - WHERE 1=1 {} - GROUP BY date, client_id, provider, model - ORDER BY date DESC - LIMIT 200 - "#, - period_clause - ); - - let mut q = sqlx::query(&sql); - for b in &period_binds { - q = q.bind(b); - } - - let result = q.fetch_all(pool).await; - - match result { - Ok(rows) => { - let usage: Vec = rows - .into_iter() - .map(|row| { - serde_json::json!({ - "date": row.get::("date"), - "client": row.get::("client_id"), - "provider": row.get::("provider"), - "model": row.get::("model"), - "requests": row.get::("requests"), - "tokens": row.get::("tokens"), - "cost": row.get::("cost"), - "cache_read_tokens": row.get::("cache_read"), - "cache_write_tokens": row.get::("cache_write"), - }) - }) - .collect(); - - Json(ApiResponse::success(serde_json::json!(usage))) - } - Err(e) => { - warn!("Failed to fetch detailed usage: {}", e); - Json(ApiResponse::error("Failed to fetch detailed usage".to_string())) - } - } -} - -pub(super) async fn handle_analytics_breakdown( - State(state): State, - headers: axum::http::HeaderMap, - Query(filter): Query, -) -> Json> { - let (_session, _) = match super::auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - let (period_clause, period_binds) = filter.to_sql(); - - // Model breakdown - let model_sql = format!( - "SELECT model as label, COUNT(*) as value FROM llm_requests WHERE 1=1 {} GROUP BY model ORDER BY value DESC", - period_clause - ); - let mut mq = sqlx::query(&model_sql); - for b in &period_binds { - mq = mq.bind(b); - } - let models = mq.fetch_all(pool); - - // Client breakdown - let client_sql = format!( - "SELECT client_id as label, COUNT(*) as value FROM llm_requests WHERE 1=1 {} GROUP BY client_id ORDER BY value DESC", - period_clause - ); - let mut cq = sqlx::query(&client_sql); - for b in &period_binds { - cq = cq.bind(b); - } - let clients = cq.fetch_all(pool); - - match tokio::join!(models, clients) { - (Ok(m_rows), Ok(c_rows)) => { - let model_breakdown: Vec = m_rows - .into_iter() - .map(|r| serde_json::json!({ "label": r.get::("label"), "value": r.get::("value") })) - .collect(); - - let client_breakdown: Vec = c_rows - .into_iter() - .map(|r| serde_json::json!({ "label": r.get::("label"), "value": r.get::("value") })) - .collect(); - - Json(ApiResponse::success(serde_json::json!({ - "models": model_breakdown, - "clients": client_breakdown - }))) - } - _ => Json(ApiResponse::error("Failed to fetch analytics breakdown".to_string())), - } -} diff --git a/src/dashboard/users.rs b/src/dashboard/users.rs deleted file mode 100644 index d71a7899..00000000 --- a/src/dashboard/users.rs +++ /dev/null @@ -1,290 +0,0 @@ -use axum::{ - extract::{Path, State}, - response::Json, -}; -use serde::Deserialize; -use sqlx::Row; -use tracing::warn; - -use super::{ApiResponse, DashboardState, auth}; - -// ── User management endpoints (admin-only) ────────────────────────── - -pub(super) async fn handle_get_users( - State(state): State, - headers: axum::http::HeaderMap, -) -> Json> { - let (_session, _) = match auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - - let result = sqlx::query( - "SELECT id, username, display_name, role, must_change_password, created_at FROM users ORDER BY created_at ASC", - ) - .fetch_all(pool) - .await; - - match result { - Ok(rows) => { - let users: Vec = rows - .into_iter() - .map(|row| { - let username: String = row.get("username"); - let display_name: Option = row.get("display_name"); - serde_json::json!({ - "id": row.get::("id"), - "username": &username, - "display_name": display_name.as_deref().unwrap_or(&username), - "role": row.get::("role"), - "must_change_password": row.get::("must_change_password"), - "created_at": row.get::, _>("created_at"), - }) - }) - .collect(); - - Json(ApiResponse::success(serde_json::json!(users))) - } - Err(e) => { - warn!("Failed to fetch users: {}", e); - Json(ApiResponse::error("Failed to fetch users".to_string())) - } - } -} - -#[derive(Deserialize)] -pub(super) struct CreateUserRequest { - pub(super) username: String, - pub(super) password: String, - pub(super) display_name: Option, - pub(super) role: Option, -} - -pub(super) async fn handle_create_user( - State(state): State, - headers: axum::http::HeaderMap, - Json(payload): Json, -) -> Json> { - let (_session, _) = match auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - - // Validate role - let role = payload.role.as_deref().unwrap_or("viewer"); - if role != "admin" && role != "viewer" { - return Json(ApiResponse::error("Role must be 'admin' or 'viewer'".to_string())); - } - - // Validate username - let username = payload.username.trim(); - if username.is_empty() || username.len() > 64 { - return Json(ApiResponse::error("Username must be 1-64 characters".to_string())); - } - - // Validate password - if payload.password.len() < 4 { - return Json(ApiResponse::error("Password must be at least 4 characters".to_string())); - } - - let password_hash = match bcrypt::hash(&payload.password, 12) { - Ok(h) => h, - Err(_) => return Json(ApiResponse::error("Failed to hash password".to_string())), - }; - - let result = sqlx::query( - r#" - INSERT INTO users (username, password_hash, display_name, role, must_change_password) - VALUES (?, ?, ?, ?, TRUE) - RETURNING id, username, display_name, role, must_change_password, created_at - "#, - ) - .bind(username) - .bind(&password_hash) - .bind(&payload.display_name) - .bind(role) - .fetch_one(pool) - .await; - - match result { - Ok(row) => { - let uname: String = row.get("username"); - let display_name: Option = row.get("display_name"); - Json(ApiResponse::success(serde_json::json!({ - "id": row.get::("id"), - "username": &uname, - "display_name": display_name.as_deref().unwrap_or(&uname), - "role": row.get::("role"), - "must_change_password": row.get::("must_change_password"), - "created_at": row.get::, _>("created_at"), - }))) - } - Err(e) => { - let msg = if e.to_string().contains("UNIQUE") { - format!("Username '{}' already exists", username) - } else { - format!("Failed to create user: {}", e) - }; - warn!("Failed to create user: {}", e); - Json(ApiResponse::error(msg)) - } - } -} - -#[derive(Deserialize)] -pub(super) struct UpdateUserRequest { - pub(super) display_name: Option, - pub(super) role: Option, - pub(super) password: Option, - pub(super) must_change_password: Option, -} - -pub(super) async fn handle_update_user( - State(state): State, - headers: axum::http::HeaderMap, - Path(id): Path, - Json(payload): Json, -) -> Json> { - let (_session, _) = match auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - - // Validate role if provided - if let Some(ref role) = payload.role { - if role != "admin" && role != "viewer" { - return Json(ApiResponse::error("Role must be 'admin' or 'viewer'".to_string())); - } - } - - // Build dynamic update - let mut sets = Vec::new(); - let mut string_binds: Vec = Vec::new(); - let mut has_password = false; - let mut has_must_change = false; - - if let Some(ref display_name) = payload.display_name { - sets.push("display_name = ?"); - string_binds.push(display_name.clone()); - } - if let Some(ref role) = payload.role { - sets.push("role = ?"); - string_binds.push(role.clone()); - } - if let Some(ref password) = payload.password { - if password.len() < 4 { - return Json(ApiResponse::error("Password must be at least 4 characters".to_string())); - } - let hash = match bcrypt::hash(password, 12) { - Ok(h) => h, - Err(_) => return Json(ApiResponse::error("Failed to hash password".to_string())), - }; - sets.push("password_hash = ?"); - string_binds.push(hash); - has_password = true; - } - if let Some(mcp) = payload.must_change_password { - sets.push("must_change_password = ?"); - has_must_change = true; - let _ = mcp; // used below in bind - } - - if sets.is_empty() { - return Json(ApiResponse::error("No fields to update".to_string())); - } - - let sql = format!("UPDATE users SET {} WHERE id = ?", sets.join(", ")); - let mut query = sqlx::query(&sql); - - for b in &string_binds { - query = query.bind(b); - } - if has_must_change { - query = query.bind(payload.must_change_password.unwrap()); - } - let _ = has_password; // consumed above via string_binds - query = query.bind(id); - - match query.execute(pool).await { - Ok(result) => { - if result.rows_affected() == 0 { - return Json(ApiResponse::error("User not found".to_string())); - } - // Fetch updated user - let row = sqlx::query( - "SELECT id, username, display_name, role, must_change_password, created_at FROM users WHERE id = ?", - ) - .bind(id) - .fetch_one(pool) - .await; - - match row { - Ok(row) => { - let uname: String = row.get("username"); - let display_name: Option = row.get("display_name"); - Json(ApiResponse::success(serde_json::json!({ - "id": row.get::("id"), - "username": &uname, - "display_name": display_name.as_deref().unwrap_or(&uname), - "role": row.get::("role"), - "must_change_password": row.get::("must_change_password"), - "created_at": row.get::, _>("created_at"), - }))) - } - Err(_) => Json(ApiResponse::success(serde_json::json!({ "message": "User updated" }))), - } - } - Err(e) => { - warn!("Failed to update user: {}", e); - Json(ApiResponse::error(format!("Failed to update user: {}", e))) - } - } -} - -pub(super) async fn handle_delete_user( - State(state): State, - headers: axum::http::HeaderMap, - Path(id): Path, -) -> Json> { - let (session, _) = match auth::require_admin(&state, &headers).await { - Ok((session, new_token)) => (session, new_token), - Err(e) => return e, - }; - - let pool = &state.app_state.db_pool; - - // Don't allow deleting yourself - let target_username: Option = - sqlx::query_scalar::<_, String>("SELECT username FROM users WHERE id = ?") - .bind(id) - .fetch_optional(pool) - .await - .unwrap_or(None); - - match target_username { - None => return Json(ApiResponse::error("User not found".to_string())), - Some(ref uname) if uname == &session.username => { - return Json(ApiResponse::error("Cannot delete your own account".to_string())); - } - _ => {} - } - - let result = sqlx::query("DELETE FROM users WHERE id = ?") - .bind(id) - .execute(pool) - .await; - - match result { - Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "User deleted" }))), - Err(e) => { - warn!("Failed to delete user: {}", e); - Json(ApiResponse::error(format!("Failed to delete user: {}", e))) - } - } -} diff --git a/src/dashboard/websocket.rs b/src/dashboard/websocket.rs deleted file mode 100644 index 846ced4b..00000000 --- a/src/dashboard/websocket.rs +++ /dev/null @@ -1,75 +0,0 @@ -use axum::{ - extract::{ - State, - ws::{Message, WebSocket, WebSocketUpgrade}, - }, - response::IntoResponse, -}; -use serde_json; -use tracing::info; - -use super::DashboardState; - -// WebSocket handler -pub(super) async fn handle_websocket(ws: WebSocketUpgrade, State(state): State) -> impl IntoResponse { - ws.on_upgrade(|socket| handle_websocket_connection(socket, state)) -} - -pub(super) async fn handle_websocket_connection(mut socket: WebSocket, state: DashboardState) { - info!("WebSocket connection established"); - - // Subscribe to events from the global bus - let mut rx = state.app_state.dashboard_tx.subscribe(); - - // Send initial connection message - let _ = socket - .send(Message::Text( - serde_json::json!({ - "type": "connected", - "message": "Connected to LLM Proxy Dashboard" - }) - .to_string() - .into(), - )) - .await; - - // Handle incoming messages and broadcast events - loop { - tokio::select! { - // Receive broadcast events - Ok(event) = rx.recv() => { - let Ok(json_str) = serde_json::to_string(&event) else { - continue; - }; - let message = Message::Text(json_str.into()); - if socket.send(message).await.is_err() { - break; - } - } - - // Receive WebSocket messages - result = socket.recv() => { - match result { - Some(Ok(Message::Text(text))) => { - handle_websocket_message(&text, &state).await; - } - _ => break, - } - } - } - } - - info!("WebSocket connection closed"); -} - -pub(super) async fn handle_websocket_message(text: &str, state: &DashboardState) { - // Parse and handle WebSocket messages - if let Ok(data) = serde_json::from_str::(text) - && data.get("type").and_then(|v| v.as_str()) == Some("ping") - { - let _ = state.app_state.dashboard_tx.send(serde_json::json!({ - "type": "pong", - "payload": {} - })); - } -} diff --git a/src/database/mod.rs b/src/database/mod.rs deleted file mode 100644 index 374599df..00000000 --- a/src/database/mod.rs +++ /dev/null @@ -1,275 +0,0 @@ -use anyhow::Result; -use sqlx::sqlite::{SqliteConnectOptions, SqlitePool}; -use std::str::FromStr; -use tracing::info; - -use crate::config::DatabaseConfig; - -pub type DbPool = SqlitePool; - -pub async fn init(config: &DatabaseConfig) -> Result { - // Ensure the database directory exists - if let Some(parent) = config.path.parent() - && !parent.as_os_str().is_empty() - { - tokio::fs::create_dir_all(parent).await?; - } - - let database_path = config.path.to_string_lossy().to_string(); - info!("Connecting to database at {}", database_path); - - let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))? - .create_if_missing(true) - .pragma("foreign_keys", "ON"); - - let pool = SqlitePool::connect_with(options).await?; - - // Run migrations - run_migrations(&pool).await?; - info!("Database migrations completed"); - - Ok(pool) -} - -pub async fn run_migrations(pool: &DbPool) -> Result<()> { - // Create clients table if it doesn't exist - sqlx::query( - r#" - CREATE TABLE IF NOT EXISTS clients ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - client_id TEXT UNIQUE NOT NULL, - name TEXT, - description TEXT, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - is_active BOOLEAN DEFAULT TRUE, - rate_limit_per_minute INTEGER DEFAULT 60, - total_requests INTEGER DEFAULT 0, - total_tokens INTEGER DEFAULT 0, - total_cost REAL DEFAULT 0.0 - ) - "#, - ) - .execute(pool) - .await?; - - // Create llm_requests table if it doesn't exist - sqlx::query( - r#" - CREATE TABLE IF NOT EXISTS llm_requests ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, - client_id TEXT, - provider TEXT, - model TEXT, - prompt_tokens INTEGER, - completion_tokens INTEGER, - reasoning_tokens INTEGER DEFAULT 0, - total_tokens INTEGER, - cost REAL, - has_images BOOLEAN DEFAULT FALSE, - status TEXT DEFAULT 'success', - error_message TEXT, - duration_ms INTEGER, - request_body TEXT, - response_body TEXT, - FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE SET NULL - ) - "#, - ) - .execute(pool) - .await?; - - // Create provider_configs table - sqlx::query( - r#" - CREATE TABLE IF NOT EXISTS provider_configs ( - id TEXT PRIMARY KEY, - display_name TEXT NOT NULL, - enabled BOOLEAN DEFAULT TRUE, - base_url TEXT, - api_key TEXT, - credit_balance REAL DEFAULT 0.0, - low_credit_threshold REAL DEFAULT 5.0, - billing_mode TEXT, - api_key_encrypted BOOLEAN DEFAULT FALSE, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - "#, - ) - .execute(pool) - .await?; - - // Create model_configs table - sqlx::query( - r#" - CREATE TABLE IF NOT EXISTS model_configs ( - id TEXT PRIMARY KEY, - provider_id TEXT NOT NULL, - display_name TEXT, - enabled BOOLEAN DEFAULT TRUE, - prompt_cost_per_m REAL, - completion_cost_per_m REAL, - cache_read_cost_per_m REAL, - cache_write_cost_per_m REAL, - mapping TEXT, - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE - ) - "#, - ) - .execute(pool) - .await?; - - // Create users table for dashboard access - sqlx::query( - r#" - CREATE TABLE IF NOT EXISTS users ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - username TEXT UNIQUE NOT NULL, - password_hash TEXT NOT NULL, - display_name TEXT, - role TEXT DEFAULT 'admin', - must_change_password BOOLEAN DEFAULT FALSE, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP - ) - "#, - ) - .execute(pool) - .await?; - - // Create client_tokens table for DB-based token auth - sqlx::query( - r#" - CREATE TABLE IF NOT EXISTS client_tokens ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - client_id TEXT NOT NULL, - token TEXT NOT NULL UNIQUE, - name TEXT DEFAULT 'default', - is_active BOOLEAN DEFAULT TRUE, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - last_used_at DATETIME, - FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE CASCADE - ) - "#, - ) - .execute(pool) - .await?; - - // Add must_change_password column if it doesn't exist (migration for existing DBs) - let _ = sqlx::query("ALTER TABLE users ADD COLUMN must_change_password BOOLEAN DEFAULT FALSE") - .execute(pool) - .await; - - // Add display_name column if it doesn't exist (migration for existing DBs) - let _ = sqlx::query("ALTER TABLE users ADD COLUMN display_name TEXT") - .execute(pool) - .await; - - // Add cache token columns if they don't exist (migration for existing DBs) - let _ = sqlx::query("ALTER TABLE llm_requests ADD COLUMN cache_read_tokens INTEGER DEFAULT 0") - .execute(pool) - .await; - let _ = sqlx::query("ALTER TABLE llm_requests ADD COLUMN cache_write_tokens INTEGER DEFAULT 0") - .execute(pool) - .await; - let _ = sqlx::query("ALTER TABLE llm_requests ADD COLUMN reasoning_tokens INTEGER DEFAULT 0") - .execute(pool) - .await; - - // Add billing_mode column if it doesn't exist (migration for existing DBs) - let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN billing_mode TEXT") - .execute(pool) - .await; - // Add api_key_encrypted column if it doesn't exist (migration for existing DBs) - let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN api_key_encrypted BOOLEAN DEFAULT FALSE") - .execute(pool) - .await; - - // Add manual cache cost columns to model_configs if they don't exist - let _ = sqlx::query("ALTER TABLE model_configs ADD COLUMN cache_read_cost_per_m REAL") - .execute(pool) - .await; - let _ = sqlx::query("ALTER TABLE model_configs ADD COLUMN cache_write_cost_per_m REAL") - .execute(pool) - .await; - - // Insert default admin user if none exists (default password: admin) - let user_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users").fetch_one(pool).await?; - - if user_count.0 == 0 { - // 'admin' hashed with default cost (12) - let default_admin_hash = - bcrypt::hash("admin", 12).map_err(|e| anyhow::anyhow!("Failed to hash default password: {}", e))?; - sqlx::query( - "INSERT INTO users (username, password_hash, role, must_change_password) VALUES ('admin', ?, 'admin', TRUE)" - ) - .bind(default_admin_hash) - .execute(pool) - .await?; - info!("Created default admin user with password 'admin' (must change on first login)"); - } - - // Create indices - sqlx::query("CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)") - .execute(pool) - .await?; - - sqlx::query("CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)") - .execute(pool) - .await?; - - sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)") - .execute(pool) - .await?; - - sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)") - .execute(pool) - .await?; - - sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)") - .execute(pool) - .await?; - - sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)") - .execute(pool) - .await?; - - sqlx::query("CREATE UNIQUE INDEX IF NOT EXISTS idx_client_tokens_token ON client_tokens(token)") - .execute(pool) - .await?; - - sqlx::query("CREATE INDEX IF NOT EXISTS idx_client_tokens_client_id ON client_tokens(client_id)") - .execute(pool) - .await?; - - // Composite indexes for performance - sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp)") - .execute(pool) - .await?; - - sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp)") - .execute(pool) - .await?; - - sqlx::query("CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id)") - .execute(pool) - .await?; - - // Insert default client if none exists - sqlx::query( - r#" - INSERT OR IGNORE INTO clients (client_id, name, description) - VALUES ('default', 'Default Client', 'Default client for anonymous requests') - "#, - ) - .execute(pool) - .await?; - - Ok(()) -} - -pub async fn test_connection(pool: &DbPool) -> Result<()> { - sqlx::query("SELECT 1").execute(pool).await?; - Ok(()) -} diff --git a/src/errors/mod.rs b/src/errors/mod.rs deleted file mode 100644 index 6f4b8c03..00000000 --- a/src/errors/mod.rs +++ /dev/null @@ -1,58 +0,0 @@ -use thiserror::Error; - -#[derive(Error, Debug, Clone)] -pub enum AppError { - #[error("Authentication failed: {0}")] - AuthError(String), - - #[error("Configuration error: {0}")] - ConfigError(String), - - #[error("Database error: {0}")] - DatabaseError(String), - - #[error("Provider error: {0}")] - ProviderError(String), - - #[error("Validation error: {0}")] - ValidationError(String), - - #[error("Multimodal processing error: {0}")] - MultimodalError(String), - - #[error("Rate limit exceeded: {0}")] - RateLimitError(String), - - #[error("Internal server error: {0}")] - InternalError(String), -} - -impl From for AppError { - fn from(err: sqlx::Error) -> Self { - AppError::DatabaseError(err.to_string()) - } -} - -impl From for AppError { - fn from(err: anyhow::Error) -> Self { - AppError::InternalError(err.to_string()) - } -} - -impl axum::response::IntoResponse for AppError { - fn into_response(self) -> axum::response::Response { - let status = match self { - AppError::AuthError(_) => axum::http::StatusCode::UNAUTHORIZED, - AppError::RateLimitError(_) => axum::http::StatusCode::TOO_MANY_REQUESTS, - AppError::ValidationError(_) => axum::http::StatusCode::BAD_REQUEST, - _ => axum::http::StatusCode::INTERNAL_SERVER_ERROR, - }; - - let body = axum::Json(serde_json::json!({ - "error": self.to_string(), - "type": format!("{:?}", self) - })); - - (status, body).into_response() - } -} diff --git a/src/lib.rs b/src/lib.rs deleted file mode 100644 index ab114519..00000000 --- a/src/lib.rs +++ /dev/null @@ -1,328 +0,0 @@ -//! LLM Proxy Library -//! -//! This library provides the core functionality for the LLM proxy gateway, -//! including provider integration, token tracking, and API endpoints. - -pub mod auth; -pub mod client; -pub mod config; -pub mod dashboard; -pub mod database; -pub mod errors; -pub mod logging; -pub mod models; -pub mod multimodal; -pub mod providers; -pub mod rate_limiting; -pub mod server; -pub mod state; -pub mod utils; - -// Re-exports for convenience -pub use auth::{AuthenticatedClient, validate_token}; -pub use config::{ - AppConfig, DatabaseConfig, DeepSeekConfig, GeminiConfig, GrokConfig, ModelMappingConfig, ModelPricing, - OllamaConfig, OpenAIConfig, PricingConfig, ProviderConfig, ServerConfig, -}; -pub use database::{DbPool, init as init_db, test_connection}; -pub use errors::AppError; -pub use logging::{LoggingContext, RequestLog, RequestLogger}; -pub use models::{ - ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage, - ChatStreamChoice, ChatStreamDelta, ContentPart, ContentPartValue, FromOpenAI, ImageUrl, MessageContent, - OpenAIContentPart, OpenAIMessage, OpenAIRequest, ToOpenAI, UnifiedMessage, UnifiedRequest, Usage, -}; -pub use providers::{Provider, ProviderManager, ProviderResponse, ProviderStreamChunk}; -pub use server::router; -pub use state::AppState; - -/// Test utilities for integration testing -#[cfg(test)] -pub mod test_utils { - use std::sync::Arc; - - use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState, utils::crypto, database::run_migrations}; - use sqlx::sqlite::SqlitePool; - - /// Create a test application state - pub async fn create_test_state() -> AppState { - // Create in-memory database - let pool = SqlitePool::connect("sqlite::memory:") - .await - .expect("Failed to create test database"); - - // Run migrations on the pool - run_migrations(&pool).await.expect("Failed to run migrations"); - - let rate_limit_manager = RateLimitManager::new( - crate::rate_limiting::RateLimiterConfig::default(), - crate::rate_limiting::CircuitBreakerConfig::default(), - ); - - let client_manager = Arc::new(ClientManager::new(pool.clone())); - - // Create provider manager - let provider_manager = ProviderManager::new(); - - let model_registry = crate::models::registry::ModelRegistry { - providers: std::collections::HashMap::new(), - }; - - let (dashboard_tx, _) = tokio::sync::broadcast::channel::(100); - - let config = Arc::new(crate::config::AppConfig { - server: crate::config::ServerConfig { - port: 8080, - host: "127.0.0.1".to_string(), - auth_tokens: vec![], - }, - database: crate::config::DatabaseConfig { - path: std::path::PathBuf::from(":memory:"), - max_connections: 5, - }, - providers: crate::config::ProviderConfig { - openai: crate::config::OpenAIConfig { - api_key_env: "OPENAI_API_KEY".to_string(), - base_url: "".to_string(), - default_model: "".to_string(), - enabled: true, - }, - gemini: crate::config::GeminiConfig { - api_key_env: "GEMINI_API_KEY".to_string(), - base_url: "".to_string(), - default_model: "".to_string(), - enabled: true, - }, - deepseek: crate::config::DeepSeekConfig { - api_key_env: "DEEPSEEK_API_KEY".to_string(), - base_url: "".to_string(), - default_model: "".to_string(), - enabled: true, - }, - grok: crate::config::GrokConfig { - api_key_env: "GROK_API_KEY".to_string(), - base_url: "".to_string(), - default_model: "".to_string(), - enabled: true, - }, - ollama: crate::config::OllamaConfig { - base_url: "".to_string(), - enabled: true, - models: vec![], - }, - }, - model_mapping: crate::config::ModelMappingConfig { patterns: vec![] }, - pricing: crate::config::PricingConfig { - openai: vec![], - gemini: vec![], - deepseek: vec![], - grok: vec![], - ollama: vec![], - }, - config_path: None, - encryption_key: "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f".to_string(), - }); - - // Initialize encryption with the test key - crypto::init_with_key(&config.encryption_key).expect("failed to initialize crypto"); - - AppState::new( - config, - provider_manager, - pool, - rate_limit_manager, - model_registry, - vec![], // auth_tokens - ) - } - - /// Create a test HTTP client - pub fn create_test_client() -> reqwest::Client { - reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(30)) - .build() - .expect("Failed to create test HTTP client") - } -} - -#[cfg(test)] -mod integration_tests { - use super::test_utils::*; - use crate::{ - models::{ChatCompletionRequest, ChatMessage}, - server::router, - utils::crypto, - }; - use axum::{ - body::Body, - http::{Request, StatusCode}, - }; - use mockito::Server; - use serde_json::json; - use sqlx::Row; - use tower::util::ServiceExt; - - #[tokio::test] - async fn test_encrypted_provider_key_integration() { - // Step 1: Setup test database and state - let state = create_test_state().await; - let pool = state.db_pool.clone(); - - // Step 2: Insert provider with encrypted API key - let test_api_key = "test-openai-key-12345"; - let encrypted_key = crypto::encrypt(test_api_key).expect("Failed to encrypt test key"); - - sqlx::query( - r#" - INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - "#, - ) - .bind("openai") - .bind("OpenAI") - .bind(true) - .bind("http://localhost:1234") // Mock server URL - .bind(&encrypted_key) - .bind(true) // api_key_encrypted flag - .bind(100.0) - .bind(5.0) - .execute(&pool) - .await - .expect("Failed to update provider URL"); - - // Re-initialize provider with new URL - state - .provider_manager - .initialize_provider("openai", &state.config, &pool) - .await - .expect("Failed to re-initialize provider"); - - // Step 4: Mock OpenAI API server - let mut server = Server::new_async().await; - let mock = server - .mock("POST", "/chat/completions") - .match_header("authorization", format!("Bearer {}", test_api_key).as_str()) - .with_status(200) - .with_header("content-type", "application/json") - .with_body( - json!({ - "id": "chatcmpl-test", - "object": "chat.completion", - "created": 1234567890, - "model": "gpt-3.5-turbo", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": "Hello, world!" - }, - "finish_reason": "stop" - }], - "usage": { - "prompt_tokens": 10, - "completion_tokens": 5, - "total_tokens": 15 - } - }) - .to_string(), - ) - .create_async() - .await; - - // Update provider base URL to use mock server - sqlx::query("UPDATE provider_configs SET base_url = ? WHERE id = 'openai'") - .bind(&server.url()) - .execute(&pool) - .await - .expect("Failed to update provider URL"); - - // Re-initialize provider with new URL - state - .provider_manager - .initialize_provider("openai", &state.config, &pool) - .await - .expect("Failed to re-initialize provider"); - - // Step 5: Create test router and make request - let app = router(state); - - let request_body = ChatCompletionRequest { - model: "gpt-3.5-turbo".to_string(), - messages: vec![ChatMessage { - role: "user".to_string(), - content: crate::models::MessageContent::Text { - content: "Hello".to_string(), - }, - reasoning_content: None, - tool_calls: None, - name: None, - tool_call_id: None, - }], - temperature: None, - top_p: None, - top_k: None, - n: None, - stop: None, - max_tokens: Some(100), - presence_penalty: None, - frequency_penalty: None, - stream: Some(false), - tools: None, - tool_choice: None, - }; - - let request = Request::builder() - .method("POST") - .uri("/v1/chat/completions") - .header("content-type", "application/json") - .header("authorization", "Bearer test-token") - .body(Body::from(serde_json::to_string(&request_body).unwrap())) - .unwrap(); - - // Step 6: Execute request through proxy - let response = app - .oneshot(request) - .await - .expect("Failed to execute request"); - - let status = response.status(); - println!("Response status: {}", status); - - if status != StatusCode::OK { - let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); - let body_str = String::from_utf8(body_bytes.to_vec()).unwrap(); - println!("Response body: {}", body_str); - panic!("Response status is not OK: {}", status); - } - - assert_eq!(status, StatusCode::OK); - - // Verify the mock was called - mock.assert_async().await; - - // Give the async logging task time to complete - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - - // Step 7: Verify usage was logged in database - let log_row = sqlx::query("SELECT * FROM llm_requests WHERE client_id = 'client_test-tok' ORDER BY id DESC LIMIT 1") - .fetch_one(&pool) - .await - .expect("Request log not found"); - - assert_eq!(log_row.get::("provider"), "openai"); - assert_eq!(log_row.get::("model"), "gpt-3.5-turbo"); - assert_eq!(log_row.get::("prompt_tokens"), 10); - assert_eq!(log_row.get::("completion_tokens"), 5); - assert_eq!(log_row.get::("total_tokens"), 15); - assert_eq!(log_row.get::("status"), "success"); - - // Verify client usage was updated - let client_row = sqlx::query("SELECT total_requests, total_tokens, total_cost FROM clients WHERE client_id = 'client_test-tok'") - .fetch_one(&pool) - .await - .expect("Client not found"); - - assert_eq!(client_row.get::("total_requests"), 1); - assert_eq!(client_row.get::("total_tokens"), 15); - } -} diff --git a/src/logging/mod.rs b/src/logging/mod.rs deleted file mode 100644 index ab5df368..00000000 --- a/src/logging/mod.rs +++ /dev/null @@ -1,242 +0,0 @@ -use chrono::{DateTime, Utc}; -use serde::Serialize; -use sqlx::SqlitePool; -use tokio::sync::broadcast; -use tracing::{info, warn}; - -use crate::errors::AppError; - -/// Request log entry for database storage -#[derive(Debug, Clone, Serialize)] -pub struct RequestLog { - pub timestamp: DateTime, - pub client_id: String, - pub provider: String, - pub model: String, - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub reasoning_tokens: u32, - pub total_tokens: u32, - pub cache_read_tokens: u32, - pub cache_write_tokens: u32, - pub cost: f64, - pub has_images: bool, - pub status: String, // "success", "error" - pub error_message: Option, - pub duration_ms: u64, -} - -/// Database operations for request logging -pub struct RequestLogger { - db_pool: SqlitePool, - dashboard_tx: broadcast::Sender, -} - -impl RequestLogger { - pub fn new(db_pool: SqlitePool, dashboard_tx: broadcast::Sender) -> Self { - Self { db_pool, dashboard_tx } - } - - /// Log a request to the database (async, spawns a task) - pub fn log_request(&self, log: RequestLog) { - let pool = self.db_pool.clone(); - let tx = self.dashboard_tx.clone(); - - // Spawn async task to log without blocking response - tokio::spawn(async move { - // Broadcast to dashboard - let broadcast_result = tx.send(serde_json::json!({ - "type": "request", - "channel": "requests", - "payload": log - })); - match broadcast_result { - Ok(receivers) => info!("Broadcast request log to {} dashboard listeners", receivers), - Err(_) => {} // No active WebSocket clients — expected when dashboard isn't open - } - - match Self::insert_log(&pool, log).await { - Ok(()) => info!("Request logged to database successfully"), - Err(e) => warn!("Failed to log request to database: {}", e), - } - }); - } - - /// Insert a log entry into the database - async fn insert_log(pool: &SqlitePool, log: RequestLog) -> Result<(), sqlx::Error> { - let mut tx = pool.begin().await?; - - // Ensure the client row exists (FK constraint requires it) - sqlx::query( - "INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')", - ) - .bind(&log.client_id) - .bind(&log.client_id) - .execute(&mut *tx) - .await?; - - sqlx::query( - r#" - INSERT INTO llm_requests - (timestamp, client_id, provider, model, prompt_tokens, completion_tokens, reasoning_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status, error_message, duration_ms, request_body, response_body) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - "#, - ) - .bind(log.timestamp) - .bind(&log.client_id) - .bind(&log.provider) - .bind(&log.model) - .bind(log.prompt_tokens as i64) - .bind(log.completion_tokens as i64) - .bind(log.reasoning_tokens as i64) - .bind(log.total_tokens as i64) - .bind(log.cache_read_tokens as i64) - .bind(log.cache_write_tokens as i64) - .bind(log.cost) - .bind(log.has_images) - .bind(&log.status) - .bind(log.error_message) - .bind(log.duration_ms as i64) - .bind(None::) // request_body - optional, not stored to save disk space - .bind(None::) // response_body - optional, not stored to save disk space - .execute(&mut *tx) - .await?; - - // Update client usage statistics - sqlx::query( - r#" - UPDATE clients SET - total_requests = total_requests + 1, - total_tokens = total_tokens + ?, - total_cost = total_cost + ?, - updated_at = CURRENT_TIMESTAMP - WHERE client_id = ? - "#, - ) - .bind(log.total_tokens as i64) - .bind(log.cost) - .bind(&log.client_id) - .execute(&mut *tx) - .await?; - - // Deduct from provider balance if successful. - // Providers configured with billing_mode = 'postpaid' will not have their - // credit_balance decremented. Use a conditional UPDATE so we don't need - // a prior SELECT and avoid extra round-trips. - if log.cost > 0.0 { - sqlx::query( - "UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')", - ) - .bind(log.cost) - .bind(&log.provider) - .execute(&mut *tx) - .await?; - } - - tx.commit().await?; - - Ok(()) - } -} - -// /// Middleware to log LLM API requests -// /// TODO: Implement proper middleware that can extract response body details -// pub async fn request_logging_middleware( -// // Extract the authenticated client (if available) -// auth_result: Result, -// request: Request, -// next: Next, -// ) -> Response { -// let start_time = std::time::Instant::now(); -// -// // Extract client_id from auth or use "unknown" -// let client_id = match auth_result { -// Ok(auth) => auth.client_id, -// Err(_) => "unknown".to_string(), -// }; -// -// // Try to extract request details -// let (request_parts, request_body) = request.into_parts(); -// -// // Clone request parts for logging -// let path = request_parts.uri.path().to_string(); -// -// // Check if this is a chat completion request -// let is_chat_completion = path == "/v1/chat/completions"; -// -// // Reconstruct request for downstream handlers -// let request = Request::from_parts(request_parts, request_body); -// -// // Process request and get response -// let response = next.run(request).await; -// -// // Calculate duration -// let duration = start_time.elapsed(); -// let duration_ms = duration.as_millis() as u64; -// -// // Log basic request info -// info!( -// "Request from {} to {} - Status: {} - Duration: {}ms", -// client_id, -// path, -// response.status().as_u16(), -// duration_ms -// ); -// -// // TODO: Extract more details from request/response for logging -// // For now, we'll need to modify the server handler to pass additional context -// -// response -// } - -/// Context for request logging that can be passed through extensions -#[derive(Clone)] -pub struct LoggingContext { - pub client_id: String, - pub provider_name: String, - pub model: String, - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, - pub cost: f64, - pub has_images: bool, - pub error: Option, -} - -impl LoggingContext { - pub fn new(client_id: String, provider_name: String, model: String) -> Self { - Self { - client_id, - provider_name, - model, - prompt_tokens: 0, - completion_tokens: 0, - total_tokens: 0, - cost: 0.0, - has_images: false, - error: None, - } - } - - pub fn with_token_counts(mut self, prompt_tokens: u32, completion_tokens: u32) -> Self { - self.prompt_tokens = prompt_tokens; - self.completion_tokens = completion_tokens; - self.total_tokens = prompt_tokens + completion_tokens; - self - } - - pub fn with_cost(mut self, cost: f64) -> Self { - self.cost = cost; - self - } - - pub fn with_images(mut self, has_images: bool) -> Self { - self.has_images = has_images; - self - } - - pub fn with_error(mut self, error: AppError) -> Self { - self.error = Some(error); - self - } -} diff --git a/src/main.rs b/src/main.rs deleted file mode 100644 index b12a3615..00000000 --- a/src/main.rs +++ /dev/null @@ -1,96 +0,0 @@ -use anyhow::Result; -use axum::{Router, routing::get}; -use std::net::SocketAddr; -use tracing::{error, info}; - -use llm_proxy::{ - config::AppConfig, - dashboard, database, - providers::ProviderManager, - rate_limiting::{CircuitBreakerConfig, RateLimitManager, RateLimiterConfig}, - server, - state::AppState, - utils::crypto, -}; - -#[tokio::main] -async fn main() -> Result<()> { - // Initialize tracing (logging) - tracing_subscriber::fmt() - .with_max_level(tracing::Level::INFO) - .with_target(false) - .init(); - - info!("Starting LLM Proxy Gateway v{}", env!("CARGO_PKG_VERSION")); - - // Load configuration - let config = AppConfig::load().await?; - info!("Configuration loaded from {:?}", config.config_path); - - // Initialize encryption - crypto::init_with_key(&config.encryption_key)?; - info!("Encryption initialized"); - - // Initialize database connection pool - let db_pool = database::init(&config.database).await?; - info!("Database initialized at {:?}", config.database.path); - - // Initialize provider manager with configured providers - let provider_manager = ProviderManager::new(); - - // Initialize all supported providers (they handle their own enabled check) - let supported_providers = vec!["openai", "gemini", "deepseek", "grok", "ollama"]; - for name in supported_providers { - if let Err(e) = provider_manager.initialize_provider(name, &config, &db_pool).await { - error!("Failed to initialize provider {}: {}", name, e); - } - } - - // Create rate limit manager - let rate_limit_manager = RateLimitManager::new(RateLimiterConfig::default(), CircuitBreakerConfig::default()); - - // Fetch model registry from models.dev - let model_registry = match llm_proxy::utils::registry::fetch_registry().await { - Ok(registry) => registry, - Err(e) => { - error!("Failed to fetch model registry: {}. Using empty registry.", e); - llm_proxy::models::registry::ModelRegistry { - providers: std::collections::HashMap::new(), - } - } - }; - - // Create application state - let state = AppState::new( - config.clone(), - provider_manager, - db_pool, - rate_limit_manager, - model_registry, - config.server.auth_tokens.clone(), - ); - - // Initialize model config cache and start background refresh (every 30s) - state.model_config_cache.refresh().await; - state.model_config_cache.clone().start_refresh_task(30); - info!("Model config cache initialized"); - - // Create application router - let app = Router::new() - .route("/health", get(health_check)) - .merge(server::router(state.clone())) - .merge(dashboard::router(state.clone())); - - // Start server - let addr = SocketAddr::from(([0, 0, 0, 0], config.server.port)); - info!("Server listening on http://{}", addr); - - let listener = tokio::net::TcpListener::bind(&addr).await?; - axum::serve(listener, app).await?; - - Ok(()) -} - -async fn health_check() -> &'static str { - "OK" -} diff --git a/src/models/mod.rs b/src/models/mod.rs deleted file mode 100644 index 5b44bf7a..00000000 --- a/src/models/mod.rs +++ /dev/null @@ -1,381 +0,0 @@ -use serde::{Deserialize, Serialize}; -use serde_json::Value; - -pub mod registry; - -// ========== OpenAI-compatible Request/Response Structs ========== - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChatCompletionRequest { - pub model: String, - pub messages: Vec, - #[serde(default)] - pub temperature: Option, - #[serde(default)] - pub top_p: Option, - #[serde(default)] - pub top_k: Option, - #[serde(default)] - pub n: Option, - #[serde(default)] - pub stop: Option, // Can be string or array of strings - #[serde(default)] - pub max_tokens: Option, - #[serde(default)] - pub presence_penalty: Option, - #[serde(default)] - pub frequency_penalty: Option, - #[serde(default)] - pub stream: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub tools: Option>, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub tool_choice: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChatMessage { - pub role: String, // "system", "user", "assistant", "tool" - #[serde(flatten)] - pub content: MessageContent, - #[serde(alias = "reasoning", alias = "thought", skip_serializing_if = "Option::is_none")] - pub reasoning_content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_call_id: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] -pub enum MessageContent { - Text { content: String }, - Parts { content: Vec }, - None, // Handle cases where content might be null but reasoning is present -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum ContentPartValue { - Text { text: String }, - ImageUrl { image_url: ImageUrl }, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ImageUrl { - pub url: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub detail: Option, -} - -// ========== Tool-Calling Types ========== - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Tool { - #[serde(rename = "type")] - pub tool_type: String, - pub function: FunctionDef, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FunctionDef { - pub name: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub parameters: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] -pub enum ToolChoice { - Mode(String), // "auto", "none", "required" - Specific(ToolChoiceSpecific), -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolChoiceSpecific { - #[serde(rename = "type")] - pub choice_type: String, - pub function: ToolChoiceFunction, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolChoiceFunction { - pub name: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolCall { - pub id: String, - #[serde(rename = "type")] - pub call_type: String, - pub function: FunctionCall, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FunctionCall { - pub name: String, - pub arguments: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolCallDelta { - pub index: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub id: Option, - #[serde(rename = "type", skip_serializing_if = "Option::is_none")] - pub call_type: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub function: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FunctionCallDelta { - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub arguments: Option, -} - -// ========== OpenAI-compatible Response Structs ========== - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChatCompletionResponse { - pub id: String, - pub object: String, - pub created: u64, - pub model: String, - pub choices: Vec, - pub usage: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChatChoice { - pub index: u32, - pub message: ChatMessage, - pub finish_reason: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Usage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, - #[serde(skip_serializing_if = "Option::is_none")] - pub reasoning_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub cache_read_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub cache_write_tokens: Option, -} - -// ========== Streaming Response Structs ========== - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChatCompletionStreamResponse { - pub id: String, - pub object: String, - pub created: u64, - pub model: String, - pub choices: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub usage: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChatStreamChoice { - pub index: u32, - pub delta: ChatStreamDelta, - pub finish_reason: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChatStreamDelta { - pub role: Option, - pub content: Option, - #[serde(alias = "reasoning", alias = "thought", skip_serializing_if = "Option::is_none")] - pub reasoning_content: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, -} - -// ========== Unified Request Format (for internal use) ========== - -#[derive(Debug, Clone)] -pub struct UnifiedRequest { - pub client_id: String, - pub model: String, - pub messages: Vec, - pub temperature: Option, - pub top_p: Option, - pub top_k: Option, - pub n: Option, - pub stop: Option>, - pub max_tokens: Option, - pub presence_penalty: Option, - pub frequency_penalty: Option, - pub stream: bool, - pub has_images: bool, - pub tools: Option>, - pub tool_choice: Option, -} - -#[derive(Debug, Clone)] -pub struct UnifiedMessage { - pub role: String, - pub content: Vec, - pub reasoning_content: Option, - pub tool_calls: Option>, - pub name: Option, - pub tool_call_id: Option, -} - -#[derive(Debug, Clone)] -pub enum ContentPart { - Text { text: String }, - Image(crate::multimodal::ImageInput), -} - -// ========== Provider-specific Structs ========== - -#[derive(Debug, Clone, Serialize)] -pub struct OpenAIRequest { - pub model: String, - pub messages: Vec, - pub temperature: Option, - pub max_tokens: Option, - pub stream: Option, -} - -#[derive(Debug, Clone, Serialize)] -pub struct OpenAIMessage { - pub role: String, - pub content: Vec, -} - -#[derive(Debug, Clone, Serialize)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum OpenAIContentPart { - Text { text: String }, - ImageUrl { image_url: ImageUrl }, -} - -// Note: ImageUrl struct is defined earlier in the file - -// ========== Conversion Traits ========== - -pub trait ToOpenAI { - fn to_openai(&self) -> Result; -} - -pub trait FromOpenAI { - fn from_openai(request: &OpenAIRequest) -> Result - where - Self: Sized; -} - -impl UnifiedRequest { - /// Hydrate all image content by fetching URLs and converting to base64/bytes - pub async fn hydrate_images(&mut self) -> anyhow::Result<()> { - if !self.has_images { - return Ok(()); - } - - for msg in &mut self.messages { - for part in &mut msg.content { - if let ContentPart::Image(image_input) = part { - // Pre-fetch and validate if it's a URL - if let crate::multimodal::ImageInput::Url(_url) = image_input { - let (base64_data, mime_type) = image_input.to_base64().await?; - *image_input = crate::multimodal::ImageInput::Base64 { - data: base64_data, - mime_type, - }; - } - } - } - } - Ok(()) - } -} - -impl TryFrom for UnifiedRequest { - type Error = anyhow::Error; - - fn try_from(req: ChatCompletionRequest) -> Result { - let mut has_images = false; - - // Convert OpenAI-compatible request to unified format - let messages = req - .messages - .into_iter() - .map(|msg| { - let (content, _images_in_message) = match msg.content { - MessageContent::Text { content } => (vec![ContentPart::Text { text: content }], false), - MessageContent::Parts { content } => { - let mut unified_content = Vec::new(); - let mut has_images_in_msg = false; - - for part in content { - match part { - ContentPartValue::Text { text } => { - unified_content.push(ContentPart::Text { text }); - } - ContentPartValue::ImageUrl { image_url } => { - has_images_in_msg = true; - has_images = true; - unified_content.push(ContentPart::Image(crate::multimodal::ImageInput::from_url( - image_url.url, - ))); - } - } - } - - (unified_content, has_images_in_msg) - } - MessageContent::None => (vec![], false), - }; - - UnifiedMessage { - role: msg.role, - content, - reasoning_content: msg.reasoning_content, - tool_calls: msg.tool_calls, - name: msg.name, - tool_call_id: msg.tool_call_id, - } - }) - .collect(); - - let stop = match req.stop { - Some(Value::String(s)) => Some(vec![s]), - Some(Value::Array(a)) => Some( - a.into_iter() - .filter_map(|v| v.as_str().map(|s| s.to_string())) - .collect(), - ), - _ => None, - }; - - Ok(UnifiedRequest { - client_id: String::new(), // Will be populated by auth middleware - model: req.model, - messages, - temperature: req.temperature, - top_p: req.top_p, - top_k: req.top_k, - n: req.n, - stop, - max_tokens: req.max_tokens, - presence_penalty: req.presence_penalty, - frequency_penalty: req.frequency_penalty, - stream: req.stream.unwrap_or(false), - has_images, - tools: req.tools, - tool_choice: req.tool_choice, - }) - } -} diff --git a/src/models/registry.rs b/src/models/registry.rs deleted file mode 100644 index 40f1430d..00000000 --- a/src/models/registry.rs +++ /dev/null @@ -1,219 +0,0 @@ -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ModelRegistry { - #[serde(flatten)] - pub providers: HashMap, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ProviderInfo { - pub id: String, - pub name: String, - pub models: HashMap, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ModelMetadata { - pub id: String, - pub name: String, - pub cost: Option, - pub limit: Option, - pub modalities: Option, - pub tool_call: Option, - pub reasoning: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ModelCost { - pub input: f64, - pub output: f64, - pub cache_read: Option, - pub cache_write: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ModelLimit { - pub context: u32, - pub output: u32, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ModelModalities { - pub input: Vec, - pub output: Vec, -} - -/// A model entry paired with its provider ID, returned by listing/filtering methods. -#[derive(Debug, Clone)] -pub struct ModelEntry<'a> { - pub model_key: &'a str, - pub provider_id: &'a str, - pub provider_name: &'a str, - pub metadata: &'a ModelMetadata, -} - -/// Filter criteria for listing models. All fields are optional; `None` means no filter. -#[derive(Debug, Default, Clone, Deserialize)] -pub struct ModelFilter { - /// Filter by provider ID (exact match). - pub provider: Option, - /// Text search on model ID or name (case-insensitive substring). - pub search: Option, - /// Filter by input modality (e.g. "image", "text"). - pub modality: Option, - /// Only models that support tool calling. - pub tool_call: Option, - /// Only models that support reasoning. - pub reasoning: Option, - /// Only models that have pricing data. - pub has_cost: Option, -} - -/// Sort field for model listings. -#[derive(Debug, Clone, Deserialize, Default, PartialEq)] -#[serde(rename_all = "snake_case")] -pub enum ModelSortBy { - #[default] - Name, - Id, - Provider, - ContextLimit, - InputCost, - OutputCost, -} - -/// Sort direction. -#[derive(Debug, Clone, Deserialize, Default, PartialEq)] -#[serde(rename_all = "snake_case")] -pub enum SortOrder { - #[default] - Asc, - Desc, -} - -impl ModelRegistry { - /// Find a model by its ID (searching across all providers) - pub fn find_model(&self, model_id: &str) -> Option<&ModelMetadata> { - // First try exact match if the key in models map matches the ID - for provider in self.providers.values() { - if let Some(model) = provider.models.get(model_id) { - return Some(model); - } - } - - // Try searching for the model ID inside the metadata if the key was different - for provider in self.providers.values() { - for model in provider.models.values() { - if model.id == model_id { - return Some(model); - } - } - } - - None - } - - /// List all models with optional filtering and sorting. - pub fn list_models( - &self, - filter: &ModelFilter, - sort_by: &ModelSortBy, - sort_order: &SortOrder, - ) -> Vec> { - let mut entries: Vec> = Vec::new(); - - for (p_id, p_info) in &self.providers { - // Provider filter - if let Some(ref prov) = filter.provider { - if p_id != prov { - continue; - } - } - - for (m_key, m_meta) in &p_info.models { - // Text search filter - if let Some(ref search) = filter.search { - let search_lower = search.to_lowercase(); - if !m_meta.id.to_lowercase().contains(&search_lower) - && !m_meta.name.to_lowercase().contains(&search_lower) - && !m_key.to_lowercase().contains(&search_lower) - { - continue; - } - } - - // Modality filter - if let Some(ref modality) = filter.modality { - let has_modality = m_meta - .modalities - .as_ref() - .is_some_and(|m| m.input.iter().any(|i| i.eq_ignore_ascii_case(modality))); - if !has_modality { - continue; - } - } - - // Tool call filter - if let Some(tc) = filter.tool_call { - if m_meta.tool_call.unwrap_or(false) != tc { - continue; - } - } - - // Reasoning filter - if let Some(r) = filter.reasoning { - if m_meta.reasoning.unwrap_or(false) != r { - continue; - } - } - - // Has cost filter - if let Some(hc) = filter.has_cost { - if hc != m_meta.cost.is_some() { - continue; - } - } - - entries.push(ModelEntry { - model_key: m_key, - provider_id: p_id, - provider_name: &p_info.name, - metadata: m_meta, - }); - } - } - - // Sort - entries.sort_by(|a, b| { - let cmp = match sort_by { - ModelSortBy::Name => a.metadata.name.to_lowercase().cmp(&b.metadata.name.to_lowercase()), - ModelSortBy::Id => a.model_key.cmp(b.model_key), - ModelSortBy::Provider => a.provider_id.cmp(b.provider_id), - ModelSortBy::ContextLimit => { - let a_ctx = a.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0); - let b_ctx = b.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0); - a_ctx.cmp(&b_ctx) - } - ModelSortBy::InputCost => { - let a_cost = a.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0); - let b_cost = b.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0); - a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal) - } - ModelSortBy::OutputCost => { - let a_cost = a.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0); - let b_cost = b.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0); - a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal) - } - }; - - match sort_order { - SortOrder::Asc => cmp, - SortOrder::Desc => cmp.reverse(), - } - }); - - entries - } -} diff --git a/src/multimodal/mod.rs b/src/multimodal/mod.rs deleted file mode 100644 index 94fc09fc..00000000 --- a/src/multimodal/mod.rs +++ /dev/null @@ -1,299 +0,0 @@ -//! Multimodal support for image processing and conversion -//! -//! This module handles: -//! 1. Image format detection and conversion -//! 2. Base64 encoding/decoding -//! 3. URL fetching for images -//! 4. Provider-specific image format conversion - -use anyhow::{Context, Result}; -use base64::{Engine as _, engine::general_purpose}; -use std::sync::LazyLock; -use tracing::{info, warn}; - -/// Shared HTTP client for image fetching — avoids creating a new TCP+TLS -/// connection for every image URL. -static IMAGE_CLIENT: LazyLock = LazyLock::new(|| { - reqwest::Client::builder() - .connect_timeout(std::time::Duration::from_secs(5)) - .timeout(std::time::Duration::from_secs(30)) - .pool_idle_timeout(std::time::Duration::from_secs(60)) - .build() - .expect("Failed to build image HTTP client") -}); - -/// Supported image formats for multimodal input -#[derive(Debug, Clone)] -pub enum ImageInput { - /// Base64-encoded image data with MIME type - Base64 { data: String, mime_type: String }, - /// URL to fetch image from - Url(String), - /// Raw bytes with MIME type - Bytes { data: Vec, mime_type: String }, -} - -impl ImageInput { - /// Create ImageInput from base64 string - pub fn from_base64(data: String, mime_type: String) -> Self { - Self::Base64 { data, mime_type } - } - - /// Create ImageInput from URL - pub fn from_url(url: String) -> Self { - Self::Url(url) - } - - /// Create ImageInput from raw bytes - pub fn from_bytes(data: Vec, mime_type: String) -> Self { - Self::Bytes { data, mime_type } - } - - /// Get MIME type if available - pub fn mime_type(&self) -> Option<&str> { - match self { - Self::Base64 { mime_type, .. } => Some(mime_type), - Self::Bytes { mime_type, .. } => Some(mime_type), - Self::Url(_) => None, - } - } - - /// Convert to base64 if not already - pub async fn to_base64(&self) -> Result<(String, String)> { - match self { - Self::Base64 { data, mime_type } => Ok((data.clone(), mime_type.clone())), - Self::Bytes { data, mime_type } => { - let base64_data = general_purpose::STANDARD.encode(data); - Ok((base64_data, mime_type.clone())) - } - Self::Url(url) => { - // Fetch image from URL using shared client - info!("Fetching image from URL: {}", url); - let response = IMAGE_CLIENT - .get(url) - .send() - .await - .context("Failed to fetch image from URL")?; - - if !response.status().is_success() { - anyhow::bail!("Failed to fetch image: HTTP {}", response.status()); - } - - let mime_type = response - .headers() - .get(reqwest::header::CONTENT_TYPE) - .and_then(|h| h.to_str().ok()) - .unwrap_or("image/jpeg") - .to_string(); - - let bytes = response.bytes().await.context("Failed to read image bytes")?; - - let base64_data = general_purpose::STANDARD.encode(&bytes); - Ok((base64_data, mime_type)) - } - } - } - - /// Get image dimensions (width, height) - pub async fn get_dimensions(&self) -> Result<(u32, u32)> { - let bytes = match self { - Self::Base64 { data, .. } => general_purpose::STANDARD - .decode(data) - .context("Failed to decode base64")?, - Self::Bytes { data, .. } => data.clone(), - Self::Url(_) => { - let (base64_data, _) = self.to_base64().await?; - general_purpose::STANDARD - .decode(&base64_data) - .context("Failed to decode base64")? - } - }; - - let img = image::load_from_memory(&bytes).context("Failed to load image from bytes")?; - Ok((img.width(), img.height())) - } - - /// Validate image size and format - pub async fn validate(&self, max_size_mb: f64) -> Result<()> { - let (width, height) = self.get_dimensions().await?; - - // Check dimensions - if width > 4096 || height > 4096 { - warn!("Image dimensions too large: {}x{}", width, height); - // Continue anyway, but log warning - } - - // Check file size - let size_bytes = match self { - Self::Base64 { data, .. } => { - // Base64 size is ~4/3 of original - (data.len() as f64 * 0.75) as usize - } - Self::Bytes { data, .. } => data.len(), - Self::Url(_) => { - // For URLs, we'd need to fetch to check size - // Skip size check for URLs for now - return Ok(()); - } - }; - - let size_mb = size_bytes as f64 / (1024.0 * 1024.0); - if size_mb > max_size_mb { - anyhow::bail!("Image too large: {:.2}MB > {:.2}MB limit", size_mb, max_size_mb); - } - - Ok(()) - } -} - -/// Provider-specific image format conversion -pub struct ImageConverter; - -impl ImageConverter { - /// Convert image to OpenAI-compatible format - pub async fn to_openai_format(image: &ImageInput) -> Result { - let (base64_data, mime_type) = image.to_base64().await?; - - // OpenAI expects data URL format: "data:image/jpeg;base64,{data}" - let data_url = format!("data:{};base64,{}", mime_type, base64_data); - - Ok(serde_json::json!({ - "type": "image_url", - "image_url": { - "url": data_url, - "detail": "auto" // Can be "low", "high", or "auto" - } - })) - } - - /// Convert image to Gemini-compatible format - pub async fn to_gemini_format(image: &ImageInput) -> Result { - let (base64_data, mime_type) = image.to_base64().await?; - - // Gemini expects inline data format - Ok(serde_json::json!({ - "inline_data": { - "mime_type": mime_type, - "data": base64_data - } - })) - } - - /// Convert image to DeepSeek-compatible format - pub async fn to_deepseek_format(image: &ImageInput) -> Result { - // DeepSeek uses OpenAI-compatible format for vision models - Self::to_openai_format(image).await - } - - /// Detect if a model supports multimodal input - pub fn model_supports_multimodal(model: &str) -> bool { - // OpenAI vision models - if (model.starts_with("gpt-4") && (model.contains("vision") || model.contains("-v") || model.contains("4o"))) - || model.starts_with("o1-") - || model.starts_with("o3-") - { - return true; - } - - // Gemini vision models - if model.starts_with("gemini") { - // Most Gemini models support vision - return true; - } - - // DeepSeek vision models - if model.starts_with("deepseek-vl") { - return true; - } - - false - } -} - -/// Parse OpenAI-compatible multimodal message content -pub fn parse_openai_content(content: &serde_json::Value) -> Result)>> { - let mut parts = Vec::new(); - - if let Some(content_str) = content.as_str() { - // Simple text content - parts.push((content_str.to_string(), None)); - } else if let Some(content_array) = content.as_array() { - // Array of content parts (text and/or images) - for part in content_array { - if let Some(part_obj) = part.as_object() - && let Some(part_type) = part_obj.get("type").and_then(|t| t.as_str()) - { - match part_type { - "text" => { - if let Some(text) = part_obj.get("text").and_then(|t| t.as_str()) { - parts.push((text.to_string(), None)); - } - } - "image_url" => { - if let Some(image_url_obj) = part_obj.get("image_url").and_then(|o| o.as_object()) - && let Some(url) = image_url_obj.get("url").and_then(|u| u.as_str()) - { - if url.starts_with("data:") { - // Parse data URL - if let Some((mime_type, data)) = parse_data_url(url) { - let image_input = ImageInput::from_base64(data, mime_type); - parts.push(("".to_string(), Some(image_input))); - } - } else { - // Regular URL - let image_input = ImageInput::from_url(url.to_string()); - parts.push(("".to_string(), Some(image_input))); - } - } - } - _ => { - warn!("Unknown content part type: {}", part_type); - } - } - } - } - } - - Ok(parts) -} - -/// Parse data URL (data:image/jpeg;base64,{data}) -fn parse_data_url(data_url: &str) -> Option<(String, String)> { - if !data_url.starts_with("data:") { - return None; - } - - let parts: Vec<&str> = data_url[5..].split(";base64,").collect(); - if parts.len() != 2 { - return None; - } - - let mime_type = parts[0].to_string(); - let data = parts[1].to_string(); - - Some((mime_type, data)) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_parse_data_url() { - let test_url = "data:image/jpeg;base64,SGVsbG8gV29ybGQ="; // "Hello World" in base64 - let (mime_type, data) = parse_data_url(test_url).unwrap(); - - assert_eq!(mime_type, "image/jpeg"); - assert_eq!(data, "SGVsbG8gV29ybGQ="); - } - - #[tokio::test] - async fn test_model_supports_multimodal() { - assert!(ImageConverter::model_supports_multimodal("gpt-4-vision-preview")); - assert!(ImageConverter::model_supports_multimodal("gpt-4o")); - assert!(ImageConverter::model_supports_multimodal("gemini-pro-vision")); - assert!(ImageConverter::model_supports_multimodal("gemini-pro")); - assert!(!ImageConverter::model_supports_multimodal("gpt-3.5-turbo")); - assert!(!ImageConverter::model_supports_multimodal("claude-3-opus")); - } -} diff --git a/src/providers/deepseek.rs b/src/providers/deepseek.rs deleted file mode 100644 index 12b94efb..00000000 --- a/src/providers/deepseek.rs +++ /dev/null @@ -1,268 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; -use futures::stream::BoxStream; -use futures::StreamExt; - -use super::helpers; -use super::{ProviderResponse, ProviderStreamChunk}; -use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest}; - -pub struct DeepSeekProvider { - client: reqwest::Client, - config: crate::config::DeepSeekConfig, - api_key: String, - pricing: Vec, -} - -impl DeepSeekProvider { - pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result { - let api_key = app_config.get_api_key("deepseek")?; - Self::new_with_key(config, app_config, api_key) - } - - pub fn new_with_key( - config: &crate::config::DeepSeekConfig, - app_config: &AppConfig, - api_key: String, - ) -> Result { - let client = reqwest::Client::builder() - .connect_timeout(std::time::Duration::from_secs(5)) - .timeout(std::time::Duration::from_secs(300)) - .pool_idle_timeout(std::time::Duration::from_secs(90)) - .pool_max_idle_per_host(4) - .tcp_keepalive(std::time::Duration::from_secs(30)) - .build()?; - - Ok(Self { - client, - config: config.clone(), - api_key, - pricing: app_config.pricing.deepseek.clone(), - }) - } -} - -#[async_trait] -impl super::Provider for DeepSeekProvider { - fn name(&self) -> &str { - "deepseek" - } - - fn supports_model(&self, model: &str) -> bool { - model.starts_with("deepseek-") || model.contains("deepseek") - } - - fn supports_multimodal(&self) -> bool { - false - } - - async fn chat_completion(&self, request: UnifiedRequest) -> Result { - let messages_json = helpers::messages_to_openai_json(&request.messages).await?; - let mut body = helpers::build_openai_body(&request, messages_json, false); - - // Sanitize and fix for deepseek-reasoner (R1) - if request.model == "deepseek-reasoner" { - if let Some(obj) = body.as_object_mut() { - // Remove unsupported parameters - obj.remove("temperature"); - obj.remove("top_p"); - obj.remove("presence_penalty"); - obj.remove("frequency_penalty"); - obj.remove("logit_bias"); - obj.remove("logprobs"); - obj.remove("top_logprobs"); - - // ENSURE: EVERY assistant message must have reasoning_content and valid content - if let Some(messages) = obj.get_mut("messages").and_then(|m| m.as_array_mut()) { - for m in messages { - if m["role"].as_str() == Some("assistant") { - // DeepSeek R1 requires reasoning_content for consistency in history. - if m.get("reasoning_content").is_none() || m["reasoning_content"].is_null() { - m["reasoning_content"] = serde_json::json!(" "); - } - // DeepSeek R1 often requires content to be a string, not null/array - if m.get("content").is_none() || m["content"].is_null() || m["content"].is_array() { - m["content"] = serde_json::json!(""); - } - } - } - } - } - } - - let response = self - .client - .post(format!("{}/chat/completions", self.config.base_url)) - .header("Authorization", format!("Bearer {}", self.api_key)) - .json(&body) - .send() - .await - .map_err(|e| AppError::ProviderError(e.to_string()))?; - - if !response.status().is_success() { - let status = response.status(); - let error_text = response.text().await.unwrap_or_default(); - tracing::error!("DeepSeek API error ({}): {}", status, error_text); - tracing::error!("Offending DeepSeek Request Body: {}", serde_json::to_string(&body).unwrap_or_default()); - return Err(AppError::ProviderError(format!("DeepSeek API error ({}): {}", status, error_text))); - } - - let resp_json: serde_json::Value = response - .json() - .await - .map_err(|e| AppError::ProviderError(e.to_string()))?; - - helpers::parse_openai_response(&resp_json, request.model) - } - - fn estimate_tokens(&self, request: &UnifiedRequest) -> Result { - Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) - } - - fn calculate_cost( - &self, - model: &str, - prompt_tokens: u32, - completion_tokens: u32, - cache_read_tokens: u32, - cache_write_tokens: u32, - registry: &crate::models::registry::ModelRegistry, - ) -> f64 { - if let Some(metadata) = registry.find_model(model) { - if metadata.cost.is_some() { - return helpers::calculate_cost_with_registry( - model, - prompt_tokens, - completion_tokens, - cache_read_tokens, - cache_write_tokens, - registry, - &self.pricing, - 0.28, - 0.42, - ); - } - } - - // Custom DeepSeek fallback that correctly handles cache hits - let (prompt_rate, completion_rate) = self - .pricing - .iter() - .find(|p| model.contains(&p.model)) - .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) - .unwrap_or((0.28, 0.42)); // Default to DeepSeek's current API pricing - - let cache_hit_rate = prompt_rate / 10.0; - let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens); - - (non_cached_prompt as f64 * prompt_rate / 1_000_000.0) - + (cache_read_tokens as f64 * cache_hit_rate / 1_000_000.0) - + (completion_tokens as f64 * completion_rate / 1_000_000.0) - } - - async fn chat_completion_stream( - &self, - request: UnifiedRequest, - ) -> Result>, AppError> { - // DeepSeek doesn't support images in streaming, use text-only - let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?; - let mut body = helpers::build_openai_body(&request, messages_json, true); - - // Sanitize and fix for deepseek-reasoner (R1) - if request.model == "deepseek-reasoner" { - if let Some(obj) = body.as_object_mut() { - // Keep stream_options if present (DeepSeek supports include_usage) - - // Remove unsupported parameters - obj.remove("temperature"); - - obj.remove("top_p"); - obj.remove("presence_penalty"); - obj.remove("frequency_penalty"); - obj.remove("logit_bias"); - obj.remove("logprobs"); - obj.remove("top_logprobs"); - - // ENSURE: EVERY assistant message must have reasoning_content and valid content - if let Some(messages) = obj.get_mut("messages").and_then(|m| m.as_array_mut()) { - for m in messages { - if m["role"].as_str() == Some("assistant") { - // DeepSeek R1 requires reasoning_content for consistency in history. - if m.get("reasoning_content").is_none() || m["reasoning_content"].is_null() { - m["reasoning_content"] = serde_json::json!(" "); - } - // DeepSeek R1 often requires content to be a string, not null/array - if m.get("content").is_none() || m["content"].is_null() || m["content"].is_array() { - m["content"] = serde_json::json!(""); - } - } - } - } - } - } - - let url = format!("{}/chat/completions", self.config.base_url); - let api_key = self.api_key.clone(); - let probe_client = self.client.clone(); - let probe_body = body.clone(); - let model = request.model.clone(); - - let es = reqwest_eventsource::EventSource::new( - self.client - .post(&url) - .header("Authorization", format!("Bearer {}", self.api_key)) - .json(&body), - ) - .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; - - let stream = async_stream::try_stream! { - let mut es = es; - while let Some(event) = es.next().await { - match event { - Ok(reqwest_eventsource::Event::Message(msg)) => { - if msg.data == "[DONE]" { - break; - } - - let chunk: serde_json::Value = serde_json::from_str(&msg.data) - .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; - - if let Some(p_chunk) = helpers::parse_openai_stream_chunk(&chunk, &model, None) { - yield p_chunk?; - } - } - Ok(_) => continue, - Err(e) => { - // Attempt to probe for the actual error body - let probe_resp = probe_client - .post(&url) - .header("Authorization", format!("Bearer {}", api_key)) - .json(&probe_body) - .send() - .await; - - match probe_resp { - Ok(r) if !r.status().is_success() => { - let status = r.status(); - let error_body = r.text().await.unwrap_or_default(); - tracing::error!("DeepSeek Stream Error Probe ({}): {}", status, error_body); - // Log the offending request body at ERROR level so it shows up in standard logs - tracing::error!("Offending DeepSeek Request Body: {}", serde_json::to_string(&probe_body).unwrap_or_default()); - Err(AppError::ProviderError(format!("DeepSeek API error ({}): {}", status, error_body)))?; - } - Ok(_) => { - Err(AppError::ProviderError(format!("Stream error (probe returned 200): {}", e)))?; - } - Err(probe_err) => { - tracing::error!("DeepSeek Stream Error Probe failed: {}", probe_err); - Err(AppError::ProviderError(format!("Stream error (probe failed: {}): {}", probe_err, e)))?; - } - } - } - } - } - }; - - Ok(Box::pin(stream)) - } -} diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs deleted file mode 100644 index cc2ef10b..00000000 --- a/src/providers/gemini.rs +++ /dev/null @@ -1,1062 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; -use futures::stream::{BoxStream, StreamExt}; -use reqwest_eventsource::Event; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use uuid::Uuid; - -use super::{ProviderResponse, ProviderStreamChunk}; -use crate::{ - config::AppConfig, - errors::AppError, - models::{ContentPart, FunctionCall, FunctionCallDelta, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest}, -}; - -// ========== Gemini Request Structs ========== - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -struct GeminiRequest { - contents: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - system_instruction: Option, - #[serde(skip_serializing_if = "Option::is_none")] - generation_config: Option, - #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - tool_config: Option, - #[serde(skip_serializing_if = "Option::is_none")] - safety_settings: Option>, -} - -#[derive(Debug, Clone, Serialize)] -struct GeminiSafetySetting { - category: String, - threshold: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct GeminiContent { - parts: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - role: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -struct GeminiPart { - #[serde(skip_serializing_if = "Option::is_none")] - text: Option, - #[serde(skip_serializing_if = "Option::is_none")] - inline_data: Option, - #[serde(skip_serializing_if = "Option::is_none")] - function_call: Option, - #[serde(skip_serializing_if = "Option::is_none")] - function_response: Option, - #[serde(skip_serializing_if = "Option::is_none")] - thought: Option, - #[serde(skip_serializing_if = "Option::is_none", rename = "thought_signature")] - thought_signature_snake: Option, - #[serde(skip_serializing_if = "Option::is_none", rename = "thoughtSignature")] - thought_signature_camel: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct GeminiInlineData { - mime_type: String, - data: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -struct GeminiFunctionCall { - name: String, - args: Value, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct GeminiFunctionResponse { - name: String, - response: Value, -} - - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -struct GeminiGenerationConfig { - #[serde(skip_serializing_if = "Option::is_none")] - temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - top_p: Option, - #[serde(skip_serializing_if = "Option::is_none")] - top_k: Option, - #[serde(skip_serializing_if = "Option::is_none")] - max_output_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - stop_sequences: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - candidate_count: Option, -} - -// ========== Gemini Tool Structs ========== - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -struct GeminiTool { - function_declarations: Vec, -} - -#[derive(Debug, Clone, Serialize)] -struct GeminiFunctionDeclaration { - name: String, - #[serde(skip_serializing_if = "Option::is_none")] - description: Option, - #[serde(skip_serializing_if = "Option::is_none")] - parameters: Option, -} - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -struct GeminiToolConfig { - function_calling_config: GeminiFunctionCallingConfig, -} - -#[derive(Debug, Clone, Serialize)] -struct GeminiFunctionCallingConfig { - mode: String, - #[serde(skip_serializing_if = "Option::is_none", rename = "allowedFunctionNames")] - allowed_function_names: Option>, -} - -// ========== Gemini Response Structs ========== - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -struct GeminiCandidate { - content: GeminiContent, - #[serde(default)] - #[allow(dead_code)] - finish_reason: Option, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -struct GeminiUsageMetadata { - #[serde(default)] - prompt_token_count: u32, - #[serde(default)] - candidates_token_count: u32, - #[serde(default)] - total_token_count: u32, - #[serde(default)] - cached_content_token_count: u32, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -struct GeminiResponse { - candidates: Vec, - usage_metadata: Option, -} - -// Streaming responses from Gemini may include messages without `candidates` (e.g. promptFeedback). -// Use a more permissive struct for streaming to avoid aborting the SSE prematurely. -#[derive(Debug, Deserialize, Default)] -#[serde(rename_all = "camelCase")] -struct GeminiStreamResponse { - #[serde(default)] - candidates: Vec, - #[serde(default)] - usage_metadata: Option, -} - -#[derive(Debug, Deserialize, Default)] -#[serde(rename_all = "camelCase")] -struct GeminiStreamCandidate { - #[serde(default)] - content: Option, - #[serde(default)] - finish_reason: Option, -} - -// ========== Provider Implementation ========== - -pub struct GeminiProvider { - client: reqwest::Client, - config: crate::config::GeminiConfig, - api_key: String, - pricing: Vec, -} - -impl GeminiProvider { - pub fn new(config: &crate::config::GeminiConfig, app_config: &AppConfig) -> Result { - let api_key = app_config.get_api_key("gemini")?; - Self::new_with_key(config, app_config, api_key) - } - - pub fn new_with_key(config: &crate::config::GeminiConfig, app_config: &AppConfig, api_key: String) -> Result { - let client = reqwest::Client::builder() - .connect_timeout(std::time::Duration::from_secs(5)) - .timeout(std::time::Duration::from_secs(300)) - .pool_idle_timeout(std::time::Duration::from_secs(90)) - .pool_max_idle_per_host(4) - .tcp_keepalive(std::time::Duration::from_secs(30)) - .build()?; - - Ok(Self { - client, - config: config.clone(), - api_key, - pricing: app_config.pricing.gemini.clone(), - }) - } - - /// Convert unified messages to Gemini content format. - /// Handles text, images, tool calls (assistant), and tool results. - /// Returns (contents, system_instruction) - async fn convert_messages( - messages: Vec, - ) -> Result<(Vec, Option), AppError> { - let mut contents: Vec = Vec::new(); - let mut system_parts = Vec::new(); - - // PRE-PASS: Build tool_id -> function_name mapping for tool responses - let mut tool_id_to_name = std::collections::HashMap::new(); - for msg in &messages { - if let Some(tool_calls) = &msg.tool_calls { - for tc in tool_calls { - tool_id_to_name.insert(tc.id.clone(), tc.function.name.clone()); - } - } - } - - for msg in messages { - if msg.role == "system" { - for part in msg.content { - if let ContentPart::Text { text } = part { - if !text.trim().is_empty() { - system_parts.push(GeminiPart { - text: Some(text), - inline_data: None, - function_call: None, - function_response: None, - thought: None, - thought_signature_snake: None, - thought_signature_camel: None, - }); - } - } - } - continue; - } - - let role = match msg.role.as_str() { - "assistant" => "model".to_string(), - "tool" => "user".to_string(), // Tool results are user-side in Gemini - _ => "user".to_string(), - }; - - let mut parts = Vec::new(); - - // Handle tool results (role "tool") - if msg.role == "tool" { - let text_content = msg - .content - .first() - .map(|p| match p { - ContentPart::Text { text } => text.clone(), - ContentPart::Image(_) => "[Image]".to_string(), - }) - .unwrap_or_default(); - - // RESOLVE: Use msg.name if present, otherwise look up by tool_call_id - let name = msg.name.clone() - .or_else(|| { - msg.tool_call_id.as_ref() - .and_then(|id| tool_id_to_name.get(id).cloned()) - }) - .or_else(|| msg.tool_call_id.clone()) - .unwrap_or_else(|| "unknown_function".to_string()); - - // Gemini API requires 'response' to be a JSON object (google.protobuf.Struct). - // If it is an array or primitive, wrap it in an object. - let mut response_value = serde_json::from_str::(&text_content) - .unwrap_or_else(|_| serde_json::json!({ "result": text_content })); - - if !response_value.is_object() { - response_value = serde_json::json!({ "result": response_value }); - } - - parts.push(GeminiPart { - text: None, - inline_data: None, - function_call: None, - function_response: Some(GeminiFunctionResponse { - name, - response: response_value, - }), - thought: None, - thought_signature_snake: None, - thought_signature_camel: None, - }); - } else if msg.role == "assistant" { - // Assistant messages: handle text, thought (reasoning), and tool_calls - for p in &msg.content { - if let ContentPart::Text { text } = p { - if !text.trim().is_empty() { - parts.push(GeminiPart { - text: Some(text.clone()), - inline_data: None, - function_call: None, - function_response: None, - thought: None, - thought_signature_snake: None, - thought_signature_camel: None, - }); - } - } - } - - // If reasoning_content is present, include it as a 'thought' part - if let Some(reasoning) = &msg.reasoning_content { - if !reasoning.trim().is_empty() { - parts.push(GeminiPart { - text: None, - inline_data: None, - function_call: None, - function_response: None, - thought: Some(reasoning.clone()), - thought_signature_snake: None, - thought_signature_camel: None, - }); - } - } - - if let Some(tool_calls) = &msg.tool_calls { - for tc in tool_calls { - let args = serde_json::from_str::(&tc.function.arguments) - .unwrap_or_else(|_| serde_json::json!({})); - - // RESTORE: Only use tc.id as thought_signature if it's NOT a synthetic ID. - // Synthetic IDs (starting with 'call_') cause 400 errors as they are not valid Base64 for the TYPE_BYTES field. - let thought_signature = if tc.id.starts_with("call_") { - None - } else { - Some(tc.id.clone()) - }; - - parts.push(GeminiPart { - text: None, - inline_data: None, - function_call: Some(GeminiFunctionCall { - name: tc.function.name.clone(), - args, - }), - function_response: None, - thought: None, - thought_signature_snake: thought_signature.clone(), - thought_signature_camel: thought_signature, - }); - } - } - } else { - // Regular text/image messages (mostly user) - for part in msg.content { - match part { - ContentPart::Text { text } => { - if !text.trim().is_empty() { - parts.push(GeminiPart { - text: Some(text), - inline_data: None, - function_call: None, - function_response: None, - thought: None, - thought_signature_snake: None, - thought_signature_camel: None, - }); - } - } - ContentPart::Image(image_input) => { - let (base64_data, mime_type) = image_input - .to_base64() - .await - .map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?; - - parts.push(GeminiPart { - text: None, - inline_data: Some(GeminiInlineData { - mime_type, - data: base64_data, - }), - function_call: None, - function_response: None, - thought: None, - thought_signature_snake: None, - thought_signature_camel: None, - }); - } - } - } - } - - if parts.is_empty() { - continue; - } - - // STRATEGY: Strictly enforce alternating roles. - if let Some(last_content) = contents.last_mut() { - if last_content.role.as_ref() == Some(&role) { - last_content.parts.extend(parts); - continue; - } - } - - contents.push(GeminiContent { - parts, - role: Some(role), - }); - } - - // Gemini requires the first message to be from "user". - if let Some(first) = contents.first() { - if first.role.as_deref() == Some("model") { - contents.insert(0, GeminiContent { - role: Some("user".to_string()), - parts: vec![GeminiPart { - text: Some("Continue conversation.".to_string()), - inline_data: None, - function_call: None, - function_response: None, - thought: None, - thought_signature_snake: None, - thought_signature_camel: None, - }], - }); - } - } - - // Final check: ensure we don't have empty contents after filtering. - if contents.is_empty() && system_parts.is_empty() { - return Err(AppError::ProviderError("No valid content parts after filtering".to_string())); - } - - let system_instruction = if !system_parts.is_empty() { - Some(GeminiContent { - parts: system_parts, - role: None, - }) - } else { - None - }; - - Ok((contents, system_instruction)) - } - - /// Convert OpenAI tools to Gemini function declarations. - fn convert_tools(request: &UnifiedRequest) -> Option> { - request.tools.as_ref().map(|tools| { - let declarations: Vec = tools - .iter() - .map(|t| { - let mut parameters = t.function.parameters.clone().unwrap_or(serde_json::json!({ - "type": "object", - "properties": {} - })); - Self::sanitize_schema(&mut parameters); - - GeminiFunctionDeclaration { - name: t.function.name.clone(), - description: t.function.description.clone(), - parameters: Some(parameters), - } - }) - .collect(); - vec![GeminiTool { - function_declarations: declarations, - }] - }) - } - - /// Recursively remove unsupported JSON Schema fields that Gemini's API rejects. - fn sanitize_schema(value: &mut Value) { - if let Value::Object(map) = value { - // Remove unsupported fields at this level - map.remove("$schema"); - map.remove("additionalProperties"); - map.remove("exclusiveMaximum"); - map.remove("exclusiveMinimum"); - - // Recursively sanitize all object properties - if let Some(properties) = map.get_mut("properties") { - if let Value::Object(props_map) = properties { - for prop_value in props_map.values_mut() { - Self::sanitize_schema(prop_value); - } - } - } - - // Recursively sanitize array items - if let Some(items) = map.get_mut("items") { - Self::sanitize_schema(items); - } - - // Gemini 1.5/2.0+ supports anyOf in some contexts, but it's often - // the source of additionalProperties errors when nested. - if let Some(any_of) = map.get_mut("anyOf") { - if let Value::Array(arr) = any_of { - for item in arr { - Self::sanitize_schema(item); - } - } - } - if let Some(one_of) = map.get_mut("oneOf") { - if let Value::Array(arr) = one_of { - for item in arr { - Self::sanitize_schema(item); - } - } - } - if let Some(all_of) = map.get_mut("allOf") { - if let Value::Array(arr) = all_of { - for item in arr { - Self::sanitize_schema(item); - } - } - } - } - } - - /// Convert OpenAI tool_choice to Gemini tool_config. - fn convert_tool_config(request: &UnifiedRequest) -> Option { - request.tool_choice.as_ref().map(|tc| { - let (mode, allowed_names) = match tc { - crate::models::ToolChoice::Mode(mode) => { - let gemini_mode = match mode.as_str() { - "auto" => "AUTO", - "none" => "NONE", - "required" => "ANY", - _ => "AUTO", - }; - (gemini_mode.to_string(), None) - } - crate::models::ToolChoice::Specific(specific) => { - ("ANY".to_string(), Some(vec![specific.function.name.clone()])) - } - }; - GeminiToolConfig { - function_calling_config: GeminiFunctionCallingConfig { - mode, - allowed_function_names: allowed_names, - }, - } - }) - } - - /// Extract tool calls from Gemini response parts into OpenAI-format ToolCalls. - fn extract_tool_calls(parts: &[GeminiPart]) -> Option> { - let calls: Vec = parts - .iter() - .filter(|p| p.function_call.is_some()) - .map(|p| { - let fc = p.function_call.as_ref().unwrap(); - // CAPTURE: Try extracting thought_signature from sibling fields - let id = p.thought_signature_camel.clone() - .or_else(|| p.thought_signature_snake.clone()) - .unwrap_or_else(|| format!("call_{}", Uuid::new_v4().simple())); - - ToolCall { - id, - call_type: "function".to_string(), - function: FunctionCall { - name: fc.name.clone(), - arguments: serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string()), - }, - } - }) - .collect(); - - if calls.is_empty() { None } else { Some(calls) } - } - - /// Determine the appropriate base URL for the model. - /// "preview" models often require the v1beta endpoint, but newer promoted ones may be on v1. - fn get_base_url(&self, model: &str) -> String { - let base = &self.config.base_url; - - // If the model requires v1beta but the base is currently v1 - if (model.contains("preview") || model.contains("thinking") || model.contains("gemini-3")) && base.ends_with("/v1") { - return base.replace("/v1", "/v1beta"); - } - - // If the model is a standard model but the base is v1beta, we could downgrade it, - // but typically v1beta is a superset, so we just return the base as configured. - base.clone() - } - - /// Default safety settings to avoid blocking responses. - fn get_safety_settings(&self, base_url: &str) -> Vec { - let mut categories = vec![ - "HARM_CATEGORY_HARASSMENT", - "HARM_CATEGORY_HATE_SPEECH", - "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "HARM_CATEGORY_DANGEROUS_CONTENT", - ]; - - // Civic integrity is only available in v1beta - if base_url.contains("v1beta") { - categories.push("HARM_CATEGORY_CIVIC_INTEGRITY"); - } - - categories - .into_iter() - .map(|c| GeminiSafetySetting { - category: c.to_string(), - threshold: "BLOCK_NONE".to_string(), - }) - .collect() - } -} - -#[async_trait] -impl super::Provider for GeminiProvider { - fn name(&self) -> &str { - "gemini" - } - - fn supports_model(&self, model: &str) -> bool { - model.starts_with("gemini-") - } - - fn supports_multimodal(&self) -> bool { - true // Gemini supports vision - } - - async fn chat_completion(&self, request: UnifiedRequest) -> Result { - let mut model = request.model.clone(); - - // Normalize model name: If it's a known Gemini model version, use it; - // otherwise, if it starts with gemini- but is an unknown legacy version, - // fallback to the default model to avoid 400 errors. - // We now allow gemini-3+ as valid versions. - let is_known_version = model.starts_with("gemini-1.5") || - model.starts_with("gemini-2.0") || - model.starts_with("gemini-2.5") || - model.starts_with("gemini-3"); - - if !is_known_version && model.starts_with("gemini-") { - tracing::info!("Mapping unknown Gemini model {} to default {}", model, self.config.default_model); - model = self.config.default_model.clone(); - } - - let tools = Self::convert_tools(&request); - let tool_config = Self::convert_tool_config(&request); - let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?; - - if contents.is_empty() && system_instruction.is_none() { - return Err(AppError::ProviderError("No valid messages to send".to_string())); - } - - let base_url = self.get_base_url(&model); - - // Sanitize stop sequences: Gemini rejects empty strings - let stop_sequences = request.stop.map(|s| { - s.into_iter() - .filter(|seq| !seq.is_empty()) - .collect::>() - }).filter(|s| !s.is_empty()); - - let generation_config = Some(GeminiGenerationConfig { - temperature: request.temperature, - top_p: request.top_p, - top_k: request.top_k, - max_output_tokens: request.max_tokens.map(|t| t.min(65536)), - stop_sequences, - candidate_count: request.n, - }); - - let gemini_request = GeminiRequest { - contents, - system_instruction, - generation_config, - tools, - tool_config, - safety_settings: Some(self.get_safety_settings(&base_url)), - }; - - let url = format!("{}/models/{}:generateContent", base_url, model); - tracing::info!("Calling Gemini API: {}", url); - - let response = self - .client - .post(&url) - .header("x-goog-api-key", &self.api_key) - .json(&gemini_request) - .send() - .await - .map_err(|e| AppError::ProviderError(format!("HTTP request failed: {}", e)))?; - - let status = response.status(); - if !status.is_success() { - let error_text = response.text().await.unwrap_or_default(); - return Err(AppError::ProviderError(format!( - "Gemini API error ({}): {}", - status, error_text - ))); - } - - let gemini_response: GeminiResponse = response - .json() - .await - .map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?; - - let candidate = gemini_response.candidates.first(); - - // Extract text content - let content = candidate - .and_then(|c| c.content.parts.iter().find_map(|p| p.text.clone())) - .unwrap_or_default(); - - // Extract reasoning (Gemini 3 'thought' parts) - let reasoning_content = candidate - .and_then(|c| c.content.parts.iter().find_map(|p| p.thought.clone())); - - let reasoning_tokens = reasoning_content.as_ref() - .map(|r| crate::utils::tokens::estimate_completion_tokens(r, &model)) - .unwrap_or(0); - - // Extract function calls → OpenAI tool_calls - let tool_calls = candidate.and_then(|c| Self::extract_tool_calls(&c.content.parts)); - - let prompt_tokens = gemini_response - .usage_metadata - .as_ref() - .map(|u| u.prompt_token_count) - .unwrap_or(0); - let completion_tokens = gemini_response - .usage_metadata - .as_ref() - .map(|u| u.candidates_token_count) - .unwrap_or(0); - let total_tokens = gemini_response - .usage_metadata - .as_ref() - .map(|u| u.total_token_count) - .unwrap_or(0); - let cache_read_tokens = gemini_response - .usage_metadata - .as_ref() - .map(|u| u.cached_content_token_count) - .unwrap_or(0); - - Ok(ProviderResponse { - content, - reasoning_content, - tool_calls, - prompt_tokens, - completion_tokens, - reasoning_tokens, - total_tokens, - cache_read_tokens, - cache_write_tokens: 0, // Gemini doesn't report cache writes separately - model, - }) - } - - fn estimate_tokens(&self, request: &UnifiedRequest) -> Result { - Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) - } - - fn calculate_cost( - &self, - model: &str, - prompt_tokens: u32, - completion_tokens: u32, - cache_read_tokens: u32, - cache_write_tokens: u32, - registry: &crate::models::registry::ModelRegistry, - ) -> f64 { - if let Some(metadata) = registry.find_model(model) { - if metadata.cost.is_some() { - return super::helpers::calculate_cost_with_registry( - model, - prompt_tokens, - completion_tokens, - cache_read_tokens, - cache_write_tokens, - registry, - &self.pricing, - 0.075, - 0.30, - ); - } - } - - // Custom Gemini fallback that correctly handles cache hits (25% of input cost) - let (prompt_rate, completion_rate) = self - .pricing - .iter() - .find(|p| model.contains(&p.model)) - .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) - .unwrap_or((0.075, 0.30)); // Default to Gemini 1.5 Flash current API pricing - - let cache_hit_rate = prompt_rate * 0.25; - let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens); - - (non_cached_prompt as f64 * prompt_rate / 1_000_000.0) - + (cache_read_tokens as f64 * cache_hit_rate / 1_000_000.0) - + (completion_tokens as f64 * completion_rate / 1_000_000.0) - } - - async fn chat_completion_stream( - &self, - request: UnifiedRequest, - ) -> Result>, AppError> { - let mut model = request.model.clone(); - - // Normalize model name: fallback to default if unknown Gemini model is requested - let is_known_version = model.starts_with("gemini-1.5") || - model.starts_with("gemini-2.0") || - model.starts_with("gemini-2.5") || - model.starts_with("gemini-3"); - - if !is_known_version && model.starts_with("gemini-") { - tracing::info!("Mapping unknown Gemini model {} to default {}", model, self.config.default_model); - model = self.config.default_model.clone(); - } - - let tools = Self::convert_tools(&request); - let tool_config = Self::convert_tool_config(&request); - let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?; - - if contents.is_empty() && system_instruction.is_none() { - return Err(AppError::ProviderError("No valid messages to send".to_string())); - } - - let base_url = self.get_base_url(&model); - - // Sanitize stop sequences: Gemini rejects empty strings - let stop_sequences = request.stop.map(|s| { - s.into_iter() - .filter(|seq| !seq.is_empty()) - .collect::>() - }).filter(|s| !s.is_empty()); - - let generation_config = Some(GeminiGenerationConfig { - temperature: request.temperature, - top_p: request.top_p, - top_k: request.top_k, - max_output_tokens: request.max_tokens.map(|t| t.min(65536)), - stop_sequences, - candidate_count: request.n, - }); - - let gemini_request = GeminiRequest { - contents, - system_instruction, - generation_config, - tools, - tool_config, - safety_settings: Some(self.get_safety_settings(&base_url)), - }; - - let url = format!( - "{}/models/{}:streamGenerateContent?alt=sse", - base_url, model, - ); - tracing::info!("Calling Gemini Stream API: {}", url); - - // Capture a clone of the request to probe for errors (Gemini 400s are common) - let probe_request = gemini_request.clone(); - let probe_client = self.client.clone(); - // Use non-streaming URL for probing to get a valid JSON error body - let probe_url = format!("{}/models/{}:generateContent", base_url, model); - let probe_api_key = self.api_key.clone(); - - // Create the EventSource first (it doesn't send until polled) - let es = reqwest_eventsource::EventSource::new( - self.client - .post(&url) - .header("x-goog-api-key", &self.api_key) - .header("Accept", "text/event-stream") - .json(&gemini_request), - ).map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; - - let stream = async_stream::try_stream! { - let mut es = es; - // Track tool call IDs by their part index to ensure stability during streaming. - // Gemini doesn't always include the thoughtSignature in every chunk for the same part. - let mut tool_call_ids: std::collections::HashMap = std::collections::HashMap::new(); - let mut seen_tool_calls = false; - - while let Some(event) = es.next().await { - match event { - Ok(Event::Message(msg)) => { - let gemini_response: GeminiStreamResponse = serde_json::from_str(&msg.data) - .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; - - tracing::info!("Received Gemini stream chunk (candidates: {}, has_usage: {}, finish_reason: {:?})", - gemini_response.candidates.len(), - gemini_response.usage_metadata.is_some(), - gemini_response.candidates.first().and_then(|c| c.finish_reason.as_deref()) - ); - - // Extract usage from usageMetadata if present (reported on every/last chunk) - let stream_usage = gemini_response.usage_metadata.as_ref().map(|u| { - super::StreamUsage { - prompt_tokens: u.prompt_token_count, - completion_tokens: u.candidates_token_count, - reasoning_tokens: 0, - total_tokens: u.total_token_count, - cache_read_tokens: u.cached_content_token_count, - cache_write_tokens: 0, - } - }); - - // Some streaming events may not contain candidates (e.g. promptFeedback). - // Only emit chunks when we have candidate content or tool calls. - if let Some(candidate) = gemini_response.candidates.first() { - if let Some(content_obj) = &candidate.content { - let content = content_obj - .parts - .iter() - .find_map(|p| p.text.clone()) - .unwrap_or_default(); - - let reasoning_content = content_obj - .parts - .iter() - .find_map(|p| p.thought.clone()); - - // Extract tool calls with index and ID stability - let mut deltas = Vec::new(); - for (p_idx, p) in content_obj.parts.iter().enumerate() { - if let Some(fc) = &p.function_call { - seen_tool_calls = true; - let tool_call_idx = p_idx as u32; - - // Attempt to find a signature in sibling fields - let signature = p.thought_signature_camel.clone() - .or_else(|| p.thought_signature_snake.clone()); - - // Ensure the ID remains stable for this tool call index. - // If we found a real signature now, we update it; otherwise use the existing or new random ID. - let entry = tool_call_ids.entry(tool_call_idx); - let current_id = match entry { - std::collections::hash_map::Entry::Occupied(mut e) => { - if let Some(sig) = signature { - // If we previously had a 'call_' ID but now found a real signature, upgrade it. - if e.get().starts_with("call_") { - e.insert(sig); - } - } - e.get().clone() - } - std::collections::hash_map::Entry::Vacant(e) => { - let id = signature.unwrap_or_else(|| format!("call_{}", Uuid::new_v4().simple())); - e.insert(id.clone()); - id - } - }; - - deltas.push(ToolCallDelta { - index: tool_call_idx, - id: Some(current_id), - call_type: Some("function".to_string()), - function: Some(FunctionCallDelta { - name: Some(fc.name.clone()), - arguments: Some(serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string())), - }), - }); - } - } - let tool_calls = if deltas.is_empty() { None } else { Some(deltas) }; - - // Determine finish_reason - // STRATEGY: If we have tool calls in this chunk, OR if we have seen them - // previously in the stream, the finish_reason MUST be "tool_calls" - // if the provider signals a stop. This ensures the client executes tools. - let mut finish_reason = candidate.finish_reason.as_ref().map(|fr| { - match fr.as_str() { - "STOP" => "stop".to_string(), - _ => fr.to_lowercase(), - } - }); - - if seen_tool_calls && finish_reason.as_deref() == Some("stop") { - finish_reason = Some("tool_calls".to_string()); - } else if tool_calls.is_some() && finish_reason.is_none() { - // Optional: Could signal tool_calls here too, but OpenAI often waits until EOF. - // For now we only override it at the actual stop signal. - } - - // Avoid emitting completely empty chunks unless they carry usage. - if !content.is_empty() || reasoning_content.is_some() || tool_calls.is_some() || stream_usage.is_some() { - yield ProviderStreamChunk { - content, - reasoning_content, - finish_reason, - tool_calls, - model: model.clone(), - usage: stream_usage, - }; - } - } else if stream_usage.is_some() { - // Usage-only update - yield ProviderStreamChunk { - content: String::new(), - reasoning_content: None, - finish_reason: None, - tool_calls: None, - model: model.clone(), - usage: stream_usage, - }; - } - } else if stream_usage.is_some() { - // No candidates but usage present - yield ProviderStreamChunk { - content: String::new(), - reasoning_content: None, - finish_reason: None, - tool_calls: None, - model: model.clone(), - usage: stream_usage, - }; - } - } - Ok(_) => continue, - Err(e) => { - // "Stream ended" is usually a normal EOF signal in reqwest-eventsource. - // We check the string representation to avoid returning it as an error. - if e.to_string().contains("Stream ended") { - break; - } - - // On stream error, attempt to probe for the actual error body from the provider - let probe_resp = probe_client - .post(&probe_url) - .header("x-goog-api-key", &probe_api_key) - .json(&probe_request) - .send() - .await; - - match probe_resp { - Ok(r) if !r.status().is_success() => { - let status = r.status(); - let body = r.text().await.unwrap_or_default(); - tracing::error!("Gemini Stream Error Probe ({}): {}", status, body); - Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, body)))?; - } - _ => { - Err(AppError::ProviderError(format!("Stream error: {}", e)))?; - } - } - } - } - } - }; - - Ok(Box::pin(stream)) - } -} diff --git a/src/providers/grok.rs b/src/providers/grok.rs deleted file mode 100644 index 3da7fa89..00000000 --- a/src/providers/grok.rs +++ /dev/null @@ -1,123 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; -use futures::stream::BoxStream; - -use super::helpers; -use super::{ProviderResponse, ProviderStreamChunk}; -use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest}; - -pub struct GrokProvider { - client: reqwest::Client, - config: crate::config::GrokConfig, - api_key: String, - pricing: Vec, -} - -impl GrokProvider { - pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result { - let api_key = app_config.get_api_key("grok")?; - Self::new_with_key(config, app_config, api_key) - } - - pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result { - let client = reqwest::Client::builder() - .connect_timeout(std::time::Duration::from_secs(5)) - .timeout(std::time::Duration::from_secs(300)) - .pool_idle_timeout(std::time::Duration::from_secs(90)) - .pool_max_idle_per_host(4) - .tcp_keepalive(std::time::Duration::from_secs(30)) - .build()?; - - Ok(Self { - client, - config: config.clone(), - api_key, - pricing: app_config.pricing.grok.clone(), - }) - } -} - -#[async_trait] -impl super::Provider for GrokProvider { - fn name(&self) -> &str { - "grok" - } - - fn supports_model(&self, model: &str) -> bool { - model.starts_with("grok-") - } - - fn supports_multimodal(&self) -> bool { - true - } - - async fn chat_completion(&self, request: UnifiedRequest) -> Result { - let messages_json = helpers::messages_to_openai_json(&request.messages).await?; - let body = helpers::build_openai_body(&request, messages_json, false); - - let response = self - .client - .post(format!("{}/chat/completions", self.config.base_url)) - .header("Authorization", format!("Bearer {}", self.api_key)) - .json(&body) - .send() - .await - .map_err(|e| AppError::ProviderError(e.to_string()))?; - - if !response.status().is_success() { - let error_text = response.text().await.unwrap_or_default(); - return Err(AppError::ProviderError(format!("Grok API error: {}", error_text))); - } - - let resp_json: serde_json::Value = response - .json() - .await - .map_err(|e| AppError::ProviderError(e.to_string()))?; - - helpers::parse_openai_response(&resp_json, request.model) - } - - fn estimate_tokens(&self, request: &UnifiedRequest) -> Result { - Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) - } - - fn calculate_cost( - &self, - model: &str, - prompt_tokens: u32, - completion_tokens: u32, - cache_read_tokens: u32, - cache_write_tokens: u32, - registry: &crate::models::registry::ModelRegistry, - ) -> f64 { - helpers::calculate_cost_with_registry( - model, - prompt_tokens, - completion_tokens, - cache_read_tokens, - cache_write_tokens, - registry, - &self.pricing, - 5.0, - 15.0, - ) - } - - async fn chat_completion_stream( - &self, - request: UnifiedRequest, - ) -> Result>, AppError> { - let messages_json = helpers::messages_to_openai_json(&request.messages).await?; - let body = helpers::build_openai_body(&request, messages_json, true); - - let es = reqwest_eventsource::EventSource::new( - self.client - .post(format!("{}/chat/completions", self.config.base_url)) - .header("Authorization", format!("Bearer {}", self.api_key)) - .json(&body), - ) - .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; - - Ok(helpers::create_openai_stream(es, request.model, None)) - } -} diff --git a/src/providers/helpers.rs b/src/providers/helpers.rs deleted file mode 100644 index a845d065..00000000 --- a/src/providers/helpers.rs +++ /dev/null @@ -1,454 +0,0 @@ -use super::{ProviderResponse, ProviderStreamChunk, StreamUsage}; -use crate::errors::AppError; -use crate::models::{ContentPart, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest}; -use futures::stream::{BoxStream, StreamExt}; -use serde_json::Value; - -/// Convert messages to OpenAI-compatible JSON, resolving images asynchronously. -/// -/// This avoids the deadlock caused by `futures::executor::block_on` inside a -/// Tokio async context. All image base64 conversions are awaited properly. -/// Handles tool-calling messages: assistant messages with tool_calls, and -/// tool-role messages with tool_call_id/name. -pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result, AppError> { - let mut result = Vec::new(); - for m in messages { - // Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." } - if m.role == "tool" { - let text_content = m - .content - .first() - .map(|p| match p { - ContentPart::Text { text } => text.clone(), - ContentPart::Image(_) => "[Image]".to_string(), - }) - .unwrap_or_default(); - - let mut msg = serde_json::json!({ - "role": "tool", - "content": text_content - }); - if let Some(tool_call_id) = &m.tool_call_id { - // OpenAI and others have a 40-char limit for tool_call_id. - // Gemini signatures (56 chars) must be shortened for compatibility. - let id = if tool_call_id.len() > 40 { - &tool_call_id[..40] - } else { - tool_call_id - }; - msg["tool_call_id"] = serde_json::json!(id); - } - if let Some(name) = &m.name { - msg["name"] = serde_json::json!(name); - } - result.push(msg); - continue; - } - - // Build content parts for non-tool messages - let mut parts = Vec::new(); - for p in &m.content { - match p { - ContentPart::Text { text } => { - parts.push(serde_json::json!({ "type": "text", "text": text })); - } - ContentPart::Image(image_input) => { - let (base64_data, mime_type) = image_input - .to_base64() - .await - .map_err(|e| AppError::MultimodalError(e.to_string()))?; - parts.push(serde_json::json!({ - "type": "image_url", - "image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) } - })); - } - } - } - - let mut msg = serde_json::json!({ "role": m.role }); - - // Include reasoning_content if present (DeepSeek R1/reasoner requires this in history) - if let Some(reasoning) = &m.reasoning_content { - msg["reasoning_content"] = serde_json::json!(reasoning); - } - - // For assistant messages with tool_calls, content can be empty string - if let Some(tool_calls) = &m.tool_calls { - // Sanitize tool call IDs for OpenAI compatibility (max 40 chars) - let sanitized_calls: Vec<_> = tool_calls.iter().map(|tc| { - let mut sanitized = tc.clone(); - if sanitized.id.len() > 40 { - sanitized.id = sanitized.id[..40].to_string(); - } - sanitized - }).collect(); - - if parts.is_empty() { - msg["content"] = serde_json::json!(""); - } else { - msg["content"] = serde_json::json!(parts); - } - msg["tool_calls"] = serde_json::json!(sanitized_calls); - } else { - msg["content"] = serde_json::json!(parts); - } - - if let Some(name) = &m.name { - msg["name"] = serde_json::json!(name); - } - - result.push(msg); - } - Ok(result) -} - -/// Convert messages to OpenAI-compatible JSON, but replace images with a -/// text placeholder "[Image]". Useful for providers that don't support -/// multimodal in streaming mode or at all. -/// -/// Handles tool-calling messages identically to `messages_to_openai_json`: -/// assistant messages with `tool_calls`, and tool-role messages with -/// `tool_call_id`/`name`. -pub async fn messages_to_openai_json_text_only( - messages: &[UnifiedMessage], -) -> Result, AppError> { - let mut result = Vec::new(); - for m in messages { - // Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." } - if m.role == "tool" { - let text_content = m - .content - .first() - .map(|p| match p { - ContentPart::Text { text } => text.clone(), - ContentPart::Image(_) => "[Image]".to_string(), - }) - .unwrap_or_default(); - - let mut msg = serde_json::json!({ - "role": "tool", - "content": text_content - }); - if let Some(tool_call_id) = &m.tool_call_id { - // OpenAI and others have a 40-char limit for tool_call_id. - let id = if tool_call_id.len() > 40 { - &tool_call_id[..40] - } else { - tool_call_id - }; - msg["tool_call_id"] = serde_json::json!(id); - } - if let Some(name) = &m.name { - msg["name"] = serde_json::json!(name); - } - result.push(msg); - continue; - } - - // Build content parts for non-tool messages (images become "[Image]" text) - let mut parts = Vec::new(); - for p in &m.content { - match p { - ContentPart::Text { text } => { - parts.push(serde_json::json!({ "type": "text", "text": text })); - } - ContentPart::Image(_) => { - parts.push(serde_json::json!({ "type": "text", "text": "[Image]" })); - } - } - } - - let mut msg = serde_json::json!({ "role": m.role }); - - // Include reasoning_content if present (DeepSeek R1/reasoner requires this in history) - if let Some(reasoning) = &m.reasoning_content { - msg["reasoning_content"] = serde_json::json!(reasoning); - } - - // For assistant messages with tool_calls, content can be empty string - if let Some(tool_calls) = &m.tool_calls { - // Sanitize tool call IDs for OpenAI compatibility (max 40 chars) - let sanitized_calls: Vec<_> = tool_calls.iter().map(|tc| { - let mut sanitized = tc.clone(); - if sanitized.id.len() > 40 { - sanitized.id = sanitized.id[..40].to_string(); - } - sanitized - }).collect(); - - if parts.is_empty() { - msg["content"] = serde_json::json!(""); - } else { - msg["content"] = serde_json::json!(parts); - } - msg["tool_calls"] = serde_json::json!(sanitized_calls); - } else { - msg["content"] = serde_json::json!(parts); - } - - if let Some(name) = &m.name { - msg["name"] = serde_json::json!(name); - } - - result.push(msg); - } - Ok(result) -} - -/// Build an OpenAI-compatible request body from a UnifiedRequest and pre-converted messages. -/// Includes tools and tool_choice when present. -/// When streaming, adds `stream_options.include_usage: true` so providers report -/// token counts in the final SSE chunk. -pub fn build_openai_body( - request: &UnifiedRequest, - messages_json: Vec, - stream: bool, -) -> serde_json::Value { - let mut body = serde_json::json!({ - "model": request.model, - "messages": messages_json, - "stream": stream, - }); - - if stream { - body["stream_options"] = serde_json::json!({ "include_usage": true }); - } - - if let Some(temp) = request.temperature { - body["temperature"] = serde_json::json!(temp); - } - if let Some(max_tokens) = request.max_tokens { - body["max_tokens"] = serde_json::json!(max_tokens); - } - if let Some(tools) = &request.tools { - body["tools"] = serde_json::json!(tools); - } - if let Some(tool_choice) = &request.tool_choice { - body["tool_choice"] = serde_json::json!(tool_choice); - } - - body -} - -/// Parse an OpenAI-compatible chat completion response JSON into a ProviderResponse. -/// Extracts tool_calls from the message when present. -/// Extracts cache token counts from: -/// - OpenAI/Grok: `usage.prompt_tokens_details.cached_tokens` -/// - DeepSeek: `usage.prompt_cache_hit_tokens` / `usage.prompt_cache_miss_tokens` -pub fn parse_openai_response(resp_json: &Value, model: String) -> Result { - let choice = resp_json["choices"] - .get(0) - .ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?; - let message = &choice["message"]; - - let content = message["content"].as_str().unwrap_or_default().to_string(); - let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string()); - - // Parse tool_calls from the response message - let tool_calls: Option> = message - .get("tool_calls") - .and_then(|tc| serde_json::from_value(tc.clone()).ok()); - - let usage = &resp_json["usage"]; - let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32; - let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32; - let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32; - - // Extract reasoning tokens - let reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"] - .as_u64() - .unwrap_or(0) as u32; - - // Extract cache tokens — try OpenAI/Grok format first, then DeepSeek format - let cache_read_tokens = usage["prompt_tokens_details"]["cached_tokens"] - .as_u64() - // DeepSeek uses a different field name - .or_else(|| usage["prompt_cache_hit_tokens"].as_u64()) - .unwrap_or(0) as u32; - - // DeepSeek reports prompt_cache_miss_tokens which are just regular non-cached tokens. - // They do not incur a separate cache_write fee, so we don't map them here to avoid double-charging. - let cache_write_tokens = 0; - - Ok(ProviderResponse { - content, - reasoning_content, - tool_calls, - prompt_tokens, - completion_tokens, - reasoning_tokens, - total_tokens, - cache_read_tokens, - cache_write_tokens, - model, - }) -} - -/// Parse a single OpenAI-compatible stream chunk into a ProviderStreamChunk. -/// Returns None if the chunk should be skipped (e.g. promptFeedback). -pub fn parse_openai_stream_chunk( - chunk: &Value, - model: &str, - reasoning_field: Option<&'static str>, -) -> Option> { - // Parse usage from the final chunk (sent when stream_options.include_usage is true). - // This chunk may have an empty `choices` array. - let stream_usage = chunk.get("usage").and_then(|u| { - if u.is_null() { - return None; - } - let prompt_tokens = u["prompt_tokens"].as_u64().unwrap_or(0) as u32; - let completion_tokens = u["completion_tokens"].as_u64().unwrap_or(0) as u32; - let total_tokens = u["total_tokens"].as_u64().unwrap_or(0) as u32; - - let reasoning_tokens = u["completion_tokens_details"]["reasoning_tokens"] - .as_u64() - .unwrap_or(0) as u32; - - let cache_read_tokens = u["prompt_tokens_details"]["cached_tokens"] - .as_u64() - .or_else(|| u["prompt_cache_hit_tokens"].as_u64()) - .unwrap_or(0) as u32; - - let cache_write_tokens = 0; - - Some(StreamUsage { - prompt_tokens, - completion_tokens, - reasoning_tokens, - total_tokens, - cache_read_tokens, - cache_write_tokens, - }) - }); - - if let Some(choice) = chunk["choices"].get(0) { - let delta = &choice["delta"]; - let content = delta["content"].as_str().unwrap_or_default().to_string(); - let reasoning_content = delta["reasoning_content"] - .as_str() - .or_else(|| reasoning_field.and_then(|f| delta[f].as_str())) - .map(|s| s.to_string()); - let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string()); - - // Parse tool_calls deltas from the stream chunk - let tool_calls: Option> = delta - .get("tool_calls") - .and_then(|tc| serde_json::from_value(tc.clone()).ok()); - - Some(Ok(ProviderStreamChunk { - content, - reasoning_content, - finish_reason, - tool_calls, - model: model.to_string(), - usage: stream_usage, - })) - } else if stream_usage.is_some() { - // Final usage-only chunk (empty choices array) — yield it so - // AggregatingStream can capture the real token counts. - Some(Ok(ProviderStreamChunk { - content: String::new(), - reasoning_content: None, - finish_reason: None, - tool_calls: None, - model: model.to_string(), - usage: stream_usage, - })) - } else { - None - } -} - -/// Create an SSE stream that parses OpenAI-compatible streaming chunks. -/// -/// The optional `reasoning_field` allows overriding the field name for -/// reasoning content (e.g., "thought" for Ollama). -/// Parses tool_calls deltas from streaming chunks when present. -/// When `stream_options.include_usage: true` was sent, the provider sends a -/// final chunk with `usage` data — this is parsed into `StreamUsage` and -/// attached to the yielded `ProviderStreamChunk`. -pub fn create_openai_stream( - es: reqwest_eventsource::EventSource, - model: String, - reasoning_field: Option<&'static str>, -) -> BoxStream<'static, Result> { - use reqwest_eventsource::Event; - - let stream = async_stream::try_stream! { - let mut es = es; - while let Some(event) = es.next().await { - match event { - Ok(Event::Message(msg)) => { - if msg.data == "[DONE]" { - break; - } - - let chunk: Value = serde_json::from_str(&msg.data) - .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; - - if let Some(p_chunk) = parse_openai_stream_chunk(&chunk, &model, reasoning_field) { - yield p_chunk?; - } - } - Ok(_) => continue, - Err(e) => { - Err(AppError::ProviderError(format!("Stream error: {}", e)))?; - } - } - } - }; - - Box::pin(stream) -} - -/// Calculate cost using the model registry first, then falling back to provider pricing config. -/// -/// When the registry provides `cache_read` / `cache_write` rates, the formula is: -/// (prompt_tokens - cache_read_tokens) * input_rate -/// + cache_read_tokens * cache_read_rate -/// + cache_write_tokens * cache_write_rate (if applicable) -/// + completion_tokens * output_rate -/// -/// All rates are per-token (the registry stores per-million-token rates). -pub fn calculate_cost_with_registry( - model: &str, - prompt_tokens: u32, - completion_tokens: u32, - cache_read_tokens: u32, - cache_write_tokens: u32, - registry: &crate::models::registry::ModelRegistry, - pricing: &[crate::config::ModelPricing], - default_prompt_rate: f64, - default_completion_rate: f64, -) -> f64 { - if let Some(metadata) = registry.find_model(model) - && let Some(cost) = &metadata.cost - { - let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens); - let mut total = (non_cached_prompt as f64 * cost.input / 1_000_000.0) - + (completion_tokens as f64 * cost.output / 1_000_000.0); - - if let Some(cache_read_rate) = cost.cache_read { - total += cache_read_tokens as f64 * cache_read_rate / 1_000_000.0; - } else { - // No cache_read rate — charge cached tokens at full input rate - total += cache_read_tokens as f64 * cost.input / 1_000_000.0; - } - - if let Some(cache_write_rate) = cost.cache_write { - total += cache_write_tokens as f64 * cache_write_rate / 1_000_000.0; - } - - return total; - } - - // Fallback: no registry entry — use provider pricing config (no cache awareness) - let (prompt_rate, completion_rate) = pricing - .iter() - .find(|p| model.contains(&p.model)) - .map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million)) - .unwrap_or((default_prompt_rate, default_completion_rate)); - - (prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0) -} diff --git a/src/providers/mod.rs b/src/providers/mod.rs deleted file mode 100644 index 975342e5..00000000 --- a/src/providers/mod.rs +++ /dev/null @@ -1,365 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; -use futures::stream::BoxStream; -use sqlx::Row; -use std::sync::Arc; - -use crate::errors::AppError; -use crate::models::UnifiedRequest; - - -pub mod deepseek; -pub mod gemini; -pub mod grok; -pub mod helpers; -pub mod ollama; -pub mod openai; - -#[async_trait] -pub trait Provider: Send + Sync { - /// Get provider name (e.g., "openai", "gemini") - fn name(&self) -> &str; - - /// Check if provider supports a specific model - fn supports_model(&self, model: &str) -> bool; - - /// Check if provider supports multimodal (images, etc.) - fn supports_multimodal(&self) -> bool; - - /// Process a chat completion request - async fn chat_completion(&self, request: UnifiedRequest) -> Result; - - /// Process a chat request using provider-specific "responses" style endpoint - /// Default implementation falls back to `chat_completion` for providers - /// that do not implement a dedicated responses endpoint. - async fn chat_responses(&self, request: UnifiedRequest) -> Result { - self.chat_completion(request).await - } - - /// Process a streaming chat completion request - async fn chat_completion_stream( - &self, - request: UnifiedRequest, - ) -> Result>, AppError>; - - /// Process a streaming chat request using provider-specific "responses" style endpoint - /// Default implementation falls back to `chat_completion_stream` for providers - /// that do not implement a dedicated responses endpoint. - async fn chat_responses_stream( - &self, - request: UnifiedRequest, - ) -> Result>, AppError> { - self.chat_completion_stream(request).await - } - - /// Estimate token count for a request (for cost calculation) - fn estimate_tokens(&self, request: &UnifiedRequest) -> Result; - - /// Calculate cost based on token usage and model using the registry. - /// `cache_read_tokens` / `cache_write_tokens` allow cache-aware pricing - /// when the registry provides `cache_read` / `cache_write` rates. - fn calculate_cost( - &self, - model: &str, - prompt_tokens: u32, - completion_tokens: u32, - cache_read_tokens: u32, - cache_write_tokens: u32, - registry: &crate::models::registry::ModelRegistry, - ) -> f64; -} - -pub struct ProviderResponse { - pub content: String, - pub reasoning_content: Option, - pub tool_calls: Option>, - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub reasoning_tokens: u32, - pub total_tokens: u32, - pub cache_read_tokens: u32, - pub cache_write_tokens: u32, - pub model: String, -} - -/// Usage data from the final streaming chunk (when providers report real token counts). -#[derive(Debug, Clone, Default)] -pub struct StreamUsage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub reasoning_tokens: u32, - pub total_tokens: u32, - pub cache_read_tokens: u32, - pub cache_write_tokens: u32, -} - -#[derive(Debug, Clone)] -pub struct ProviderStreamChunk { - pub content: String, - pub reasoning_content: Option, - pub finish_reason: Option, - pub tool_calls: Option>, - pub model: String, - /// Populated only on the final chunk when providers report usage (e.g. stream_options.include_usage). - pub usage: Option, -} - -use tokio::sync::RwLock; - -use crate::config::AppConfig; -use crate::providers::{ - deepseek::DeepSeekProvider, gemini::GeminiProvider, grok::GrokProvider, ollama::OllamaProvider, - openai::OpenAIProvider, -}; - -#[derive(Clone)] -pub struct ProviderManager { - providers: Arc>>>, -} - -impl Default for ProviderManager { - fn default() -> Self { - Self::new() - } -} - -impl ProviderManager { - pub fn new() -> Self { - Self { - providers: Arc::new(RwLock::new(Vec::new())), - } - } - - /// Initialize a provider by name using config and database overrides - pub async fn initialize_provider( - &self, - name: &str, - app_config: &AppConfig, - db_pool: &crate::database::DbPool, - ) -> Result<()> { - // Load override from database - let db_config = sqlx::query("SELECT enabled, base_url, api_key, api_key_encrypted FROM provider_configs WHERE id = ?") - .bind(name) - .fetch_optional(db_pool) - .await?; - - let (enabled, base_url, api_key) = if let Some(row) = db_config { - let enabled = row.get::("enabled"); - let base_url = row.get::, _>("base_url"); - let api_key_encrypted = row.get::("api_key_encrypted"); - let api_key = row.get::, _>("api_key"); - // Decrypt API key if encrypted - let api_key = match (api_key, api_key_encrypted) { - (Some(key), true) => { - match crate::utils::crypto::decrypt(&key) { - Ok(decrypted) => Some(decrypted), - Err(e) => { - tracing::error!("Failed to decrypt API key for provider {}: {}", name, e); - None - } - } - } - (Some(key), false) => { - // Plaintext key - optionally encrypt and update database (lazy migration) - // For now, just use plaintext - Some(key) - } - (None, _) => None, - }; - (enabled, base_url, api_key) - } else { - // No database override, use defaults from AppConfig - match name { - "openai" => ( - app_config.providers.openai.enabled, - Some(app_config.providers.openai.base_url.clone()), - None, - ), - "gemini" => ( - app_config.providers.gemini.enabled, - Some(app_config.providers.gemini.base_url.clone()), - None, - ), - "deepseek" => ( - app_config.providers.deepseek.enabled, - Some(app_config.providers.deepseek.base_url.clone()), - None, - ), - "grok" => ( - app_config.providers.grok.enabled, - Some(app_config.providers.grok.base_url.clone()), - None, - ), - "ollama" => ( - app_config.providers.ollama.enabled, - Some(app_config.providers.ollama.base_url.clone()), - None, - ), - _ => (false, None, None), - } - }; - - if !enabled { - self.remove_provider(name).await; - return Ok(()); - } - - // Create provider instance with merged config - let provider: Arc = match name { - "openai" => { - let mut cfg = app_config.providers.openai.clone(); - if let Some(url) = base_url { - cfg.base_url = url; - } - // Handle API key override if present - let p = if let Some(key) = api_key { - // We need a way to create a provider with an explicit key - // Let's modify the providers to allow this - OpenAIProvider::new_with_key(&cfg, app_config, key)? - } else { - OpenAIProvider::new(&cfg, app_config)? - }; - Arc::new(p) - } - "ollama" => { - let mut cfg = app_config.providers.ollama.clone(); - if let Some(url) = base_url { - cfg.base_url = url; - } - Arc::new(OllamaProvider::new(&cfg, app_config)?) - } - "gemini" => { - let mut cfg = app_config.providers.gemini.clone(); - if let Some(url) = base_url { - cfg.base_url = url; - } - let p = if let Some(key) = api_key { - GeminiProvider::new_with_key(&cfg, app_config, key)? - } else { - GeminiProvider::new(&cfg, app_config)? - }; - Arc::new(p) - } - "deepseek" => { - let mut cfg = app_config.providers.deepseek.clone(); - if let Some(url) = base_url { - cfg.base_url = url; - } - let p = if let Some(key) = api_key { - DeepSeekProvider::new_with_key(&cfg, app_config, key)? - } else { - DeepSeekProvider::new(&cfg, app_config)? - }; - Arc::new(p) - } - "grok" => { - let mut cfg = app_config.providers.grok.clone(); - if let Some(url) = base_url { - cfg.base_url = url; - } - let p = if let Some(key) = api_key { - GrokProvider::new_with_key(&cfg, app_config, key)? - } else { - GrokProvider::new(&cfg, app_config)? - }; - Arc::new(p) - } - _ => return Err(anyhow::anyhow!("Unknown provider: {}", name)), - }; - - self.add_provider(provider).await; - Ok(()) - } - - pub async fn add_provider(&self, provider: Arc) { - let mut providers = self.providers.write().await; - // If provider with same name exists, replace it - if let Some(index) = providers.iter().position(|p| p.name() == provider.name()) { - providers[index] = provider; - } else { - providers.push(provider); - } - } - - pub async fn remove_provider(&self, name: &str) { - let mut providers = self.providers.write().await; - providers.retain(|p| p.name() != name); - } - - pub async fn get_provider_for_model(&self, model: &str) -> Option> { - let providers = self.providers.read().await; - providers.iter().find(|p| p.supports_model(model)).map(Arc::clone) - } - - pub async fn get_provider(&self, name: &str) -> Option> { - let providers = self.providers.read().await; - providers.iter().find(|p| p.name() == name).map(Arc::clone) - } - - pub async fn get_all_providers(&self) -> Vec> { - let providers = self.providers.read().await; - providers.clone() - } -} - -// Create placeholder provider implementations -pub mod placeholder { - use super::*; - - pub struct PlaceholderProvider { - name: String, - } - - impl PlaceholderProvider { - pub fn new(name: &str) -> Self { - Self { name: name.to_string() } - } - } - - #[async_trait] - impl Provider for PlaceholderProvider { - fn name(&self) -> &str { - &self.name - } - - fn supports_model(&self, _model: &str) -> bool { - false - } - - fn supports_multimodal(&self) -> bool { - false - } - - async fn chat_completion_stream( - &self, - _request: UnifiedRequest, - ) -> Result>, AppError> { - Err(AppError::ProviderError( - "Streaming not supported for placeholder provider".to_string(), - )) - } - - async fn chat_completion(&self, _request: UnifiedRequest) -> Result { - Err(AppError::ProviderError(format!( - "Provider {} not implemented", - self.name - ))) - } - - fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result { - Ok(0) - } - - fn calculate_cost( - &self, - _model: &str, - _prompt_tokens: u32, - _completion_tokens: u32, - _cache_read_tokens: u32, - _cache_write_tokens: u32, - _registry: &crate::models::registry::ModelRegistry, - ) -> f64 { - 0.0 - } - } -} diff --git a/src/providers/ollama.rs b/src/providers/ollama.rs deleted file mode 100644 index 90b63d10..00000000 --- a/src/providers/ollama.rs +++ /dev/null @@ -1,140 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; -use futures::stream::BoxStream; - -use super::helpers; -use super::{ProviderResponse, ProviderStreamChunk}; -use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest}; - -pub struct OllamaProvider { - client: reqwest::Client, - config: crate::config::OllamaConfig, - pricing: Vec, -} - -impl OllamaProvider { - pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result { - let client = reqwest::Client::builder() - .connect_timeout(std::time::Duration::from_secs(5)) - .timeout(std::time::Duration::from_secs(300)) - .pool_idle_timeout(std::time::Duration::from_secs(90)) - .pool_max_idle_per_host(4) - .tcp_keepalive(std::time::Duration::from_secs(30)) - .build()?; - - Ok(Self { - client, - config: config.clone(), - pricing: app_config.pricing.ollama.clone(), - }) - } -} - -#[async_trait] -impl super::Provider for OllamaProvider { - fn name(&self) -> &str { - "ollama" - } - - fn supports_model(&self, model: &str) -> bool { - self.config.models.iter().any(|m| m == model) || model.starts_with("ollama/") - } - - fn supports_multimodal(&self) -> bool { - true - } - - async fn chat_completion(&self, mut request: UnifiedRequest) -> Result { - // Strip "ollama/" prefix if present for the API call - let api_model = request - .model - .strip_prefix("ollama/") - .unwrap_or(&request.model) - .to_string(); - let original_model = request.model.clone(); - request.model = api_model; - - let messages_json = helpers::messages_to_openai_json(&request.messages).await?; - let body = helpers::build_openai_body(&request, messages_json, false); - - let response = self - .client - .post(format!("{}/chat/completions", self.config.base_url)) - .json(&body) - .send() - .await - .map_err(|e| AppError::ProviderError(e.to_string()))?; - - if !response.status().is_success() { - let error_text = response.text().await.unwrap_or_default(); - return Err(AppError::ProviderError(format!("Ollama API error: {}", error_text))); - } - - let resp_json: serde_json::Value = response - .json() - .await - .map_err(|e| AppError::ProviderError(e.to_string()))?; - - // Ollama also supports "thought" as an alias for reasoning_content - let mut result = helpers::parse_openai_response(&resp_json, original_model)?; - if result.reasoning_content.is_none() { - result.reasoning_content = resp_json["choices"] - .get(0) - .and_then(|c| c["message"]["thought"].as_str()) - .map(|s| s.to_string()); - } - Ok(result) - } - - fn estimate_tokens(&self, request: &UnifiedRequest) -> Result { - Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) - } - - fn calculate_cost( - &self, - model: &str, - prompt_tokens: u32, - completion_tokens: u32, - cache_read_tokens: u32, - cache_write_tokens: u32, - registry: &crate::models::registry::ModelRegistry, - ) -> f64 { - helpers::calculate_cost_with_registry( - model, - prompt_tokens, - completion_tokens, - cache_read_tokens, - cache_write_tokens, - registry, - &self.pricing, - 0.0, - 0.0, - ) - } - - async fn chat_completion_stream( - &self, - mut request: UnifiedRequest, - ) -> Result>, AppError> { - let api_model = request - .model - .strip_prefix("ollama/") - .unwrap_or(&request.model) - .to_string(); - let original_model = request.model.clone(); - request.model = api_model; - - let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?; - let body = helpers::build_openai_body(&request, messages_json, true); - - let es = reqwest_eventsource::EventSource::new( - self.client - .post(format!("{}/chat/completions", self.config.base_url)) - .json(&body), - ) - .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; - - // Ollama uses "thought" as an alternative field for reasoning content - Ok(helpers::create_openai_stream(es, original_model, Some("thought"))) - } -} diff --git a/src/providers/openai.rs b/src/providers/openai.rs deleted file mode 100644 index 7bc5d86c..00000000 --- a/src/providers/openai.rs +++ /dev/null @@ -1,564 +0,0 @@ -use anyhow::Result; -use async_trait::async_trait; -use futures::stream::BoxStream; -use futures::StreamExt; - -use super::helpers; -use super::{ProviderResponse, ProviderStreamChunk}; -use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest}; - -pub struct OpenAIProvider { - client: reqwest::Client, - config: crate::config::OpenAIConfig, - api_key: String, - pricing: Vec, -} - -impl OpenAIProvider { - pub fn new(config: &crate::config::OpenAIConfig, app_config: &AppConfig) -> Result { - let api_key = app_config.get_api_key("openai")?; - Self::new_with_key(config, app_config, api_key) - } - - pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result { - let client = reqwest::Client::builder() - .connect_timeout(std::time::Duration::from_secs(5)) - .timeout(std::time::Duration::from_secs(300)) - .pool_idle_timeout(std::time::Duration::from_secs(90)) - .pool_max_idle_per_host(4) - .tcp_keepalive(std::time::Duration::from_secs(15)) - .build()?; - - Ok(Self { - client, - config: config.clone(), - api_key, - pricing: app_config.pricing.openai.clone(), - }) - } -} - -#[async_trait] -impl super::Provider for OpenAIProvider { - fn name(&self) -> &str { - "openai" - } - - fn supports_model(&self, model: &str) -> bool { - model.starts_with("gpt-") || - model.starts_with("o1-") || - model.starts_with("o2-") || - model.starts_with("o3-") || - model.starts_with("o4-") || - model.starts_with("o5-") || - model.contains("gpt-5") - } - - fn supports_multimodal(&self) -> bool { - true - } - - async fn chat_completion(&self, request: UnifiedRequest) -> Result { - // Allow proactive routing to Responses API based on heuristic - let model_lc = request.model.to_lowercase(); - if model_lc.contains("gpt-5") || model_lc.contains("codex") { - return self.chat_responses(request).await; - } - - let messages_json = helpers::messages_to_openai_json(&request.messages).await?; - let mut body = helpers::build_openai_body(&request, messages_json, false); - - // Transition: Newer OpenAI models (o1, o3, gpt-5) require max_completion_tokens - // instead of the legacy max_tokens parameter. - if request.model.starts_with("o1-") || request.model.starts_with("o3-") || request.model.contains("gpt-5") { - if let Some(max_tokens) = body.as_object_mut().and_then(|obj| obj.remove("max_tokens")) { - body["max_completion_tokens"] = max_tokens; - } - } - - let response = self - .client - .post(format!("{}/chat/completions", self.config.base_url)) - .header("Authorization", format!("Bearer {}", self.api_key)) - .json(&body) - .send() - .await - .map_err(|e| AppError::ProviderError(e.to_string()))?; - - if !response.status().is_success() { - let status = response.status(); - let error_text = response.text().await.unwrap_or_default(); - - // Read error body to diagnose. If the model requires the Responses - // API (v1/responses), retry against that endpoint. - if error_text.to_lowercase().contains("v1/responses") || error_text.to_lowercase().contains("only supported in v1/responses") { - return self.chat_responses(request).await; - } - - tracing::error!("OpenAI API error ({}): {}", status, error_text); - return Err(AppError::ProviderError(format!("OpenAI API error ({}): {}", status, error_text))); - } - - let resp_json: serde_json::Value = response - .json() - .await - .map_err(|e| AppError::ProviderError(e.to_string()))?; - - helpers::parse_openai_response(&resp_json, request.model) - } - - async fn chat_responses(&self, request: UnifiedRequest) -> Result { - // Build a structured input for the Responses API. - let messages_json = helpers::messages_to_openai_json(&request.messages).await?; - let mut input_parts = Vec::new(); - for m in &messages_json { - let mut role = m["role"].as_str().unwrap_or("user").to_string(); - // Newer models (gpt-5, o1) prefer "developer" over "system" - if role == "system" { - role = "developer".to_string(); - } - - let mut content = m.get("content").cloned().unwrap_or(serde_json::json!([])); - - // Map content types based on role for Responses API - if let Some(content_array) = content.as_array_mut() { - for part in content_array { - if let Some(part_obj) = part.as_object_mut() { - if let Some(t) = part_obj.get("type").and_then(|v| v.as_str()) { - match t { - "text" => { - let new_type = if role == "assistant" { "output_text" } else { "input_text" }; - part_obj.insert("type".to_string(), serde_json::json!(new_type)); - } - "image_url" => { - // Assistant typically doesn't have image_url in history this way, but for safety: - let new_type = if role == "assistant" { "output_image" } else { "input_image" }; - part_obj.insert("type".to_string(), serde_json::json!(new_type)); - if let Some(img_url) = part_obj.remove("image_url") { - part_obj.insert("image".to_string(), img_url); - } - } - _ => {} - } - } - } - } - } else if let Some(text) = content.as_str() { - let new_type = if role == "assistant" { "output_text" } else { "input_text" }; - content = serde_json::json!([{ "type": new_type, "text": text }]); - } - - input_parts.push(serde_json::json!({ - "role": role, - "content": content - })); - } - - let mut body = serde_json::json!({ - "model": request.model, - "input": input_parts, - }); - - // Add standard parameters - if let Some(temp) = request.temperature { - body["temperature"] = serde_json::json!(temp); - } - - // Newer models (gpt-5, o1) in Responses API use max_output_tokens - if let Some(max_tokens) = request.max_tokens { - if request.model.contains("gpt-5") || request.model.starts_with("o1-") || request.model.starts_with("o3-") { - body["max_output_tokens"] = serde_json::json!(max_tokens); - } else { - body["max_tokens"] = serde_json::json!(max_tokens); - } - } - - if let Some(tools) = &request.tools { - body["tools"] = serde_json::json!(tools); - } - - let resp = self - .client - .post(format!("{}/responses", self.config.base_url)) - .header("Authorization", format!("Bearer {}", self.api_key)) - .json(&body) - .send() - .await - .map_err(|e| AppError::ProviderError(e.to_string()))?; - - if !resp.status().is_success() { - let err = resp.text().await.unwrap_or_default(); - return Err(AppError::ProviderError(format!("OpenAI Responses API error: {}", err))); - } - - let resp_json: serde_json::Value = resp.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?; - - // Try to normalize: if it's chat-style, use existing parser - if resp_json.get("choices").is_some() { - return helpers::parse_openai_response(&resp_json, request.model); - } - - // Normalize Responses API output into ProviderResponse - let mut content_text = String::new(); - if let Some(output) = resp_json.get("output").and_then(|o| o.as_array()) { - for out in output { - if let Some(contents) = out.get("content").and_then(|c| c.as_array()) { - for item in contents { - if let Some(text) = item.get("text").and_then(|t| t.as_str()) { - if !content_text.is_empty() { content_text.push_str("\n"); } - content_text.push_str(text); - } else if let Some(parts) = item.get("parts").and_then(|p| p.as_array()) { - for p in parts { - if let Some(t) = p.as_str() { - if !content_text.is_empty() { content_text.push_str("\n"); } - content_text.push_str(t); - } - } - } - } - } - } - } - - if content_text.is_empty() { - if let Some(cands) = resp_json.get("candidates").and_then(|c| c.as_array()) { - if let Some(c0) = cands.get(0) { - if let Some(content) = c0.get("content") { - if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) { - for p in parts { - if let Some(t) = p.get("text").and_then(|v| v.as_str()) { - if !content_text.is_empty() { content_text.push_str("\n"); } - content_text.push_str(t); - } - } - } - } - } - } - } - - let prompt_tokens = resp_json.get("usage").and_then(|u| u.get("prompt_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32; - let completion_tokens = resp_json.get("usage").and_then(|u| u.get("completion_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32; - let total_tokens = resp_json.get("usage").and_then(|u| u.get("total_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32; - - Ok(ProviderResponse { - content: content_text, - reasoning_content: None, - tool_calls: None, - prompt_tokens, - completion_tokens, - reasoning_tokens: 0, - total_tokens, - cache_read_tokens: 0, - cache_write_tokens: 0, - model: request.model, - }) - } - - fn estimate_tokens(&self, request: &UnifiedRequest) -> Result { - Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request)) - } - - fn calculate_cost( - &self, - model: &str, - prompt_tokens: u32, - completion_tokens: u32, - cache_read_tokens: u32, - cache_write_tokens: u32, - registry: &crate::models::registry::ModelRegistry, - ) -> f64 { - helpers::calculate_cost_with_registry( - model, - prompt_tokens, - completion_tokens, - cache_read_tokens, - cache_write_tokens, - registry, - &self.pricing, - 0.15, - 0.60, - ) - } - - async fn chat_completion_stream( - &self, - request: UnifiedRequest, - ) -> Result>, AppError> { - // Allow proactive routing to Responses API based on heuristic - let model_lc = request.model.to_lowercase(); - if model_lc.contains("gpt-5") || model_lc.contains("codex") { - return self.chat_responses_stream(request).await; - } - - let messages_json = helpers::messages_to_openai_json(&request.messages).await?; - let mut body = helpers::build_openai_body(&request, messages_json, true); - - // Standard OpenAI cleanup - if let Some(obj) = body.as_object_mut() { - // stream_options.include_usage is supported by OpenAI for token usage in streaming - // Transition: Newer OpenAI models (o1, o3, gpt-5) require max_completion_tokens - if request.model.starts_with("o1-") || request.model.starts_with("o3-") || request.model.contains("gpt-5") { - if let Some(max_tokens) = obj.remove("max_tokens") { - obj.insert("max_completion_tokens".to_string(), max_tokens); - } - } - } - - let url = format!("{}/chat/completions", self.config.base_url); - let api_key = self.api_key.clone(); - let probe_client = self.client.clone(); - let probe_body = body.clone(); - let model = request.model.clone(); - - let es = reqwest_eventsource::EventSource::new( - self.client - .post(&url) - .header("Authorization", format!("Bearer {}", self.api_key)) - .json(&body), - ) - .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?; - - let stream = async_stream::try_stream! { - let mut es = es; - while let Some(event) = es.next().await { - match event { - Ok(reqwest_eventsource::Event::Message(msg)) => { - if msg.data == "[DONE]" { - break; - } - - let chunk: serde_json::Value = serde_json::from_str(&msg.data) - .map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?; - - if let Some(p_chunk) = helpers::parse_openai_stream_chunk(&chunk, &model, None) { - yield p_chunk?; - } - } - Ok(_) => continue, - Err(e) => { - // Attempt to probe for the actual error body - let probe_resp = probe_client - .post(&url) - .header("Authorization", format!("Bearer {}", api_key)) - .json(&probe_body) - .send() - .await; - - match probe_resp { - Ok(r) if !r.status().is_success() => { - let status = r.status(); - let error_body = r.text().await.unwrap_or_default(); - tracing::error!("OpenAI Stream Error Probe ({}): {}", status, error_body); - tracing::debug!("Offending OpenAI Request Body: {}", serde_json::to_string(&probe_body).unwrap_or_default()); - Err(AppError::ProviderError(format!("OpenAI API error ({}): {}", status, error_body)))?; - } - Ok(_) => { - // Probe returned success? This is unexpected if the original stream failed. - Err(AppError::ProviderError(format!("Stream error (probe returned 200): {}", e)))?; - } - Err(probe_err) => { - // Probe itself failed - tracing::error!("OpenAI Stream Error Probe failed: {}", probe_err); - Err(AppError::ProviderError(format!("Stream error (probe failed: {}): {}", probe_err, e)))?; - } - } - } - } - } - }; - - Ok(Box::pin(stream)) - } - - async fn chat_responses_stream( - &self, - request: UnifiedRequest, - ) -> Result>, AppError> { - // Build a structured input for the Responses API. - let messages_json = helpers::messages_to_openai_json(&request.messages).await?; - let mut input_parts = Vec::new(); - for m in &messages_json { - let mut role = m["role"].as_str().unwrap_or("user").to_string(); - // Newer models (gpt-5, o1) prefer "developer" over "system" - if role == "system" { - role = "developer".to_string(); - } - - let mut content = m.get("content").cloned().unwrap_or(serde_json::json!([])); - - // Map content types based on role for Responses API - if let Some(content_array) = content.as_array_mut() { - for part in content_array { - if let Some(part_obj) = part.as_object_mut() { - if let Some(t) = part_obj.get("type").and_then(|v| v.as_str()) { - match t { - "text" => { - let new_type = if role == "assistant" { "output_text" } else { "input_text" }; - part_obj.insert("type".to_string(), serde_json::json!(new_type)); - } - "image_url" => { - // Assistant typically doesn't have image_url in history this way, but for safety: - let new_type = if role == "assistant" { "output_image" } else { "input_image" }; - part_obj.insert("type".to_string(), serde_json::json!(new_type)); - if let Some(img_url) = part_obj.remove("image_url") { - part_obj.insert("image".to_string(), img_url); - } - } - _ => {} - } - } - } - } - } else if let Some(text) = content.as_str() { - let new_type = if role == "assistant" { "output_text" } else { "input_text" }; - content = serde_json::json!([{ "type": new_type, "text": text }]); - } - - input_parts.push(serde_json::json!({ - "role": role, - "content": content - })); - } - - let mut body = serde_json::json!({ - "model": request.model, - "input": input_parts, - "stream": true, - }); - - // Add standard parameters - if let Some(temp) = request.temperature { - body["temperature"] = serde_json::json!(temp); - } - - // Newer models (gpt-5, o1) in Responses API use max_output_tokens - if let Some(max_tokens) = request.max_tokens { - if request.model.contains("gpt-5") || request.model.starts_with("o1-") || request.model.starts_with("o3-") { - body["max_output_tokens"] = serde_json::json!(max_tokens); - } else { - body["max_tokens"] = serde_json::json!(max_tokens); - } - } - - let url = format!("{}/responses", self.config.base_url); - let api_key = self.api_key.clone(); - let model = request.model.clone(); - let probe_client = self.client.clone(); - let probe_body = body.clone(); - - let es = reqwest_eventsource::EventSource::new( - self.client - .post(&url) - .header("Authorization", format!("Bearer {}", api_key)) - .header("Accept", "text/event-stream") - .json(&body), - ) - .map_err(|e| AppError::ProviderError(format!("Failed to create EventSource for Responses API: {}", e)))?; - - let stream = async_stream::try_stream! { - let mut es = es; - while let Some(event) = es.next().await { - match event { - Ok(reqwest_eventsource::Event::Message(msg)) => { - if msg.data == "[DONE]" { - break; - } - - let chunk: serde_json::Value = serde_json::from_str(&msg.data) - .map_err(|e| AppError::ProviderError(format!("Failed to parse Responses stream chunk: {}", e)))?; - - // Try standard OpenAI parsing first (choices/usage) - if let Some(p_chunk) = helpers::parse_openai_stream_chunk(&chunk, &model, None) { - yield p_chunk?; - } else { - // Responses API specific parsing for streaming - let mut content = String::new(); - let mut finish_reason = None; - - let event_type = chunk.get("type").and_then(|v| v.as_str()).unwrap_or(""); - - match event_type { - "response.output_text.delta" => { - if let Some(delta) = chunk.get("delta").and_then(|v| v.as_str()) { - content.push_str(delta); - } - } - "response.output_text.done" => { - if let Some(text) = chunk.get("text").and_then(|v| v.as_str()) { - // Some implementations send the full text at the end - // We usually prefer deltas, but if we haven't seen them, this is the fallback. - // However, if we're already yielding deltas, we might not want this. - // For now, let's just use it as a signal that we're done. - finish_reason = Some("stop".to_string()); - } - } - "response.done" => { - finish_reason = Some("stop".to_string()); - } - _ => { - // Fallback to older nested structure if present - if let Some(output) = chunk.get("output").and_then(|o| o.as_array()) { - for out in output { - if let Some(contents) = out.get("content").and_then(|c| c.as_array()) { - for item in contents { - if let Some(text) = item.get("text").and_then(|t| t.as_str()) { - content.push_str(text); - } else if let Some(delta) = item.get("delta").and_then(|d| d.get("text")).and_then(|t| t.as_str()) { - content.push_str(delta); - } - } - } - } - } - } - } - - if !content.is_empty() || finish_reason.is_some() { - yield ProviderStreamChunk { - content, - reasoning_content: None, - finish_reason, - tool_calls: None, - model: model.clone(), - usage: None, - }; - } - } - } - Ok(_) => continue, - Err(e) => { - // Attempt to probe for the actual error body - let probe_resp = probe_client - .post(&url) - .header("Authorization", format!("Bearer {}", api_key)) - .header("Accept", "application/json") // Ask for JSON during probe - .json(&probe_body) - .send() - .await; - - match probe_resp { - Ok(r) => { - let status = r.status(); - let body = r.text().await.unwrap_or_default(); - if status.is_success() { - tracing::warn!("Responses stream ended prematurely but probe returned 200 OK. Body: {}", body); - Err(AppError::ProviderError(format!("Responses stream ended (server sent 200 OK with body: {})", body)))?; - } else { - tracing::error!("OpenAI Responses Stream Error Probe ({}): {}", status, body); - Err(AppError::ProviderError(format!("OpenAI Responses API error ({}): {}", status, body)))?; - } - } - Err(probe_err) => { - tracing::error!("OpenAI Responses Stream Error Probe failed: {}", probe_err); - Err(AppError::ProviderError(format!("Responses stream error (probe failed: {}): {}", probe_err, e)))?; - } - } - } - } - } - }; - - Ok(Box::pin(stream)) - } -} diff --git a/src/rate_limiting/mod.rs b/src/rate_limiting/mod.rs deleted file mode 100644 index 20ab75c9..00000000 --- a/src/rate_limiting/mod.rs +++ /dev/null @@ -1,353 +0,0 @@ -//! Rate limiting and circuit breaking for LLM proxy -//! -//! This module provides: -//! 1. Per-client rate limiting using governor crate -//! 2. Provider circuit breaking to handle API failures -//! 3. Global rate limiting for overall system protection - -use anyhow::Result; -use governor::{Quota, RateLimiter, DefaultDirectRateLimiter}; -use std::collections::HashMap; -use std::num::NonZeroU32; -use std::sync::Arc; -use tokio::sync::RwLock; -use tracing::{info, warn}; - -type GovRateLimiter = DefaultDirectRateLimiter; - -/// Rate limiter configuration -#[derive(Debug, Clone)] -pub struct RateLimiterConfig { - /// Requests per minute per client - pub requests_per_minute: u32, - /// Burst size (maximum burst capacity) - pub burst_size: u32, - /// Global requests per minute (across all clients) - pub global_requests_per_minute: u32, -} - -impl Default for RateLimiterConfig { - fn default() -> Self { - Self { - requests_per_minute: 60, // 1 request per second per client - burst_size: 10, // Allow bursts of up to 10 requests - global_requests_per_minute: 600, // 10 requests per second globally - } - } -} - -/// Circuit breaker state -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum CircuitState { - Closed, // Normal operation - Open, // Circuit is open, requests fail fast - HalfOpen, // Testing if service has recovered -} - -/// Circuit breaker configuration -#[derive(Debug, Clone)] -pub struct CircuitBreakerConfig { - /// Failure threshold to open circuit - pub failure_threshold: u32, - /// Time window for failure counting (seconds) - pub failure_window_secs: u64, - /// Time to wait before trying half-open state (seconds) - pub reset_timeout_secs: u64, - /// Success threshold to close circuit - pub success_threshold: u32, -} - -impl Default for CircuitBreakerConfig { - fn default() -> Self { - Self { - failure_threshold: 5, // 5 failures - failure_window_secs: 60, // within 60 seconds - reset_timeout_secs: 30, // wait 30 seconds before half-open - success_threshold: 3, // 3 successes to close circuit - } - } -} - - - -/// Circuit breaker for a provider -#[derive(Debug)] -pub struct ProviderCircuitBreaker { - state: CircuitState, - failure_count: u32, - success_count: u32, - last_failure_time: Option, - last_state_change: std::time::Instant, - config: CircuitBreakerConfig, -} - -impl ProviderCircuitBreaker { - pub fn new(config: CircuitBreakerConfig) -> Self { - Self { - state: CircuitState::Closed, - failure_count: 0, - success_count: 0, - last_failure_time: None, - last_state_change: std::time::Instant::now(), - config, - } - } - - /// Check if request is allowed - pub fn allow_request(&mut self) -> bool { - match self.state { - CircuitState::Closed => true, - CircuitState::Open => { - // Check if reset timeout has passed - let elapsed = self.last_state_change.elapsed(); - if elapsed.as_secs() >= self.config.reset_timeout_secs { - self.state = CircuitState::HalfOpen; - self.last_state_change = std::time::Instant::now(); - info!("Circuit breaker transitioning to half-open state"); - true - } else { - false - } - } - CircuitState::HalfOpen => true, - } - } - - /// Record a successful request - pub fn record_success(&mut self) { - match self.state { - CircuitState::Closed => { - // Reset failure count on success - self.failure_count = 0; - self.last_failure_time = None; - } - CircuitState::HalfOpen => { - self.success_count += 1; - if self.success_count >= self.config.success_threshold { - self.state = CircuitState::Closed; - self.success_count = 0; - self.failure_count = 0; - self.last_state_change = std::time::Instant::now(); - info!("Circuit breaker closed after successful requests"); - } - } - CircuitState::Open => { - // Should not happen, but handle gracefully - } - } - } - - /// Record a failed request - pub fn record_failure(&mut self) { - let now = std::time::Instant::now(); - - // Check if failure window has expired - if let Some(last_failure) = self.last_failure_time - && now.duration_since(last_failure).as_secs() > self.config.failure_window_secs - { - // Reset failure count if window expired - self.failure_count = 0; - } - - self.failure_count += 1; - self.last_failure_time = Some(now); - - if self.failure_count >= self.config.failure_threshold && self.state == CircuitState::Closed { - self.state = CircuitState::Open; - self.last_state_change = now; - warn!("Circuit breaker opened due to {} failures", self.failure_count); - } else if self.state == CircuitState::HalfOpen { - // Failure in half-open state, go back to open - self.state = CircuitState::Open; - self.success_count = 0; - self.last_state_change = now; - warn!("Circuit breaker re-opened after failure in half-open state"); - } - } - - /// Get current state - pub fn state(&self) -> CircuitState { - self.state - } -} - -/// Rate limiting and circuit breaking manager -#[derive(Debug)] -pub struct RateLimitManager { - client_buckets: Arc>>, - global_bucket: Arc, - circuit_breakers: Arc>>, - config: RateLimiterConfig, - circuit_config: CircuitBreakerConfig, -} - -impl RateLimitManager { - pub fn new(config: RateLimiterConfig, circuit_config: CircuitBreakerConfig) -> Self { - // Create global rate limiter quota - // Use a much larger burst size for the global bucket to handle concurrent dashboard load - let global_burst = config.global_requests_per_minute / 6; // e.g., 100 for 600 req/min - let global_quota = Quota::per_minute( - NonZeroU32::new(config.global_requests_per_minute).expect("global_requests_per_minute must be positive") - ) - .allow_burst(NonZeroU32::new(global_burst).expect("global_burst must be positive")); - let global_bucket = RateLimiter::direct(global_quota); - - Self { - client_buckets: Arc::new(RwLock::new(HashMap::new())), - global_bucket: Arc::new(global_bucket), - circuit_breakers: Arc::new(RwLock::new(HashMap::new())), - config, - circuit_config, - } - } - - /// Check if a client request is allowed - pub async fn check_client_request(&self, client_id: &str) -> Result { - // Check global rate limit first (1 token per request) - if self.global_bucket.check().is_err() { - warn!("Global rate limit exceeded"); - return Ok(false); - } - - // Check client-specific rate limit - let mut buckets = self.client_buckets.write().await; - let bucket = buckets.entry(client_id.to_string()).or_insert_with(|| { - let quota = Quota::per_minute( - NonZeroU32::new(self.config.requests_per_minute).expect("requests_per_minute must be positive") - ) - .allow_burst(NonZeroU32::new(self.config.burst_size).expect("burst_size must be positive")); - RateLimiter::direct(quota) - }); - - Ok(bucket.check().is_ok()) - } - - /// Check if provider requests are allowed (circuit breaker) - pub async fn check_provider_request(&self, provider_name: &str) -> Result { - let mut breakers = self.circuit_breakers.write().await; - let breaker = breakers - .entry(provider_name.to_string()) - .or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone())); - - Ok(breaker.allow_request()) - } - - /// Record provider success - pub async fn record_provider_success(&self, provider_name: &str) { - let mut breakers = self.circuit_breakers.write().await; - if let Some(breaker) = breakers.get_mut(provider_name) { - breaker.record_success(); - } - } - - /// Record provider failure - pub async fn record_provider_failure(&self, provider_name: &str) { - let mut breakers = self.circuit_breakers.write().await; - let breaker = breakers - .entry(provider_name.to_string()) - .or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone())); - - breaker.record_failure(); - } - - /// Get provider circuit state - pub async fn get_provider_state(&self, provider_name: &str) -> CircuitState { - let breakers = self.circuit_breakers.read().await; - breakers - .get(provider_name) - .map(|b| b.state()) - .unwrap_or(CircuitState::Closed) - } -} - -/// Axum middleware for rate limiting -pub mod middleware { - use super::*; - use crate::errors::AppError; - use crate::state::AppState; - use crate::auth::AuthInfo; - use axum::{ - extract::{Request, State}, - middleware::Next, - response::Response, - }; - use sqlx; - - /// Rate limiting middleware - pub async fn rate_limit_middleware( - State(state): State, - mut request: Request, - next: Next, - ) -> Result { - // Extract token synchronously from headers (avoids holding &Request across await) - let token = extract_bearer_token(&request); - - // Resolve client_id and populate AuthInfo: DB token lookup, then prefix fallback - let auth_info = resolve_auth_info(token, &state).await; - let client_id = auth_info.client_id.clone(); - - // Check rate limits - if !state.rate_limit_manager.check_client_request(&client_id).await? { - return Err(AppError::RateLimitError("Rate limit exceeded".to_string())); - } - - // Store AuthInfo in request extensions for extractors and downstream handlers - request.extensions_mut().insert(auth_info); - - Ok(next.run(request).await) - } - - /// Synchronously extract bearer token from request headers - fn extract_bearer_token(request: &Request) -> Option { - request.headers().get("Authorization") - .and_then(|v| v.to_str().ok()) - .and_then(|s| s.strip_prefix("Bearer ")) - .map(|t| t.to_string()) - } - - /// Resolve auth info: try DB token first, then fall back to token-prefix derivation - async fn resolve_auth_info(token: Option, state: &AppState) -> AuthInfo { - if let Some(token) = token { - // Try DB token lookup first - match sqlx::query_scalar::<_, String>( - "UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = TRUE RETURNING client_id", - ) - .bind(&token) - .fetch_optional(&state.db_pool) - .await - { - Ok(Some(cid)) => { - return AuthInfo { - token, - client_id: cid, - }; - } - Err(e) => { - warn!("DB error during token lookup: {}", e); - } - _ => {} - } - - // Fallback to token-prefix derivation (env tokens / permissive mode) - let client_id = format!("client_{}", &token[..8.min(token.len())]); - return AuthInfo { token, client_id }; - } - - // No token — anonymous - AuthInfo { - token: String::new(), - client_id: "anonymous".to_string(), - } - } - - /// Circuit breaker middleware for provider requests - pub async fn circuit_breaker_middleware(provider_name: &str, state: &AppState) -> Result<(), AppError> { - if !state.rate_limit_manager.check_provider_request(provider_name).await? { - return Err(AppError::ProviderError(format!( - "Provider {} is currently unavailable (circuit breaker open)", - provider_name - ))); - } - Ok(()) - } -} diff --git a/src/server/mod.rs b/src/server/mod.rs deleted file mode 100644 index 11a692b8..00000000 --- a/src/server/mod.rs +++ /dev/null @@ -1,482 +0,0 @@ -use axum::{ - Json, Router, - extract::State, - response::IntoResponse, - response::sse::{Event, Sse}, - routing::{get, post}, -}; -use axum::http::{header, HeaderValue}; -use tower_http::{ - limit::RequestBodyLimitLayer, - set_header::SetResponseHeaderLayer, -}; - -use futures::StreamExt; -use std::sync::Arc; -use uuid::Uuid; -use tracing::{info, warn}; - -use crate::{ - auth::AuthenticatedClient, - errors::AppError, - models::{ - ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage, - ChatStreamChoice, ChatStreamDelta, Usage, - }, - rate_limiting, - state::AppState, -}; - -pub fn router(state: AppState) -> Router { - // Security headers - let csp_header: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( - header::CONTENT_SECURITY_POLICY, - "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws:;" - .parse() - .unwrap(), - ); - let x_frame_options: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( - header::X_FRAME_OPTIONS, - "DENY".parse().unwrap(), - ); - let x_content_type_options: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( - header::X_CONTENT_TYPE_OPTIONS, - "nosniff".parse().unwrap(), - ); - let strict_transport_security: SetResponseHeaderLayer = SetResponseHeaderLayer::overriding( - header::STRICT_TRANSPORT_SECURITY, - "max-age=31536000; includeSubDomains".parse().unwrap(), - ); - - Router::new() - .route("/v1/chat/completions", post(chat_completions)) - .route("/v1/models", get(list_models)) - .layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit - .layer(csp_header) - .layer(x_frame_options) - .layer(x_content_type_options) - .layer(strict_transport_security) - .layer(axum::middleware::from_fn_with_state( - state.clone(), - rate_limiting::middleware::rate_limit_middleware, - )) - .with_state(state) -} - - -/// GET /v1/models — OpenAI-compatible model listing. -/// Returns all models from enabled providers so clients like Open WebUI can -/// discover which models are available through the proxy. -async fn list_models( - State(state): State, - _auth: AuthenticatedClient, -) -> Result, AppError> { - let registry = &state.model_registry; - let providers = state.provider_manager.get_all_providers().await; - - let mut models = Vec::new(); - - for provider in &providers { - let provider_name = provider.name(); - - // Map internal provider names to registry provider IDs - let registry_key = match provider_name { - "gemini" => "google", - "grok" => "xai", - _ => provider_name, - }; - - // Find this provider's models in the registry - if let Some(provider_info) = registry.providers.get(registry_key) { - for (model_id, meta) in &provider_info.models { - // Skip disabled models via the config cache - if let Some(cfg) = state.model_config_cache.get(model_id).await { - if !cfg.enabled { - continue; - } - } - - models.push(serde_json::json!({ - "id": model_id, - "object": "model", - "created": 0, - "owned_by": provider_name, - "name": meta.name, - })); - } - } - - // For Ollama, models are configured in the TOML, not the registry - if provider_name == "ollama" { - for model_id in &state.config.providers.ollama.models { - models.push(serde_json::json!({ - "id": model_id, - "object": "model", - "created": 0, - "owned_by": "ollama", - })); - } - } - } - - Ok(Json(serde_json::json!({ - "object": "list", - "data": models - }))) -} - -async fn get_model_cost( - model: &str, - prompt_tokens: u32, - completion_tokens: u32, - cache_read_tokens: u32, - cache_write_tokens: u32, - provider: &Arc, - state: &AppState, -) -> f64 { - // Check in-memory cache for cost overrides (no SQLite hit) - if let Some(cached) = state.model_config_cache.get(model).await { - if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) { - // Manual overrides logic: if cache rates are provided, use cache-aware formula. - // Formula: (non_cached_prompt * input_rate) + (cache_read * read_rate) + (cache_write * write_rate) + (completion * output_rate) - let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens); - let mut total = (non_cached_prompt as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0); - - if let Some(cr) = cached.cache_read_cost_per_m { - total += cache_read_tokens as f64 * cr / 1_000_000.0; - } else { - // No manual cache_read rate — charge cached tokens at full input rate (backwards compatibility) - total += cache_read_tokens as f64 * p / 1_000_000.0; - } - - if let Some(cw) = cached.cache_write_cost_per_m { - total += cache_write_tokens as f64 * cw / 1_000_000.0; - } - - return total; - } - } - - // Fallback to provider's registry-based calculation (cache-aware) - provider.calculate_cost(model, prompt_tokens, completion_tokens, cache_read_tokens, cache_write_tokens, &state.model_registry) -} - -async fn chat_completions( - State(state): State, - auth: AuthenticatedClient, - Json(mut request): Json, -) -> Result { - let client_id = auth.client_id.clone(); - let token = auth.token.clone(); - - // Verify token if env tokens are configured - if !state.auth_tokens.is_empty() && !state.auth_tokens.contains(&token) { - // If not in env tokens, check if it was a DB token (client_id wouldn't be client_XXXX prefix) - if client_id.starts_with("client_") { - return Err(AppError::AuthError("Invalid authentication token".to_string())); - } - } - - let start_time = std::time::Instant::now(); - let model = request.model.clone(); - - info!("Chat completion request from client {} for model {}", client_id, model); - - // Check if model is enabled via in-memory cache (no SQLite hit) - let cached_config = state.model_config_cache.get(&model).await; - - let (model_enabled, model_mapping) = match cached_config { - Some(cfg) => (cfg.enabled, cfg.mapping), - None => (true, None), - }; - - if !model_enabled { - return Err(AppError::ValidationError(format!( - "Model {} is currently disabled", - model - ))); - } - - // Apply mapping if present - if let Some(target_model) = model_mapping { - info!("Mapping model {} to {}", model, target_model); - request.model = target_model; - } - - // Find appropriate provider for the model - let provider = state - .provider_manager - .get_provider_for_model(&request.model) - .await - .ok_or_else(|| AppError::ProviderError(format!("No provider found for model: {}", request.model)))?; - - let provider_name = provider.name().to_string(); - - // Check circuit breaker for this provider - rate_limiting::middleware::circuit_breaker_middleware(&provider_name, &state).await?; - - // Convert to unified request format - let mut unified_request = - crate::models::UnifiedRequest::try_from(request).map_err(|e| AppError::ValidationError(e.to_string()))?; - - // Set client_id from authentication - unified_request.client_id = client_id.clone(); - - // Hydrate images if present - if unified_request.has_images { - unified_request - .hydrate_images() - .await - .map_err(|e| AppError::ValidationError(format!("Failed to process images: {}", e)))?; - } - - let has_images = unified_request.has_images; - - // Measure proxy overhead (time spent before sending to upstream provider) - let proxy_overhead = start_time.elapsed(); - - // Check if streaming is requested - if unified_request.stream { - // Estimate prompt tokens for logging later - let prompt_tokens = crate::utils::tokens::estimate_request_tokens(&model, &unified_request); - - // Handle streaming response - // Allow provider-specific routing for streaming too - let use_responses = provider.name() == "openai" - && crate::utils::registry::model_prefers_responses(&state.model_registry, &unified_request.model); - - let stream_result = if use_responses { - provider.chat_responses_stream(unified_request).await - } else { - provider.chat_completion_stream(unified_request).await - }; - - match stream_result { - Ok(stream) => { - // Record provider success - state.rate_limit_manager.record_provider_success(&provider_name).await; - - info!( - "Streaming started for {} (proxy overhead: {}ms)", - model, - proxy_overhead.as_millis() - ); - - // Wrap with AggregatingStream for token counting and database logging - let aggregating_stream = crate::utils::streaming::AggregatingStream::new( - stream, - crate::utils::streaming::StreamConfig { - client_id: client_id.clone(), - provider: provider.clone(), - model: model.clone(), - prompt_tokens, - has_images, - logger: state.request_logger.clone(), - model_registry: state.model_registry.clone(), - model_config_cache: state.model_config_cache.clone(), - }, - ); - - // Create SSE stream - simpler approach that works - let stream_id = format!("chatcmpl-{}", Uuid::new_v4()); - let stream_created = chrono::Utc::now().timestamp() as u64; - let stream_id_sse = stream_id.clone(); - - // Build stream that yields events wrapped in Result - let stream = async_stream::stream! { - let mut aggregator = Box::pin(aggregating_stream); - let mut first_chunk = true; - - while let Some(chunk_result) = aggregator.next().await { - match chunk_result { - Ok(chunk) => { - let role = if first_chunk { - first_chunk = false; - Some("assistant".to_string()) - } else { - None - }; - - let response = ChatCompletionStreamResponse { - id: stream_id_sse.clone(), - object: "chat.completion.chunk".to_string(), - created: stream_created, - model: chunk.model.clone(), - choices: vec![ChatStreamChoice { - index: 0, - delta: ChatStreamDelta { - role, - content: Some(chunk.content), - reasoning_content: chunk.reasoning_content, - tool_calls: chunk.tool_calls, - }, - finish_reason: chunk.finish_reason, - }], - usage: chunk.usage.as_ref().map(|u| crate::models::Usage { - prompt_tokens: u.prompt_tokens, - completion_tokens: u.completion_tokens, - total_tokens: u.total_tokens, - reasoning_tokens: if u.reasoning_tokens > 0 { Some(u.reasoning_tokens) } else { None }, - cache_read_tokens: if u.cache_read_tokens > 0 { Some(u.cache_read_tokens) } else { None }, - cache_write_tokens: if u.cache_write_tokens > 0 { Some(u.cache_write_tokens) } else { None }, - }), - }; - - // Use axum's Event directly, wrap in Ok - match Event::default().json_data(response) { - Ok(event) => yield Ok::<_, crate::errors::AppError>(event), - Err(e) => { - warn!("Failed to serialize SSE: {}", e); - } - } - } - Err(e) => { - warn!("Stream error: {}", e); - } - } - } - - // Yield [DONE] at the end - yield Ok::<_, crate::errors::AppError>(Event::default().data("[DONE]")); - }; - - Ok(Sse::new(stream).into_response()) - } - Err(e) => { - // Record provider failure - state.rate_limit_manager.record_provider_failure(&provider_name).await; - - // Log failed request - let duration = start_time.elapsed(); - warn!("Streaming request failed after {:?}: {}", duration, e); - - Err(e) - } - } - } else { - // Handle non-streaming response - // Allow provider-specific routing: for OpenAI, some models prefer the - // Responses API (/v1/responses). Use the model registry heuristic to - // choose chat_responses vs chat_completion automatically. - let use_responses = provider.name() == "openai" - && crate::utils::registry::model_prefers_responses(&state.model_registry, &unified_request.model); - - let result = if use_responses { - provider.chat_responses(unified_request).await - } else { - provider.chat_completion(unified_request).await - }; - - match result { - Ok(response) => { - // Record provider success - state.rate_limit_manager.record_provider_success(&provider_name).await; - - let duration = start_time.elapsed(); - let cost = get_model_cost( - &response.model, - response.prompt_tokens, - response.completion_tokens, - response.cache_read_tokens, - response.cache_write_tokens, - &provider, - &state, - ) - .await; - // Log request to database - state.request_logger.log_request(crate::logging::RequestLog { - timestamp: chrono::Utc::now(), - client_id: client_id.clone(), - provider: provider_name.clone(), - model: response.model.clone(), - prompt_tokens: response.prompt_tokens, - completion_tokens: response.completion_tokens, - reasoning_tokens: response.reasoning_tokens, - total_tokens: response.total_tokens, - cache_read_tokens: response.cache_read_tokens, - cache_write_tokens: response.cache_write_tokens, - cost, - has_images, - status: "success".to_string(), - error_message: None, - duration_ms: duration.as_millis() as u64, - }); - - // Convert ProviderResponse to ChatCompletionResponse - let finish_reason = if response.tool_calls.is_some() { - "tool_calls".to_string() - } else { - "stop".to_string() - }; - - let chat_response = ChatCompletionResponse { - id: format!("chatcmpl-{}", Uuid::new_v4()), - object: "chat.completion".to_string(), - created: chrono::Utc::now().timestamp() as u64, - model: response.model, - choices: vec![ChatChoice { - index: 0, - message: ChatMessage { - role: "assistant".to_string(), - content: crate::models::MessageContent::Text { - content: response.content, - }, - reasoning_content: response.reasoning_content, - tool_calls: response.tool_calls, - name: None, - tool_call_id: None, - }, - finish_reason: Some(finish_reason), - }], - usage: Some(Usage { - prompt_tokens: response.prompt_tokens, - completion_tokens: response.completion_tokens, - total_tokens: response.total_tokens, - reasoning_tokens: if response.reasoning_tokens > 0 { Some(response.reasoning_tokens) } else { None }, - cache_read_tokens: if response.cache_read_tokens > 0 { Some(response.cache_read_tokens) } else { None }, - cache_write_tokens: if response.cache_write_tokens > 0 { Some(response.cache_write_tokens) } else { None }, - }), - }; - - // Log successful request with proxy overhead breakdown - let upstream_ms = duration.as_millis() as u64 - proxy_overhead.as_millis() as u64; - info!( - "Request completed in {:?} (proxy: {}ms, upstream: {}ms)", - duration, - proxy_overhead.as_millis(), - upstream_ms - ); - - Ok(Json(chat_response).into_response()) - } - Err(e) => { - // Record provider failure - state.rate_limit_manager.record_provider_failure(&provider_name).await; - - // Log failed request to database - let duration = start_time.elapsed(); - state.request_logger.log_request(crate::logging::RequestLog { - timestamp: chrono::Utc::now(), - client_id: client_id.clone(), - provider: provider_name.clone(), - model: model.clone(), - prompt_tokens: 0, - completion_tokens: 0, - reasoning_tokens: 0, - total_tokens: 0, - cache_read_tokens: 0, - cache_write_tokens: 0, - cost: 0.0, - has_images: false, - status: "error".to_string(), - error_message: Some(e.to_string()), - duration_ms: duration.as_millis() as u64, - }); - - warn!("Request failed after {:?}: {}", duration, e); - - Err(e) - } - } - } -} diff --git a/src/state/mod.rs b/src/state/mod.rs deleted file mode 100644 index 690b528b..00000000 --- a/src/state/mod.rs +++ /dev/null @@ -1,133 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::{broadcast, RwLock}; -use tracing::warn; - -use crate::{ - client::ClientManager, config::AppConfig, database::DbPool, logging::RequestLogger, - models::registry::ModelRegistry, providers::ProviderManager, rate_limiting::RateLimitManager, -}; - -/// Cached model configuration entry -#[derive(Debug, Clone)] -pub struct CachedModelConfig { - pub enabled: bool, - pub mapping: Option, - pub prompt_cost_per_m: Option, - pub completion_cost_per_m: Option, - pub cache_read_cost_per_m: Option, - pub cache_write_cost_per_m: Option, -} - -/// In-memory cache for model_configs table. -/// Refreshes periodically to avoid hitting SQLite on every request. -#[derive(Clone)] -pub struct ModelConfigCache { - cache: Arc>>, - db_pool: DbPool, -} - -impl ModelConfigCache { - pub fn new(db_pool: DbPool) -> Self { - Self { - cache: Arc::new(RwLock::new(HashMap::new())), - db_pool, - } - } - - /// Load all model configs from the database into cache - pub async fn refresh(&self) { - match sqlx::query_as::<_, (String, bool, Option, Option, Option, Option, Option)>( - "SELECT id, enabled, mapping, prompt_cost_per_m, completion_cost_per_m, cache_read_cost_per_m, cache_write_cost_per_m FROM model_configs", - ) - .fetch_all(&self.db_pool) - .await - { - Ok(rows) => { - let mut map = HashMap::with_capacity(rows.len()); - for (id, enabled, mapping, prompt_cost, completion_cost, cache_read_cost, cache_write_cost) in rows { - map.insert( - id, - CachedModelConfig { - enabled, - mapping, - prompt_cost_per_m: prompt_cost, - completion_cost_per_m: completion_cost, - cache_read_cost_per_m: cache_read_cost, - cache_write_cost_per_m: cache_write_cost, - }, - ); - } - *self.cache.write().await = map; - } - Err(e) => { - warn!("Failed to refresh model config cache: {}", e); - } - } - } - - /// Get a cached model config. Returns None if not in cache (model is unconfigured). - pub async fn get(&self, model: &str) -> Option { - self.cache.read().await.get(model).cloned() - } - - /// Invalidate cache — call this after dashboard writes to model_configs - pub async fn invalidate(&self) { - self.refresh().await; - } - - /// Start a background task that refreshes the cache every `interval` seconds - pub fn start_refresh_task(self, interval_secs: u64) { - tokio::spawn(async move { - let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs)); - loop { - interval.tick().await; - self.refresh().await; - } - }); - } -} - -/// Shared application state -#[derive(Clone)] -pub struct AppState { - pub config: Arc, - pub provider_manager: ProviderManager, - pub db_pool: DbPool, - pub rate_limit_manager: Arc, - pub client_manager: Arc, - pub request_logger: Arc, - pub model_registry: Arc, - pub model_config_cache: ModelConfigCache, - pub dashboard_tx: broadcast::Sender, - pub auth_tokens: Vec, -} - -impl AppState { - pub fn new( - config: Arc, - provider_manager: ProviderManager, - db_pool: DbPool, - rate_limit_manager: RateLimitManager, - model_registry: ModelRegistry, - auth_tokens: Vec, - ) -> Self { - let client_manager = Arc::new(ClientManager::new(db_pool.clone())); - let (dashboard_tx, _) = broadcast::channel(100); - let request_logger = Arc::new(RequestLogger::new(db_pool.clone(), dashboard_tx.clone())); - let model_config_cache = ModelConfigCache::new(db_pool.clone()); - - Self { - config, - provider_manager, - db_pool, - rate_limit_manager: Arc::new(rate_limit_manager), - client_manager, - request_logger, - model_registry: Arc::new(model_registry), - model_config_cache, - dashboard_tx, - auth_tokens, - } - } -} diff --git a/src/utils/crypto.rs b/src/utils/crypto.rs deleted file mode 100644 index 7f7740c2..00000000 --- a/src/utils/crypto.rs +++ /dev/null @@ -1,171 +0,0 @@ -use aes_gcm::{ - aead::{Aead, AeadCore, KeyInit, OsRng}, - Aes256Gcm, Key, Nonce, -}; -use anyhow::{anyhow, Context, Result}; -use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; -use std::env; -use std::sync::OnceLock; - -static RAW_KEY: OnceLock<[u8; 32]> = OnceLock::new(); - -/// Initialize the encryption key from a hex or base64 encoded string. -/// Must be called before any encryption/decryption operations. -/// Returns error if the key is invalid or already initialized with a different key. -pub fn init_with_key(key_str: &str) -> Result<()> { - let key_bytes = hex::decode(key_str) - .or_else(|_| BASE64.decode(key_str)) - .context("Encryption key must be hex or base64 encoded")?; - if key_bytes.len() != 32 { - anyhow::bail!( - "Encryption key must be 32 bytes (256 bits), got {} bytes", - key_bytes.len() - ); - } - let key_array: [u8; 32] = key_bytes.try_into().unwrap(); // safe due to length check - // Check if already initialized with same key - if let Some(existing) = RAW_KEY.get() { - if existing == &key_array { - // Same key already initialized, okay - return Ok(()); - } else { - anyhow::bail!("Encryption key already initialized with a different key"); - } - } - // Store raw key bytes - RAW_KEY - .set(key_array) - .map_err(|_| anyhow::anyhow!("Encryption key already initialized"))?; - Ok(()) -} - -/// Initialize the encryption key from the environment variable `LLM_PROXY__ENCRYPTION_KEY`. -/// Must be called before any encryption/decryption operations. -/// Panics if the environment variable is missing or invalid. -pub fn init_from_env() -> Result<()> { - let key_str = - env::var("LLM_PROXY__ENCRYPTION_KEY").context("LLM_PROXY__ENCRYPTION_KEY environment variable not set")?; - init_with_key(&key_str) -} - -/// Get the encryption key bytes, panicking if not initialized. -fn get_key() -> &'static [u8; 32] { - RAW_KEY - .get() - .expect("Encryption key not initialized. Call crypto::init_with_key() or crypto::init_from_env() first.") -} - -/// Encrypt a plaintext string and return a base64-encoded ciphertext (nonce || ciphertext || tag). -pub fn encrypt(plaintext: &str) -> Result { - let key = Key::::from_slice(get_key()); - let cipher = Aes256Gcm::new(key); - let nonce = Aes256Gcm::generate_nonce(&mut OsRng); // 12 bytes - let ciphertext = cipher - .encrypt(&nonce, plaintext.as_bytes()) - .map_err(|e| anyhow!("Encryption failed: {}", e))?; - // Combine nonce and ciphertext (ciphertext already includes tag) - let mut combined = Vec::with_capacity(nonce.len() + ciphertext.len()); - combined.extend_from_slice(&nonce); - combined.extend_from_slice(&ciphertext); - Ok(BASE64.encode(combined)) -} - -/// Decrypt a base64-encoded ciphertext (nonce || ciphertext || tag) to a plaintext string. -pub fn decrypt(ciphertext_b64: &str) -> Result { - let key = Key::::from_slice(get_key()); - let cipher = Aes256Gcm::new(key); - let combined = BASE64.decode(ciphertext_b64).context("Invalid base64 ciphertext")?; - if combined.len() < 12 { - anyhow::bail!("Ciphertext too short"); - } - let (nonce_bytes, ciphertext_and_tag) = combined.split_at(12); - let nonce = Nonce::from_slice(nonce_bytes); - let plaintext_bytes = cipher - .decrypt(nonce, ciphertext_and_tag) - .map_err(|e| anyhow!("Decryption failed (invalid key or corrupted ciphertext): {}", e))?; - String::from_utf8(plaintext_bytes).context("Decrypted bytes are not valid UTF-8") -} - -#[cfg(test)] -mod tests { - use super::*; - - const TEST_KEY_HEX: &str = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"; - - #[test] - fn test_encrypt_decrypt() { - init_with_key(TEST_KEY_HEX).unwrap(); - let plaintext = "super secret api key"; - let ciphertext = encrypt(plaintext).unwrap(); - assert_ne!(ciphertext, plaintext); - let decrypted = decrypt(&ciphertext).unwrap(); - assert_eq!(decrypted, plaintext); - } - - #[test] - fn test_different_inputs_produce_different_ciphertexts() { - init_with_key(TEST_KEY_HEX).unwrap(); - let plaintext = "same"; - let cipher1 = encrypt(plaintext).unwrap(); - let cipher2 = encrypt(plaintext).unwrap(); - assert_ne!(cipher1, cipher2, "Nonce should make ciphertexts differ"); - assert_eq!(decrypt(&cipher1).unwrap(), plaintext); - assert_eq!(decrypt(&cipher2).unwrap(), plaintext); - } - - #[test] - fn test_invalid_key_length() { - let result = init_with_key("tooshort"); - assert!(result.is_err()); - } - - #[test] - fn test_init_from_env() { - unsafe { std::env::set_var("LLM_PROXY__ENCRYPTION_KEY", TEST_KEY_HEX) }; - let result = init_from_env(); - assert!(result.is_ok()); - // Ensure encryption works - let ciphertext = encrypt("test").unwrap(); - let decrypted = decrypt(&ciphertext).unwrap(); - assert_eq!(decrypted, "test"); - } - - #[test] - fn test_missing_env_key() { - unsafe { std::env::remove_var("LLM_PROXY__ENCRYPTION_KEY") }; - let result = init_from_env(); - assert!(result.is_err()); - } - - #[test] - fn test_key_hex_and_base64() { - // Hex key works - init_with_key(TEST_KEY_HEX).unwrap(); - // Base64 key (same bytes encoded as base64) - let base64_key = BASE64.encode(hex::decode(TEST_KEY_HEX).unwrap()); - // Re-initialization with same key (different encoding) is allowed - let result = init_with_key(&base64_key); - assert!(result.is_ok()); - // Encryption should still work - let ciphertext = encrypt("test").unwrap(); - let decrypted = decrypt(&ciphertext).unwrap(); - assert_eq!(decrypted, "test"); - } - - #[test] - #[ignore] // conflicts with global state from other tests - fn test_already_initialized() { - init_with_key(TEST_KEY_HEX).unwrap(); - let result = init_with_key(TEST_KEY_HEX); - assert!(result.is_ok()); // same key allowed - } - - #[test] - #[ignore] // conflicts with global state from other tests - fn test_already_initialized_different_key() { - init_with_key(TEST_KEY_HEX).unwrap(); - let different_key = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e20"; - let result = init_with_key(different_key); - assert!(result.is_err()); - } -} diff --git a/src/utils/mod.rs b/src/utils/mod.rs deleted file mode 100644 index 6e09e6c9..00000000 --- a/src/utils/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod crypto; -pub mod registry; -pub mod streaming; -pub mod tokens; diff --git a/src/utils/registry.rs b/src/utils/registry.rs deleted file mode 100644 index 1bf9e1cf..00000000 --- a/src/utils/registry.rs +++ /dev/null @@ -1,49 +0,0 @@ -use crate::models::registry::ModelRegistry; -use anyhow::Result; -use tracing::info; - -const MODELS_DEV_URL: &str = "https://models.dev/api.json"; - -pub async fn fetch_registry() -> Result { - info!("Fetching model registry from {}", MODELS_DEV_URL); - - let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(10)) - .build()?; - - let response = client.get(MODELS_DEV_URL).send().await?; - - if !response.status().is_success() { - return Err(anyhow::anyhow!("Failed to fetch registry: HTTP {}", response.status())); - } - - let registry: ModelRegistry = response.json().await?; - info!("Successfully loaded model registry"); - - Ok(registry) -} - -/// Heuristic: decide whether a model should be routed to OpenAI Responses API -/// instead of the legacy chat/completions endpoint. -/// -/// Currently this uses simple patterns (codex, gpt-5 series) and also checks -/// the loaded registry metadata name for the substring "codex" as a hint. -pub fn model_prefers_responses(registry: &ModelRegistry, model: &str) -> bool { - let model_lc = model.to_lowercase(); - - if model_lc.contains("codex") { - return true; - } - - if model_lc.starts_with("gpt-5") || model_lc.contains("gpt-5.") { - return true; - } - - if let Some(meta) = registry.find_model(model) { - if meta.name.to_lowercase().contains("codex") { - return true; - } - } - - false -} diff --git a/src/utils/streaming.rs b/src/utils/streaming.rs deleted file mode 100644 index 2885c9ef..00000000 --- a/src/utils/streaming.rs +++ /dev/null @@ -1,340 +0,0 @@ - -use crate::errors::AppError; -use crate::logging::{RequestLog, RequestLogger}; -use crate::models::ToolCall; -use crate::providers::{Provider, ProviderStreamChunk, StreamUsage}; -use crate::state::ModelConfigCache; -use crate::utils::tokens::estimate_completion_tokens; -use futures::stream::Stream; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -/// Configuration for creating an AggregatingStream. -pub struct StreamConfig { - pub client_id: String, - pub provider: Arc, - pub model: String, - pub prompt_tokens: u32, - pub has_images: bool, - pub logger: Arc, - pub model_registry: Arc, - pub model_config_cache: ModelConfigCache, -} - -pub struct AggregatingStream { - inner: S, - client_id: String, - provider: Arc, - model: String, - prompt_tokens: u32, - has_images: bool, - accumulated_content: String, - accumulated_reasoning: String, - accumulated_tool_calls: Vec, - /// Real usage data from the provider's final stream chunk (when available). - real_usage: Option, - logger: Arc, - model_registry: Arc, - model_config_cache: ModelConfigCache, - start_time: std::time::Instant, - has_logged: bool, -} - -impl AggregatingStream -where - S: Stream> + Unpin, -{ - pub fn new(inner: S, config: StreamConfig) -> Self { - Self { - inner, - client_id: config.client_id, - provider: config.provider, - model: config.model, - prompt_tokens: config.prompt_tokens, - has_images: config.has_images, - accumulated_content: String::new(), - accumulated_reasoning: String::new(), - accumulated_tool_calls: Vec::new(), - real_usage: None, - logger: config.logger, - model_registry: config.model_registry, - model_config_cache: config.model_config_cache, - start_time: std::time::Instant::now(), - has_logged: false, - } - } - - fn finalize(&mut self) { - if self.has_logged { - return; - } - self.has_logged = true; - - let duration = self.start_time.elapsed(); - let client_id = self.client_id.clone(); - let provider_name = self.provider.name().to_string(); - let model = self.model.clone(); - let logger = self.logger.clone(); - let provider = self.provider.clone(); - let estimated_prompt_tokens = self.prompt_tokens; - let has_images = self.has_images; - let registry = self.model_registry.clone(); - let config_cache = self.model_config_cache.clone(); - let real_usage = self.real_usage.take(); - - // Estimate completion tokens (including reasoning if present) - let estimated_content_tokens = estimate_completion_tokens(&self.accumulated_content, &model); - let estimated_reasoning_tokens = if !self.accumulated_reasoning.is_empty() { - estimate_completion_tokens(&self.accumulated_reasoning, &model) - } else { - 0 - }; - - let estimated_completion = estimated_content_tokens + estimated_reasoning_tokens; - - // Spawn a background task to log the completion - tokio::spawn(async move { - // Use real usage from the provider when available, otherwise fall back to estimates - let (prompt_tokens, completion_tokens, reasoning_tokens, total_tokens, cache_read_tokens, cache_write_tokens) = - if let Some(usage) = &real_usage { - ( - usage.prompt_tokens, - usage.completion_tokens, - usage.reasoning_tokens, - usage.total_tokens, - usage.cache_read_tokens, - usage.cache_write_tokens, - ) - } else { - ( - estimated_prompt_tokens, - estimated_completion, - estimated_reasoning_tokens, - estimated_prompt_tokens + estimated_completion, - 0u32, - 0u32, - ) - }; - - // Check in-memory cache for cost overrides (no SQLite hit) - let cost = if let Some(cached) = config_cache.get(&model).await { - if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) { - // Manual overrides logic: if cache rates are provided, use cache-aware formula. - let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens); - let mut total = (non_cached_prompt as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0); - - if let Some(cr) = cached.cache_read_cost_per_m { - total += cache_read_tokens as f64 * cr / 1_000_000.0; - } else { - // Charge cached tokens at full input rate if no specific rate provided - total += cache_read_tokens as f64 * p / 1_000_000.0; - } - - if let Some(cw) = cached.cache_write_cost_per_m { - total += cache_write_tokens as f64 * cw / 1_000_000.0; - } - - total - } else { - provider.calculate_cost( - &model, - prompt_tokens, - completion_tokens, - cache_read_tokens, - cache_write_tokens, - ®istry, - ) - } - } else { - provider.calculate_cost( - &model, - prompt_tokens, - completion_tokens, - cache_read_tokens, - cache_write_tokens, - ®istry, - ) - }; - - // Log to database - logger.log_request(RequestLog { - timestamp: chrono::Utc::now(), - client_id: client_id.clone(), - provider: provider_name, - model, - prompt_tokens, - completion_tokens, - reasoning_tokens, - total_tokens, - cache_read_tokens, - cache_write_tokens, - cost, - has_images, - status: "success".to_string(), - error_message: None, - duration_ms: duration.as_millis() as u64, - }); - }); - } -} - -impl Stream for AggregatingStream -where - S: Stream> + Unpin, -{ - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let result = Pin::new(&mut self.inner).poll_next(cx); - - match &result { - Poll::Ready(Some(Ok(chunk))) => { - self.accumulated_content.push_str(&chunk.content); - if let Some(reasoning) = &chunk.reasoning_content { - self.accumulated_reasoning.push_str(reasoning); - } - // Capture real usage from the provider when present (typically on the final chunk) - if let Some(usage) = &chunk.usage { - self.real_usage = Some(usage.clone()); - } - // Accumulate tool call deltas into complete tool calls - if let Some(deltas) = &chunk.tool_calls { - for delta in deltas { - let idx = delta.index as usize; - // Grow the accumulated_tool_calls vec if needed - while self.accumulated_tool_calls.len() <= idx { - self.accumulated_tool_calls.push(ToolCall { - id: String::new(), - call_type: "function".to_string(), - function: crate::models::FunctionCall { - name: String::new(), - arguments: String::new(), - }, - }); - } - let tc = &mut self.accumulated_tool_calls[idx]; - if let Some(id) = &delta.id { - tc.id.clone_from(id); - } - if let Some(ct) = &delta.call_type { - tc.call_type.clone_from(ct); - } - if let Some(f) = &delta.function { - if let Some(name) = &f.name { - tc.function.name.push_str(name); - } - if let Some(args) = &f.arguments { - tc.function.arguments.push_str(args); - } - } - } - } - } - Poll::Ready(Some(Err(_))) => { - // If there's an error, we might still want to log what we got so far? - // For now, just finalize if we have content - if !self.accumulated_content.is_empty() { - self.finalize(); - } - } - Poll::Ready(None) => { - self.finalize(); - } - Poll::Pending => {} - } - - result - } -} - -#[cfg(test)] -mod tests { - use super::*; - use anyhow::Result; - use futures::stream::{self, StreamExt}; - - // Simple mock provider for testing - struct MockProvider; - #[async_trait::async_trait] - impl Provider for MockProvider { - fn name(&self) -> &str { - "mock" - } - fn supports_model(&self, _model: &str) -> bool { - true - } - fn supports_multimodal(&self) -> bool { - false - } - async fn chat_completion( - &self, - _req: crate::models::UnifiedRequest, - ) -> Result { - unimplemented!() - } - async fn chat_completion_stream( - &self, - _req: crate::models::UnifiedRequest, - ) -> Result>, AppError> { - unimplemented!() - } - fn estimate_tokens(&self, _req: &crate::models::UnifiedRequest) -> Result { - Ok(10) - } - fn calculate_cost(&self, _model: &str, _p: u32, _c: u32, _cr: u32, _cw: u32, _r: &crate::models::registry::ModelRegistry) -> f64 { - 0.05 - } - } - - #[tokio::test] - async fn test_aggregating_stream() { - let chunks = vec![ - Ok(ProviderStreamChunk { - content: "Hello".to_string(), - reasoning_content: None, - finish_reason: None, - tool_calls: None, - model: "test".to_string(), - usage: None, - }), - Ok(ProviderStreamChunk { - content: " World".to_string(), - reasoning_content: None, - finish_reason: Some("stop".to_string()), - tool_calls: None, - model: "test".to_string(), - usage: None, - }), - ]; - let inner_stream = stream::iter(chunks); - - let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); - let (dashboard_tx, _) = tokio::sync::broadcast::channel(16); - let logger = Arc::new(RequestLogger::new(pool.clone(), dashboard_tx)); - let registry = Arc::new(crate::models::registry::ModelRegistry { - providers: std::collections::HashMap::new(), - }); - - let mut agg_stream = AggregatingStream::new( - inner_stream, - StreamConfig { - client_id: "client_1".to_string(), - provider: Arc::new(MockProvider), - model: "test".to_string(), - prompt_tokens: 10, - has_images: false, - logger, - model_registry: registry, - model_config_cache: ModelConfigCache::new(pool.clone()), - }, - ); - - while let Some(item) = agg_stream.next().await { - assert!(item.is_ok()); - } - - assert_eq!(agg_stream.accumulated_content, "Hello World"); - assert!(agg_stream.has_logged); - } -} diff --git a/src/utils/tokens.rs b/src/utils/tokens.rs deleted file mode 100644 index 469ab02a..00000000 --- a/src/utils/tokens.rs +++ /dev/null @@ -1,69 +0,0 @@ -use crate::models::UnifiedRequest; -use tiktoken_rs::get_bpe_from_model; - -/// Count tokens for a given model and text -pub fn count_tokens(model: &str, text: &str) -> u32 { - // If we can't get the bpe for the model, fallback to a safe default (cl100k_base for GPT-4/o1) - let bpe = get_bpe_from_model(model) - .unwrap_or_else(|_| tiktoken_rs::cl100k_base().expect("Failed to get cl100k_base encoding")); - - bpe.encode_with_special_tokens(text).len() as u32 -} - -/// Estimate tokens for a unified request. -/// Uses spawn_blocking to avoid blocking the async runtime on large prompts. -pub fn estimate_request_tokens(model: &str, request: &UnifiedRequest) -> u32 { - let mut total_text = String::new(); - let msg_count = request.messages.len(); - - // Base tokens per message for OpenAI (approximate) - let tokens_per_message: u32 = 3; - - for msg in &request.messages { - for part in &msg.content { - match part { - crate::models::ContentPart::Text { text } => { - total_text.push_str(text); - total_text.push('\n'); - } - crate::models::ContentPart::Image { .. } => { - // Vision models usually have a fixed cost or calculation based on size - } - } - } - } - - // Quick heuristic for small inputs (< 1KB) — avoid spawn_blocking overhead - if total_text.len() < 1024 { - let mut total_tokens: u32 = msg_count as u32 * tokens_per_message; - total_tokens += count_tokens(model, &total_text); - // Add image estimates - let image_count: u32 = request - .messages - .iter() - .flat_map(|m| m.content.iter()) - .filter(|p| matches!(p, crate::models::ContentPart::Image { .. })) - .count() as u32; - total_tokens += image_count * 1000; - total_tokens += 3; // assistant reply header - return total_tokens; - } - - // For large inputs, use the fast heuristic (chars / 4) to avoid blocking - // the async runtime. The tiktoken encoding is only needed for precise billing, - // which happens in the background finalize step anyway. - let estimated_text_tokens = (total_text.len() as u32) / 4; - let image_count: u32 = request - .messages - .iter() - .flat_map(|m| m.content.iter()) - .filter(|p| matches!(p, crate::models::ContentPart::Image { .. })) - .count() as u32; - - (msg_count as u32 * tokens_per_message) + estimated_text_tokens + (image_count * 1000) + 3 -} - -/// Estimate tokens for completion text -pub fn estimate_completion_tokens(text: &str, model: &str) -> u32 { - count_tokens(model, text) -} diff --git a/test_dashboard.sh b/test_dashboard.sh deleted file mode 100755 index 5a6afa94..00000000 --- a/test_dashboard.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/bin/bash - -# Test script for LLM Proxy Dashboard - -echo "Building LLM Proxy Gateway..." -cargo build --release - -echo "" -echo "Starting server in background..." -./target/release/llm-proxy & -SERVER_PID=$! - -# Wait for server to start -sleep 3 - -echo "" -echo "Testing dashboard endpoints..." - -# Test health endpoint -echo "1. Testing health endpoint:" -curl -s http://localhost:8080/health - -echo "" -echo "2. Testing dashboard static files:" -curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/ - -echo "" -echo "3. Testing API endpoints:" -curl -s http://localhost:8080/api/auth/status | jq . 2>/dev/null || echo "JSON response received" - -echo "" -echo "Dashboard should be available at: http://localhost:8080" -echo "Default login: admin / admin" -echo "" -echo "Press Ctrl+C to stop the server" - -# Keep script running -wait $SERVER_PID \ No newline at end of file diff --git a/test_server.sh b/test_server.sh deleted file mode 100755 index 8fe73f77..00000000 --- a/test_server.sh +++ /dev/null @@ -1,75 +0,0 @@ -#!/bin/bash - -# Test script for LLM Proxy Gateway - -echo "Building LLM Proxy Gateway..." -cargo build --release - -if [ $? -ne 0 ]; then - echo "Build failed!" - exit 1 -fi - -echo "Build successful!" - -echo "" -echo "Project Structure Summary:" -echo "==========================" -echo "Core Components:" -echo " - main.rs: Application entry point with server setup" -echo " - config/: Configuration management" -echo " - server/: API route handlers" -echo " - auth/: Bearer token authentication" -echo " - database/: SQLite database setup" -echo " - models/: Data structures (OpenAI-compatible)" -echo " - providers/: LLM provider implementations (OpenAI, Gemini, DeepSeek, Grok)" -echo " - errors/: Custom error types" -echo " - dashboard/: Admin dashboard with WebSocket support" -echo " - logging/: Request logging middleware" -echo " - state/: Shared application state" -echo " - multimodal/: Image processing support (basic structure)" - -echo "" -echo "Key Features Implemented:" -echo "==========================" -echo "✓ OpenAI-compatible API endpoint (/v1/chat/completions)" -echo "✓ Bearer token authentication" -echo "✓ SQLite database for request tracking" -echo "✓ Request logging with token/cost calculation" -echo "✓ Provider abstraction layer" -echo "✓ Admin dashboard with real-time monitoring" -echo "✓ WebSocket support for live updates" -echo "✓ Configuration management (config.toml, .env, env vars)" -echo "✓ Multimodal support structure (images)" -echo "✓ Error handling with proper HTTP status codes" - -echo "" -echo "Next Steps Needed:" -echo "==================" -echo "1. Add API keys to .env file:" -echo " OPENAI_API_KEY=your_key_here" -echo " GEMINI_API_KEY=your_key_here" -echo " DEEPSEEK_API_KEY=your_key_here" -echo " GROK_API_KEY=your_key_here (optional)" -echo "" -echo "2. Create config.toml for custom configuration (optional)" -echo "" -echo "3. Run the server:" -echo " cargo run" -echo "" -echo "4. Access dashboard at: http://localhost:8080" -echo "" -echo "5. Test API with curl:" -echo " curl -X POST http://localhost:8080/v1/chat/completions \\" -echo " -H 'Authorization: Bearer your_token' \\" -echo " -H 'Content-Type: application/json' \\" -echo " -d '{\"model\": \"gpt-4\", \"messages\": [{\"role\": \"user\", \"content\": \"Hello\"}]}'" - -echo "" -echo "Deployment Notes:" -echo "=================" -echo "Memory: Designed for 512MB RAM (LXC container)" -echo "Database: SQLite (./data/llm_proxy.db)" -echo "Port: 8080 (configurable)" -echo "Authentication: Single Bearer token (configurable)" -echo "Providers: OpenAI, Gemini, DeepSeek, Grok (disabled by default)" \ No newline at end of file