Compare commits

...

54 Commits

Author SHA1 Message Date
9375448087 fix(moonshot): resolve 401 Unauthorized errors by trimming API keys and improving request compatibility
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-03-26 17:09:27 +00:00
5be2f6f7aa fix: use Moonshot test model
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-03-26 10:12:44 -04:00
eebcadcba1 fix: surface moonshot on providers page
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-03-25 09:35:41 -04:00
6b2bd13903 chore: remove tracked binary gophergate
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-03-25 13:32:51 +00:00
5dfda0a10c merge: resolve conflicts in server.go and integrate moonshot support 2026-03-25 13:32:40 +00:00
a8a02d9e1c feat: add moonshot kimi k2.5 support
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-03-25 09:28:52 -04:00
bd1d17cc4d feat: add moonshot kimi k2.5 support
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-03-25 09:27:46 -04:00
9207a7231c chore: update all grok-2 references to grok-4
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-03-25 13:17:06 +00:00
c6efff9034 fix: update grok test model to grok-4-1-fast-non-reasoning
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
2026-03-25 13:14:31 +00:00
27fbd8ed15 chore: cleanup repository and update gitignore
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Removed binary 'gophergate'
- Removed 'data/llm_proxy.db' from source control (kept locally)
- Removed old database backups in 'data/backups/'
- Updated .gitignore to exclude data directory and gophergate binary
2026-03-25 13:08:33 +00:00
348341f304 fix: prioritize database provider configs and implement API key encryption
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Added AES-GCM encryption/decryption for provider API keys in the database.
- Implemented RefreshProviders to load provider configs from the database with precedence over environment variables.
- Updated dashboard handlers to encrypt keys on save and trigger in-memory provider refresh.
- Updated Grok test model to grok-3-mini for better compatibility.
2026-03-25 13:04:26 +00:00
9380580504 fix: resolve dashboard websocket 'disconnected' status
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Fixed status indicator UI mapping in websocket.js and index.html.
- Added missing CSS for connection status indicator and pulse animation.
- Made initial model registry fetch asynchronous to prevent blocking server startup.
- Improved configuration loading to correctly handle LLM_PROXY__SERVER__PORT from environment.
2026-03-19 14:32:34 -04:00
08cf5cc1d9 fix: improve cost tracking accuracy for modern models
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
- Added support for reasoning tokens in cost calculations.
- Fixed DeepSeek cache-write token mapping (PromptCacheMissTokens).
- Improved CalculateCost debug logging to trace all pricing variables.
2026-03-19 14:14:54 -04:00
0f0486d8d4 fix: resolve user dashboard field mapping and session consistency
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Added JSON tags to the User struct to match frontend expectations and excluded sensitive fields.
Updated session management to include and persist DisplayName.
Unified user field names (using display_name) across backend, sessions, and frontend UI.
2026-03-19 14:01:59 -04:00
0ea2a3a985 fix: improve provider and model data accuracy in dashboard
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Updated handleGetProviders to include available models and last-used timestamps. Refined Model Pricing table to strictly filter by core providers and actual usage.
2026-03-19 13:51:46 -04:00
21e5908c35 fix: resolve sidebar overlap and top-bar layout
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Added padding-left to main-content and implemented missing top-bar and content-body styles to ensure correct layout with fixed sidebar.
2026-03-19 13:48:24 -04:00
6f0a159245 fix: resolve login visibility issues and improve sidebar layout
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Corrected element ID mismatches between index.html and auth.js. Improved sidebar CSS to handle collapsed state and logo visibility correctly.
2026-03-19 13:45:55 -04:00
4120a83b67 fix: correct login button selector in auth.js
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Changed querySelector('.login-btn') to getElementById('login-btn') to match index.html.
2026-03-19 13:43:02 -04:00
742cd9e921 fix: resolve login button TypeError and add favicon
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Added id='login-btn' to index.html and created a placeholder favicon.ico.
2026-03-19 13:41:58 -04:00
593971ecb5 fix: resolve TypeError in login error display
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Added missing span element to login-error div to ensure compatibility with auth.js.
2026-03-19 13:39:51 -04:00
03dca998df chore: rebrand project to GopherGate
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Updated all naming from LLM Proxy to GopherGate. Implemented new CSS-based branding and updated Go module/binary naming.
2026-03-19 13:37:05 -04:00
0ce5f4f490 docs: finalize documentation for Go migration
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Updated README, architecture, and TODO to reflect full feature parity, system metrics, and registry integration.
2026-03-19 13:26:31 -04:00
dec4b927dc feat: implement system metrics and fix monitoring charts
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Added /api/system/metrics with CPU/Mem/Disk/Load data using gopsutil. Updated Hub to track active WebSocket listeners. Verified log format for monitoring charts.
2026-03-19 13:15:48 -04:00
3f1e6d3407 fix: restrict Model Pricing table to core providers and actual usage
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Filtered registry iteration to only include openai, gemini, deepseek, and grok. Improved used_only logic to match specific (model, provider) pairs from logs.
2026-03-19 13:10:50 -04:00
f02fd6c249 fix: normalize provider names in model pricing table
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Mapped registry provider IDs (google, xai) to proxy-internal names (gemini, grok) for better dashboard consistency.
2026-03-19 13:06:52 -04:00
f23796f0cc fix: restrict Model Pricing table to used models and fix cost stats
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Implemented used_only filter for /api/models. Added missing cache token and cost fields to usage summary and provider usage endpoints.
2026-03-19 13:02:45 -04:00
3f76a544e0 fix: improve analytics accuracy and cost calculation
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Refined CalculateCost to correctly handle cached token discounts. Added fuzzy matching to model lookup. Robustified SQL date extraction using SUBSTR and LIKE for better SQLite compatibility.
2026-03-19 12:58:08 -04:00
e474549940 fix: resolve zero-time dashboard display and improve SQL robustness
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Fixed '2025 years ago' issue by correctly handling zero-value timestamps. Improved SQL scanning logic to handle NULL values more safely across all analytics handlers.
2026-03-19 12:42:41 -04:00
b7e37b0399 fix: resolve dashboard SQL scan errors and 401 noise
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Robustified all analytics queries to handle empty datasets and NULL values. Restricted AuthMiddleware to /v1 group only.
2026-03-19 12:39:48 -04:00
263c0f0dc9 fix: resolve dashboard 401 and 500 errors
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Restricted AuthMiddleware to /v1 group to prevent dashboard session interference. Robustified analytics SQL queries with COALESCE to handle empty datasets.
2026-03-19 12:35:14 -04:00
26d8431998 feat: implement /api/usage/clients endpoint
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Added client-specific usage aggregation for the analytics dashboard.
2026-03-19 12:31:11 -04:00
1f3adceda4 fix: robustify analytics handlers and fix auth middleware scope
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Moved AuthMiddleware to /v1 group only. Added COALESCE and empty result handling to analytics SQL queries to prevent 500 errors on empty databases.
2026-03-19 12:28:56 -04:00
9c64a8fe42 fix: restore analytics page functionality
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Implemented missing /api/usage/detailed endpoint and ensured analytics breakdown and time-series return data in the expected format.
2026-03-19 12:24:58 -04:00
b04b794705 fix: restore clients page functionality
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Updated handleGetClients to return UI-compatible data format and implemented handleGetClient/handleUpdateClient endpoints.
2026-03-19 12:06:52 -04:00
0f3c5b6eb4 feat: enhance usage and cost tracking accuracy
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Improved extraction of reasoning and cached tokens from OpenAI and DeepSeek responses (including streams). Ensured accurate cost calculation using registry metadata.
2026-03-19 11:56:26 -04:00
66a1643bca chore: filter /v1/models to allowed providers
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Restricted model listing to OpenAI, Google (Gemini), DeepSeek, and xAI (Grok) to match available access.
2026-03-19 11:33:47 -04:00
edc6445d70 feat: implement /v1/models endpoint
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Added OpenAI-compatible model listing endpoint using the registry data.
2026-03-19 11:31:26 -04:00
2d8f1a1fd0 chore: use newest cheap models for provider tests
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Updated OpenAI test model to gpt-4o-mini and verified Gemini is using gemini-2.0-flash.
2026-03-19 11:27:12 -04:00
cd1a1b45aa fix: restore models page functionality
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Updated handleGetModels to merge registry data with DB overrides and implemented handleUpdateModel. Verified API response format matches frontend requirements.
2026-03-19 11:26:13 -04:00
246a6d88f0 fix: update grok default model to grok-2
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Changed grok-beta to grok-2 across backend config, dashboard tests, and frontend monitoring.
2026-03-19 11:23:56 -04:00
7d43b2c31b fix: restore default admin password and add reset flag
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Restored 'admin123' as the default password in db init and added a -reset-admin flag to main.go.
2026-03-19 11:22:11 -04:00
45c2d5e643 fix: implement provider test endpoint and fix static asset routing
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Added handleTestProvider to dashboard and verified static file mapping for /css, /js, and /img.
2026-03-19 11:19:20 -04:00
1d032c6732 feat: complete dashboard API migration
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Implemented missing system, analytics, and auth endpoints. Verified parity with frontend expectations.
2026-03-19 11:14:28 -04:00
2245cca67a fix: correct static file routing for dashboard assets
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Mapped /css, /js, and /img to their respective subdirectories in ./static to resolve 404 errors.
2026-03-19 11:07:29 -04:00
c7c244992a fix: ensure LLM_PROXY__ENCRYPTION_KEY is correctly loaded from env
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Explicitly bound the encryption_key to handle the double underscore convention in Viper.
2026-03-19 11:04:57 -04:00
4f5b55d40f chore: remove obsolete files and update CI to Go
Some checks failed
CI / Lint (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Build (push) Has been cancelled
Removed old Rust-era documentation, scripts, and migrations. Updated GitHub Actions workflow to use Go 1.22.
2026-03-19 10:46:23 -04:00
90874a6721 chore: consolidate env files and update gitignore
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled
Removed .env and .env.backup from git tracking and consolidated configuration into .env.example. Updated .gitignore to robustly prevent accidental inclusion of sensitive files.
2026-03-19 10:44:22 -04:00
6b10d4249c feat: migrate backend from rust to go
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled
This commit replaces the Axum/Rust backend with a Gin/Go implementation. The original Rust code has been archived in the 'rust' branch.
2026-03-19 10:30:05 -04:00
57aa0aa70e fix(openai): unify tool call indexing for both standard and embedded calls
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled
- Sequential next_tool_index is now used for both Responses API 'function_call' events and the proxy's 'tool_uses' JSON extraction.
- This ensures tool_calls arrays in the stream always start at index 0 and are dense, even if standard and embedded calls were somehow mixed.
- Fixed 'payload_idx' logic to correctly align argument chunks with their initialization chunks.
2026-03-18 18:31:24 +00:00
4de457cc5e fix(openai): correctly map tool_call indexes in Responses API stream
- The OpenAI Responses API uses 'output_index' to identify items in the response.
- If a response starts with text (output_index 0) followed by a tool call (output_index 1), the standard Chat Completions streaming format requires the first tool call to have index 0.
- Previously, the proxy was passing output_index (1) as the tool_call index, causing client-side SDKs to fail parsing the stream and silently drop the tool calls.
- Implemented a local mapping within the stream to ensure tool_call indexes are always dense and start at 0.
2026-03-18 18:26:27 +00:00
66e8b114b9 fix(openai): split embedded tool_calls into standard chunk format
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled
- Standard OpenAI clients expect tool_calls to be streamed as two parts:
  1. Initialization chunk containing 'id', 'type', and 'name', with empty 'arguments'.
  2. Payload chunk(s) containing 'arguments', with 'id', 'type', and 'name' omitted.
- Previously, the proxy was yielding all fields in a single chunk when parsing the custom 'tool_uses' JSON from gpt-5.4, causing strict clients like opencode to fail silently when delegating parallel tasks.
- The proxy now splits the extracted JSON into the correct two-chunk sequence, restoring subagent compatibility.
2026-03-18 18:05:37 +00:00
1cac45502a fix(openai): fix stream whitespace loss and finish_reason for gpt-5.4
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled
- Remove overzealous .trim() in strip_internal_metadata which destroyed whitespace between text stream chunks, causing client hangs
- Fix finish_reason logic to only yield once at the end of the stream
- Correctly yield finish_reason: 'tool_calls' instead of 'stop' when tool calls are generated
2026-03-18 17:48:55 +00:00
79dc8fe409 fix(openai): correctly parse Responses API tool call events
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled
- The Responses API does not use 'response.item.delta' for tool calls.
- It uses 'response.output_item.added' to initialize the function call.
- It uses 'response.function_call_arguments.delta' for the payload stream.
- Updated the streaming parser to catch these events and correctly yield ToolCallDelta objects.
- This restores proper streaming of tool calls back to the client.
2026-03-18 16:13:13 +00:00
24a898c9a7 fix(openai): gracefully handle stream endings
- The Responses API ends streams without a final '[DONE]' message.
- This causes reqwest_eventsource to return Error::StreamEnded.
- Previously, this was treated as a premature termination, triggering an error probe.
- We now explicitly match and break on Err(StreamEnded) for normal completion.
2026-03-18 15:39:18 +00:00
96 changed files with 5244 additions and 17579 deletions

28
.env
View File

@@ -1,28 +0,0 @@
# LLM Proxy Gateway Environment Variables
# OpenAI
OPENAI_API_KEY=sk-demo-openai-key
# Google Gemini
GEMINI_API_KEY=AIza-demo-gemini-key
# DeepSeek
DEEPSEEK_API_KEY=sk-demo-deepseek-key
# xAI Grok (not yet available)
GROK_API_KEY=gk-demo-grok-key
# Authentication tokens (comma-separated list)
LLM_PROXY__SERVER__AUTH_TOKENS=demo-token-123456,another-token
# Database path (optional)
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
# Session Secret (for signed tokens)
SESSION_SECRET=ki9khXAk9usDkasMrD2UbK4LOgrDRJz0
# Encryption key (required)
LLM_PROXY__ENCRYPTION_KEY=69879f5b7913ba169982190526ae213e830b3f1f33e785ef2b68cf48c7853fcd
# Server port (optional)
LLM_PROXY__SERVER__PORT=8080

View File

@@ -1,22 +0,0 @@
# LLM Proxy Gateway Environment Variables
# OpenAI
OPENAI_API_KEY=sk-demo-openai-key
# Google Gemini
GEMINI_API_KEY=AIza-demo-gemini-key
# DeepSeek
DEEPSEEK_API_KEY=sk-demo-deepseek-key
# xAI Grok (not yet available)
GROK_API_KEY=gk-demo-grok-key
# Authentication tokens (comma-separated list)
LLM_PROXY__SERVER__AUTH_TOKENS=demo-token-123456,another-token
# Server port (optional)
LLM_PROXY__SERVER__PORT=8080
# Database path (optional)
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db

View File

@@ -1,31 +1,47 @@
# LLM Proxy Gateway Environment Variables # GopherGate Configuration Example
# Copy to .env and fill in your API keys # Copy this file to .env and fill in your values
# OpenAI # ==============================================================================
OPENAI_API_KEY=your_openai_api_key_here # MANDATORY: Encryption & Security
# ==============================================================================
# A 32-byte hex or base64 encoded string used for session signing and
# database encryption.
# Generate one with: openssl rand -hex 32
LLM_PROXY__ENCRYPTION_KEY=your_secure_32_byte_key_here
# Google Gemini # ==============================================================================
GEMINI_API_KEY=your_gemini_api_key_here # LLM Provider API Keys
# ==============================================================================
OPENAI_API_KEY=sk-...
GEMINI_API_KEY=AIza...
DEEPSEEK_API_KEY=sk-...
MOONSHOT_API_KEY=sk-...
GROK_API_KEY=xai-...
# DeepSeek # ==============================================================================
DEEPSEEK_API_KEY=your_deepseek_api_key_here # Server Configuration
# ==============================================================================
LLM_PROXY__SERVER__PORT=8080
LLM_PROXY__SERVER__HOST=0.0.0.0
# xAI Grok (not yet available) # Optional: Bearer tokens for client authentication (comma-separated)
GROK_API_KEY=your_grok_api_key_here # If not set, the proxy will look up tokens in the database.
# LLM_PROXY__SERVER__AUTH_TOKENS=token1,token2
# Ollama (local server) # ==============================================================================
# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://your-ollama-host:11434/v1 # Database Configuration
# ==============================================================================
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
LLM_PROXY__DATABASE__MAX_CONNECTIONS=10
# ==============================================================================
# Provider Overrides (Optional)
# ==============================================================================
# LLM_PROXY__PROVIDERS__OPENAI__BASE_URL=https://api.openai.com/v1
# LLM_PROXY__PROVIDERS__GEMINI__ENABLED=true
# LLM_PROXY__PROVIDERS__MOONSHOT__BASE_URL=https://api.moonshot.ai/v1
# LLM_PROXY__PROVIDERS__MOONSHOT__ENABLED=true
# LLM_PROXY__PROVIDERS__MOONSHOT__DEFAULT_MODEL=kimi-k2.5
# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://localhost:11434/v1
# LLM_PROXY__PROVIDERS__OLLAMA__ENABLED=true # LLM_PROXY__PROVIDERS__OLLAMA__ENABLED=true
# LLM_PROXY__PROVIDERS__OLLAMA__MODELS=llama3,mistral,llava # LLM_PROXY__PROVIDERS__OLLAMA__MODELS=llama3,mistral,llava
# Authentication tokens (comma-separated list)
LLM_PROXY__SERVER__AUTH_TOKENS=your_bearer_token_here,another_token
# Server port (optional)
LLM_PROXY__SERVER__PORT=8080
# Database path (optional)
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
# Session secret for HMAC-signed tokens (hex or base64 encoded, 32 bytes)
SESSION_SECRET=your_session_secret_here_32_bytes

View File

@@ -6,56 +6,44 @@ on:
pull_request: pull_request:
branches: [main] branches: [main]
env:
CARGO_TERM_COLOR: always
RUST_BACKTRACE: 1
jobs: jobs:
check: lint:
name: Check name: Lint
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable - name: Set up Go
- uses: Swatinem/rust-cache@v2 uses: actions/setup-go@v5
- run: cargo check --all-targets
clippy:
name: Clippy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with: with:
components: clippy go-version: '1.22'
- uses: Swatinem/rust-cache@v2 cache: true
- run: cargo clippy --all-targets -- -D warnings - name: golangci-lint
uses: golangci/golangci-lint-action@v4
fmt:
name: Formatting
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with: with:
components: rustfmt version: latest
- run: cargo fmt --all -- --check
test: test:
name: Test name: Test
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable - name: Set up Go
- uses: Swatinem/rust-cache@v2 uses: actions/setup-go@v5
- run: cargo test --all-targets with:
go-version: '1.22'
cache: true
- name: Run Tests
run: go test -v ./...
build-release: build:
name: Release Build name: Build
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [check, clippy, test]
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable - name: Set up Go
- uses: Swatinem/rust-cache@v2 uses: actions/setup-go@v5
- run: cargo build --release with:
go-version: '1.22'
cache: true
- name: Build
run: go build -v -o gophergate ./cmd/gophergate

16
.gitignore vendored
View File

@@ -1,5 +1,13 @@
.env
.env.*
!.env.example
/target /target
/.env /llm-proxy
/*.db /llm-proxy-go
/*.db-shm /gophergate
/*.db-wal /data/
*.db
*.db-shm
*.db-wal
*.log
server.pid

62
BACKEND_ARCHITECTURE.md Normal file
View File

@@ -0,0 +1,62 @@
# Backend Architecture (Go)
The GopherGate backend is implemented in Go, focusing on high performance, clear concurrency patterns, and maintainability.
## Core Technologies
- **Runtime:** Go 1.22+
- **Web Framework:** [Gin Gonic](https://github.com/gin-gonic/gin) - Fast and lightweight HTTP routing.
- **Database:** [sqlx](https://github.com/jmoiron/sqlx) - Lightweight wrapper for standard `database/sql`.
- **SQLite Driver:** [modernc.org/sqlite](https://modernc.org/sqlite) - CGO-free SQLite implementation for ease of cross-compilation.
- **Config:** [Viper](https://github.com/spf13/viper) - Robust configuration management supporting environment variables and files.
- **Metrics:** [gopsutil](https://github.com/shirou/gopsutil) - System-level resource monitoring.
## Project Structure
```text
├── cmd/
│ └── gophergate/ # Entry point (main.go)
├── internal/
│ ├── config/ # Configuration loading and validation
│ ├── db/ # Database schema, migrations, and models
│ ├── middleware/ # Auth and logging middleware
│ ├── models/ # Unified request/response structs
│ ├── providers/ # LLM provider implementations (OpenAI, Gemini, etc.)
│ ├── server/ # HTTP server, dashboard handlers, and WebSocket hub
│ └── utils/ # Common utilities (registry, pricing, etc.)
└── static/ # Frontend assets (served by the backend)
```
## Key Components
### 1. Provider Interface (`internal/providers/provider.go`)
Standardized interface for all LLM backends. Implementations handle mapping between the unified format and provider-specific APIs (OpenAI, Gemini, DeepSeek, Grok).
### 2. Model Registry & Pricing (`internal/utils/registry.go`)
Integrates with `models.dev/api.json` to provide real-time model metadata and pricing.
- **Fuzzy Matching:** Supports matching versioned model IDs (e.g., `gpt-4o-2024-08-06`) to base registry entries.
- **Automatic Refreshes:** The registry is fetched at startup and refreshed every 24 hours via a background goroutine.
### 3. Asynchronous Logging (`internal/server/logging.go`)
Uses a buffered channel and background worker to log every request to SQLite without blocking the client response. It also broadcasts logs to the WebSocket hub for real-time dashboard updates.
### 4. Session Management (`internal/server/sessions.go`)
Implements HMAC-SHA256 signed tokens for dashboard authentication. Tokens secure the management interface while standard Bearer tokens are used for LLM API access.
### 5. WebSocket Hub (`internal/server/websocket.go`)
A centralized hub for managing WebSocket connections, allowing real-time broadcast of system events, system metrics, and request logs to the dashboard.
## Concurrency Model
Go's goroutines and channels are used extensively:
- **Streaming:** Each streaming request uses a goroutine to read and parse the provider's response, feeding chunks into a channel for SSE delivery.
- **Logging:** A single background worker processes the `logChan` to perform serial database writes.
- **WebSocket:** The `Hub` runs in a dedicated goroutine, handling registration and broadcasting.
- **Maintenance:** Background tasks handle registry refreshes and status monitoring.
## Security
- **Encryption Key:** A mandatory 32-byte key is used for both session signing and encryption of sensitive data.
- **Auth Middleware:** Scoped to `/v1` routes to verify client API keys against the database.
- **Bcrypt:** Passwords for dashboard users are hashed using Bcrypt with a work factor of 12.
- **Database Hardening:** Automatic migrations ensure the schema is always current with the code.

View File

@@ -1,65 +0,0 @@
# LLM Proxy Code Review Plan
## Overview
The **LLM Proxy** project is a Rust-based middleware designed to provide a unified interface for multiple Large Language Models (LLMs). Based on the repository structure, the project aims to implement a high-performance proxy server (`src/`) that handles request routing, usage tracking, and billing logic. A static dashboard (`static/`) provides a management interface for monitoring consumption and managing API keys. The architecture leverages Rust's async capabilities for efficient request handling and SQLite for persistent state management.
## Review Phases
### Phase 1: Backend Architecture & Rust Logic (@code-reviewer)
- **Focus on:**
- **Core Proxy Logic:** Efficiency of the request/response pipeline and streaming support.
- **State Management:** Thread-safety and shared state patterns using `Arc` and `Mutex`/`RwLock`.
- **Error Handling:** Use of idiomatic Rust error types and propagation.
- **Async Performance:** Proper use of `tokio` or similar runtimes to avoid blocking the executor.
- **Rust Idioms:** Adherence to Clippy suggestions and standard Rust naming conventions.
### Phase 2: Security & Authentication Audit (@security-auditor)
- **Focus on:**
- **API Key Management:** Secure storage, masking in logs, and rotation mechanisms.
- **JWT Handling:** Validation logic, signature verification, and expiration checks.
- **Input Validation:** Sanitization of prompts and configuration parameters to prevent injection.
- **Dependency Audit:** Scanning for known vulnerabilities in the `Cargo.lock` using `cargo-audit`.
### Phase 3: Database & Data Integrity Review (@database-optimizer)
- **Focus on:**
- **Schema Design:** Efficiency of the SQLite schema for usage tracking and billing.
- **Migration Strategy:** Robustness of the migration scripts to prevent data loss.
- **Usage Tracking:** Accuracy of token counting and concurrency handling during increments.
- **Query Optimization:** Identifying potential bottlenecks in reporting queries.
### Phase 4: Frontend & Dashboard Review (@frontend-developer)
- **Focus on:**
- **Vanilla JS Patterns:** Review of Web Components and modular JS in `static/js`.
- **Security:** Protection against XSS in the dashboard and secure handling of local storage.
- **UI/UX Consistency:** Ensuring the management interface is intuitive and responsive.
- **API Integration:** Robustness of the frontend's communication with the Rust backend.
### Phase 5: Infrastructure & Deployment Review (@devops-engineer)
- **Focus on:**
- **Dockerfile Optimization:** Multi-stage builds to minimize image size and attack surface.
- **Resource Limits:** Configuration of CPU/Memory limits for the proxy container.
- **Deployment Docs:** Clarity of the setup process and environment variable documentation.
## Timeline (Gantt)
```mermaid
gantt
title LLM Proxy Code Review Timeline (March 2026)
dateFormat YYYY-MM-DD
section Backend & Security
Architecture & Rust Logic (Phase 1) :active, p1, 2026-03-06, 1d
Security & Auth Audit (Phase 2) :p2, 2026-03-07, 1d
section Data & Frontend
Database & Integrity (Phase 3) :p3, 2026-03-07, 1d
Frontend & Dashboard (Phase 4) :p4, 2026-03-08, 1d
section DevOps
Infra & Deployment (Phase 5) :p5, 2026-03-08, 1d
Final Review & Sign-off :2026-03-08, 4h
```
## Success Criteria
- **Security:** Zero high-priority vulnerabilities identified; all API keys masked in logs.
- **Performance:** Proxy overhead is minimal (<10ms latency addition); queries are indexed.
- **Maintainability:** Code passes all linting (`cargo clippy`) and formatting (`cargo fmt`) checks.
- **Documentation:** README and deployment guides are up-to-date and accurate.
- **Reliability:** Usage tracking matches actual API consumption with 99.9% accuracy.

4139
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,75 +0,0 @@
[package]
name = "llm-proxy"
version = "0.1.0"
edition = "2024"
rust-version = "1.87"
description = "Unified LLM proxy gateway supporting OpenAI, Gemini, DeepSeek, and Grok with token tracking and cost calculation"
authors = ["newkirk"]
license = "MIT OR Apache-2.0"
repository = ""
[dependencies]
# ========== Web Framework & Async Runtime ==========
axum = { version = "0.8", features = ["macros", "ws"] }
tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "net", "time", "signal", "fs"] }
tower = "0.5"
tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip", "fs", "set-header", "limit"] }
governor = "0.7"
# ========== HTTP Clients ==========
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
tiktoken-rs = "0.9"
# ========== Database & ORM ==========
sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite", "macros", "migrate", "chrono"] }
# ========== Authentication & Middleware ==========
axum-extra = { version = "0.12", features = ["typed-header"] }
headers = "0.4"
# ========== Configuration Management ==========
config = "0.13"
dotenvy = "0.15"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
toml = "0.8"
# ========== Logging & Monitoring ==========
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
# ========== Multimodal & Image Processing ==========
base64 = "0.21"
image = { version = "0.25", default-features = false, features = ["jpeg", "png", "webp"] }
mime = "0.3"
# ========== Error Handling & Utilities ==========
anyhow = "1.0"
thiserror = "1.0"
bcrypt = "0.15"
aes-gcm = "0.10"
hmac = "0.12"
sha2 = "0.10"
chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1.0", features = ["v4", "serde"] }
futures = "0.3"
async-trait = "0.1"
async-stream = "0.3"
reqwest-eventsource = "0.6"
rand = "0.9"
hex = "0.4"
[dev-dependencies]
tokio-test = "0.4"
mockito = "1.0"
tempfile = "3.10"
assert_cmd = "2.0"
insta = "1.39"
anyhow = "1.0"
[profile.release]
opt-level = 3
lto = true
codegen-units = 1
strip = true
panic = "abort"

View File

@@ -1,220 +0,0 @@
# LLM Proxy Gateway - Admin Dashboard
## Overview
This is a comprehensive admin dashboard for the LLM Proxy Gateway, providing real-time monitoring, analytics, and management capabilities for the proxy service.
## Features
### 1. Dashboard Overview
- Real-time request counters and statistics
- System health indicators
- Provider status monitoring
- Recent requests stream
### 2. Usage Analytics
- Time series charts for requests, tokens, and costs
- Filter by date range, client, provider, and model
- Top clients and models analysis
- Export functionality to CSV/JSON
### 3. Cost Management
- Cost breakdown by provider, client, and model
- Budget tracking with alerts
- Cost projections
- Pricing configuration management
### 4. Client Management
- List, create, revoke, and rotate API tokens
- Client-specific rate limits
- Usage statistics per client
- Token management interface
### 5. Provider Configuration
- Enable/disable LLM providers
- Configure API keys (masked display)
- Test provider connections
- Model availability management
### 6. User Management (RBAC)
- **Admin Role:** Full access to all dashboard features, user management, system configuration
- **Viewer Role:** Read-only access to usage analytics, costs, and monitoring
- Create/manage dashboard users with role assignment
- Secure password management
### 7. Real-time Monitoring
- Live request stream via WebSocket
- System metrics dashboard
- Response time and error rate tracking
- Live system logs
### 7. **System Settings**
- General configuration
- Database management
- Logging settings
- Security settings
## Technology Stack
### Frontend
- **HTML5/CSS3**: Modern, responsive design with CSS Grid/Flexbox
- **JavaScript (ES6+)**: Vanilla JavaScript with modular architecture
- **Chart.js**: Interactive data visualizations
- **Luxon**: Date/time manipulation
- **WebSocket API**: Real-time updates
### Backend (Rust/Axum)
- **Axum**: Web framework with WebSocket support
- **Tokio**: Async runtime
- **Serde**: JSON serialization/deserialization
- **Broadcast channels**: Real-time event distribution
## Installation & Setup
### 1. Build and Run the Server
```bash
# Build the project
cargo build --release
# Run the server
cargo run --release
```
### 2. Access the Dashboard
Once the server is running, access the dashboard at:
```
http://localhost:8080
```
### 3. Default Login Credentials
- **Username**: `admin`
- **Password**: `admin123`
## API Endpoints
### Authentication
- `POST /api/auth/login` - Dashboard login
- `GET /api/auth/status` - Authentication status
### Analytics
- `GET /api/usage/summary` - Overall usage summary
- `GET /api/usage/time-series` - Time series data
- `GET /api/usage/clients` - Client breakdown
- `GET /api/usage/providers` - Provider breakdown
### Clients
- `GET /api/clients` - List all clients
- `POST /api/clients` - Create new client
- `PUT /api/clients/{id}` - Update client
- `DELETE /api/clients/{id}` - Revoke client
- `GET /api/clients/{id}/usage` - Client-specific usage
### Users (RBAC)
- `GET /api/users` - List all dashboard users
- `POST /api/users` - Create new user
- `PUT /api/users/{id}` - Update user (admin only)
- `DELETE /api/users/{id}` - Delete user (admin only)
### Providers
- `GET /api/providers` - List providers and status
- `PUT /api/providers/{name}` - Update provider config
- `POST /api/providers/{name}/test` - Test provider connection
### System
- `GET /api/system/health` - System health
- `GET /api/system/logs` - Recent logs
- `POST /api/system/backup` - Trigger backup
### WebSocket
- `GET /ws` - WebSocket endpoint for real-time updates
## Project Structure
```
llm-proxy/
├── src/
│ ├── dashboard/ # Dashboard backend module
│ │ └── mod.rs # Dashboard routes and handlers
│ ├── server/ # Main proxy server
│ ├── providers/ # LLM provider implementations
│ └── ... # Other modules
├── static/ # Frontend dashboard files
│ ├── index.html # Main dashboard HTML
│ ├── css/
│ │ └── dashboard.css # Dashboard styles
│ ├── js/
│ │ ├── auth.js # Authentication module
│ │ ├── dashboard.js # Main dashboard controller
│ │ ├── websocket.js # WebSocket manager
│ │ ├── charts.js # Chart.js utilities
│ │ └── pages/ # Page-specific modules
│ │ ├── overview.js
│ │ ├── analytics.js
│ │ ├── costs.js
│ │ ├── clients.js
│ │ ├── providers.js
│ │ ├── monitoring.js
│ │ ├── settings.js
│ │ └── logs.js
│ ├── img/ # Images and icons
│ └── fonts/ # Font files
└── Cargo.toml # Rust dependencies
```
## Development
### Adding New Pages
1. Create a new JavaScript module in `static/js/pages/`
2. Implement the page class with `init()` method
3. Register the page in `dashboard.js`
4. Add menu item in `index.html`
### Adding New API Endpoints
1. Add route in `src/dashboard/mod.rs`
2. Implement handler function
3. Update frontend JavaScript to call the endpoint
### Styling Guidelines
- Use CSS custom properties (variables) from `:root`
- Follow mobile-first responsive design
- Use BEM-like naming convention for CSS classes
- Maintain consistent spacing with CSS variables
## Security Considerations
1. **Authentication**: Simple password-based auth for demo; replace with proper auth in production
2. **API Keys**: Tokens are masked in the UI (only last 4 characters shown)
3. **CORS**: Configure appropriate CORS headers for production
4. **Rate Limiting**: Implement rate limiting for API endpoints
5. **HTTPS**: Always use HTTPS in production
## Performance Optimizations
1. **Code Splitting**: JavaScript modules are loaded on-demand
2. **Caching**: Static assets are served with cache headers
3. **WebSocket**: Real-time updates reduce polling overhead
4. **Lazy Loading**: Charts and tables load data as needed
5. **Compression**: Enable gzip/brotli compression for static files
## Browser Support
- Chrome 60+
- Firefox 55+
- Safari 11+
- Edge 79+
## License
MIT License - See LICENSE file for details.
## Contributing
1. Fork the repository
2. Create a feature branch
3. Make your changes
4. Add tests if applicable
5. Submit a pull request
## Support
For issues and feature requests, please use the GitHub issue tracker.

View File

@@ -1,480 +0,0 @@
# Database Review Report for LLM-Proxy Repository
**Review Date:** 2025-03-06
**Reviewer:** Database Optimization Expert
**Repository:** llm-proxy
**Focus Areas:** Schema Design, Query Optimization, Migration Strategy, Data Integrity, Usage Tracking Accuracy
## Executive Summary
The llm-proxy database implementation demonstrates solid foundation with appropriate table structures and clear separation of concerns. However, several areas require improvement to ensure scalability, data consistency, and performance as usage grows. Key findings include:
1. **Schema Design**: Generally normalized but missing foreign key enforcement and some critical indexes.
2. **Query Optimization**: Well-optimized for most queries but missing composite indexes for common filtering patterns.
3. **Migration Strategy**: Ad-hoc migration approach that may cause issues with schema evolution.
4. **Data Integrity**: Potential race conditions in usage tracking and missing transaction boundaries.
5. **Usage Tracking**: Generally accurate but risk of inconsistent state between related tables.
This report provides detailed analysis and actionable recommendations for each area.
## 1. Schema Design Review
### Tables Overview
The database consists of 6 main tables:
1. **clients**: Client management with usage aggregates
2. **llm_requests**: Request logging with token counts and costs
3. **provider_configs**: Provider configuration and credit balances
4. **model_configs**: Model-specific configuration and cost overrides
5. **users**: Dashboard user authentication
6. **client_tokens**: API token storage for client authentication
### Normalization Assessment
**Strengths:**
- Tables follow 3rd Normal Form (3NF) with appropriate separation
- Foreign key relationships properly defined
- No obvious data duplication across tables
**Areas for Improvement:**
- **Denormalized aggregates**: `clients.total_requests`, `total_tokens`, `total_cost` are derived from `llm_requests`. This introduces risk of inconsistency.
- **Provider credit balance**: Stored in `provider_configs` but also updated based on `llm_requests`. No audit trail for balance changes.
### Data Type Analysis
**Appropriate Choices:**
- INTEGER for token counts (cast from u32 to i64)
- REAL for monetary values
- DATETIME for timestamps using SQLite's CURRENT_TIMESTAMP
- TEXT for identifiers with appropriate length
**Potential Issues:**
- `llm_requests.request_body` and `response_body` defined as TEXT but always set to NULL - consider removing or making optional columns.
- `provider_configs.billing_mode` added via migration but default value not consistently applied to existing rows.
### Constraints and Foreign Keys
**Current Constraints:**
- Primary keys defined for all tables
- UNIQUE constraints on `clients.client_id`, `users.username`, `client_tokens.token`
- Foreign key definitions present but **not enforced** (SQLite default)
**Missing Constraints:**
- NOT NULL constraints missing on several columns where nullability not intended
- CHECK constraints for positive values (`credit_balance >= 0`)
- Foreign key enforcement not enabled
## 2. Query Optimization Analysis
### Indexing Strategy
**Existing Indexes:**
- `idx_clients_client_id` - Essential for client lookups
- `idx_clients_created_at` - Useful for chronological listing
- `idx_llm_requests_timestamp` - Critical for time-based queries
- `idx_llm_requests_client_id` - Supports client-specific queries
- `idx_llm_requests_provider` - Good for provider breakdowns
- `idx_llm_requests_status` - Low cardinality but acceptable
- `idx_client_tokens_token` UNIQUE - Essential for authentication
- `idx_client_tokens_client_id` - Supports token management
**Missing Critical Indexes:**
1. `model_configs.provider_id` - Foreign key column used in JOINs
2. `llm_requests(client_id, timestamp)` - Composite index for client time-series queries
3. `llm_requests(provider, timestamp)` - For provider performance analysis
4. `llm_requests(status, timestamp)` - For error trend analysis
### N+1 Query Detection
**Well-Optimized Areas:**
- Model configuration caching prevents repeated database hits
- Provider configs loaded in batch for dashboard display
- Client listing uses single efficient query
**Potential N+1 Patterns:**
- In `server/mod.rs` list_models function, cache lookup per model but this is in-memory
- No significant database N+1 issues identified
### Inefficient Query Patterns
**Query 1: Time-series aggregation with strftime()**
```sql
SELECT strftime('%Y-%m-%d', timestamp) as date, ...
FROM llm_requests
WHERE 1=1 {}
GROUP BY date, client_id, provider, model
ORDER BY date DESC
LIMIT 200
```
**Issue:** Function on indexed column prevents index utilization for the WHERE clause when filtering by timestamp range.
**Recommendation:** Store computed date column or use range queries on timestamp directly.
**Query 2: Today's stats using strftime()**
```sql
WHERE strftime('%Y-%m-%d', timestamp) = ?
```
**Issue:** Non-sargable query prevents index usage.
**Recommendation:** Use range query:
```sql
WHERE timestamp >= date(?) AND timestamp < date(?, '+1 day')
```
### Recommended Index Additions
```sql
-- Composite indexes for common query patterns
CREATE INDEX idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp);
CREATE INDEX idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp);
CREATE INDEX idx_llm_requests_status_timestamp ON llm_requests(status, timestamp);
-- Foreign key index
CREATE INDEX idx_model_configs_provider_id ON model_configs(provider_id);
-- Optional: Covering index for client usage queries
CREATE INDEX idx_clients_usage ON clients(client_id, total_requests, total_tokens, total_cost);
```
## 3. Migration Strategy Assessment
### Current Approach
The migration system uses a hybrid approach:
1. **Schema synchronization**: `CREATE TABLE IF NOT EXISTS` on startup
2. **Ad-hoc migrations**: `ALTER TABLE` statements with error suppression
3. **Single migration file**: `migrations/001-add-billing-mode.sql` with transaction wrapper
**Pros:**
- Simple to understand and maintain
- Automatic schema creation for new deployments
- Error suppression prevents crashes on column existence
**Cons:**
- No version tracking of applied migrations
- Potential for inconsistent schema across deployments
- `ALTER TABLE` error suppression hides genuine schema issues
- No rollback capability
### Risks and Limitations
1. **Schema Drift**: Different instances may have different schemas if migrations are applied out of order
2. **Data Loss Risk**: No backup/verification before schema changes
3. **Production Issues**: Error suppression could mask migration failures until runtime
### Recommendations
1. **Implement Proper Migration Tooling**: Use `sqlx migrate` or similar versioned migration system
2. **Add Migration Version Table**: Track applied migrations and checksum verification
3. **Separate Migration Scripts**: One file per migration with up/down directions
4. **Pre-deployment Validation**: Schema checks in CI/CD pipeline
5. **Backup Strategy**: Automatic backups before migration execution
## 4. Data Integrity Evaluation
### Foreign Key Enforcement
**Critical Issue:** Foreign key constraints are defined but **not enforced** in SQLite.
**Impact:** Orphaned records, inconsistent referential integrity.
**Solution:** Enable foreign key support in connection string:
```rust
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?
.create_if_missing(true)
.pragma("foreign_keys", "ON");
```
### Transaction Usage
**Good Patterns:**
- Request logging uses transactions for insert + provider balance update
- Atomic UPDATE for client usage statistics
**Problematic Areas:**
1. **Split Transactions**: Client usage update and request logging are in separate transactions
- In `logging/mod.rs`: `insert_log` transaction includes provider balance update
- In `utils/streaming.rs`: Client usage updated separately after logging
- **Risk**: Partial updates if one transaction fails
2. **No Transaction for Client Creation**: Client and token creation not atomic
**Recommendations:**
- Wrap client usage update within the same transaction as request logging
- Use transaction for client + token creation
- Consider using savepoints for complex operations
### Race Conditions and Consistency
**Potential Race Conditions:**
1. **Provider credit balance**: Concurrent requests may cause lost updates
- Current: `UPDATE provider_configs SET credit_balance = credit_balance - ?`
- SQLite provides serializable isolation, but negative balances not prevented
2. **Client usage aggregates**: Concurrent updates to `total_requests`, `total_tokens`, `total_cost`
- Similar UPDATE pattern, generally safe but consider idempotency
**Recommendations:**
- Add check constraint: `CHECK (credit_balance >= 0)`
- Implement idempotent request logging with unique request IDs
- Consider optimistic concurrency control for critical balances
## 5. Usage Tracking Accuracy
### Token Counting Methodology
**Current Approach:**
- Prompt tokens: Estimated using provider-specific estimators
- Completion tokens: Estimated or from provider real usage data
- Cache tokens: Separately tracked for cache-aware pricing
**Strengths:**
- Fallback to estimation when provider doesn't report usage
- Cache token differentiation for accurate pricing
**Weaknesses:**
- Estimation may differ from actual provider counts
- No validation of provider-reported token counts
### Cost Calculation
**Well Implemented:**
- Model-specific cost overrides via `model_configs`
- Cache-aware pricing when supported by registry
- Provider fallback calculations
**Potential Issues:**
- Floating-point precision for monetary calculations
- No rounding strategy for fractional cents
### Update Consistency
**Inconsistency Risk:** Client aggregates updated separately from request logging.
**Example Flow:**
1. Request log inserted and provider balance updated (transaction)
2. Client usage updated (separate operation)
3. If step 2 fails, client stats undercount usage
**Solution:** Include client update in the same transaction:
```rust
// In insert_log function, add:
UPDATE clients
SET total_requests = total_requests + 1,
total_tokens = total_tokens + ?,
total_cost = total_cost + ?
WHERE client_id = ?;
```
### Financial Accuracy
**Good Practices:**
- Token-level granularity for cost calculation
- Separation of prompt/completion/cache pricing
- Database persistence for audit trail
**Recommendations:**
1. **Audit Trail**: Add `balance_transactions` table for provider credit changes
2. **Rounding Policy**: Define rounding strategy (e.g., to 6 decimal places)
3. **Validation**: Periodic reconciliation of aggregates vs. detail records
## 6. Performance Recommendations
### Schema Improvements
1. **Partitioning Strategy**: For high-volume `llm_requests`, consider:
- Monthly partitioning by timestamp
- Archive old data to separate tables
2. **Data Retention Policy**: Implement automatic cleanup of old request logs
```sql
DELETE FROM llm_requests WHERE timestamp < date('now', '-90 days');
```
3. **Column Optimization**: Remove unused `request_body`, `response_body` columns or implement compression
### Query Optimizations
1. **Avoid Functions on Indexed Columns**: Rewrite date queries as range queries
2. **Batch Updates**: Consider batch updates for client usage instead of per-request
3. **Read Replicas**: For dashboard queries, consider separate read connection
### Connection Pooling
**Current:** SQLx connection pool with default settings
**Recommendations:**
- Configure pool size based on expected concurrency
- Implement connection health checks
- Monitor pool utilization metrics
### Monitoring Setup
**Essential Metrics:**
- Query execution times (slow query logging)
- Index usage statistics
- Table growth trends
- Connection pool utilization
**Implementation:**
- Add `sqlx::metrics` integration
- Regular `ANALYZE` execution for query planner
- Dashboard for database health monitoring
## 7. Security Considerations
### Data Protection
**Sensitive Data:**
- `provider_configs.api_key` - Should be encrypted at rest
- `users.password_hash` - Already hashed with bcrypt
- `client_tokens.token` - Plain text storage
**Recommendations:**
- Encrypt API keys using libsodium or similar
- Implement token hashing (similar to password hashing)
- Regular security audits of authentication flows
### SQL Injection Prevention
**Good Practices:**
- Use sqlx query builder with parameter binding
- No raw SQL concatenation observed in code review
**Verification Needed:** Ensure all dynamic SQL uses parameterized queries
### Access Controls
**Database Level:**
- SQLite lacks built-in user management
- Consider file system permissions for database file
- Application-level authentication is primary control
## 8. Summary of Critical Issues
**Priority 1 (Critical):**
1. Foreign key constraints not enabled
2. Split transactions risking data inconsistency
3. Missing composite indexes for common queries
**Priority 2 (High):**
1. No proper migration versioning system
2. Potential race conditions in balance updates
3. Non-sargable date queries impacting performance
**Priority 3 (Medium):**
1. Denormalized aggregates without consistency guarantees
2. No data retention policy for request logs
3. Missing check constraints for data validation
## 9. Recommended Action Plan
### Phase 1: Immediate Fixes (1-2 weeks)
1. Enable foreign key constraints in database connection
2. Add composite indexes for common query patterns
3. Fix transaction boundaries for client usage updates
4. Rewrite non-sargable date queries
### Phase 2: Short-term Improvements (3-4 weeks)
1. Implement proper migration system with version tracking
2. Add check constraints for data validation
3. Implement connection pooling configuration
4. Create database monitoring dashboard
### Phase 3: Long-term Enhancements (2-3 months)
1. Implement data retention and archiving strategy
2. Add audit trail for provider balance changes
3. Consider partitioning for high-volume tables
4. Implement encryption for sensitive data
### Phase 4: Ongoing Maintenance
1. Regular index maintenance and query plan analysis
2. Periodic reconciliation of aggregate vs. detail data
3. Security audits and dependency updates
4. Performance benchmarking and optimization
---
## Appendices
### A. Sample Migration Implementation
```sql
-- migrations/002-enable-foreign-keys.sql
PRAGMA foreign_keys = ON;
-- migrations/003-add-composite-indexes.sql
CREATE INDEX idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp);
CREATE INDEX idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp);
CREATE INDEX idx_model_configs_provider_id ON model_configs(provider_id);
```
### B. Transaction Fix Example
```rust
async fn insert_log(pool: &SqlitePool, log: RequestLog) -> Result<(), sqlx::Error> {
let mut tx = pool.begin().await?;
// Insert or ignore client
sqlx::query("INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')")
.bind(&log.client_id)
.bind(&log.client_id)
.execute(&mut *tx)
.await?;
// Insert request log
sqlx::query("INSERT INTO llm_requests ...")
.execute(&mut *tx)
.await?;
// Update provider balance
if log.cost > 0.0 {
sqlx::query("UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')")
.bind(log.cost)
.bind(&log.provider)
.execute(&mut *tx)
.await?;
}
// Update client aggregates within same transaction
sqlx::query("UPDATE clients SET total_requests = total_requests + 1, total_tokens = total_tokens + ?, total_cost = total_cost + ? WHERE client_id = ?")
.bind(log.total_tokens as i64)
.bind(log.cost)
.bind(&log.client_id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(())
}
```
### C. Monitoring Query Examples
```sql
-- Identify unused indexes
SELECT * FROM sqlite_master
WHERE type = 'index'
AND name NOT IN (
SELECT DISTINCT name
FROM sqlite_stat1
WHERE tbl = 'llm_requests'
);
-- Table size analysis
SELECT name, (pgsize * page_count) / 1024 / 1024 as size_mb
FROM dbstat
WHERE name = 'llm_requests';
-- Query performance analysis (requires EXPLAIN QUERY PLAN)
EXPLAIN QUERY PLAN
SELECT * FROM llm_requests
WHERE client_id = ? AND timestamp >= ?;
```
---
*This report provides a comprehensive analysis of the current database implementation and actionable recommendations for improvement. Regular review and iteration will ensure the database continues to meet performance, consistency, and scalability requirements as the application grows.*

View File

@@ -1,35 +1,34 @@
# ── Build stage ────────────────────────────────────────────── # Build stage
FROM rust:1-bookworm AS builder FROM golang:1.22-alpine AS builder
WORKDIR /app WORKDIR /app
# Cache dependency build # Copy go mod and sum files
COPY Cargo.toml Cargo.lock ./ COPY go.mod go.sum ./
RUN mkdir src && echo 'fn main() {}' > src/main.rs && \ RUN go mod download
cargo build --release && \
rm -rf src
# Build the actual binary # Copy the source code
COPY src/ src/ COPY . .
RUN touch src/main.rs && cargo build --release
# ── Runtime stage ──────────────────────────────────────────── # Build the application
FROM debian:bookworm-slim RUN CGO_ENABLED=0 GOOS=linux go build -o gophergate ./cmd/gophergate
RUN apt-get update && \ # Final stage
apt-get install -y --no-install-recommends ca-certificates && \ FROM alpine:latest
rm -rf /var/lib/apt/lists/*
RUN apk --no-cache add ca-certificates tzdata
WORKDIR /app WORKDIR /app
COPY --from=builder /app/target/release/llm-proxy /app/llm-proxy # Copy the binary from the builder stage
COPY static/ /app/static/ COPY --from=builder /app/gophergate .
COPY --from=builder /app/static ./static
# Default config location # Create data directory
VOLUME ["/app/config", "/app/data"] RUN mkdir -p /app/data
# Expose port
EXPOSE 8080 EXPOSE 8080
ENV RUST_LOG=info # Run the application
CMD ["./gophergate"]
ENTRYPOINT ["/app/llm-proxy"]

View File

@@ -1,232 +0,0 @@
# Optimization for 512MB RAM Environment
This document provides guidance for optimizing the LLM Proxy Gateway for deployment in resource-constrained environments (512MB RAM).
## Memory Optimization Strategies
### 1. Build Optimization
The project is already configured with optimized build settings in `Cargo.toml`:
```toml
[profile.release]
opt-level = 3 # Maximum optimization
lto = true # Link-time optimization
codegen-units = 1 # Single codegen unit for better optimization
strip = true # Strip debug symbols
```
**Additional optimizations you can apply:**
```bash
# Build with specific target for better optimization
cargo build --release --target x86_64-unknown-linux-musl
# Or for ARM (Raspberry Pi, etc.)
cargo build --release --target aarch64-unknown-linux-musl
```
### 2. Runtime Memory Management
#### Database Connection Pool
- Default: 10 connections
- Recommended for 512MB: 5 connections
Update `config.toml`:
```toml
[database]
max_connections = 5
```
#### Rate Limiting Memory Usage
- Client rate limit buckets: Store in memory
- Circuit breakers: Minimal memory usage
- Consider reducing burst capacity if memory is critical
#### Provider Management
- Only enable providers you actually use
- Disable unused providers in configuration
### 3. Configuration for Low Memory
Create a `config-low-memory.toml`:
```toml
[server]
port = 8080
host = "0.0.0.0"
[database]
path = "./data/llm_proxy.db"
max_connections = 3 # Reduced from default 10
[providers]
# Only enable providers you need
openai.enabled = true
gemini.enabled = false # Disable if not used
deepseek.enabled = false # Disable if not used
grok.enabled = false # Disable if not used
[rate_limiting]
# Reduce memory usage for rate limiting
client_requests_per_minute = 30 # Reduced from 60
client_burst_size = 5 # Reduced from 10
global_requests_per_minute = 300 # Reduced from 600
```
### 4. System-Level Optimizations
#### Linux Kernel Parameters
Add to `/etc/sysctl.conf`:
```bash
# Reduce TCP buffer sizes
net.ipv4.tcp_rmem = 4096 87380 174760
net.ipv4.tcp_wmem = 4096 65536 131072
# Reduce connection tracking
net.netfilter.nf_conntrack_max = 65536
net.netfilter.nf_conntrack_tcp_timeout_established = 1200
# Reduce socket buffer sizes
net.core.rmem_max = 131072
net.core.wmem_max = 131072
net.core.rmem_default = 65536
net.core.wmem_default = 65536
```
#### Systemd Service Configuration
Create `/etc/systemd/system/llm-proxy.service`:
```ini
[Unit]
Description=LLM Proxy Gateway
After=network.target
[Service]
Type=simple
User=llmproxy
Group=llmproxy
WorkingDirectory=/opt/llm-proxy
ExecStart=/opt/llm-proxy/llm-proxy
Restart=on-failure
RestartSec=5
# Memory limits
MemoryMax=400M
MemorySwapMax=100M
# CPU limits
CPUQuota=50%
# Process limits
LimitNOFILE=65536
LimitNPROC=512
Environment="RUST_LOG=info"
Environment="LLM_PROXY__DATABASE__MAX_CONNECTIONS=3"
[Install]
WantedBy=multi-user.target
```
### 5. Application-Specific Optimizations
#### Disable Unused Features
- **Multimodal support**: If not using images, disable image processing dependencies
- **Dashboard**: The dashboard uses WebSockets and additional memory. Consider disabling if not needed.
- **Detailed logging**: Reduce log verbosity in production
#### Memory Pool Sizes
The application uses several memory pools:
1. **Database connection pool**: Configured via `max_connections`
2. **HTTP client pool**: Reqwest client pool (defaults to reasonable values)
3. **Async runtime**: Tokio worker threads
Reduce Tokio worker threads for low-core systems:
```rust
// In main.rs, modify tokio runtime creation
#[tokio::main(flavor = "current_thread")] // Single-threaded runtime
async fn main() -> Result<()> {
// Or for multi-threaded with limited threads:
// #[tokio::main(worker_threads = 2)]
```
### 6. Monitoring and Profiling
#### Memory Usage Monitoring
```bash
# Install heaptrack for memory profiling
cargo install heaptrack
# Profile memory usage
heaptrack ./target/release/llm-proxy
# Monitor with ps
ps aux --sort=-%mem | head -10
# Monitor with top
top -p $(pgrep llm-proxy)
```
#### Performance Benchmarks
Test with different configurations:
```bash
# Test with 100 concurrent connections
wrk -t4 -c100 -d30s http://localhost:8080/health
# Test chat completion endpoint
ab -n 1000 -c 10 -p test_request.json -T application/json http://localhost:8080/v1/chat/completions
```
### 7. Deployment Checklist for 512MB RAM
- [ ] Build with release profile: `cargo build --release`
- [ ] Configure database with `max_connections = 3`
- [ ] Disable unused providers in configuration
- [ ] Set appropriate rate limiting limits
- [ ] Configure systemd with memory limits
- [ ] Set up log rotation to prevent disk space issues
- [ ] Monitor memory usage during initial deployment
- [ ] Consider using swap space (512MB-1GB) for safety
### 8. Troubleshooting High Memory Usage
#### Common Issues and Solutions:
1. **Database connection leaks**: Ensure connections are properly closed
2. **Memory fragmentation**: Use jemalloc or mimalloc as allocator
3. **Unbounded queues**: Check WebSocket message queues
4. **Cache growth**: Implement cache limits or TTL
#### Add to Cargo.toml for alternative allocator:
```toml
[dependencies]
mimalloc = { version = "0.1", default-features = false }
[features]
default = ["mimalloc"]
```
#### In main.rs:
```rust
#[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
```
### 9. Expected Memory Usage
| Component | Baseline | With 10 clients | With 100 clients |
|-----------|----------|-----------------|------------------|
| Base executable | 15MB | 15MB | 15MB |
| Database connections | 5MB | 8MB | 15MB |
| Rate limiting | 2MB | 5MB | 20MB |
| HTTP clients | 3MB | 5MB | 10MB |
| **Total** | **25MB** | **33MB** | **60MB** |
**Note**: These are estimates. Actual usage depends on request volume, payload sizes, and configuration.
### 10. Further Reading
- [Tokio performance guide](https://tokio.rs/tokio/topics/performance)
- [Rust performance book](https://nnethercote.github.io/perf-book/)
- [Linux memory management](https://www.kernel.org/doc/html/latest/admin-guide/mm/)
- [SQLite performance tips](https://www.sqlite.org/faq.html#q19)

99
PLAN.md
View File

@@ -1,99 +0,0 @@
# Project Plan: LLM Proxy Enhancements & Security Upgrade
This document outlines the roadmap for standardizing frontend security, cleaning up the codebase, upgrading session management to HMAC-signed tokens, and extending integration testing.
## Phase 1: Frontend Security Standardization
**Primary Agent:** `frontend-developer`
- [x] Audit `static/js/pages/users.js` for manual HTML string concatenation.
- [x] Replace custom escaping or unescaped injections with `window.api.escapeHtml`.
- [x] Verify user list and user detail rendering for XSS vulnerabilities.
## Phase 2: Codebase Cleanup
**Primary Agent:** `backend-developer`
- [x] Identify and remove unused imports in `src/config/mod.rs`.
- [x] Identify and remove unused imports in `src/providers/mod.rs`.
- [x] Run `cargo clippy` and `cargo fmt` to ensure adherence to standards.
## Phase 3: HMAC Architectural Upgrade
**Primary Agents:** `fullstack-developer`, `security-auditor`, `backend-developer`
### 3.1 Design (Security Auditor)
- [x] Define Token Structure: `base64(payload).signature`.
- Payload: `{ "session_id": "...", "username": "...", "role": "...", "exp": ... }`
- [x] Select HMAC algorithm (HMAC-SHA256).
- [x] Define environment variable for secret key: `SESSION_SECRET`.
### 3.2 Implementation (Backend Developer)
- [x] Refactor `src/dashboard/sessions.rs`:
- Integrate `hmac` and `sha2` crates (or similar).
- Update `create_session` to return signed tokens.
- Update `validate_session` to verify signature before checking store.
- [x] Implement activity-based session refresh:
- If session is valid and >50% through its TTL, extend `expires_at` and issue new signed token.
### 3.3 Integration (Fullstack Developer)
- [x] Update dashboard API handlers to handle new token format.
- [x] Update frontend session storage/retrieval if necessary.
## Phase 4: Extended Integration Testing
**Primary Agent:** `qa-automation`
- [ ] Setup test environment with encrypted key storage enabled.
- [ ] Implement end-to-end flow:
1. Store encrypted provider key via API.
2. Authenticate through Proxy.
3. Make proxied LLM request (verifying decryption and usage).
- [ ] Validate HMAC token expiration and refresh logic in automated tests.
## Phase 5: Code Quality & Refactoring
**Primary Agent:** `fullstack-developer`
- [x] Refactor dashboard monolith into modular sub-modules (`auth.rs`, `usage.rs`, etc.).
- [x] Standardize error handling and remove `unwrap()` in production paths.
- [x] Implement system health metrics and backup functionality.
---
# Phase 6: Cache Cost & Provider Audit (ACTIVE)
**Primary Agents:** `frontend-developer`, `backend-developer`, `database-optimizer`, `lab-assistant`
## 6.1 Dashboard UI Updates (@frontend-developer)
- [ ] **Update Models Page Modal:** Add input fields for `Cache Read Cost` and `Cache Write Cost` in `static/js/pages/models.js`.
- [ ] **API Integration:** Ensure `window.api.put` includes these new cost fields in the request body.
- [ ] **Verify Costs Page:** Confirm `static/js/pages/costs.js` displays these rates correctly in the pricing table.
## 6.2 Provider Audit & Stream Fixes (@backend-developer)
- [ ] **Standard DeepSeek Fix:** Modify `src/providers/deepseek.rs` to stop stripping `stream_options` for `deepseek-chat`.
- [ ] **Grok Audit:** Verify if Grok correctly returns usage in streaming; it uses `build_openai_body` and doesn't seem to strip it.
- [ ] **Gemini Audit:** Confirm Gemini returns `usage_metadata` reliably in the final chunk.
- [ ] **Anthropic Audit:** Check if Anthropic streaming requires `include_usage` or similar flags.
## 6.3 Database & Migration Validation (@database-optimizer)
- [ ] **Test Migrations:** Run the server to ensure `ALTER TABLE` logic in `src/database/mod.rs` applies the new columns correctly.
- [ ] **Schema Verification:** Verify `model_configs` has `cache_read_cost_per_m` and `cache_write_cost_per_m` columns.
## 6.4 Token Estimation Refinement (@lab-assistant)
- [ ] **Analyze Heuristic:** Review `chars / 4` in `src/utils/tokens.rs`.
- [ ] **Background Precise Recount:** Propose a mechanism for a precise token count (using Tiktoken) after the response is finalized.
## Critical Path
Migration Validation → UI Fields → Provider Stream Usage Reporting.
```mermaid
gantt
title Phase 6 Timeline
dateFormat YYYY-MM-DD
section Frontend
Models Page UI :2026-03-06, 1d
Costs Table Update:after Models Page UI, 1d
section Backend
DeepSeek Fix :2026-03-06, 1d
Provider Audit (Grok/Gemini):after DeepSeek Fix, 2d
section Database
Migration Test :2026-03-06, 1d
section Optimization
Token Heuristic Review :2026-03-06, 1d
```

View File

@@ -1,119 +1,125 @@
# LLM Proxy Gateway # GopherGate
A unified, high-performance LLM proxy gateway built in Rust. It provides a single OpenAI-compatible API to access multiple providers (OpenAI, Gemini, DeepSeek, Grok, Ollama) with built-in token tracking, real-time cost calculation, multi-user authentication, and a management dashboard. A unified, high-performance LLM proxy gateway built in Go. It provides a single OpenAI-compatible API to access multiple providers (OpenAI, Gemini, DeepSeek, Moonshot, Grok, Ollama) with built-in token tracking, real-time cost calculation, multi-user authentication, and a management dashboard.
## Features ## Features
- **Unified API:** OpenAI-compatible `/v1/chat/completions` and `/v1/models` endpoints. - **Unified API:** OpenAI-compatible `/v1/chat/completions` and `/v1/models` endpoints.
- **Multi-Provider Support:** - **Multi-Provider Support:**
- **OpenAI:** GPT-4o, GPT-4o Mini, o1, o3 reasoning models. - **OpenAI:** GPT-4o, GPT-4o Mini, o1, o3 reasoning models.
- **Google Gemini:** Gemini 2.0 Flash, Pro, and vision models. - **Google Gemini:** Gemini 2.0 Flash, Pro, and vision models (with native CoT support).
- **DeepSeek:** DeepSeek Chat and Reasoner models. - **DeepSeek:** DeepSeek Chat and Reasoner (R1) models.
- **xAI Grok:** Grok-beta models. - **Moonshot:** Kimi K2.5 and other Kimi models.
- **xAI Grok:** Grok-4 models.
- **Ollama:** Local LLMs running on your network. - **Ollama:** Local LLMs running on your network.
- **Observability & Tracking:** - **Observability & Tracking:**
- **Real-time Costing:** Fetches live pricing and context specs from `models.dev` on startup. - **Asynchronous Logging:** Non-blocking request logging to SQLite using background workers.
- **Token Counting:** Precise estimation using `tiktoken-rs`. - **Token Counting:** Precise estimation and tracking of prompt, completion, and reasoning tokens.
- **Database Logging:** Every request logged to SQLite for historical analysis. - **Database Persistence:** Every request logged to SQLite for historical analysis and dashboard analytics.
- **Streaming Support:** Full SSE (Server-Sent Events) with `[DONE]` termination for client compatibility. - **Streaming Support:** Full SSE (Server-Sent Events) support for all providers.
- **Multimodal (Vision):** Image processing (Base64 and remote URLs) across compatible providers. - **Multimodal (Vision):** Image processing (Base64 and remote URLs) across compatible providers.
- **Multi-User Access Control:** - **Multi-User Access Control:**
- **Admin Role:** Full access to all dashboard features, user management, and system configuration. - **Admin Role:** Full access to all dashboard features, user management, and system configuration.
- **Viewer Role:** Read-only access to usage analytics, costs, and monitoring. - **Viewer Role:** Read-only access to usage analytics, costs, and monitoring.
- **Client API Keys:** Create and manage multiple client tokens for external integrations. - **Client API Keys:** Create and manage multiple client tokens for external integrations.
- **Reliability:** - **Reliability:**
- **Circuit Breaking:** Automatically protects when providers are down. - **Circuit Breaking:** Automatically protects when providers are down (coming soon).
- **Rate Limiting:** Per-client and global rate limits. - **Rate Limiting:** Per-client and global rate limits (coming soon).
- **Cache-Aware Costing:** Tracks cache hit/miss tokens for accurate billing.
## Security ## Security
LLM Proxy is designed with security in mind: GopherGate is designed with security in mind:
- **HMAC Session Tokens:** Management dashboard sessions are secured using HMAC-SHA256 signed tokens. - **Signed Session Tokens:** Management dashboard sessions are secured using HMAC-SHA256 signed tokens.
- **Encrypted Provider Keys:** Sensitive LLM provider API keys are stored encrypted (AES-256-GCM) in the database. - **Encrypted Storage:** Support for encrypted provider API keys in the database.
- **Session Refresh:** Activity-based session extension prevents session hijacking while maintaining user convenience. - **Auth Middleware:** Secure client authentication via database-backed API keys.
- **XSS Prevention:** Standardized frontend escaping using `window.api.escapeHtml`.
**Note:** You must define a `SESSION_SECRET` in your `.env` file for secure session signing. **Note:** You must define an `LLM_PROXY__ENCRYPTION_KEY` in your `.env` file for secure session signing and encryption.
## Tech Stack ## Tech Stack
- **Runtime:** Rust with Tokio. - **Runtime:** Go 1.22+
- **Web Framework:** Axum. - **Web Framework:** Gin Gonic
- **Database:** SQLx with SQLite. - **Database:** sqlx with SQLite (CGO-free via `modernc.org/sqlite`)
- **Frontend:** Vanilla JS/CSS with Chart.js for visualizations. - **Frontend:** Vanilla JS/CSS with Chart.js for visualizations
## Getting Started ## Getting Started
### Prerequisites ### Prerequisites
- Rust (1.80+) - Go (1.22+)
- SQLite3 - SQLite3 (optional, driver is built-in)
- Docker (optional, for containerized deployment) - Docker (optional, for containerized deployment)
### Quick Start ### Quick Start
1. Clone and build: 1. Clone and build:
```bash ```bash
git clone ssh://git.dustin.coffee:2222/hobokenchicken/llm-proxy.git git clone <repository-url>
cd llm-proxy cd gophergate
cargo build --release go build -o gophergate ./cmd/gophergate
``` ```
2. Configure environment: 2. Configure environment:
```bash ```bash
cp .env.example .env cp .env.example .env
# Edit .env and add your API keys: # Edit .env and add your configuration:
# SESSION_SECRET=... (Generate a strong random secret) # LLM_PROXY__ENCRYPTION_KEY=... (32-byte hex or base64 string)
# OPENAI_API_KEY=sk-... # OPENAI_API_KEY=sk-...
# GEMINI_API_KEY=AIza... # GEMINI_API_KEY=AIza...
# MOONSHOT_API_KEY=...
``` ```
3. Run the proxy: 3. Run the proxy:
```bash ```bash
cargo run --release ./gophergate
``` ```
The server starts on `http://localhost:8080` by default. The server starts on `http://0.0.0.0:8080` by default.
### Deployment (Docker) ### Deployment (Docker)
A multi-stage `Dockerfile` is provided for efficient deployment:
```bash ```bash
# Build the container # Build the container
docker build -t llm-proxy . docker build -t gophergate .
# Run the container # Run the container
docker run -p 8080:8080 \ docker run -p 8080:8080 \
-e SESSION_SECRET=your-secure-secret \ -e LLM_PROXY__ENCRYPTION_KEY=your-secure-key \
-v ./data:/app/data \ -v ./data:/app/data \
llm-proxy gophergate
``` ```
## Management Dashboard ## Management Dashboard
Access the dashboard at `http://localhost:8080`. The dashboard architecture has been refactored into modular sub-components for better maintainability: Access the dashboard at `http://localhost:8080`.
- **Auth (`/api/auth`):** Login, session management, and password changes. - **Auth:** Login, session management, and status tracking.
- **Usage (`/api/usage`):** Summary stats, time-series analytics, and provider breakdown. - **Usage:** Summary stats, time-series analytics, and provider breakdown.
- **Clients (`/api/clients`):** API key management and per-client usage tracking. - **Clients:** API key management and per-client usage tracking.
- **Providers (`/api/providers`):** Provider configuration, status monitoring, and connection testing. - **Providers:** Provider configuration and status monitoring.
- **System (`/api/system`):** Health metrics, live logs, database backups, and global settings. - **Users:** Admin-only user management for dashboard access.
- **Monitoring:** Live request stream via WebSocket. - **Monitoring:** Live request stream via WebSocket.
### Default Credentials ### Default Credentials
- **Username:** `admin` - **Username:** `admin`
- **Password:** `admin123` - **Password:** `admin123` (You will be prompted to change this on first login)
Change the admin password in the dashboard after first login! **Forgot Password?**
You can reset the admin password to default by running:
```bash
./gophergate -reset-admin
```
## API Usage ## API Usage
The proxy is a drop-in replacement for OpenAI. Configure your client: The proxy is a drop-in replacement for OpenAI. Configure your client:
Moonshot models are available through the same OpenAI-compatible endpoint. For
example, use `kimi-k2.5` as the model name after setting `MOONSHOT_API_KEY` in
your environment.
### Python ### Python
```python ```python
from openai import OpenAI from openai import OpenAI
@@ -131,4 +137,4 @@ response = client.chat.completions.create(
## License ## License
MIT OR Apache-2.0 MIT

View File

@@ -1,58 +0,0 @@
# LLM Proxy Rust Backend Code Review Report
## Executive Summary
This code review examines the `llm-proxy` Rust backend, focusing on architectural soundness, performance characteristics, and adherence to Rust idioms. The codebase demonstrates solid engineering with well-structured modular design, comprehensive error handling, and thoughtful async patterns. However, several areas require attention for production readiness, particularly around thread safety, memory efficiency, and error recovery.
## 1. Core Proxy Logic Review
### Strengths
- Clean provider abstraction (`Provider` trait).
- Streaming support with `AggregatingStream` for token counting.
- Model mapping and caching system.
### Issues Found
- **Provider Manager Thread Safety Risk:** O(n) lookups using `Vec` with `RwLock`. Use `DashMap` instead.
- **Streaming Memory Inefficiency:** Accumulates complete response content in memory.
- **Model Registry Cache Invalidation:** No strategy when config changes via dashboard.
## 2. State Management Review
- **Token Bucket Algorithm Flaw:** Custom implementation lacks thread-safe refill sync. Use `governor` crate.
- **Broadcast Channel Unbounded Growth Risk:** Fixed-size (100) may drop messages.
- **Database Connection Pool Contention:** SQLite connections shared without differentiation.
## 3. Error Handling Review
- **Error Recovery Missing:** No circuit breaker or retry logic for provider calls.
- **Stream Error Logging Gap:** Stream errors swallowed without logging partial usage.
## 4. Rust Idioms and Performance
- **Unnecessary String Cloning:** Frequent cloning in authentication hot paths.
- **JSON Parsing Inefficiency:** Multiple passes with `serde_json::Value`. Use typed structs.
- **Missing `#[derive(Copy)]`:** For small enums like `CircuitState`.
## 5. Async Performance
- **Blocking Calls:** Token estimation may block async runtime.
- **Missing Connection Timeouts:** Only overall timeout, no separate read/write timeouts.
- **Unbounded Task Spawn:** For client usage updates under load.
## 6. Security Considerations
- **Token Leakage:** Redact tokens in `Debug` and `Display` impls.
- **No Request Size Limits:** Vulnerable to memory exhaustion.
## 7. Testability
- **Mocking Difficulty:** Tight coupling to concrete provider implementations.
- **Missing Integration Tests:** No E2E tests for streaming.
## 8. Summary of Critical Actions
### High Priority
1. Replace custom token bucket with `governor` crate.
2. Fix provider lookup O(n) scaling issue.
3. Implement proper error recovery with retries.
4. Add request size limits and timeout configurations.
### Medium Priority
1. Reduce string cloning in hot paths.
2. Implement cache invalidation for model configs.
3. Add connection pooling separation.
4. Improve streaming memory efficiency.
---
*Review conducted by: Senior Principal Engineer (Code Reviewer)*

View File

@@ -1,58 +0,0 @@
# LLM Proxy Security Audit Report
## Executive Summary
A comprehensive security audit of the `llm-proxy` repository was conducted. The audit identified **1 critical vulnerability**, **3 high-risk issues**, **4 medium-risk issues**, and **3 low-risk issues**. The most severe findings include Cross-Site Scripting (XSS) in the dashboard interface and insecure storage of provider API keys in the database.
## Detailed Findings
### Critical Risk Vulnerabilities
#### **CRITICAL-01: Cross-Site Scripting (XSS) in Dashboard Interface**
- **Location**: `static/js/pages/clients.js` (multiple locations).
- **Description**: User-controlled data (e.g., `client.id`) inserted directly into HTML or `onclick` handlers without escaping.
- **Impact**: Arbitrary JavaScript execution in admin context, potentially stealing session tokens.
#### **CRITICAL-02: Insecure API Key Storage in Database**
- **Location**: `src/database/mod.rs`, `src/providers/mod.rs`, `src/dashboard/providers.rs`.
- **Description**: Provider API keys are stored in **plaintext** in the SQLite database.
- **Impact**: Compromised database file exposes all provider API keys.
### High Risk Vulnerabilities
#### **HIGH-01: Missing Input Validation and Size Limits**
- **Location**: `src/server/mod.rs`, `src/models/mod.rs`.
- **Impact**: Denial of Service via large payloads.
#### **HIGH-02: Sensitive Data Logging Without Encryption**
- **Location**: `src/database/mod.rs`, `src/logging/mod.rs`.
- **Description**: Full request and response bodies stored in `llm_requests` table without encryption or redaction.
#### **HIGH-03: Weak Default Credentials and Password Policy**
- **Description**: Default admin password is 'admin' with only 4-character minimum password length.
### Medium Risk Vulnerabilities
#### **MEDIUM-01: Missing CSRF Protection**
- No CSRF tokens or SameSite cookie attributes for state-changing dashboard endpoints.
#### **MEDIUM-02: Insecure Session Management**
- Session tokens stored in localStorage without HttpOnly flag.
- Tokens use simple `session-{uuid}` format.
#### **MEDIUM-03: Error Information Leakage**
- Internal error details exposed to clients in some cases.
#### **MEDIUM-04: Outdated Dependencies**
- Outdated versions of `chrono`, `tokio`, and `reqwest`.
### Low Risk Vulnerabilities
- Missing security headers (CSP, HSTS, X-Frame-Options).
- Insufficient rate limiting on dashboard authentication.
- No database encryption at rest.
## Recommendations
### Immediate Actions
1. **Fix XSS Vulnerabilities:** Implement proper HTML escaping for all user-controlled data.
2. **Secure API Key Storage:** Encrypt API keys in database using a library like `ring`.
3. **Implement Input Validation:** Add maximum payload size limits (e.g., 10MB).
4. **Improve Data Protection:** Add option to disable request/response body logging.
---
*Report generated by Security Auditor Agent on March 6, 2026*

56
TODO.md Normal file
View File

@@ -0,0 +1,56 @@
# Migration TODO List
## Completed Tasks
- [x] Initial Go project setup
- [x] Database schema & migrations (hardcoded in `db.go`)
- [x] Configuration loader (Viper)
- [x] Auth Middleware (scoped to `/v1`)
- [x] Basic Provider implementations (OpenAI, Gemini, DeepSeek, Grok)
- [x] Streaming Support (SSE & Gemini custom streaming)
- [x] Archive Rust files to `rust` branch
- [x] Clean root and set Go version as `main`
- [x] Enhanced `helpers.go` for Multimodal & Tool Calling (OpenAI compatible)
- [x] Enhanced `server.go` for robust request conversion
- [x] Dashboard Management APIs (Clients, Tokens, Users, Providers)
- [x] Dashboard Analytics & Usage Summary (Fixed SQL robustness)
- [x] WebSocket for real-time dashboard updates (Hub with client counting)
- [x] Asynchronous Request Logging to SQLite
- [x] Update documentation (README, deployment, architecture)
- [x] Cost Tracking accuracy (Registry integration with `models.dev`)
- [x] Model Listing endpoint (`/v1/models`) with provider filtering
- [x] System Metrics endpoint (`/api/system/metrics` using `gopsutil`)
- [x] Fixed dashboard 404s and 500s
## Feature Parity Checklist (High Priority)
### OpenAI Provider
- [x] Tool Calling
- [x] Multimodal (Images) support
- [x] Accurate usage parsing (cached & reasoning tokens)
- [ ] Reasoning Content (CoT) support for `o1`, `o3` (need to ensure it's parsed in responses)
- [ ] Support for `/v1/responses` API (required for some gpt-5/o1 models)
### Gemini Provider
- [x] Tool Calling (mapping to Gemini format)
- [x] Multimodal (Images) support
- [x] Reasoning/Thought support
- [x] Handle Tool Response role in unified format
### DeepSeek Provider
- [x] Reasoning Content (CoT) support
- [x] Parameter sanitization for `deepseek-reasoner`
- [x] Tool Calling support
- [x] Accurate usage parsing (cache hits & reasoning)
### Grok Provider
- [x] Tool Calling support
- [x] Multimodal support
- [x] Accurate usage parsing (via OpenAI helper)
## Infrastructure & Middleware
- [ ] Implement Rate Limiting (`golang.org/x/time/rate`)
- [ ] Implement Circuit Breaker (`github.com/sony/gobreaker`)
## Verification
- [ ] Unit tests for feature-specific mapping (CoT, Tools, Images)
- [ ] Integration tests with live LLM APIs

View File

@@ -1 +0,0 @@
too-many-arguments-threshold = 8

55
cmd/gophergate/main.go Normal file
View File

@@ -0,0 +1,55 @@
package main
import (
"flag"
"log"
"os"
"gophergate/internal/config"
"gophergate/internal/db"
"gophergate/internal/server"
"github.com/joho/godotenv"
"golang.org/x/crypto/bcrypt"
)
func main() {
resetAdmin := flag.Bool("reset-admin", false, "Reset admin password to admin123")
flag.Parse()
// Load environment variables
if err := godotenv.Load(); err != nil {
log.Println("No .env file found")
}
// Load configuration
cfg, err := config.Load()
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
// Initialize database
database, err := db.Init(cfg.Database.Path)
if err != nil {
log.Fatalf("Failed to initialize database: %v", err)
}
if *resetAdmin {
hash, _ := bcrypt.GenerateFromPassword([]byte("admin123"), 12)
_, err = database.Exec("UPDATE users SET password_hash = ?, must_change_password = 1 WHERE username = 'admin'", string(hash))
if err != nil {
log.Fatalf("Failed to reset admin password: %v", err)
}
log.Println("Admin password has been reset to 'admin123'")
os.Exit(0)
}
// Initialize server
s := server.NewServer(cfg, database)
// Run server
log.Printf("Starting GopherGate on %s:%d", cfg.Server.Host, cfg.Server.Port)
if err := s.Run(); err != nil {
log.Fatalf("Server failed: %v", err)
}
}

Binary file not shown.

667
deploy.sh
View File

@@ -1,667 +0,0 @@
#!/bin/bash
# LLM Proxy Gateway Deployment Script
# This script automates the deployment of the LLM Proxy Gateway on a Linux server
set -e # Exit on error
set -u # Exit on undefined variable
# Configuration
APP_NAME="llm-proxy"
APP_USER="llmproxy"
APP_GROUP="llmproxy"
GIT_REPO="ssh://git.dustin.coffee:2222/hobokenchicken/llm-proxy.git"
INSTALL_DIR="/opt/$APP_NAME"
CONFIG_DIR="/etc/$APP_NAME"
DATA_DIR="/var/lib/$APP_NAME"
LOG_DIR="/var/log/$APP_NAME"
SERVICE_FILE="/etc/systemd/system/$APP_NAME.service"
ENV_FILE="$CONFIG_DIR/.env"
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color
# Logging functions
log_info() {
echo -e "${GREEN}[INFO]${NC} $1"
}
log_warn() {
echo -e "${YELLOW}[WARN]${NC} $1"
}
log_error() {
echo -e "${RED}[ERROR]${NC} $1"
}
# Check if running as root
check_root() {
if [[ $EUID -ne 0 ]]; then
log_error "This script must be run as root"
exit 1
fi
}
# Install system dependencies
install_dependencies() {
log_info "Installing system dependencies..."
# Detect package manager
if command -v apt-get &> /dev/null; then
# Debian/Ubuntu
apt-get update
apt-get install -y \
build-essential \
pkg-config \
libssl-dev \
sqlite3 \
curl \
git
elif command -v yum &> /dev/null; then
# RHEL/CentOS
yum groupinstall -y "Development Tools"
yum install -y \
openssl-devel \
sqlite \
curl \
git
elif command -v dnf &> /dev/null; then
# Fedora
dnf groupinstall -y "Development Tools"
dnf install -y \
openssl-devel \
sqlite \
curl \
git
elif command -v pacman &> /dev/null; then
# Arch Linux
pacman -Syu --noconfirm \
base-devel \
openssl \
sqlite \
curl \
git
else
log_warn "Could not detect package manager. Please install dependencies manually."
fi
}
# Install Rust if not present
install_rust() {
log_info "Checking for Rust installation..."
if ! command -v rustc &> /dev/null; then
log_info "Installing Rust..."
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"
else
log_info "Rust is already installed"
fi
# Verify installation
rustc --version
cargo --version
}
# Create system user and directories
setup_directories() {
log_info "Creating system user and directories..."
# Create user and group if they don't exist
if ! id "$APP_USER" &>/dev/null; then
# Arch uses /usr/bin/nologin, Debian/Ubuntu use /usr/sbin/nologin
NOLOGIN=$(command -v nologin 2>/dev/null || echo "/usr/bin/nologin")
useradd -r -s "$NOLOGIN" -M "$APP_USER"
fi
# Create directories
mkdir -p "$INSTALL_DIR"
mkdir -p "$CONFIG_DIR"
mkdir -p "$DATA_DIR"
mkdir -p "$LOG_DIR"
# Set permissions
chown -R "$APP_USER:$APP_GROUP" "$INSTALL_DIR"
chown -R "$APP_USER:$APP_GROUP" "$CONFIG_DIR"
chown -R "$APP_USER:$APP_GROUP" "$DATA_DIR"
chown -R "$APP_USER:$APP_GROUP" "$LOG_DIR"
chmod 750 "$INSTALL_DIR"
chmod 750 "$CONFIG_DIR"
chmod 750 "$DATA_DIR"
chmod 750 "$LOG_DIR"
}
# Build the application
build_application() {
log_info "Building the application..."
# Clone or update repository
if [[ ! -d "$INSTALL_DIR/.git" ]]; then
log_info "Cloning repository..."
git clone "$GIT_REPO" "$INSTALL_DIR"
else
log_info "Updating repository..."
cd "$INSTALL_DIR"
git pull
fi
# Build in release mode
cd "$INSTALL_DIR"
log_info "Building release binary..."
cargo build --release
# Verify build
if [[ -f "target/release/$APP_NAME" ]]; then
log_info "Build successful"
else
log_error "Build failed"
exit 1
fi
}
# Create configuration files
create_configuration() {
log_info "Creating configuration files..."
# Create .env file with API keys
cat > "$ENV_FILE" << EOF
# LLM Proxy Gateway Environment Variables
# Add your API keys here
# OpenAI API Key
# OPENAI_API_KEY=sk-your-key-here
# Google Gemini API Key
# GEMINI_API_KEY=AIza-your-key-here
# DeepSeek API Key
# DEEPSEEK_API_KEY=sk-your-key-here
# xAI Grok API Key
# GROK_API_KEY=gk-your-key-here
# Authentication tokens (comma-separated)
# LLM_PROXY__SERVER__AUTH_TOKENS=token1,token2,token3
EOF
# Create config.toml
cat > "$CONFIG_DIR/config.toml" << EOF
# LLM Proxy Gateway Configuration
[server]
port = 8080
host = "0.0.0.0"
# auth_tokens = ["token1", "token2", "token3"] # Uncomment to enable authentication
[database]
path = "$DATA_DIR/llm_proxy.db"
max_connections = 5
[providers.openai]
enabled = true
api_key_env = "OPENAI_API_KEY"
base_url = "https://api.openai.com/v1"
default_model = "gpt-4o"
[providers.gemini]
enabled = true
api_key_env = "GEMINI_API_KEY"
base_url = "https://generativelanguage.googleapis.com/v1"
default_model = "gemini-2.0-flash"
[providers.deepseek]
enabled = true
api_key_env = "DEEPSEEK_API_KEY"
base_url = "https://api.deepseek.com"
default_model = "deepseek-reasoner"
[providers.grok]
enabled = false # Disabled by default until API is researched
api_key_env = "GROK_API_KEY"
base_url = "https://api.x.ai/v1"
default_model = "grok-beta"
[model_mapping]
"gpt-*" = "openai"
"gemini-*" = "gemini"
"deepseek-*" = "deepseek"
"grok-*" = "grok"
[pricing]
openai = { input = 0.01, output = 0.03 }
gemini = { input = 0.0005, output = 0.0015 }
deepseek = { input = 0.00014, output = 0.00028 }
grok = { input = 0.001, output = 0.003 }
EOF
# Set permissions
chown "$APP_USER:$APP_GROUP" "$ENV_FILE"
chown "$APP_USER:$APP_GROUP" "$CONFIG_DIR/config.toml"
chmod 640 "$ENV_FILE"
chmod 640 "$CONFIG_DIR/config.toml"
}
# Create systemd service
create_systemd_service() {
log_info "Creating systemd service..."
cat > "$SERVICE_FILE" << EOF
[Unit]
Description=LLM Proxy Gateway
Documentation=https://git.dustin.coffee/hobokenchicken/llm-proxy
After=network.target
Wants=network.target
[Service]
Type=simple
User=$APP_USER
Group=$APP_GROUP
WorkingDirectory=$INSTALL_DIR
EnvironmentFile=$ENV_FILE
Environment="RUST_LOG=info"
Environment="LLM_PROXY__CONFIG_PATH=$CONFIG_DIR/config.toml"
Environment="LLM_PROXY__DATABASE__PATH=$DATA_DIR/llm_proxy.db"
ExecStart=$INSTALL_DIR/target/release/$APP_NAME
Restart=on-failure
RestartSec=5
# Security hardening
NoNewPrivileges=true
PrivateTmp=true
ProtectSystem=strict
ProtectHome=true
ReadWritePaths=$DATA_DIR $LOG_DIR
# Resource limits (adjust based on your server)
MemoryMax=400M
MemorySwapMax=100M
CPUQuota=50%
LimitNOFILE=65536
# Logging
StandardOutput=journal
StandardError=journal
SyslogIdentifier=$APP_NAME
[Install]
WantedBy=multi-user.target
EOF
# Reload systemd
systemctl daemon-reload
}
# Setup nginx reverse proxy (optional)
setup_nginx_proxy() {
if ! command -v nginx &> /dev/null; then
log_warn "nginx not installed. Skipping reverse proxy setup."
return
fi
log_info "Setting up nginx reverse proxy..."
cat > "/etc/nginx/sites-available/$APP_NAME" << EOF
server {
listen 80;
server_name your-domain.com; # Change to your domain
# Redirect to HTTPS (recommended)
return 301 https://\$server_name\$request_uri;
}
server {
listen 443 ssl http2;
server_name your-domain.com; # Change to your domain
# SSL certificates (adjust paths)
ssl_certificate /etc/letsencrypt/live/your-domain.com/fullchain.pem;
ssl_certificate_key /etc/letsencrypt/live/your-domain.com/privkey.pem;
# SSL configuration
ssl_protocols TLSv1.2 TLSv1.3;
ssl_ciphers ECDHE-RSA-AES256-GCM-SHA512:DHE-RSA-AES256-GCM-SHA512:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384;
ssl_prefer_server_ciphers off;
# Proxy to LLM Proxy Gateway
location / {
proxy_pass http://127.0.0.1:8080;
proxy_http_version 1.1;
proxy_set_header Upgrade \$http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host \$host;
proxy_set_header X-Real-IP \$remote_addr;
proxy_set_header X-Forwarded-For \$proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto \$scheme;
# Timeouts
proxy_connect_timeout 60s;
proxy_send_timeout 60s;
proxy_read_timeout 60s;
}
# Health check endpoint
location /health {
proxy_pass http://127.0.0.1:8080/health;
access_log off;
}
# Dashboard
location /dashboard {
proxy_pass http://127.0.0.1:8080/dashboard;
}
}
EOF
# Enable site
ln -sf "/etc/nginx/sites-available/$APP_NAME" "/etc/nginx/sites-enabled/"
# Test nginx configuration
nginx -t
log_info "nginx configuration created. Please update the domain and SSL certificate paths."
}
# Setup firewall
setup_firewall() {
log_info "Configuring firewall..."
# Check for ufw (Ubuntu)
if command -v ufw &> /dev/null; then
ufw allow 22/tcp # SSH
ufw allow 80/tcp # HTTP
ufw allow 443/tcp # HTTPS
ufw --force enable
log_info "UFW firewall configured"
fi
# Check for firewalld (RHEL/CentOS)
if command -v firewall-cmd &> /dev/null; then
firewall-cmd --permanent --add-service=ssh
firewall-cmd --permanent --add-service=http
firewall-cmd --permanent --add-service=https
firewall-cmd --reload
log_info "Firewalld configured"
fi
}
# Initialize database
initialize_database() {
log_info "Initializing database..."
# Run the application once to create database
sudo -u "$APP_USER" "$INSTALL_DIR/target/release/$APP_NAME" --help &> /dev/null || true
log_info "Database initialized at $DATA_DIR/llm_proxy.db"
}
# Start and enable service
start_service() {
log_info "Starting $APP_NAME service..."
systemctl enable "$APP_NAME"
systemctl start "$APP_NAME"
# Check status
sleep 2
systemctl status "$APP_NAME" --no-pager
}
# Verify installation
verify_installation() {
log_info "Verifying installation..."
# Check if service is running
if systemctl is-active --quiet "$APP_NAME"; then
log_info "Service is running"
else
log_error "Service is not running"
journalctl -u "$APP_NAME" -n 20 --no-pager
exit 1
fi
# Test health endpoint
if curl -s http://localhost:8080/health | grep -q "OK"; then
log_info "Health check passed"
else
log_error "Health check failed"
exit 1
fi
# Test dashboard
if curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/dashboard | grep -q "200"; then
log_info "Dashboard is accessible"
else
log_warn "Dashboard may not be accessible (this is normal if not configured)"
fi
log_info "Installation verified successfully!"
}
# Print next steps
print_next_steps() {
cat << EOF
${GREEN}=== LLM Proxy Gateway Installation Complete ===${NC}
${YELLOW}Next steps:${NC}
1. ${GREEN}Configure API keys${NC}
Edit: $ENV_FILE
Add your API keys for the providers you want to use
2. ${GREEN}Configure authentication${NC}
Edit: $CONFIG_DIR/config.toml
Uncomment and set auth_tokens for client authentication
3. ${GREEN}Configure nginx${NC}
Edit: /etc/nginx/sites-available/$APP_NAME
Update domain name and SSL certificate paths
4. ${GREEN}Test the API${NC}
curl -X POST http://localhost:8080/v1/chat/completions \\
-H "Content-Type: application/json" \\
-H "Authorization: Bearer your-token" \\
-d '{
"model": "gpt-4o",
"messages": [{"role": "user", "content": "Hello!"}]
}'
5. ${GREEN}Access the dashboard${NC}
Open: http://your-server-ip:8080/dashboard
Or: https://your-domain.com/dashboard (if nginx configured)
${YELLOW}Useful commands:${NC}
systemctl status $APP_NAME # Check service status
journalctl -u $APP_NAME -f # View logs
systemctl restart $APP_NAME # Restart service
${YELLOW}Configuration files:${NC}
Service: $SERVICE_FILE
Config: $CONFIG_DIR/config.toml
Environment: $ENV_FILE
Database: $DATA_DIR/llm_proxy.db
Logs: $LOG_DIR/
${GREEN}For more information, see:${NC}
https://git.dustin.coffee/hobokenchicken/llm-proxy
$INSTALL_DIR/README.md
$INSTALL_DIR/deployment.md
EOF
}
# Main deployment function
deploy() {
log_info "Starting LLM Proxy Gateway deployment..."
check_root
install_dependencies
install_rust
setup_directories
build_application
create_configuration
create_systemd_service
initialize_database
start_service
verify_installation
print_next_steps
# Optional steps (uncomment if needed)
# setup_nginx_proxy
# setup_firewall
log_info "Deployment completed successfully!"
}
# Update function
update() {
log_info "Updating LLM Proxy Gateway..."
check_root
# Pull latest changes (while service keeps running)
cd "$INSTALL_DIR"
log_info "Pulling latest changes..."
git pull
# Build new binary (service stays up on the old binary)
log_info "Building release binary (service still running)..."
if ! cargo build --release; then
log_error "Build failed — service was NOT interrupted. Fix the error and try again."
exit 1
fi
# Verify binary exists
if [[ ! -f "target/release/$APP_NAME" ]]; then
log_error "Binary not found after build — aborting."
exit 1
fi
# Restart service to pick up new binary
log_info "Build succeeded. Restarting service..."
systemctl restart "$APP_NAME"
sleep 2
if systemctl is-active --quiet "$APP_NAME"; then
log_info "Update completed successfully!"
systemctl status "$APP_NAME" --no-pager
else
log_error "Service failed to start after update. Check logs:"
journalctl -u "$APP_NAME" -n 20 --no-pager
exit 1
fi
}
# Uninstall function
uninstall() {
log_info "Uninstalling LLM Proxy Gateway..."
check_root
# Stop and disable service
systemctl stop "$APP_NAME" 2>/dev/null || true
systemctl disable "$APP_NAME" 2>/dev/null || true
rm -f "$SERVICE_FILE"
systemctl daemon-reload
# Remove application files
rm -rf "$INSTALL_DIR"
rm -rf "$CONFIG_DIR"
# Keep data and logs (comment out to remove)
log_warn "Data directory $DATA_DIR and logs $LOG_DIR have been preserved"
log_warn "Remove manually if desired:"
log_warn " rm -rf $DATA_DIR $LOG_DIR"
# Remove user (optional)
read -p "Remove user $APP_USER? [y/N]: " -n 1 -r
echo
if [[ $REPLY =~ ^[Yy]$ ]]; then
userdel "$APP_USER" 2>/dev/null || true
groupdel "$APP_GROUP" 2>/dev/null || true
fi
log_info "Uninstallation completed!"
}
# Show usage
usage() {
cat << EOF
LLM Proxy Gateway Deployment Script
Usage: $0 [command]
Commands:
deploy - Install and configure LLM Proxy Gateway
update - Pull latest changes, rebuild, and restart
status - Show service status and health check
logs - Tail the service logs (Ctrl+C to stop)
uninstall - Remove LLM Proxy Gateway
help - Show this help message
Examples:
$0 deploy # Full installation
$0 update # Update to latest version
$0 status # Check if service is healthy
$0 logs # Follow live logs
EOF
}
# Status function
status() {
echo ""
log_info "Service status:"
systemctl status "$APP_NAME" --no-pager 2>/dev/null || log_warn "Service not found"
echo ""
# Health check
if curl -sf http://localhost:8080/health &>/dev/null; then
log_info "Health check: OK"
else
log_warn "Health check: FAILED (service may not be running or port 8080 not responding)"
fi
# Show current git commit
if [[ -d "$INSTALL_DIR/.git" ]]; then
echo ""
log_info "Installed version:"
git -C "$INSTALL_DIR" log -1 --format=" %h %s (%cr)" 2>/dev/null
fi
}
# Logs function
logs() {
log_info "Tailing $APP_NAME logs (Ctrl+C to stop)..."
journalctl -u "$APP_NAME" -f
}
# Parse command line arguments
case "${1:-}" in
deploy)
deploy
;;
update)
update
;;
status)
status
;;
logs)
logs
;;
uninstall)
uninstall
;;
help|--help|-h)
usage
;;
*)
usage
exit 1
;;
esac

View File

@@ -1,322 +1,52 @@
# LLM Proxy Gateway - Deployment Guide # Deployment Guide (Go)
## Overview This guide covers deploying the Go-based GopherGate.
A unified LLM proxy gateway supporting OpenAI, Google Gemini, DeepSeek, and xAI Grok with token tracking, cost calculation, and admin dashboard.
## System Requirements ## Environment Setup
- **CPU**: 2 cores minimum
- **RAM**: 512MB minimum (1GB recommended)
- **Storage**: 10GB minimum
- **OS**: Linux (tested on Arch Linux, Ubuntu, Debian)
- **Runtime**: Rust 1.70+ with Cargo
## Deployment Options 1. **Mandatory Configuration:**
Create a `.env` file from the example:
### Option 1: Docker (Recommended)
```dockerfile
FROM rust:1.70-alpine as builder
WORKDIR /app
COPY . .
RUN cargo build --release
FROM alpine:latest
RUN apk add --no-cache libgcc
COPY --from=builder /app/target/release/llm-proxy /usr/local/bin/
COPY --from=builder /app/static /app/static
WORKDIR /app
EXPOSE 8080
CMD ["llm-proxy"]
```
### Option 2: Systemd Service (Bare Metal/LXC)
```ini
# /etc/systemd/system/llm-proxy.service
[Unit]
Description=LLM Proxy Gateway
After=network.target
[Service]
Type=simple
User=llmproxy
Group=llmproxy
WorkingDirectory=/opt/llm-proxy
ExecStart=/opt/llm-proxy/llm-proxy
Restart=always
RestartSec=10
Environment="RUST_LOG=info"
Environment="LLM_PROXY__SERVER__PORT=8080"
Environment="LLM_PROXY__SERVER__AUTH_TOKENS=sk-test-123,sk-test-456"
[Install]
WantedBy=multi-user.target
```
### Option 3: LXC Container (Proxmox)
1. Create Alpine Linux LXC container
2. Install Rust: `apk add rust cargo`
3. Copy application files
4. Build: `cargo build --release`
5. Run: `./target/release/llm-proxy`
## Configuration
### Environment Variables
```bash
# Required API Keys
OPENAI_API_KEY=sk-...
GEMINI_API_KEY=AIza...
DEEPSEEK_API_KEY=sk-...
GROK_API_KEY=gk-... # Optional
# Server Configuration (with LLM_PROXY__ prefix)
LLM_PROXY__SERVER__PORT=8080
LLM_PROXY__SERVER__HOST=0.0.0.0
LLM_PROXY__SERVER__AUTH_TOKENS=sk-test-123,sk-test-456
# Database Configuration
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
LLM_PROXY__DATABASE__MAX_CONNECTIONS=10
# Provider Configuration
LLM_PROXY__PROVIDERS__OPENAI__ENABLED=true
LLM_PROXY__PROVIDERS__GEMINI__ENABLED=true
LLM_PROXY__PROVIDERS__DEEPSEEK__ENABLED=true
LLM_PROXY__PROVIDERS__GROK__ENABLED=false
```
### Configuration File (config.toml)
Create `config.toml` in the application directory:
```toml
[server]
port = 8080
host = "0.0.0.0"
auth_tokens = ["sk-test-123", "sk-test-456"]
[database]
path = "./data/llm_proxy.db"
max_connections = 10
[providers.openai]
enabled = true
base_url = "https://api.openai.com/v1"
default_model = "gpt-4o"
[providers.gemini]
enabled = true
base_url = "https://generativelanguage.googleapis.com/v1"
default_model = "gemini-2.0-flash"
[providers.deepseek]
enabled = true
base_url = "https://api.deepseek.com"
default_model = "deepseek-reasoner"
[providers.grok]
enabled = false
base_url = "https://api.x.ai/v1"
default_model = "grok-beta"
```
## Nginx Reverse Proxy Configuration
**Important for SSE/Streaming:** Disable buffering and configure timeouts for proper SSE support.
```nginx
server {
listen 80;
server_name llm-proxy.yourdomain.com;
location / {
proxy_pass http://localhost:8080;
proxy_http_version 1.1;
# SSE/Streaming support
proxy_buffering off;
chunked_transfer_encoding on;
proxy_set_header Connection '';
# Timeouts for long-running streams
proxy_connect_timeout 7200s;
proxy_read_timeout 7200s;
proxy_send_timeout 7200s;
# Disable gzip for streaming
gzip off;
# Headers
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# SSL configuration (recommended)
listen 443 ssl http2;
ssl_certificate /etc/letsencrypt/live/llm-proxy.yourdomain.com/fullchain.pem;
ssl_certificate_key /etc/letsencrypt/live/llm-proxy.yourdomain.com/privkey.pem;
}
```
### NGINX Proxy Manager
If using NGINX Proxy Manager, add this to **Advanced Settings**:
```nginx
proxy_buffering off;
proxy_http_version 1.1;
proxy_set_header Connection '';
chunked_transfer_encoding on;
proxy_connect_timeout 7200s;
proxy_read_timeout 7200s;
proxy_send_timeout 7200s;
gzip off;
```
## Security Considerations
### 1. Authentication
- Use strong Bearer tokens
- Rotate tokens regularly
- Consider implementing JWT for production
### 2. Rate Limiting
- Implement per-client rate limiting
- Consider using `governor` crate for advanced rate limiting
### 3. Network Security
- Run behind reverse proxy (nginx)
- Enable HTTPS
- Restrict access by IP if needed
- Use firewall rules
### 4. Data Security
- Database encryption (SQLCipher for SQLite)
- Secure API key storage
- Regular backups
## Monitoring & Maintenance
### Logging
- Application logs: `RUST_LOG=info` (or `debug` for troubleshooting)
- Access logs via nginx
- Database logs for audit trail
### Health Checks
```bash
# Health endpoint
curl http://localhost:8080/health
# Database check
sqlite3 ./data/llm_proxy.db "SELECT COUNT(*) FROM llm_requests;"
```
### Backup Strategy
```bash
#!/bin/bash
# backup.sh
BACKUP_DIR="/backups/llm-proxy"
DATE=$(date +%Y%m%d_%H%M%S)
# Backup database
sqlite3 ./data/llm_proxy.db ".backup $BACKUP_DIR/llm_proxy_$DATE.db"
# Backup configuration
cp config.toml $BACKUP_DIR/config_$DATE.toml
# Rotate old backups (keep 30 days)
find $BACKUP_DIR -name "*.db" -mtime +30 -delete
find $BACKUP_DIR -name "*.toml" -mtime +30 -delete
```
## Performance Tuning
### Database Optimization
```sql
-- Run these SQL commands periodically
VACUUM;
ANALYZE;
```
### Memory Management
- Monitor memory usage with `htop` or `ps aux`
- Adjust `max_connections` based on load
- Consider connection pooling for high traffic
### Scaling
1. **Vertical Scaling**: Increase container resources
2. **Horizontal Scaling**: Deploy multiple instances behind load balancer
3. **Database**: Migrate to PostgreSQL for high-volume usage
## Troubleshooting
### Common Issues
1. **Port already in use**
```bash ```bash
netstat -tulpn | grep :8080 cp .env.example .env
kill <PID> # or change port in config
``` ```
Ensure `LLM_PROXY__ENCRYPTION_KEY` is set to a secure 32-byte string.
2. **Database permissions** 2. **Data Directory:**
```bash The proxy stores its database in `./data/llm_proxy.db` by default. Ensure this directory exists and is writable.
chown -R llmproxy:llmproxy /opt/llm-proxy/data
chmod 600 /opt/llm-proxy/data/llm_proxy.db
```
3. **API key errors** ## Binary Deployment
- Verify environment variables are set
- Check provider status (dashboard)
- Test connectivity: `curl https://api.openai.com/v1/models`
4. **High memory usage** ### 1. Build
- Check for memory leaks
- Reduce `max_connections`
- Implement connection timeouts
### Debug Mode
```bash ```bash
# Run with debug logging go build -o gophergate ./cmd/gophergate
RUST_LOG=debug ./llm-proxy
# Check system logs
journalctl -u llm-proxy -f
``` ```
## Integration ### 2. Run
```bash
### Open-WebUI Compatibility ./gophergate
The proxy provides OpenAI-compatible API, so configure Open-WebUI:
```
API Base URL: http://your-proxy-address:8080
API Key: sk-test-123 (or your configured token)
``` ```
### Custom Clients ## Docker Deployment
```python
import openai
client = openai.OpenAI( The project includes a multi-stage `Dockerfile` for minimal image size.
base_url="http://localhost:8080/v1",
api_key="sk-test-123"
)
response = client.chat.completions.create( ### 1. Build Image
model="gpt-4", ```bash
messages=[{"role": "user", "content": "Hello"}] docker build -t gophergate .
)
``` ```
## Updates & Upgrades ### 2. Run Container
```bash
docker run -d \
--name gophergate \
-p 8080:8080 \
-v $(pwd)/data:/app/data \
--env-file .env \
gophergate
```
1. **Backup** current configuration and database ## Production Considerations
2. **Stop** the service: `systemctl stop llm-proxy`
3. **Update** code: `git pull` or copy new binaries
4. **Migrate** database if needed (check migrations/)
5. **Restart**: `systemctl start llm-proxy`
6. **Verify**: Check logs and test endpoints
## Support - **SSL/TLS:** It is recommended to run the proxy behind a reverse proxy like Nginx or Caddy for SSL termination.
- Check logs in `/var/log/llm-proxy/` - **Backups:** Regularly backup the `data/llm_proxy.db` file.
- Monitor dashboard at `http://your-server:8080` - **Monitoring:** Monitor the `/health` endpoint for system status.
- Review database metrics in dashboard
- Enable debug logging for troubleshooting

70
go.mod Normal file
View File

@@ -0,0 +1,70 @@
module gophergate
go 1.26.1
require (
github.com/gin-gonic/gin v1.12.0
github.com/go-resty/resty/v2 v2.17.2
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/jmoiron/sqlx v1.4.0
github.com/joho/godotenv v1.5.1
github.com/shirou/gopsutil/v3 v3.24.5
github.com/spf13/viper v1.21.0
golang.org/x/crypto v0.48.0
modernc.org/sqlite v1.47.0
)
require (
github.com/bytedance/gopkg v0.1.3 // indirect
github.com/bytedance/sonic v1.15.0 // indirect
github.com/bytedance/sonic/loader v0.5.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.30.1 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/goccy/go-yaml v1.19.2 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.59.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/sagikazarmark/locafero v0.11.0 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.3.1 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/arch v0.22.0 // indirect
golang.org/x/net v0.51.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.34.0 // indirect
golang.org/x/time v0.15.0 // indirect
google.golang.org/protobuf v1.36.10 // indirect
modernc.org/libc v1.70.0 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
)

207
go.sum Normal file
View File

@@ -0,0 +1,207 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8=
github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w=
github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM=
github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk=
github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=
github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk=
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY=
github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.32.0 h1:hjG66bI/kqIPX1b2yT6fr/jt+QedtP2fqojG2VrFuVw=
modernc.org/ccgo/v4 v4.32.0/go.mod h1:6F08EBCx5uQc38kMGl+0Nm0oWczoo1c7cgpzEry7Uc0=
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw=
modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.47.0 h1:R1XyaNpoW4Et9yly+I2EeX7pBza/w+pmYee/0HJDyKk=
modernc.org/sqlite v1.47.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=

207
internal/config/config.go Normal file
View File

@@ -0,0 +1,207 @@
package config
import (
"encoding/base64"
"encoding/hex"
"fmt"
"os"
"strings"
"github.com/spf13/viper"
)
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Providers ProviderConfig `mapstructure:"providers"`
EncryptionKey string `mapstructure:"encryption_key"`
KeyBytes []byte
}
type ServerConfig struct {
Port int `mapstructure:"port"`
Host string `mapstructure:"host"`
AuthTokens []string `mapstructure:"auth_tokens"`
}
type DatabaseConfig struct {
Path string `mapstructure:"path"`
MaxConnections int `mapstructure:"max_connections"`
}
type ProviderConfig struct {
OpenAI OpenAIConfig `mapstructure:"openai"`
Gemini GeminiConfig `mapstructure:"gemini"`
DeepSeek DeepSeekConfig `mapstructure:"deepseek"`
Moonshot MoonshotConfig `mapstructure:"moonshot"`
Grok GrokConfig `mapstructure:"grok"`
Ollama OllamaConfig `mapstructure:"ollama"`
}
type OpenAIConfig struct {
APIKeyEnv string `mapstructure:"api_key_env"`
BaseURL string `mapstructure:"base_url"`
DefaultModel string `mapstructure:"default_model"`
Enabled bool `mapstructure:"enabled"`
}
type GeminiConfig struct {
APIKeyEnv string `mapstructure:"api_key_env"`
BaseURL string `mapstructure:"base_url"`
DefaultModel string `mapstructure:"default_model"`
Enabled bool `mapstructure:"enabled"`
}
type DeepSeekConfig struct {
APIKeyEnv string `mapstructure:"api_key_env"`
BaseURL string `mapstructure:"base_url"`
DefaultModel string `mapstructure:"default_model"`
Enabled bool `mapstructure:"enabled"`
}
type MoonshotConfig struct {
APIKeyEnv string `mapstructure:"api_key_env"`
BaseURL string `mapstructure:"base_url"`
DefaultModel string `mapstructure:"default_model"`
Enabled bool `mapstructure:"enabled"`
}
type GrokConfig struct {
APIKeyEnv string `mapstructure:"api_key_env"`
BaseURL string `mapstructure:"base_url"`
DefaultModel string `mapstructure:"default_model"`
Enabled bool `mapstructure:"enabled"`
}
type OllamaConfig struct {
BaseURL string `mapstructure:"base_url"`
Enabled bool `mapstructure:"enabled"`
DefaultModel string `mapstructure:"default_model"`
Models []string `mapstructure:"models"`
}
func Load() (*Config, error) {
v := viper.New()
// Defaults
v.SetDefault("server.port", 8080)
v.SetDefault("server.host", "0.0.0.0")
v.SetDefault("server.auth_tokens", []string{})
v.SetDefault("database.path", "./data/llm_proxy.db")
v.SetDefault("database.max_connections", 10)
v.SetDefault("providers.openai.api_key_env", "OPENAI_API_KEY")
v.SetDefault("providers.openai.base_url", "https://api.openai.com/v1")
v.SetDefault("providers.openai.default_model", "gpt-4o")
v.SetDefault("providers.openai.enabled", true)
v.SetDefault("providers.gemini.api_key_env", "GEMINI_API_KEY")
v.SetDefault("providers.gemini.base_url", "https://generativelanguage.googleapis.com/v1")
v.SetDefault("providers.gemini.default_model", "gemini-2.0-flash")
v.SetDefault("providers.gemini.enabled", true)
v.SetDefault("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")
v.SetDefault("providers.deepseek.base_url", "https://api.deepseek.com")
v.SetDefault("providers.deepseek.default_model", "deepseek-reasoner")
v.SetDefault("providers.deepseek.enabled", true)
v.SetDefault("providers.moonshot.api_key_env", "MOONSHOT_API_KEY")
v.SetDefault("providers.moonshot.base_url", "https://api.moonshot.ai/v1")
v.SetDefault("providers.moonshot.default_model", "kimi-k2.5")
v.SetDefault("providers.moonshot.enabled", true)
v.SetDefault("providers.grok.api_key_env", "GROK_API_KEY")
v.SetDefault("providers.grok.base_url", "https://api.x.ai/v1")
v.SetDefault("providers.grok.default_model", "grok-4-1-fast-non-reasoning")
v.SetDefault("providers.grok.enabled", true)
v.SetDefault("providers.ollama.base_url", "http://localhost:11434/v1")
v.SetDefault("providers.ollama.enabled", false)
v.SetDefault("providers.ollama.models", []string{})
// Environment variables
v.SetEnvPrefix("LLM_PROXY")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "__"))
v.AutomaticEnv()
// Explicitly bind keys that might use double underscores in .env
v.BindEnv("encryption_key", "LLM_PROXY__ENCRYPTION_KEY")
v.BindEnv("server.port", "LLM_PROXY__SERVER__PORT")
v.BindEnv("server.host", "LLM_PROXY__SERVER__HOST")
// Config file
v.SetConfigName("config")
v.SetConfigType("toml")
v.AddConfigPath(".")
if envPath := os.Getenv("LLM_PROXY__CONFIG_PATH"); envPath != "" {
v.SetConfigFile(envPath)
}
if err := v.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
}
var cfg Config
if err := v.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
fmt.Printf("Debug Config: port from viper=%d, host from viper=%s\n", cfg.Server.Port, cfg.Server.Host)
fmt.Printf("Debug Env: LLM_PROXY__SERVER__PORT=%s, LLM_PROXY__SERVER__HOST=%s\n", os.Getenv("LLM_PROXY__SERVER__PORT"), os.Getenv("LLM_PROXY__SERVER__HOST"))
// Manual overrides for nested keys which Viper doesn't always bind correctly with AutomaticEnv + SetEnvPrefix
if port := os.Getenv("LLM_PROXY__SERVER__PORT"); port != "" {
fmt.Sscanf(port, "%d", &cfg.Server.Port)
fmt.Printf("Overriding port to %d from env\n", cfg.Server.Port)
}
if host := os.Getenv("LLM_PROXY__SERVER__HOST"); host != "" {
cfg.Server.Host = host
fmt.Printf("Overriding host to %s from env\n", cfg.Server.Host)
}
// Validate encryption key
if cfg.EncryptionKey == "" {
return nil, fmt.Errorf("encryption key is required (LLM_PROXY__ENCRYPTION_KEY)")
}
keyBytes, err := hex.DecodeString(cfg.EncryptionKey)
if err != nil {
keyBytes, err = base64.StdEncoding.DecodeString(cfg.EncryptionKey)
if err != nil {
return nil, fmt.Errorf("encryption key must be hex or base64 encoded")
}
}
if len(keyBytes) != 32 {
return nil, fmt.Errorf("encryption key must be 32 bytes, got %d", len(keyBytes))
}
cfg.KeyBytes = keyBytes
return &cfg, nil
}
func (c *Config) GetAPIKey(provider string) (string, error) {
var envVar string
switch provider {
case "openai":
envVar = c.Providers.OpenAI.APIKeyEnv
case "gemini":
envVar = c.Providers.Gemini.APIKeyEnv
case "deepseek":
envVar = c.Providers.DeepSeek.APIKeyEnv
case "moonshot":
envVar = c.Providers.Moonshot.APIKeyEnv
case "grok":
envVar = c.Providers.Grok.APIKeyEnv
default:
return "", fmt.Errorf("unknown provider: %s", provider)
}
val := os.Getenv(envVar)
if val == "" {
return "", fmt.Errorf("environment variable %s not set for %s", envVar, provider)
}
return strings.TrimSpace(val), nil
}

264
internal/db/db.go Normal file
View File

@@ -0,0 +1,264 @@
package db
import (
"fmt"
"log"
"os"
"path/filepath"
"time"
"github.com/jmoiron/sqlx"
_ "modernc.org/sqlite"
"golang.org/x/crypto/bcrypt"
)
type DB struct {
*sqlx.DB
}
func Init(path string) (*DB, error) {
// Ensure directory exists
dir := filepath.Dir(path)
if dir != "." && dir != "/" {
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create database directory: %w", err)
}
}
// Connect to SQLite
dsn := fmt.Sprintf("file:%s?_pragma=foreign_keys(1)", path)
db, err := sqlx.Connect("sqlite", dsn)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
instance := &DB{db}
// Run migrations
if err := instance.RunMigrations(); err != nil {
return nil, fmt.Errorf("failed to run migrations: %w", err)
}
return instance, nil
}
func (db *DB) RunMigrations() error {
// Tables creation
queries := []string{
`CREATE TABLE IF NOT EXISTS clients (
id INTEGER PRIMARY KEY AUTOINCREMENT,
client_id TEXT UNIQUE NOT NULL,
name TEXT,
description TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
is_active BOOLEAN DEFAULT TRUE,
rate_limit_per_minute INTEGER DEFAULT 60,
total_requests INTEGER DEFAULT 0,
total_tokens INTEGER DEFAULT 0,
total_cost REAL DEFAULT 0.0
)`,
`CREATE TABLE IF NOT EXISTS llm_requests (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
client_id TEXT,
provider TEXT,
model TEXT,
prompt_tokens INTEGER,
completion_tokens INTEGER,
reasoning_tokens INTEGER DEFAULT 0,
total_tokens INTEGER,
cost REAL,
has_images BOOLEAN DEFAULT FALSE,
status TEXT DEFAULT 'success',
error_message TEXT,
duration_ms INTEGER,
request_body TEXT,
response_body TEXT,
cache_read_tokens INTEGER DEFAULT 0,
cache_write_tokens INTEGER DEFAULT 0,
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE SET NULL
)`,
`CREATE TABLE IF NOT EXISTS provider_configs (
id TEXT PRIMARY KEY,
display_name TEXT NOT NULL,
enabled BOOLEAN DEFAULT TRUE,
base_url TEXT,
api_key TEXT,
credit_balance REAL DEFAULT 0.0,
low_credit_threshold REAL DEFAULT 5.0,
billing_mode TEXT,
api_key_encrypted BOOLEAN DEFAULT FALSE,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
`CREATE TABLE IF NOT EXISTS model_configs (
id TEXT PRIMARY KEY,
provider_id TEXT NOT NULL,
display_name TEXT,
enabled BOOLEAN DEFAULT TRUE,
prompt_cost_per_m REAL,
completion_cost_per_m REAL,
cache_read_cost_per_m REAL,
cache_write_cost_per_m REAL,
mapping TEXT,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE
)`,
`CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
display_name TEXT,
role TEXT DEFAULT 'admin',
must_change_password BOOLEAN DEFAULT FALSE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
`CREATE TABLE IF NOT EXISTS client_tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
client_id TEXT NOT NULL,
token TEXT NOT NULL UNIQUE,
name TEXT DEFAULT 'default',
is_active BOOLEAN DEFAULT TRUE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_used_at DATETIME,
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE CASCADE
)`,
}
for _, q := range queries {
if _, err := db.Exec(q); err != nil {
return fmt.Errorf("migration failed for query [%s]: %w", q, err)
}
}
// Add indices
indices := []string{
"CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)",
"CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)",
"CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)",
"CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)",
"CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)",
"CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)",
"CREATE UNIQUE INDEX IF NOT EXISTS idx_client_tokens_token ON client_tokens(token)",
"CREATE INDEX IF NOT EXISTS idx_client_tokens_client_id ON client_tokens(client_id)",
"CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp)",
"CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp)",
"CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id)",
}
for _, idx := range indices {
if _, err := db.Exec(idx); err != nil {
return fmt.Errorf("failed to create index [%s]: %w", idx, err)
}
}
// Default admin user
var count int
if err := db.Get(&count, "SELECT COUNT(*) FROM users"); err != nil {
return fmt.Errorf("failed to count users: %w", err)
}
if count == 0 {
hash, err := bcrypt.GenerateFromPassword([]byte("admin123"), 12)
if err != nil {
return fmt.Errorf("failed to hash default password: %w", err)
}
_, err = db.Exec("INSERT INTO users (username, password_hash, role, must_change_password) VALUES ('admin', ?, 'admin', 1)", string(hash))
if err != nil {
return fmt.Errorf("failed to insert default admin: %w", err)
}
log.Println("Created default admin user with password 'admin123' (must change on first login)")
}
// Default client
_, err := db.Exec(`INSERT OR IGNORE INTO clients (client_id, name, description)
VALUES ('default', 'Default Client', 'Default client for anonymous requests')`)
if err != nil {
return fmt.Errorf("failed to insert default client: %w", err)
}
return nil
}
// Data models for DB tables
type Client struct {
ID int `db:"id"`
ClientID string `db:"client_id"`
Name *string `db:"name"`
Description *string `db:"description"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
IsActive bool `db:"is_active"`
RateLimitPerMinute int `db:"rate_limit_per_minute"`
TotalRequests int `db:"total_requests"`
TotalTokens int `db:"total_tokens"`
TotalCost float64 `db:"total_cost"`
}
type LLMRequest struct {
ID int `db:"id"`
Timestamp time.Time `db:"timestamp"`
ClientID *string `db:"client_id"`
Provider *string `db:"provider"`
Model *string `db:"model"`
PromptTokens *int `db:"prompt_tokens"`
CompletionTokens *int `db:"completion_tokens"`
ReasoningTokens int `db:"reasoning_tokens"`
TotalTokens *int `db:"total_tokens"`
Cost *float64 `db:"cost"`
HasImages bool `db:"has_images"`
Status string `db:"status"`
ErrorMessage *string `db:"error_message"`
DurationMS *int `db:"duration_ms"`
RequestBody *string `db:"request_body"`
ResponseBody *string `db:"response_body"`
CacheReadTokens int `db:"cache_read_tokens"`
CacheWriteTokens int `db:"cache_write_tokens"`
}
type ProviderConfig struct {
ID string `db:"id"`
DisplayName string `db:"display_name"`
Enabled bool `db:"enabled"`
BaseURL *string `db:"base_url"`
APIKey *string `db:"api_key"`
CreditBalance float64 `db:"credit_balance"`
LowCreditThreshold float64 `db:"low_credit_threshold"`
BillingMode *string `db:"billing_mode"`
APIKeyEncrypted bool `db:"api_key_encrypted"`
UpdatedAt time.Time `db:"updated_at"`
}
type ModelConfig struct {
ID string `db:"id"`
ProviderID string `db:"provider_id"`
DisplayName *string `db:"display_name"`
Enabled bool `db:"enabled"`
PromptCostPerM *float64 `db:"prompt_cost_per_m"`
CompletionCostPerM *float64 `db:"completion_cost_per_m"`
CacheReadCostPerM *float64 `db:"cache_read_cost_per_m"`
CacheWriteCostPerM *float64 `db:"cache_write_cost_per_m"`
Mapping *string `db:"mapping"`
UpdatedAt time.Time `db:"updated_at"`
}
type User struct {
ID int `db:"id" json:"id"`
Username string `db:"username" json:"username"`
PasswordHash string `db:"password_hash" json:"-"`
DisplayName *string `db:"display_name" json:"display_name"`
Role string `db:"role" json:"role"`
MustChangePassword bool `db:"must_change_password" json:"must_change_password"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
}
type ClientToken struct {
ID int `db:"id"`
ClientID string `db:"client_id"`
Token string `db:"token"`
Name string `db:"name"`
IsActive bool `db:"is_active"`
CreatedAt time.Time `db:"created_at"`
LastUsedAt *time.Time `db:"last_used_at"`
}

View File

@@ -0,0 +1,52 @@
package middleware
import (
"log"
"strings"
"gophergate/internal/db"
"gophergate/internal/models"
"github.com/gin-gonic/gin"
)
func AuthMiddleware(database *db.DB) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.Next()
return
}
token := strings.TrimPrefix(authHeader, "Bearer ")
if token == authHeader { // No "Bearer " prefix
c.Next()
return
}
// Try to resolve client from database
var clientID string
err := database.Get(&clientID, "UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = 1 RETURNING client_id", token)
if err == nil {
c.Set("auth", models.AuthInfo{
Token: token,
ClientID: clientID,
})
} else {
// Fallback to token-prefix derivation (matches Rust behavior)
prefixLen := len(token)
if prefixLen > 8 {
prefixLen = 8
}
clientID = "client_" + token[:prefixLen]
c.Set("auth", models.AuthInfo{
Token: token,
ClientID: clientID,
})
log.Printf("Token not found in DB, using fallback client ID: %s", clientID)
}
c.Next()
}
}

216
internal/models/models.go Normal file
View File

@@ -0,0 +1,216 @@
package models
import (
"encoding/base64"
"encoding/json"
"fmt"
"github.com/go-resty/resty/v2"
)
// OpenAI-compatible Request/Response Structs
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *uint32 `json:"top_k,omitempty"`
N *uint32 `json:"n,omitempty"`
Stop json.RawMessage `json:"stop,omitempty"` // Can be string or array of strings
MaxTokens *uint32 `json:"max_tokens,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
Stream *bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
}
type ChatMessage struct {
Role string `json:"role"` // "system", "user", "assistant", "tool"
Content interface{} `json:"content"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Name *string `json:"name,omitempty"`
ToolCallID *string `json:"tool_call_id,omitempty"`
}
type ContentPart struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ImageUrl *ImageUrl `json:"image_url,omitempty"`
}
type ImageUrl struct {
URL string `json:"url"`
Detail *string `json:"detail,omitempty"`
}
// Tool-Calling Types
type Tool struct {
Type string `json:"type"`
Function FunctionDef `json:"function"`
}
type FunctionDef struct {
Name string `json:"name"`
Description *string `json:"description,omitempty"`
Parameters json.RawMessage `json:"parameters,omitempty"`
}
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function FunctionCall `json:"function"`
}
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type ToolCallDelta struct {
Index uint32 `json:"index"`
ID *string `json:"id,omitempty"`
Type *string `json:"type,omitempty"`
Function *FunctionCallDelta `json:"function,omitempty"`
}
type FunctionCallDelta struct {
Name *string `json:"name,omitempty"`
Arguments *string `json:"arguments,omitempty"`
}
// OpenAI-compatible Response Structs
type ChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatChoice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
}
type ChatChoice struct {
Index uint32 `json:"index"`
Message ChatMessage `json:"message"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type Usage struct {
PromptTokens uint32 `json:"prompt_tokens"`
CompletionTokens uint32 `json:"completion_tokens"`
TotalTokens uint32 `json:"total_tokens"`
ReasoningTokens *uint32 `json:"reasoning_tokens,omitempty"`
CacheReadTokens *uint32 `json:"cache_read_tokens,omitempty"`
CacheWriteTokens *uint32 `json:"cache_write_tokens,omitempty"`
}
// Streaming Response Structs
type ChatCompletionStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatStreamChoice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
}
type ChatStreamChoice struct {
Index uint32 `json:"index"`
Delta ChatStreamDelta `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type ChatStreamDelta struct {
Role *string `json:"role,omitempty"`
Content *string `json:"content,omitempty"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
ToolCalls []ToolCallDelta `json:"tool_calls,omitempty"`
}
type StreamUsage struct {
PromptTokens uint32 `json:"prompt_tokens"`
CompletionTokens uint32 `json:"completion_tokens"`
TotalTokens uint32 `json:"total_tokens"`
ReasoningTokens uint32 `json:"reasoning_tokens"`
CacheReadTokens uint32 `json:"cache_read_tokens"`
CacheWriteTokens uint32 `json:"cache_write_tokens"`
}
// Unified Request Format (for internal use)
type UnifiedRequest struct {
ClientID string
Model string
Messages []UnifiedMessage
Temperature *float64
TopP *float64
TopK *uint32
N *uint32
Stop []string
MaxTokens *uint32
PresencePenalty *float64
FrequencyPenalty *float64
Stream bool
HasImages bool
Tools []Tool
ToolChoice json.RawMessage
}
type UnifiedMessage struct {
Role string
Content []UnifiedContentPart
ReasoningContent *string
ToolCalls []ToolCall
Name *string
ToolCallID *string
}
type UnifiedContentPart struct {
Type string
Text string
Image *ImageInput
}
type ImageInput struct {
Base64 string `json:"base64,omitempty"`
URL string `json:"url,omitempty"`
MimeType string `json:"mime_type,omitempty"`
}
func (i *ImageInput) ToBase64() (string, string, error) {
if i.Base64 != "" {
return i.Base64, i.MimeType, nil
}
if i.URL != "" {
client := resty.New()
resp, err := client.R().Get(i.URL)
if err != nil {
return "", "", fmt.Errorf("failed to fetch image: %w", err)
}
if !resp.IsSuccess() {
return "", "", fmt.Errorf("failed to fetch image: HTTP %d", resp.StatusCode())
}
mimeType := resp.Header().Get("Content-Type")
if mimeType == "" {
mimeType = "image/jpeg"
}
encoded := base64.StdEncoding.EncodeToString(resp.Body())
return encoded, mimeType, nil
}
return "", "", fmt.Errorf("empty image input")
}
// AuthInfo for context
type AuthInfo struct {
Token string
ClientID string
}

View File

@@ -0,0 +1,69 @@
package models
import "strings"
type ModelRegistry struct {
Providers map[string]ProviderInfo `json:"-"`
}
type ProviderInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Models map[string]ModelMetadata `json:"models"`
}
type ModelMetadata struct {
ID string `json:"id"`
Name string `json:"name"`
Cost *ModelCost `json:"cost,omitempty"`
Limit *ModelLimit `json:"limit,omitempty"`
Modalities *ModelModalities `json:"modalities,omitempty"`
ToolCall *bool `json:"tool_call,omitempty"`
Reasoning *bool `json:"reasoning,omitempty"`
}
type ModelCost struct {
Input float64 `json:"input"`
Output float64 `json:"output"`
CacheRead *float64 `json:"cache_read,omitempty"`
CacheWrite *float64 `json:"cache_write,omitempty"`
}
type ModelLimit struct {
Context uint32 `json:"context"`
Output uint32 `json:"output"`
}
type ModelModalities struct {
Input []string `json:"input"`
Output []string `json:"output"`
}
func (r *ModelRegistry) FindModel(modelID string) *ModelMetadata {
// First try exact match in models map
for _, provider := range r.Providers {
if model, ok := provider.Models[modelID]; ok {
return &model
}
}
// Try searching by ID in metadata
for _, provider := range r.Providers {
for _, model := range provider.Models {
if model.ID == modelID {
return &model
}
}
}
// Try fuzzy matching (e.g. gpt-4o-2024-05-13 matching gpt-4o)
for _, provider := range r.Providers {
for id, model := range provider.Models {
if strings.HasPrefix(modelID, id) {
return &model
}
}
}
return nil
}

View File

@@ -0,0 +1,220 @@
package providers
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"strings"
"gophergate/internal/config"
"gophergate/internal/models"
"github.com/go-resty/resty/v2"
)
type DeepSeekProvider struct {
client *resty.Client
config config.DeepSeekConfig
apiKey string
}
func NewDeepSeekProvider(cfg config.DeepSeekConfig, apiKey string) *DeepSeekProvider {
return &DeepSeekProvider{
client: resty.New(),
config: cfg,
apiKey: apiKey,
}
}
func (p *DeepSeekProvider) Name() string {
return "deepseek"
}
type deepSeekUsage struct {
PromptTokens uint32 `json:"prompt_tokens"`
CompletionTokens uint32 `json:"completion_tokens"`
TotalTokens uint32 `json:"total_tokens"`
PromptCacheHitTokens uint32 `json:"prompt_cache_hit_tokens"`
PromptCacheMissTokens uint32 `json:"prompt_cache_miss_tokens"`
CompletionTokensDetails *struct {
ReasoningTokens uint32 `json:"reasoning_tokens"`
} `json:"completion_tokens_details"`
}
func (u *deepSeekUsage) ToUnified() *models.Usage {
usage := &models.Usage{
PromptTokens: u.PromptTokens,
CompletionTokens: u.CompletionTokens,
TotalTokens: u.TotalTokens,
}
if u.PromptCacheHitTokens > 0 {
usage.CacheReadTokens = &u.PromptCacheHitTokens
}
if u.PromptCacheMissTokens > 0 {
usage.CacheWriteTokens = &u.PromptCacheMissTokens
}
if u.CompletionTokensDetails != nil && u.CompletionTokensDetails.ReasoningTokens > 0 {
usage.ReasoningTokens = &u.CompletionTokensDetails.ReasoningTokens
}
return usage
}
func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
body := BuildOpenAIBody(req, messagesJSON, false)
// Sanitize for deepseek-reasoner
if req.Model == "deepseek-reasoner" {
delete(body, "temperature")
delete(body, "top_p")
delete(body, "presence_penalty")
delete(body, "frequency_penalty")
if msgs, ok := body["messages"].([]interface{}); ok {
for _, m := range msgs {
if msg, ok := m.(map[string]interface{}); ok {
if msg["role"] == "assistant" {
if msg["reasoning_content"] == nil {
msg["reasoning_content"] = " "
}
if msg["content"] == nil || msg["content"] == "" {
msg["content"] = ""
}
}
}
}
}
}
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetBody(body).
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
}
var respJSON map[string]interface{}
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
result, err := ParseOpenAIResponse(respJSON, req.Model)
if err != nil {
return nil, err
}
// Fix usage for DeepSeek specifically if details were missing in ParseOpenAIResponse
if usageData, ok := respJSON["usage"]; ok {
var dUsage deepSeekUsage
usageBytes, _ := json.Marshal(usageData)
if err := json.Unmarshal(usageBytes, &dUsage); err == nil {
result.Usage = dUsage.ToUnified()
}
}
return result, nil
}
func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
body := BuildOpenAIBody(req, messagesJSON, true)
// Sanitize for deepseek-reasoner
if req.Model == "deepseek-reasoner" {
delete(body, "temperature")
delete(body, "top_p")
delete(body, "presence_penalty")
delete(body, "frequency_penalty")
if msgs, ok := body["messages"].([]interface{}); ok {
for _, m := range msgs {
if msg, ok := m.(map[string]interface{}); ok {
if msg["role"] == "assistant" {
if msg["reasoning_content"] == nil {
msg["reasoning_content"] = " "
}
if msg["content"] == nil || msg["content"] == "" {
msg["content"] = ""
}
}
}
}
}
}
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetBody(body).
SetDoNotParseResponse(true).
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
// Custom scanner loop to handle DeepSeek specific usage in chunks
err := StreamDeepSeek(resp.RawBody(), ch)
if err != nil {
fmt.Printf("DeepSeek Stream error: %v\n", err)
}
}()
return ch, nil
}
func StreamDeepSeek(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse) error {
defer ctx.Close()
scanner := bufio.NewScanner(ctx)
for scanner.Scan() {
line := scanner.Text()
if line == "" || !strings.HasPrefix(line, "data: ") {
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
break
}
var chunk models.ChatCompletionStreamResponse
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
continue
}
// Fix DeepSeek specific usage in stream
var rawChunk struct {
Usage *deepSeekUsage `json:"usage"`
}
if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil {
chunk.Usage = rawChunk.Usage.ToUnified()
}
ch <- &chunk
}
return scanner.Err()
}

View File

@@ -0,0 +1,254 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"gophergate/internal/config"
"gophergate/internal/models"
"github.com/go-resty/resty/v2"
)
type GeminiProvider struct {
client *resty.Client
config config.GeminiConfig
apiKey string
}
func NewGeminiProvider(cfg config.GeminiConfig, apiKey string) *GeminiProvider {
return &GeminiProvider{
client: resty.New(),
config: cfg,
apiKey: apiKey,
}
}
func (p *GeminiProvider) Name() string {
return "gemini"
}
type GeminiRequest struct {
Contents []GeminiContent `json:"contents"`
}
type GeminiContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
}
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type GeminiFunctionCall struct {
Name string `json:"name"`
Args json.RawMessage `json:"args"`
}
type GeminiFunctionResponse struct {
Name string `json:"name"`
Response json.RawMessage `json:"response"`
}
func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
// Gemini mapping
var contents []GeminiContent
for _, msg := range req.Messages {
role := "user"
if msg.Role == "assistant" {
role = "model"
} else if msg.Role == "tool" {
role = "user" // Tool results are user-side in Gemini
}
var parts []GeminiPart
// Handle tool responses
if msg.Role == "tool" {
text := ""
if len(msg.Content) > 0 {
text = msg.Content[0].Text
}
// Gemini expects functionResponse to be an object
name := "unknown_function"
if msg.Name != nil {
name = *msg.Name
}
parts = append(parts, GeminiPart{
FunctionResponse: &GeminiFunctionResponse{
Name: name,
Response: json.RawMessage(text),
},
})
} else {
for _, cp := range msg.Content {
if cp.Type == "text" {
parts = append(parts, GeminiPart{Text: cp.Text})
} else if cp.Image != nil {
base64Data, mimeType, _ := cp.Image.ToBase64()
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
},
})
}
}
// Handle assistant tool calls
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
for _, tc := range msg.ToolCalls {
parts = append(parts, GeminiPart{
FunctionCall: &GeminiFunctionCall{
Name: tc.Function.Name,
Args: json.RawMessage(tc.Function.Arguments),
},
})
}
}
}
contents = append(contents, GeminiContent{
Role: role,
Parts: parts,
})
}
body := GeminiRequest{
Contents: contents,
}
url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
resp, err := p.client.R().
SetContext(ctx).
SetBody(body).
Post(url)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
}
// Parse Gemini response and convert to OpenAI format
var geminiResp struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"content"`
FinishReason string `json:"finishReason"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount uint32 `json:"promptTokenCount"`
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
TotalTokenCount uint32 `json:"totalTokenCount"`
} `json:"usageMetadata"`
}
if err := json.Unmarshal(resp.Body(), &geminiResp); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
if len(geminiResp.Candidates) == 0 {
return nil, fmt.Errorf("no candidates in Gemini response")
}
content := ""
for _, p := range geminiResp.Candidates[0].Content.Parts {
content += p.Text
}
openAIResp := &models.ChatCompletionResponse{
ID: "gemini-" + req.Model,
Object: "chat.completion",
Created: 0, // Should be current timestamp
Model: req.Model,
Choices: []models.ChatChoice{
{
Index: 0,
Message: models.ChatMessage{
Role: "assistant",
Content: content,
},
FinishReason: &geminiResp.Candidates[0].FinishReason,
},
},
Usage: &models.Usage{
PromptTokens: geminiResp.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiResp.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiResp.UsageMetadata.TotalTokenCount,
},
}
return openAIResp, nil
}
func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
// Simplified Gemini mapping
var contents []GeminiContent
for _, msg := range req.Messages {
role := "user"
if msg.Role == "assistant" {
role = "model"
}
var parts []GeminiPart
for _, p := range msg.Content {
parts = append(parts, GeminiPart{Text: p.Text})
}
contents = append(contents, GeminiContent{
Role: role,
Parts: parts,
})
}
body := GeminiRequest{
Contents: contents,
}
// Use streamGenerateContent for streaming
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
resp, err := p.client.R().
SetContext(ctx).
SetBody(body).
SetDoNotParseResponse(true).
Post(url)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
err := StreamGemini(resp.RawBody(), ch, req.Model)
if err != nil {
fmt.Printf("Gemini Stream error: %v\n", err)
}
}()
return ch, nil
}

View File

@@ -0,0 +1,95 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"gophergate/internal/config"
"gophergate/internal/models"
"github.com/go-resty/resty/v2"
)
type GrokProvider struct {
client *resty.Client
config config.GrokConfig
apiKey string
}
func NewGrokProvider(cfg config.GrokConfig, apiKey string) *GrokProvider {
return &GrokProvider{
client: resty.New(),
config: cfg,
apiKey: apiKey,
}
}
func (p *GrokProvider) Name() string {
return "grok"
}
func (p *GrokProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
body := BuildOpenAIBody(req, messagesJSON, false)
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetBody(body).
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
}
var respJSON map[string]interface{}
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return ParseOpenAIResponse(respJSON, req.Model)
}
func (p *GrokProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
body := BuildOpenAIBody(req, messagesJSON, true)
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetBody(body).
SetDoNotParseResponse(true).
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
err := StreamOpenAI(resp.RawBody(), ch)
if err != nil {
fmt.Printf("Grok Stream error: %v\n", err)
}
}()
return ch, nil
}

View File

@@ -0,0 +1,318 @@
package providers
import (
"bufio"
"encoding/json"
"fmt"
"io"
"strings"
"gophergate/internal/models"
)
// MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images.
func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) {
var result []interface{}
for _, m := range messages {
if m.Role == "tool" {
text := ""
if len(m.Content) > 0 {
text = m.Content[0].Text
}
msg := map[string]interface{}{
"role": "tool",
"content": text,
}
if m.ToolCallID != nil {
id := *m.ToolCallID
if len(id) > 40 {
id = id[:40]
}
msg["tool_call_id"] = id
}
if m.Name != nil {
msg["name"] = *m.Name
}
result = append(result, msg)
continue
}
var parts []interface{}
for _, p := range m.Content {
if p.Type == "text" {
parts = append(parts, map[string]interface{}{
"type": "text",
"text": p.Text,
})
} else if p.Image != nil {
base64Data, mimeType, err := p.Image.ToBase64()
if err != nil {
return nil, fmt.Errorf("failed to convert image to base64: %w", err)
}
parts = append(parts, map[string]interface{}{
"type": "image_url",
"image_url": map[string]interface{}{
"url": fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data),
},
})
}
}
var finalContent interface{}
if len(parts) == 1 {
if p, ok := parts[0].(map[string]interface{}); ok && p["type"] == "text" {
finalContent = p["text"]
} else {
finalContent = parts
}
} else {
finalContent = parts
}
msg := map[string]interface{}{
"role": m.Role,
"content": finalContent,
}
if m.ReasoningContent != nil {
msg["reasoning_content"] = *m.ReasoningContent
}
if len(m.ToolCalls) > 0 {
sanitizedCalls := make([]models.ToolCall, len(m.ToolCalls))
copy(sanitizedCalls, m.ToolCalls)
for i := range sanitizedCalls {
if len(sanitizedCalls[i].ID) > 40 {
sanitizedCalls[i].ID = sanitizedCalls[i].ID[:40]
}
}
msg["tool_calls"] = sanitizedCalls
if len(parts) == 0 {
msg["content"] = ""
}
}
if m.Name != nil {
msg["name"] = *m.Name
}
result = append(result, msg)
}
return result, nil
}
func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{}, stream bool) map[string]interface{} {
body := map[string]interface{}{
"model": request.Model,
"messages": messagesJSON,
"stream": stream,
}
if stream {
body["stream_options"] = map[string]interface{}{
"include_usage": true,
}
}
if request.Temperature != nil {
body["temperature"] = *request.Temperature
}
if request.MaxTokens != nil {
body["max_tokens"] = *request.MaxTokens
}
if len(request.Tools) > 0 {
body["tools"] = request.Tools
}
if request.ToolChoice != nil {
var toolChoice interface{}
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
body["tool_choice"] = toolChoice
}
}
return body
}
type openAIUsage struct {
PromptTokens uint32 `json:"prompt_tokens"`
CompletionTokens uint32 `json:"completion_tokens"`
TotalTokens uint32 `json:"total_tokens"`
PromptTokensDetails *struct {
CachedTokens uint32 `json:"cached_tokens"`
} `json:"prompt_tokens_details"`
CompletionTokensDetails *struct {
ReasoningTokens uint32 `json:"reasoning_tokens"`
} `json:"completion_tokens_details"`
}
func (u *openAIUsage) ToUnified() *models.Usage {
usage := &models.Usage{
PromptTokens: u.PromptTokens,
CompletionTokens: u.CompletionTokens,
TotalTokens: u.TotalTokens,
}
if u.PromptTokensDetails != nil && u.PromptTokensDetails.CachedTokens > 0 {
usage.CacheReadTokens = &u.PromptTokensDetails.CachedTokens
}
if u.CompletionTokensDetails != nil && u.CompletionTokensDetails.ReasoningTokens > 0 {
usage.ReasoningTokens = &u.CompletionTokensDetails.ReasoningTokens
}
return usage
}
func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models.ChatCompletionResponse, error) {
data, err := json.Marshal(respJSON)
if err != nil {
return nil, err
}
var resp models.ChatCompletionResponse
if err := json.Unmarshal(data, &resp); err != nil {
return nil, err
}
// Manually fix usage because ChatCompletionResponse uses the unified Usage struct
// but the provider might have returned more details.
if usageData, ok := respJSON["usage"]; ok {
var oUsage openAIUsage
usageBytes, _ := json.Marshal(usageData)
if err := json.Unmarshal(usageBytes, &oUsage); err == nil {
resp.Usage = oUsage.ToUnified()
}
}
return &resp, nil
}
// Streaming support
func ParseOpenAIStreamChunk(line string) (*models.ChatCompletionStreamResponse, bool, error) {
if line == "" {
return nil, false, nil
}
if !strings.HasPrefix(line, "data: ") {
return nil, false, nil
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
return nil, true, nil
}
var chunk models.ChatCompletionStreamResponse
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
return nil, false, fmt.Errorf("failed to unmarshal stream chunk: %w", err)
}
// Handle specialized usage in stream chunks
var rawChunk struct {
Usage *openAIUsage `json:"usage"`
}
if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil {
chunk.Usage = rawChunk.Usage.ToUnified()
}
return &chunk, false, nil
}
func StreamOpenAI(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse) error {
defer ctx.Close()
scanner := bufio.NewScanner(ctx)
for scanner.Scan() {
line := scanner.Text()
chunk, done, err := ParseOpenAIStreamChunk(line)
if err != nil {
return err
}
if done {
break
}
if chunk != nil {
ch <- chunk
}
}
return scanner.Err()
}
func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error {
defer ctx.Close()
dec := json.NewDecoder(ctx)
t, err := dec.Token()
if err != nil {
return err
}
if delim, ok := t.(json.Delim); ok && delim == '[' {
for dec.More() {
var geminiChunk struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text,omitempty"`
Thought string `json:"thought,omitempty"`
} `json:"parts"`
} `json:"content"`
FinishReason string `json:"finishReason"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount uint32 `json:"promptTokenCount"`
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
TotalTokenCount uint32 `json:"totalTokenCount"`
} `json:"usageMetadata"`
}
if err := dec.Decode(&geminiChunk); err != nil {
return err
}
if len(geminiChunk.Candidates) > 0 || geminiChunk.UsageMetadata.TotalTokenCount > 0 {
content := ""
var reasoning *string
if len(geminiChunk.Candidates) > 0 {
for _, p := range geminiChunk.Candidates[0].Content.Parts {
if p.Text != "" {
content += p.Text
}
if p.Thought != "" {
if reasoning == nil {
reasoning = new(string)
}
*reasoning += p.Thought
}
}
}
var finishReason *string
if len(geminiChunk.Candidates) > 0 {
fr := strings.ToLower(geminiChunk.Candidates[0].FinishReason)
finishReason = &fr
}
ch <- &models.ChatCompletionStreamResponse{
ID: "gemini-stream",
Object: "chat.completion.chunk",
Created: 0,
Model: model,
Choices: []models.ChatStreamChoice{
{
Index: 0,
Delta: models.ChatStreamDelta{
Content: &content,
ReasoningContent: reasoning,
},
FinishReason: finishReason,
},
},
Usage: &models.Usage{
PromptTokens: geminiChunk.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiChunk.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiChunk.UsageMetadata.TotalTokenCount,
},
}
}
}
}
return nil
}

View File

@@ -0,0 +1,114 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"strings"
"gophergate/internal/config"
"gophergate/internal/models"
"github.com/go-resty/resty/v2"
)
type MoonshotProvider struct {
client *resty.Client
config config.MoonshotConfig
apiKey string
}
func NewMoonshotProvider(cfg config.MoonshotConfig, apiKey string) *MoonshotProvider {
return &MoonshotProvider{
client: resty.New(),
config: cfg,
apiKey: strings.TrimSpace(apiKey),
}
}
func (p *MoonshotProvider) Name() string {
return "moonshot"
}
func (p *MoonshotProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
body := BuildOpenAIBody(req, messagesJSON, false)
if strings.Contains(strings.ToLower(req.Model), "kimi-k2.5") {
if maxTokens, ok := body["max_tokens"]; ok {
delete(body, "max_tokens")
body["max_completion_tokens"] = maxTokens
}
}
baseURL := strings.TrimRight(p.config.BaseURL, "/")
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetHeader("Content-Type", "application/json").
SetHeader("Accept", "application/json").
SetBody(body).
Post(fmt.Sprintf("%s/chat/completions", baseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), resp.String())
}
var respJSON map[string]interface{}
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return ParseOpenAIResponse(respJSON, req.Model)
}
func (p *MoonshotProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
body := BuildOpenAIBody(req, messagesJSON, true)
if strings.Contains(strings.ToLower(req.Model), "kimi-k2.5") {
if maxTokens, ok := body["max_tokens"]; ok {
delete(body, "max_tokens")
body["max_completion_tokens"] = maxTokens
}
}
baseURL := strings.TrimRight(p.config.BaseURL, "/")
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetHeader("Content-Type", "application/json").
SetHeader("Accept", "text/event-stream").
SetBody(body).
SetDoNotParseResponse(true).
Post(fmt.Sprintf("%s/chat/completions", baseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), resp.String())
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
if err := StreamOpenAI(resp.RawBody(), ch); err != nil {
fmt.Printf("Moonshot Stream error: %v\n", err)
}
}()
return ch, nil
}

View File

@@ -0,0 +1,113 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"strings"
"gophergate/internal/config"
"gophergate/internal/models"
"github.com/go-resty/resty/v2"
)
type OpenAIProvider struct {
client *resty.Client
config config.OpenAIConfig
apiKey string
}
func NewOpenAIProvider(cfg config.OpenAIConfig, apiKey string) *OpenAIProvider {
return &OpenAIProvider{
client: resty.New(),
config: cfg,
apiKey: apiKey,
}
}
func (p *OpenAIProvider) Name() string {
return "openai"
}
func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
body := BuildOpenAIBody(req, messagesJSON, false)
// Transition: Newer models require max_completion_tokens
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
if maxTokens, ok := body["max_tokens"]; ok {
delete(body, "max_tokens")
body["max_completion_tokens"] = maxTokens
}
}
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetBody(body).
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
}
var respJSON map[string]interface{}
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return ParseOpenAIResponse(respJSON, req.Model)
}
func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
body := BuildOpenAIBody(req, messagesJSON, true)
// Transition: Newer models require max_completion_tokens
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
if maxTokens, ok := body["max_tokens"]; ok {
delete(body, "max_tokens")
body["max_completion_tokens"] = maxTokens
}
}
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetBody(body).
SetDoNotParseResponse(true).
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
err := StreamOpenAI(resp.RawBody(), ch)
if err != nil {
// In a real app, you might want to send an error chunk or log it
fmt.Printf("Stream error: %v\n", err)
}
}()
return ch, nil
}

View File

@@ -0,0 +1,13 @@
package providers
import (
"context"
"gophergate/internal/models"
)
type Provider interface {
Name() string
ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error)
ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error)
}

1393
internal/server/dashboard.go Normal file

File diff suppressed because it is too large Load Diff

113
internal/server/logging.go Normal file
View File

@@ -0,0 +1,113 @@
package server
import (
"log"
"time"
"gophergate/internal/db"
)
type RequestLog struct {
Timestamp time.Time `json:"timestamp"`
ClientID string `json:"client_id"`
Provider string `json:"provider"`
Model string `json:"model"`
PromptTokens uint32 `json:"prompt_tokens"`
CompletionTokens uint32 `json:"completion_tokens"`
ReasoningTokens uint32 `json:"reasoning_tokens"`
TotalTokens uint32 `json:"total_tokens"`
CacheReadTokens uint32 `json:"cache_read_tokens"`
CacheWriteTokens uint32 `json:"cache_write_tokens"`
Cost float64 `json:"cost"`
HasImages bool `json:"has_images"`
Status string `json:"status"`
ErrorMessage string `json:"error_message,omitempty"`
DurationMS int64 `json:"duration_ms"`
}
type RequestLogger struct {
database *db.DB
hub *Hub
logChan chan RequestLog
}
func NewRequestLogger(database *db.DB, hub *Hub) *RequestLogger {
return &RequestLogger{
database: database,
hub: hub,
logChan: make(chan RequestLog, 100),
}
}
func (l *RequestLogger) Start() {
go func() {
for entry := range l.logChan {
l.processLog(entry)
}
}()
}
func (l *RequestLogger) LogRequest(entry RequestLog) {
select {
case l.logChan <- entry:
default:
log.Println("Request log channel full, dropping log entry")
}
}
func (l *RequestLogger) processLog(entry RequestLog) {
// Broadcast to dashboard
l.hub.broadcast <- map[string]interface{}{
"type": "request",
"channel": "requests",
"payload": entry,
}
// Insert into DB
tx, err := l.database.Begin()
if err != nil {
log.Printf("Failed to begin transaction for logging: %v", err)
return
}
defer tx.Rollback()
// Ensure client exists
_, _ = tx.Exec("INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')",
entry.ClientID, entry.ClientID)
// Insert log
_, err = tx.Exec(`
INSERT INTO llm_requests
(timestamp, client_id, provider, model, prompt_tokens, completion_tokens, reasoning_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status, error_message, duration_ms)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`, entry.Timestamp, entry.ClientID, entry.Provider, entry.Model,
entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.TotalTokens,
entry.CacheReadTokens, entry.CacheWriteTokens, entry.Cost, entry.HasImages,
entry.Status, entry.ErrorMessage, entry.DurationMS)
if err != nil {
log.Printf("Failed to insert request log: %v", err)
return
}
// Update client stats
_, _ = tx.Exec(`
UPDATE clients SET
total_requests = total_requests + 1,
total_tokens = total_tokens + ?,
total_cost = total_cost + ?,
updated_at = CURRENT_TIMESTAMP
WHERE client_id = ?
`, entry.TotalTokens, entry.Cost, entry.ClientID)
// Update provider balance
if entry.Cost > 0 {
_, _ = tx.Exec("UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')",
entry.Cost, entry.Provider)
}
err = tx.Commit()
if err != nil {
log.Printf("Failed to commit logging transaction: %v", err)
}
}

494
internal/server/server.go Normal file
View File

@@ -0,0 +1,494 @@
package server
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"gophergate/internal/config"
"gophergate/internal/db"
"gophergate/internal/middleware"
"gophergate/internal/models"
"gophergate/internal/providers"
"gophergate/internal/utils"
"github.com/gin-gonic/gin"
)
type Server struct {
router *gin.Engine
cfg *config.Config
database *db.DB
providers map[string]providers.Provider
sessions *SessionManager
hub *Hub
logger *RequestLogger
registry *models.ModelRegistry
}
func NewServer(cfg *config.Config, database *db.DB) *Server {
router := gin.Default()
hub := NewHub()
s := &Server{
router: router,
cfg: cfg,
database: database,
providers: make(map[string]providers.Provider),
sessions: NewSessionManager(cfg.KeyBytes, 24*time.Hour),
hub: hub,
logger: NewRequestLogger(database, hub),
registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)},
}
// Fetch registry in background
go func() {
registry, err := utils.FetchRegistry()
if err != nil {
fmt.Printf("Warning: Failed to fetch initial model registry: %v\n", err)
} else {
s.registry = registry
}
}()
// Initialize providers from DB and Config
if err := s.RefreshProviders(); err != nil {
fmt.Printf("Warning: Failed to initial refresh providers: %v\n", err)
}
s.setupRoutes()
return s
}
func (s *Server) RefreshProviders() error {
var dbConfigs []db.ProviderConfig
err := s.database.Select(&dbConfigs, "SELECT * FROM provider_configs")
if err != nil {
return fmt.Errorf("failed to fetch provider configs from db: %w", err)
}
dbMap := make(map[string]db.ProviderConfig)
for _, cfg := range dbConfigs {
dbMap[cfg.ID] = cfg
}
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok"}
for _, id := range providerIDs {
// Default values from config
enabled := false
baseURL := ""
apiKey := ""
switch id {
case "openai":
enabled = s.cfg.Providers.OpenAI.Enabled
baseURL = s.cfg.Providers.OpenAI.BaseURL
apiKey, _ = s.cfg.GetAPIKey("openai")
case "gemini":
enabled = s.cfg.Providers.Gemini.Enabled
baseURL = s.cfg.Providers.Gemini.BaseURL
apiKey, _ = s.cfg.GetAPIKey("gemini")
case "deepseek":
enabled = s.cfg.Providers.DeepSeek.Enabled
baseURL = s.cfg.Providers.DeepSeek.BaseURL
apiKey, _ = s.cfg.GetAPIKey("deepseek")
case "moonshot":
enabled = s.cfg.Providers.Moonshot.Enabled
baseURL = s.cfg.Providers.Moonshot.BaseURL
apiKey, _ = s.cfg.GetAPIKey("moonshot")
case "grok":
enabled = s.cfg.Providers.Grok.Enabled
baseURL = s.cfg.Providers.Grok.BaseURL
apiKey, _ = s.cfg.GetAPIKey("grok")
}
// Overrides from DB
if dbCfg, ok := dbMap[id]; ok {
enabled = dbCfg.Enabled
if dbCfg.BaseURL != nil && *dbCfg.BaseURL != "" {
baseURL = *dbCfg.BaseURL
}
if dbCfg.APIKey != nil && *dbCfg.APIKey != "" {
key := *dbCfg.APIKey
if dbCfg.APIKeyEncrypted {
decrypted, err := utils.Decrypt(key, s.cfg.KeyBytes)
if err == nil {
key = decrypted
} else {
fmt.Printf("Warning: Failed to decrypt API key for %s: %v\n", id, err)
}
}
apiKey = key
}
}
if !enabled {
delete(s.providers, id)
continue
}
// Initialize provider
switch id {
case "openai":
cfg := s.cfg.Providers.OpenAI
cfg.BaseURL = baseURL
s.providers["openai"] = providers.NewOpenAIProvider(cfg, apiKey)
case "gemini":
cfg := s.cfg.Providers.Gemini
cfg.BaseURL = baseURL
s.providers["gemini"] = providers.NewGeminiProvider(cfg, apiKey)
case "deepseek":
cfg := s.cfg.Providers.DeepSeek
cfg.BaseURL = baseURL
s.providers["deepseek"] = providers.NewDeepSeekProvider(cfg, apiKey)
case "moonshot":
cfg := s.cfg.Providers.Moonshot
cfg.BaseURL = baseURL
s.providers["moonshot"] = providers.NewMoonshotProvider(cfg, apiKey)
case "grok":
cfg := s.cfg.Providers.Grok
cfg.BaseURL = baseURL
s.providers["grok"] = providers.NewGrokProvider(cfg, apiKey)
}
}
return nil
}
func (s *Server) setupRoutes() {
s.router.Use(middleware.AuthMiddleware(s.database))
// Static files
s.router.StaticFile("/", "./static/index.html")
s.router.StaticFile("/favicon.ico", "./static/favicon.ico")
s.router.Static("/css", "./static/css")
s.router.Static("/js", "./static/js")
s.router.Static("/img", "./static/img")
// WebSocket
s.router.GET("/ws", s.handleWebSocket)
// API V1 (External LLM Access) - Secured with AuthMiddleware
v1 := s.router.Group("/v1")
v1.Use(middleware.AuthMiddleware(s.database))
{
v1.POST("/chat/completions", s.handleChatCompletions)
v1.GET("/models", s.handleListModels)
}
// Dashboard API Group
api := s.router.Group("/api")
{
api.POST("/auth/login", s.handleLogin)
api.GET("/auth/status", s.handleAuthStatus)
api.POST("/auth/logout", s.handleLogout)
api.POST("/auth/change-password", s.handleChangePassword)
// Protected dashboard routes (need admin session)
admin := api.Group("/")
admin.Use(s.adminAuthMiddleware())
{
admin.GET("/usage/summary", s.handleUsageSummary)
admin.GET("/usage/time-series", s.handleTimeSeries)
admin.GET("/usage/providers", s.handleProvidersUsage)
admin.GET("/usage/clients", s.handleClientsUsage)
admin.GET("/usage/detailed", s.handleDetailedUsage)
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
admin.GET("/clients", s.handleGetClients)
admin.POST("/clients", s.handleCreateClient)
admin.GET("/clients/:id", s.handleGetClient)
admin.PUT("/clients/:id", s.handleUpdateClient)
admin.DELETE("/clients/:id", s.handleDeleteClient)
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
admin.GET("/providers", s.handleGetProviders)
admin.PUT("/providers/:name", s.handleUpdateProvider)
admin.POST("/providers/:name/test", s.handleTestProvider)
admin.GET("/models", s.handleGetModels)
admin.PUT("/models/:id", s.handleUpdateModel)
admin.GET("/users", s.handleGetUsers)
admin.POST("/users", s.handleCreateUser)
admin.PUT("/users/:id", s.handleUpdateUser)
admin.DELETE("/users/:id", s.handleDeleteUser)
admin.GET("/system/health", s.handleSystemHealth)
admin.GET("/system/metrics", s.handleSystemMetrics)
admin.GET("/system/settings", s.handleGetSettings)
admin.POST("/system/backup", s.handleCreateBackup)
admin.GET("/system/logs", s.handleGetLogs)
}
}
s.router.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
}
func (s *Server) handleListModels(c *gin.Context) {
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
var data []OpenAIModel
allowedProviders := map[string]bool{
"openai": true,
"google": true, // Models from models.dev use 'google' ID for Gemini
"deepseek": true,
"moonshot": true,
"xai": true, // Models from models.dev use 'xai' ID for Grok
}
if s.registry != nil {
for pID, pInfo := range s.registry.Providers {
if !allowedProviders[pID] {
continue
}
for mID := range pInfo.Models {
data = append(data, OpenAIModel{
ID: mID,
Object: "model",
Created: 1700000000,
OwnedBy: pID,
})
}
}
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": data,
})
}
func (s *Server) handleChatCompletions(c *gin.Context) {
startTime := time.Now()
var req models.ChatCompletionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Select provider based on model name
providerName := "openai" // default
if strings.Contains(req.Model, "gemini") {
providerName = "gemini"
} else if strings.Contains(req.Model, "deepseek") {
providerName = "deepseek"
} else if strings.Contains(req.Model, "kimi") || strings.Contains(req.Model, "moonshot") {
providerName = "moonshot"
} else if strings.Contains(req.Model, "grok") {
providerName = "grok"
}
provider, ok := s.providers[providerName]
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
return
}
// Convert ChatCompletionRequest to UnifiedRequest
unifiedReq := &models.UnifiedRequest{
Model: req.Model,
Messages: []models.UnifiedMessage{},
Temperature: req.Temperature,
TopP: req.TopP,
TopK: req.TopK,
N: req.N,
MaxTokens: req.MaxTokens,
PresencePenalty: req.PresencePenalty,
FrequencyPenalty: req.FrequencyPenalty,
Stream: req.Stream != nil && *req.Stream,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
}
// Handle Stop sequences
if req.Stop != nil {
var stop []string
if err := json.Unmarshal(req.Stop, &stop); err == nil {
unifiedReq.Stop = stop
} else {
var singleStop string
if err := json.Unmarshal(req.Stop, &singleStop); err == nil {
unifiedReq.Stop = []string{singleStop}
}
}
}
// Convert messages
for _, msg := range req.Messages {
unifiedMsg := models.UnifiedMessage{
Role: msg.Role,
Content: []models.UnifiedContentPart{},
ReasoningContent: msg.ReasoningContent,
ToolCalls: msg.ToolCalls,
Name: msg.Name,
ToolCallID: msg.ToolCallID,
}
// Handle multimodal content
if strContent, ok := msg.Content.(string); ok {
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
Type: "text",
Text: strContent,
})
} else if parts, ok := msg.Content.([]interface{}); ok {
for _, part := range parts {
if partMap, ok := part.(map[string]interface{}); ok {
partType, _ := partMap["type"].(string)
if partType == "text" {
text, _ := partMap["text"].(string)
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
Type: "text",
Text: text,
})
} else if partType == "image_url" {
if imgURLMap, ok := partMap["image_url"].(map[string]interface{}); ok {
url, _ := imgURLMap["url"].(string)
imageInput := &models.ImageInput{}
if strings.HasPrefix(url, "data:") {
mime, data, err := utils.ParseDataURL(url)
if err == nil {
imageInput.Base64 = data
imageInput.MimeType = mime
}
} else {
imageInput.URL = url
}
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
Type: "image",
Image: imageInput,
})
unifiedReq.HasImages = true
}
}
}
}
}
unifiedReq.Messages = append(unifiedReq.Messages, unifiedMsg)
}
clientID := "default"
if auth, ok := c.Get("auth"); ok {
if authInfo, ok := auth.(models.AuthInfo); ok {
unifiedReq.ClientID = authInfo.ClientID
clientID = authInfo.ClientID
}
} else {
unifiedReq.ClientID = clientID
}
if unifiedReq.Stream {
ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var lastUsage *models.Usage
c.Stream(func(w io.Writer) bool {
chunk, ok := <-ch
if !ok {
fmt.Fprintf(w, "data: [DONE]\n\n")
s.logRequest(startTime, clientID, providerName, req.Model, lastUsage, nil, unifiedReq.HasImages)
return false
}
if chunk.Usage != nil {
lastUsage = chunk.Usage
}
data, err := json.Marshal(chunk)
if err != nil {
return false
}
fmt.Fprintf(w, "data: %s\n\n", data)
return true
})
return
}
resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage, nil, unifiedReq.HasImages)
c.JSON(http.StatusOK, resp)
}
func (s *Server) logRequest(start time.Time, clientID, provider, model string, usage *models.Usage, err error, hasImages bool) {
entry := RequestLog{
Timestamp: start,
ClientID: clientID,
Provider: provider,
Model: model,
Status: "success",
DurationMS: time.Since(start).Milliseconds(),
HasImages: hasImages,
}
if err != nil {
entry.Status = "error"
entry.ErrorMessage = err.Error()
}
if usage != nil {
entry.PromptTokens = usage.PromptTokens
entry.CompletionTokens = usage.CompletionTokens
entry.TotalTokens = usage.TotalTokens
if usage.ReasoningTokens != nil {
entry.ReasoningTokens = *usage.ReasoningTokens
}
if usage.CacheReadTokens != nil {
entry.CacheReadTokens = *usage.CacheReadTokens
}
if usage.CacheWriteTokens != nil {
entry.CacheWriteTokens = *usage.CacheWriteTokens
}
// Calculate cost using registry
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
fmt.Printf("[DEBUG] Request logged: model=%s, prompt=%d, completion=%d, reasoning=%d, cache_read=%d, cost=%f\n",
model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.Cost)
}
s.logger.LogRequest(entry)
}
func (s *Server) Run() error {
go s.hub.Run()
s.logger.Start()
// Start registry refresher
go func() {
ticker := time.NewTicker(24 * time.Hour)
for range ticker.C {
newRegistry, err := utils.FetchRegistry()
if err == nil {
s.registry = newRegistry
}
}
}()
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
return s.router.Run(addr)
}

155
internal/server/sessions.go Normal file
View File

@@ -0,0 +1,155 @@
package server
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
type Session struct {
Username string `json:"username"`
DisplayName string `json:"display_name"`
Role string `json:"role"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
SessionID string `json:"session_id"`
}
type SessionManager struct {
sessions map[string]Session
mu sync.RWMutex
secret []byte
ttl time.Duration
}
type sessionPayload struct {
SessionID string `json:"session_id"`
Username string `json:"username"`
DisplayName string `json:"display_name"`
Role string `json:"role"`
Exp int64 `json:"exp"`
}
func NewSessionManager(secret []byte, ttl time.Duration) *SessionManager {
return &SessionManager{
sessions: make(map[string]Session),
secret: secret,
ttl: ttl,
}
}
func (m *SessionManager) CreateSession(username, displayName, role string) (string, error) {
sessionID := uuid.New().String()
now := time.Now()
expiresAt := now.Add(m.ttl)
m.mu.Lock()
m.sessions[sessionID] = Session{
Username: username,
DisplayName: displayName,
Role: role,
CreatedAt: now,
ExpiresAt: expiresAt,
SessionID: sessionID,
}
m.mu.Unlock()
return m.createSignedToken(sessionID, username, displayName, role, expiresAt.Unix())
}
func (m *SessionManager) createSignedToken(sessionID, username, displayName, role string, exp int64) (string, error) {
payload := sessionPayload{
SessionID: sessionID,
Username: username,
DisplayName: displayName,
Role: role,
Exp: exp,
}
payloadJSON, err := json.Marshal(payload)
if err != nil {
return "", err
}
payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON)
h := hmac.New(sha256.New, m.secret)
h.Write(payloadJSON)
signature := h.Sum(nil)
signatureB64 := base64.RawURLEncoding.EncodeToString(signature)
return fmt.Sprintf("%s.%s", payloadB64, signatureB64), nil
}
func (m *SessionManager) ValidateSession(token string) (*Session, string, error) {
parts := strings.Split(token, ".")
if len(parts) != 2 {
return nil, "", fmt.Errorf("invalid token format")
}
payloadB64 := parts[0]
signatureB64 := parts[1]
payloadJSON, err := base64.RawURLEncoding.DecodeString(payloadB64)
if err != nil {
return nil, "", err
}
signature, err := base64.RawURLEncoding.DecodeString(signatureB64)
if err != nil {
return nil, "", err
}
h := hmac.New(sha256.New, m.secret)
h.Write(payloadJSON)
if !hmac.Equal(signature, h.Sum(nil)) {
return nil, "", fmt.Errorf("invalid signature")
}
var payload sessionPayload
if err := json.Unmarshal(payloadJSON, &payload); err != nil {
return nil, "", err
}
if time.Now().Unix() > payload.Exp {
return nil, "", fmt.Errorf("token expired")
}
m.mu.RLock()
session, ok := m.sessions[payload.SessionID]
m.mu.RUnlock()
if !ok {
return nil, "", fmt.Errorf("session not found")
}
return &session, "", nil
}
func (m *SessionManager) RevokeSession(token string) {
parts := strings.Split(token, ".")
if len(parts) != 2 {
return
}
payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return
}
var payload sessionPayload
if err := json.Unmarshal(payloadJSON, &payload); err != nil {
return
}
m.mu.Lock()
delete(m.sessions, payload.SessionID)
m.mu.Unlock()
}

View File

@@ -0,0 +1,107 @@
package server
import (
"log"
"net/http"
"sync"
"sync/atomic"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // In production, refine this
},
}
type Hub struct {
clients map[*websocket.Conn]bool
broadcast chan interface{}
register chan *websocket.Conn
unregister chan *websocket.Conn
mu sync.Mutex
clientCount int32
}
func NewHub() *Hub {
return &Hub{
clients: make(map[*websocket.Conn]bool),
broadcast: make(chan interface{}),
register: make(chan *websocket.Conn),
unregister: make(chan *websocket.Conn),
}
}
func (h *Hub) Run() {
for {
select {
case client := <-h.register:
h.mu.Lock()
h.clients[client] = true
atomic.AddInt32(&h.clientCount, 1)
h.mu.Unlock()
log.Println("WebSocket client registered")
case client := <-h.unregister:
h.mu.Lock()
if _, ok := h.clients[client]; ok {
delete(h.clients, client)
client.Close()
atomic.AddInt32(&h.clientCount, -1)
}
h.mu.Unlock()
log.Println("WebSocket client unregistered")
case message := <-h.broadcast:
h.mu.Lock()
for client := range h.clients {
err := client.WriteJSON(message)
if err != nil {
log.Printf("WebSocket error: %v", err)
client.Close()
delete(h.clients, client)
atomic.AddInt32(&h.clientCount, -1)
}
}
h.mu.Unlock()
}
}
}
func (h *Hub) GetClientCount() int {
return int(atomic.LoadInt32(&h.clientCount))
}
func (s *Server) handleWebSocket(c *gin.Context) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("Failed to set websocket upgrade: %v", err)
return
}
s.hub.register <- conn
defer func() {
s.hub.unregister <- conn
}()
// Initial message
conn.WriteJSON(gin.H{
"type": "connected",
"message": "Connected to GopherGate Dashboard",
})
for {
var msg map[string]interface{}
err := conn.ReadJSON(&msg)
if err != nil {
break
}
if msg["type"] == "ping" {
conn.WriteJSON(gin.H{"type": "pong", "payload": gin.H{}})
}
}
}

71
internal/utils/crypto.go Normal file
View File

@@ -0,0 +1,71 @@
package utils
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
)
// Encrypt encrypts plain text using AES-GCM with the given 32-byte key.
func Encrypt(plainText string, key []byte) (string, error) {
if len(key) != 32 {
return "", fmt.Errorf("encryption key must be 32 bytes")
}
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}
// The nonce should be prepended to the ciphertext
cipherText := gcm.Seal(nonce, nonce, []byte(plainText), nil)
return base64.StdEncoding.EncodeToString(cipherText), nil
}
// Decrypt decrypts base64-encoded cipher text using AES-GCM with the given 32-byte key.
func Decrypt(encodedCipherText string, key []byte) (string, error) {
if len(key) != 32 {
return "", fmt.Errorf("encryption key must be 32 bytes")
}
cipherText, err := base64.StdEncoding.DecodeString(encodedCipherText)
if err != nil {
return "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonceSize := gcm.NonceSize()
if len(cipherText) < nonceSize {
return "", fmt.Errorf("cipher text too short")
}
nonce, actualCipherText := cipherText[:nonceSize], cipherText[nonceSize:]
plainText, err := gcm.Open(nil, nonce, actualCipherText, nil)
if err != nil {
return "", err
}
return string(plainText), nil
}

View File

@@ -0,0 +1,69 @@
package utils
import (
"encoding/json"
"fmt"
"log"
"time"
"gophergate/internal/models"
"github.com/go-resty/resty/v2"
)
const ModelsDevURL = "https://models.dev/api.json"
func FetchRegistry() (*models.ModelRegistry, error) {
log.Printf("Fetching model registry from %s", ModelsDevURL)
client := resty.New().SetTimeout(10 * time.Second)
resp, err := client.R().Get(ModelsDevURL)
if err != nil {
return nil, fmt.Errorf("failed to fetch registry: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("failed to fetch registry: HTTP %d", resp.StatusCode())
}
var providers map[string]models.ProviderInfo
if err := json.Unmarshal(resp.Body(), &providers); err != nil {
return nil, fmt.Errorf("failed to unmarshal registry: %w", err)
}
log.Println("Successfully loaded model registry")
return &models.ModelRegistry{Providers: providers}, nil
}
func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens, completionTokens, reasoningTokens, cacheRead, cacheWrite uint32) float64 {
meta := registry.FindModel(modelID)
if meta == nil || meta.Cost == nil {
log.Printf("[DEBUG] CalculateCost: model %s not found or has no cost metadata", modelID)
return 0.0
}
// promptTokens is usually the TOTAL prompt size.
// We subtract cacheRead from it to get the uncached part.
uncachedTokens := promptTokens
if cacheRead > 0 {
if cacheRead > promptTokens {
uncachedTokens = 0
} else {
uncachedTokens = promptTokens - cacheRead
}
}
cost := (float64(uncachedTokens) * meta.Cost.Input / 1000000.0) +
(float64(completionTokens) * meta.Cost.Output / 1000000.0)
if meta.Cost.CacheRead != nil {
cost += float64(cacheRead) * (*meta.Cost.CacheRead) / 1000000.0
}
if meta.Cost.CacheWrite != nil {
cost += float64(cacheWrite) * (*meta.Cost.CacheWrite) / 1000000.0
}
log.Printf("[DEBUG] CalculateCost: model=%s, uncached=%d, completion=%d, reasoning=%d, cache_read=%d, cache_write=%d, cost=%f (input_rate=%f, output_rate=%f)",
modelID, uncachedTokens, completionTokens, reasoningTokens, cacheRead, cacheWrite, cost, meta.Cost.Input, meta.Cost.Output)
return cost
}

19
internal/utils/utils.go Normal file
View File

@@ -0,0 +1,19 @@
package utils
import (
"fmt"
"strings"
)
func ParseDataURL(dataURL string) (string, string, error) {
if !strings.HasPrefix(dataURL, "data:") {
return "", "", fmt.Errorf("not a data URL")
}
parts := strings.Split(dataURL[5:], ";base64,")
if len(parts) != 2 {
return "", "", fmt.Errorf("invalid data URL format")
}
return parts[0], parts[1], nil
}

View File

@@ -1,13 +0,0 @@
-- Migration: add billing_mode to provider_configs
-- Adds a billing_mode TEXT column with default 'prepaid'
-- After applying, set Gemini to postpaid with:
-- UPDATE provider_configs SET billing_mode = 'postpaid' WHERE id = 'gemini';
BEGIN TRANSACTION;
ALTER TABLE provider_configs ADD COLUMN billing_mode TEXT DEFAULT 'prepaid';
COMMIT;
-- NOTE: If you use a production SQLite file, run the following to set Gemini to postpaid:
-- sqlite3 /path/to/db.sqlite "UPDATE provider_configs SET billing_mode='postpaid' WHERE id='gemini';"

View File

@@ -1,13 +0,0 @@
-- Migration: add composite indexes for query performance
-- Adds three composite indexes:
-- 1. idx_llm_requests_client_timestamp on llm_requests(client_id, timestamp)
-- 2. idx_llm_requests_provider_timestamp on llm_requests(provider, timestamp)
-- 3. idx_model_configs_provider_id on model_configs(provider_id)
BEGIN TRANSACTION;
CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp);
CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp);
CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id);
COMMIT;

View File

@@ -1,2 +0,0 @@
max_width = 120
use_field_init_shorthand = true

View File

@@ -1,11 +0,0 @@
2026-03-06T20:07:36.737914Z  INFO Starting LLM Proxy Gateway v0.1.0
2026-03-06T20:07:36.738903Z  INFO Configuration loaded from Some("/home/newkirk/Documents/projects/web_projects/llm-proxy/config.toml")
2026-03-06T20:07:36.738945Z  INFO Encryption initialized
2026-03-06T20:07:36.739124Z  INFO Connecting to database at ./data/llm_proxy.db
2026-03-06T20:07:36.753254Z  INFO Database migrations completed
2026-03-06T20:07:36.753294Z  INFO Database initialized at "./data/llm_proxy.db"
2026-03-06T20:07:36.755187Z  INFO Fetching model registry from https://models.dev/api.json
2026-03-06T20:07:37.000853Z  INFO Successfully loaded model registry
2026-03-06T20:07:37.001382Z  INFO Model config cache initialized
2026-03-06T20:07:37.001702Z  WARN SESSION_SECRET environment variable not set. Using a randomly generated secret. This will invalidate all sessions on restart. Set SESSION_SECRET to a fixed hex or base64 encoded 32-byte value.
2026-03-06T20:07:37.002898Z  INFO Server listening on http://0.0.0.0:8082

View File

@@ -1 +0,0 @@
945904

View File

@@ -1,45 +0,0 @@
use axum::{extract::FromRequestParts, http::request::Parts};
use crate::errors::AppError;
#[derive(Debug, Clone)]
pub struct AuthInfo {
pub token: String,
pub client_id: String,
}
pub struct AuthenticatedClient {
pub info: AuthInfo,
}
impl<S> FromRequestParts<S> for AuthenticatedClient
where
S: Send + Sync,
{
type Rejection = AppError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
// Retrieve AuthInfo from request extensions, where it was placed by rate_limit_middleware
let info = parts
.extensions
.get::<AuthInfo>()
.cloned()
.ok_or_else(|| AppError::AuthError("Authentication info not found in request".to_string()))?;
Ok(AuthenticatedClient { info })
}
}
impl std::ops::Deref for AuthenticatedClient {
type Target = AuthInfo;
fn deref(&self) -> &Self::Target {
&self.info
}
}
pub fn validate_token(token: &str, valid_tokens: &[String]) -> bool {
// Simple validation against list of tokens
// In production, use proper token validation (JWT, database lookup, etc.)
valid_tokens.contains(&token.to_string())
}

View File

@@ -1,304 +0,0 @@
//! Client management for LLM proxy
//!
//! This module handles:
//! 1. Client registration and management
//! 2. Client usage tracking
//! 3. Client rate limit configuration
use anyhow::Result;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{Row, SqlitePool};
use tracing::{info, warn};
/// Client information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Client {
pub id: i64,
pub client_id: String,
pub name: String,
pub description: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub is_active: bool,
pub rate_limit_per_minute: i64,
pub total_requests: i64,
pub total_tokens: i64,
pub total_cost: f64,
}
/// Client creation request
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateClientRequest {
pub client_id: String,
pub name: String,
pub description: String,
pub rate_limit_per_minute: Option<i64>,
}
/// Client update request
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateClientRequest {
pub name: Option<String>,
pub description: Option<String>,
pub is_active: Option<bool>,
pub rate_limit_per_minute: Option<i64>,
}
/// Client manager for database operations
pub struct ClientManager {
db_pool: SqlitePool,
}
impl ClientManager {
pub fn new(db_pool: SqlitePool) -> Self {
Self { db_pool }
}
/// Create a new client
pub async fn create_client(&self, request: CreateClientRequest) -> Result<Client> {
let rate_limit = request.rate_limit_per_minute.unwrap_or(60);
// First insert the client
sqlx::query(
r#"
INSERT INTO clients (client_id, name, description, rate_limit_per_minute)
VALUES (?, ?, ?, ?)
"#,
)
.bind(&request.client_id)
.bind(&request.name)
.bind(&request.description)
.bind(rate_limit)
.execute(&self.db_pool)
.await?;
// Then fetch the created client
let client = self
.get_client(&request.client_id)
.await?
.ok_or_else(|| anyhow::anyhow!("Failed to retrieve created client"))?;
info!("Created client: {} ({})", client.name, client.client_id);
Ok(client)
}
/// Get a client by ID
pub async fn get_client(&self, client_id: &str) -> Result<Option<Client>> {
let row = sqlx::query(
r#"
SELECT
id, client_id, name, description,
created_at, updated_at, is_active,
rate_limit_per_minute, total_requests, total_tokens, total_cost
FROM clients
WHERE client_id = ?
"#,
)
.bind(client_id)
.fetch_optional(&self.db_pool)
.await?;
if let Some(row) = row {
let client = Client {
id: row.get("id"),
client_id: row.get("client_id"),
name: row.get("name"),
description: row.get("description"),
created_at: row.get("created_at"),
updated_at: row.get("updated_at"),
is_active: row.get("is_active"),
rate_limit_per_minute: row.get("rate_limit_per_minute"),
total_requests: row.get("total_requests"),
total_tokens: row.get("total_tokens"),
total_cost: row.get("total_cost"),
};
Ok(Some(client))
} else {
Ok(None)
}
}
/// Update a client
pub async fn update_client(&self, client_id: &str, request: UpdateClientRequest) -> Result<Option<Client>> {
// First, get the current client to check if it exists
let current_client = self.get_client(client_id).await?;
if current_client.is_none() {
return Ok(None);
}
// Build update query dynamically based on provided fields
let mut query_builder = sqlx::QueryBuilder::new("UPDATE clients SET ");
let mut has_updates = false;
if let Some(name) = &request.name {
query_builder.push("name = ");
query_builder.push_bind(name);
has_updates = true;
}
if let Some(description) = &request.description {
if has_updates {
query_builder.push(", ");
}
query_builder.push("description = ");
query_builder.push_bind(description);
has_updates = true;
}
if let Some(is_active) = request.is_active {
if has_updates {
query_builder.push(", ");
}
query_builder.push("is_active = ");
query_builder.push_bind(is_active);
has_updates = true;
}
if let Some(rate_limit) = request.rate_limit_per_minute {
if has_updates {
query_builder.push(", ");
}
query_builder.push("rate_limit_per_minute = ");
query_builder.push_bind(rate_limit);
has_updates = true;
}
// Always update the updated_at timestamp
if has_updates {
query_builder.push(", ");
}
query_builder.push("updated_at = CURRENT_TIMESTAMP");
if !has_updates {
// No updates to make
return self.get_client(client_id).await;
}
query_builder.push(" WHERE client_id = ");
query_builder.push_bind(client_id);
let query = query_builder.build();
query.execute(&self.db_pool).await?;
// Fetch the updated client
let updated_client = self.get_client(client_id).await?;
if updated_client.is_some() {
info!("Updated client: {}", client_id);
}
Ok(updated_client)
}
/// List all clients
pub async fn list_clients(&self, limit: Option<i64>, offset: Option<i64>) -> Result<Vec<Client>> {
let limit = limit.unwrap_or(100);
let offset = offset.unwrap_or(0);
let rows = sqlx::query(
r#"
SELECT
id, client_id, name, description,
created_at, updated_at, is_active,
rate_limit_per_minute, total_requests, total_tokens, total_cost
FROM clients
ORDER BY created_at DESC
LIMIT ? OFFSET ?
"#,
)
.bind(limit)
.bind(offset)
.fetch_all(&self.db_pool)
.await?;
let mut clients = Vec::new();
for row in rows {
let client = Client {
id: row.get("id"),
client_id: row.get("client_id"),
name: row.get("name"),
description: row.get("description"),
created_at: row.get("created_at"),
updated_at: row.get("updated_at"),
is_active: row.get("is_active"),
rate_limit_per_minute: row.get("rate_limit_per_minute"),
total_requests: row.get("total_requests"),
total_tokens: row.get("total_tokens"),
total_cost: row.get("total_cost"),
};
clients.push(client);
}
Ok(clients)
}
/// Delete a client
pub async fn delete_client(&self, client_id: &str) -> Result<bool> {
let result = sqlx::query("DELETE FROM clients WHERE client_id = ?")
.bind(client_id)
.execute(&self.db_pool)
.await?;
let deleted = result.rows_affected() > 0;
if deleted {
info!("Deleted client: {}", client_id);
} else {
warn!("Client not found for deletion: {}", client_id);
}
Ok(deleted)
}
/// Update client usage statistics after a request
pub async fn update_client_usage(&self, client_id: &str, tokens: i64, cost: f64) -> Result<()> {
sqlx::query(
r#"
UPDATE clients
SET
total_requests = total_requests + 1,
total_tokens = total_tokens + ?,
total_cost = total_cost + ?,
updated_at = CURRENT_TIMESTAMP
WHERE client_id = ?
"#,
)
.bind(tokens)
.bind(cost)
.bind(client_id)
.execute(&self.db_pool)
.await?;
Ok(())
}
/// Get client usage statistics
pub async fn get_client_usage(&self, client_id: &str) -> Result<Option<(i64, i64, f64)>> {
let row = sqlx::query(
r#"
SELECT total_requests, total_tokens, total_cost
FROM clients
WHERE client_id = ?
"#,
)
.bind(client_id)
.fetch_optional(&self.db_pool)
.await?;
if let Some(row) = row {
let total_requests: i64 = row.get("total_requests");
let total_tokens: i64 = row.get("total_tokens");
let total_cost: f64 = row.get("total_cost");
Ok(Some((total_requests, total_tokens, total_cost)))
} else {
Ok(None)
}
}
/// Check if a client exists and is active
pub async fn validate_client(&self, client_id: &str) -> Result<bool> {
let client = self.get_client(client_id).await?;
Ok(client.map(|c| c.is_active).unwrap_or(false))
}
}

View File

@@ -1,260 +0,0 @@
use anyhow::Result;
use base64::{Engine as _};
use config::{Config, File, FileFormat};
use hex;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
pub port: u16,
pub host: String,
#[serde(deserialize_with = "deserialize_vec_or_string")]
pub auth_tokens: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
pub path: PathBuf,
pub max_connections: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub openai: OpenAIConfig,
pub gemini: GeminiConfig,
pub deepseek: DeepSeekConfig,
pub grok: GrokConfig,
pub ollama: OllamaConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIConfig {
pub api_key_env: String,
pub base_url: String,
pub default_model: String,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeminiConfig {
pub api_key_env: String,
pub base_url: String,
pub default_model: String,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeepSeekConfig {
pub api_key_env: String,
pub base_url: String,
pub default_model: String,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrokConfig {
pub api_key_env: String,
pub base_url: String,
pub default_model: String,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaConfig {
pub base_url: String,
pub enabled: bool,
#[serde(deserialize_with = "deserialize_vec_or_string")]
pub models: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMappingConfig {
pub patterns: Vec<(String, String)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PricingConfig {
pub openai: Vec<ModelPricing>,
pub gemini: Vec<ModelPricing>,
pub deepseek: Vec<ModelPricing>,
pub grok: Vec<ModelPricing>,
pub ollama: Vec<ModelPricing>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelPricing {
pub model: String,
pub prompt_tokens_per_million: f64,
pub completion_tokens_per_million: f64,
}
#[derive(Debug, Clone)]
pub struct AppConfig {
pub server: ServerConfig,
pub database: DatabaseConfig,
pub providers: ProviderConfig,
pub model_mapping: ModelMappingConfig,
pub pricing: PricingConfig,
pub config_path: Option<PathBuf>,
pub encryption_key: String,
}
impl AppConfig {
pub async fn load() -> Result<Arc<Self>> {
Self::load_from_path(None).await
}
/// Load configuration from a specific path (for testing)
pub async fn load_from_path(config_path: Option<PathBuf>) -> Result<Arc<Self>> {
// Load configuration from multiple sources
let mut config_builder = Config::builder();
// Default configuration
config_builder = config_builder
.set_default("server.port", 8080)?
.set_default("server.host", "0.0.0.0")?
.set_default("server.auth_tokens", Vec::<String>::new())?
.set_default("database.path", "./data/llm_proxy.db")?
.set_default("database.max_connections", 10)?
.set_default("providers.openai.api_key_env", "OPENAI_API_KEY")?
.set_default("providers.openai.base_url", "https://api.openai.com/v1")?
.set_default("providers.openai.default_model", "gpt-4o")?
.set_default("providers.openai.enabled", true)?
.set_default("providers.gemini.api_key_env", "GEMINI_API_KEY")?
.set_default(
"providers.gemini.base_url",
"https://generativelanguage.googleapis.com/v1",
)?
.set_default("providers.gemini.default_model", "gemini-2.0-flash")?
.set_default("providers.gemini.enabled", true)?
.set_default("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")?
.set_default("providers.deepseek.base_url", "https://api.deepseek.com")?
.set_default("providers.deepseek.default_model", "deepseek-reasoner")?
.set_default("providers.deepseek.enabled", true)?
.set_default("providers.grok.api_key_env", "GROK_API_KEY")?
.set_default("providers.grok.base_url", "https://api.x.ai/v1")?
.set_default("providers.grok.default_model", "grok-beta")?
.set_default("providers.grok.enabled", true)?
.set_default("providers.ollama.base_url", "http://localhost:11434/v1")?
.set_default("providers.ollama.enabled", false)?
.set_default("providers.ollama.models", Vec::<String>::new())?
.set_default("encryption_key", "")?;
// Load from config file if exists
// Priority: explicit path arg > LLM_PROXY__CONFIG_PATH env var > ./config.toml
let config_path = config_path
.or_else(|| std::env::var("LLM_PROXY__CONFIG_PATH").ok().map(PathBuf::from))
.unwrap_or_else(|| {
std::env::current_dir()
.unwrap_or_else(|_| PathBuf::from("."))
.join("config.toml")
});
if config_path.exists() {
config_builder = config_builder.add_source(File::from(config_path.clone()).format(FileFormat::Toml));
}
// Load from .env file
dotenvy::dotenv().ok();
// Load from environment variables (with prefix "LLM_PROXY_")
config_builder = config_builder.add_source(
config::Environment::with_prefix("LLM_PROXY")
.separator("__")
.try_parsing(true),
);
let config = config_builder.build()?;
// Deserialize configuration
let server: ServerConfig = config.get("server")?;
let database: DatabaseConfig = config.get("database")?;
let providers: ProviderConfig = config.get("providers")?;
let encryption_key: String = config.get("encryption_key")?;
// Validate encryption key length (must be 32 bytes after hex or base64 decoding)
if encryption_key.is_empty() {
anyhow::bail!("Encryption key is required (LLM_PROXY__ENCRYPTION_KEY environment variable)");
}
// Try hex decode first, then base64
let key_bytes = hex::decode(&encryption_key)
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&encryption_key))
.map_err(|e| anyhow::anyhow!("Encryption key must be hex or base64 encoded: {}", e))?;
if key_bytes.len() != 32 {
anyhow::bail!("Encryption key must be 32 bytes (256 bits), got {} bytes", key_bytes.len());
}
// For now, use empty model mapping and pricing (will be populated later)
let model_mapping = ModelMappingConfig { patterns: vec![] };
let pricing = PricingConfig {
openai: vec![],
gemini: vec![],
deepseek: vec![],
grok: vec![],
ollama: vec![],
};
Ok(Arc::new(AppConfig {
server,
database,
providers,
model_mapping,
pricing,
config_path: Some(config_path),
encryption_key,
}))
}
pub fn get_api_key(&self, provider: &str) -> Result<String> {
let env_var = match provider {
"openai" => &self.providers.openai.api_key_env,
"gemini" => &self.providers.gemini.api_key_env,
"deepseek" => &self.providers.deepseek.api_key_env,
"grok" => &self.providers.grok.api_key_env,
_ => return Err(anyhow::anyhow!("Unknown provider: {}", provider)),
};
std::env::var(env_var).map_err(|_| anyhow::anyhow!("Environment variable {} not set for {}", env_var, provider))
}
}
/// Helper function to deserialize a Vec<String> from either a sequence or a comma-separated string
fn deserialize_vec_or_string<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
struct VecOrString;
impl<'de> serde::de::Visitor<'de> for VecOrString {
type Value = Vec<String>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a sequence or a comma-separated string")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(value
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect())
}
fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
where
S: serde::de::SeqAccess<'de>,
{
let mut vec = Vec::new();
while let Some(element) = seq.next_element()? {
vec.push(element);
}
Ok(vec)
}
}
deserializer.deserialize_any(VecOrString)
}

View File

@@ -1,229 +0,0 @@
use axum::{extract::State, http::{HeaderMap, HeaderValue}, response::{Json, IntoResponse}};
use bcrypt;
use serde::Deserialize;
use sqlx::Row;
use tracing::warn;
use super::{ApiResponse, DashboardState};
// Authentication handlers
#[derive(Deserialize)]
pub(super) struct LoginRequest {
pub(super) username: String,
pub(super) password: String,
}
pub(super) async fn handle_login(
State(state): State<DashboardState>,
Json(payload): Json<LoginRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let user_result = sqlx::query(
"SELECT username, password_hash, display_name, role, must_change_password FROM users WHERE username = ?",
)
.bind(&payload.username)
.fetch_optional(pool)
.await;
match user_result {
Ok(Some(row)) => {
let hash = row.get::<String, _>("password_hash");
if bcrypt::verify(&payload.password, &hash).unwrap_or(false) {
let username = row.get::<String, _>("username");
let role = row.get::<String, _>("role");
let display_name = row
.get::<Option<String>, _>("display_name")
.unwrap_or_else(|| username.clone());
let must_change_password = row.get::<bool, _>("must_change_password");
let token = state
.session_manager
.create_session(username.clone(), role.clone())
.await;
Json(ApiResponse::success(serde_json::json!({
"token": token,
"must_change_password": must_change_password,
"user": {
"username": username,
"name": display_name,
"role": role
}
})))
} else {
Json(ApiResponse::error("Invalid username or password".to_string()))
}
}
Ok(None) => Json(ApiResponse::error("Invalid username or password".to_string())),
Err(e) => {
warn!("Database error during login: {}", e);
Json(ApiResponse::error("Login failed due to system error".to_string()))
}
}
}
pub(super) async fn handle_auth_status(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> impl IntoResponse {
let token = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
if let Some(token) = token
&& let Some((session, new_token)) = state.session_manager.validate_session_with_refresh(token).await
{
// Look up display_name from DB
let display_name = sqlx::query_scalar::<_, Option<String>>(
"SELECT display_name FROM users WHERE username = ?",
)
.bind(&session.username)
.fetch_optional(&state.app_state.db_pool)
.await
.ok()
.flatten()
.flatten()
.unwrap_or_else(|| session.username.clone());
let mut headers = HeaderMap::new();
if let Some(refreshed_token) = new_token {
if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) {
headers.insert("X-Refreshed-Token", header_value);
}
}
return (headers, Json(ApiResponse::success(serde_json::json!({
"authenticated": true,
"user": {
"username": session.username,
"name": display_name,
"role": session.role
}
}))));
}
(HeaderMap::new(), Json(ApiResponse::error("Not authenticated".to_string())))
}
#[derive(Deserialize)]
pub(super) struct ChangePasswordRequest {
pub(super) current_password: String,
pub(super) new_password: String,
}
pub(super) async fn handle_change_password(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Json(payload): Json<ChangePasswordRequest>,
) -> impl IntoResponse {
let pool = &state.app_state.db_pool;
// Extract the authenticated user from the session token
let token = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
let (session, new_token) = match token {
Some(t) => match state.session_manager.validate_session_with_refresh(t).await {
Some((session, new_token)) => (Some(session), new_token),
None => (None, None),
},
None => (None, None),
};
let mut response_headers = HeaderMap::new();
if let Some(refreshed_token) = new_token {
if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) {
response_headers.insert("X-Refreshed-Token", header_value);
}
}
let username = match session {
Some(s) => s.username,
None => return (response_headers, Json(ApiResponse::error("Not authenticated".to_string()))),
};
let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = ?")
.bind(&username)
.fetch_one(pool)
.await;
match user_result {
Ok(row) => {
let hash = row.get::<String, _>("password_hash");
if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) {
let new_hash = match bcrypt::hash(&payload.new_password, 12) {
Ok(h) => h,
Err(_) => return (response_headers, Json(ApiResponse::error("Failed to hash new password".to_string()))),
};
let update_result = sqlx::query(
"UPDATE users SET password_hash = ?, must_change_password = FALSE WHERE username = ?",
)
.bind(new_hash)
.bind(&username)
.execute(pool)
.await;
match update_result {
Ok(_) => (response_headers, Json(ApiResponse::success(
serde_json::json!({ "message": "Password updated successfully" }),
))),
Err(e) => (response_headers, Json(ApiResponse::error(format!("Failed to update database: {}", e)))),
}
} else {
(response_headers, Json(ApiResponse::error("Current password incorrect".to_string())))
}
}
Err(e) => (response_headers, Json(ApiResponse::error(format!("User not found: {}", e)))),
}
}
pub(super) async fn handle_logout(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let token = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
if let Some(token) = token {
state.session_manager.revoke_session(token).await;
}
Json(ApiResponse::success(serde_json::json!({ "message": "Logged out" })))
}
/// Helper: Extract and validate a session from the Authorization header.
/// Returns the Session and optional new token if refreshed, or an error response.
pub(super) async fn extract_session(
state: &DashboardState,
headers: &axum::http::HeaderMap,
) -> Result<(super::sessions::Session, Option<String>), Json<ApiResponse<serde_json::Value>>> {
let token = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
match token {
Some(t) => match state.session_manager.validate_session_with_refresh(t).await {
Some((session, new_token)) => Ok((session, new_token)),
None => Err(Json(ApiResponse::error("Session expired or invalid".to_string()))),
},
None => Err(Json(ApiResponse::error("Not authenticated".to_string()))),
}
}
/// Helper: Extract session and require admin role.
/// Returns session and optional new token if refreshed.
pub(super) async fn require_admin(
state: &DashboardState,
headers: &axum::http::HeaderMap,
) -> Result<(super::sessions::Session, Option<String>), Json<ApiResponse<serde_json::Value>>> {
let (session, new_token) = extract_session(state, headers).await?;
if session.role != "admin" {
return Err(Json(ApiResponse::error("Admin access required".to_string())));
}
Ok((session, new_token))
}

View File

@@ -1,538 +0,0 @@
use axum::{
extract::{Path, State},
response::Json,
};
use chrono;
use rand::Rng;
use serde::Deserialize;
use serde_json;
use sqlx::Row;
use tracing::warn;
use uuid;
use super::{ApiResponse, DashboardState};
/// Generate a random API token: sk-{48 hex chars}
fn generate_token() -> String {
let mut rng = rand::rng();
let bytes: Vec<u8> = (0..24).map(|_| rng.random::<u8>()).collect();
format!("sk-{}", hex::encode(bytes))
}
#[derive(Deserialize)]
pub(super) struct CreateClientRequest {
pub(super) name: String,
pub(super) client_id: Option<String>,
}
#[derive(Deserialize)]
pub(super) struct UpdateClientPayload {
pub(super) name: Option<String>,
pub(super) description: Option<String>,
pub(super) is_active: Option<bool>,
pub(super) rate_limit_per_minute: Option<i64>,
}
pub(super) async fn handle_get_clients(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
client_id as id,
name,
description,
created_at,
total_requests,
total_tokens,
total_cost,
is_active,
rate_limit_per_minute
FROM clients
ORDER BY created_at DESC
"#,
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let clients: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
serde_json::json!({
"id": row.get::<String, _>("id"),
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
"description": row.get::<Option<String>, _>("description"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
"requests_count": row.get::<i64, _>("total_requests"),
"total_tokens": row.get::<i64, _>("total_tokens"),
"total_cost": row.get::<f64, _>("total_cost"),
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
"rate_limit_per_minute": row.get::<Option<i64>, _>("rate_limit_per_minute"),
})
})
.collect();
Json(ApiResponse::success(serde_json::json!(clients)))
}
Err(e) => {
warn!("Failed to fetch clients: {}", e);
Json(ApiResponse::error("Failed to fetch clients".to_string()))
}
}
}
pub(super) async fn handle_create_client(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Json(payload): Json<CreateClientRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let client_id = payload
.client_id
.unwrap_or_else(|| format!("client-{}", &uuid::Uuid::new_v4().to_string()[..8]));
let result = sqlx::query(
r#"
INSERT INTO clients (client_id, name, is_active)
VALUES (?, ?, TRUE)
RETURNING *
"#,
)
.bind(&client_id)
.bind(&payload.name)
.fetch_one(pool)
.await;
match result {
Ok(row) => {
// Auto-generate a token for the new client
let token = generate_token();
let token_result = sqlx::query(
"INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, 'default')",
)
.bind(&client_id)
.bind(&token)
.execute(pool)
.await;
if let Err(e) = token_result {
warn!("Client created but failed to generate token: {}", e);
}
Json(ApiResponse::success(serde_json::json!({
"id": row.get::<String, _>("client_id"),
"name": row.get::<Option<String>, _>("name"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
"status": "active",
"token": token,
})))
}
Err(e) => {
warn!("Failed to create client: {}", e);
Json(ApiResponse::error(format!("Failed to create client: {}", e)))
}
}
}
pub(super) async fn handle_get_client(
State(state): State<DashboardState>,
Path(id): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
c.client_id as id,
c.name,
c.description,
c.is_active,
c.rate_limit_per_minute,
c.created_at,
COALESCE(c.total_tokens, 0) as total_tokens,
COALESCE(c.total_cost, 0.0) as total_cost,
COUNT(r.id) as total_requests,
MAX(r.timestamp) as last_request
FROM clients c
LEFT JOIN llm_requests r ON c.client_id = r.client_id
WHERE c.client_id = ?
GROUP BY c.client_id
"#,
)
.bind(&id)
.fetch_optional(pool)
.await;
match result {
Ok(Some(row)) => Json(ApiResponse::success(serde_json::json!({
"id": row.get::<String, _>("id"),
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
"description": row.get::<Option<String>, _>("description"),
"is_active": row.get::<bool, _>("is_active"),
"rate_limit_per_minute": row.get::<Option<i64>, _>("rate_limit_per_minute"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
"total_tokens": row.get::<i64, _>("total_tokens"),
"total_cost": row.get::<f64, _>("total_cost"),
"total_requests": row.get::<i64, _>("total_requests"),
"last_request": row.get::<Option<chrono::DateTime<chrono::Utc>>, _>("last_request"),
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
}))),
Ok(None) => Json(ApiResponse::error(format!("Client '{}' not found", id))),
Err(e) => {
warn!("Failed to fetch client: {}", e);
Json(ApiResponse::error(format!("Failed to fetch client: {}", e)))
}
}
}
pub(super) async fn handle_update_client(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(id): Path<String>,
Json(payload): Json<UpdateClientPayload>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Build dynamic UPDATE query from provided fields
let mut sets = Vec::new();
let mut binds: Vec<String> = Vec::new();
if let Some(ref name) = payload.name {
sets.push("name = ?");
binds.push(name.clone());
}
if let Some(ref desc) = payload.description {
sets.push("description = ?");
binds.push(desc.clone());
}
if payload.is_active.is_some() {
sets.push("is_active = ?");
}
if payload.rate_limit_per_minute.is_some() {
sets.push("rate_limit_per_minute = ?");
}
if sets.is_empty() {
return Json(ApiResponse::error("No fields to update".to_string()));
}
// Always update updated_at
sets.push("updated_at = CURRENT_TIMESTAMP");
let sql = format!("UPDATE clients SET {} WHERE client_id = ?", sets.join(", "));
let mut query = sqlx::query(&sql);
// Bind in the same order as sets
for b in &binds {
query = query.bind(b);
}
if let Some(active) = payload.is_active {
query = query.bind(active);
}
if let Some(rate) = payload.rate_limit_per_minute {
query = query.bind(rate);
}
query = query.bind(&id);
match query.execute(pool).await {
Ok(result) => {
if result.rows_affected() == 0 {
return Json(ApiResponse::error(format!("Client '{}' not found", id)));
}
// Return the updated client
let row = sqlx::query(
r#"
SELECT client_id as id, name, description, is_active, rate_limit_per_minute,
created_at, total_requests, total_tokens, total_cost
FROM clients WHERE client_id = ?
"#,
)
.bind(&id)
.fetch_one(pool)
.await;
match row {
Ok(row) => Json(ApiResponse::success(serde_json::json!({
"id": row.get::<String, _>("id"),
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
"description": row.get::<Option<String>, _>("description"),
"is_active": row.get::<bool, _>("is_active"),
"rate_limit_per_minute": row.get::<Option<i64>, _>("rate_limit_per_minute"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
"total_requests": row.get::<i64, _>("total_requests"),
"total_tokens": row.get::<i64, _>("total_tokens"),
"total_cost": row.get::<f64, _>("total_cost"),
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
}))),
Err(e) => {
warn!("Failed to fetch updated client: {}", e);
// Update succeeded but fetch failed — still report success
Json(ApiResponse::success(serde_json::json!({ "message": "Client updated" })))
}
}
}
Err(e) => {
warn!("Failed to update client: {}", e);
Json(ApiResponse::error(format!("Failed to update client: {}", e)))
}
}
}
pub(super) async fn handle_delete_client(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(id): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Don't allow deleting the default client
if id == "default" {
return Json(ApiResponse::error("Cannot delete default client".to_string()));
}
let result = sqlx::query("DELETE FROM clients WHERE client_id = ?")
.bind(id)
.execute(pool)
.await;
match result {
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Client deleted" }))),
Err(e) => Json(ApiResponse::error(format!("Failed to delete client: {}", e))),
}
}
pub(super) async fn handle_client_usage(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(id): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Get per-model breakdown for this client
let result = sqlx::query(
r#"
SELECT
model,
provider,
COUNT(*) as request_count,
SUM(prompt_tokens) as prompt_tokens,
SUM(completion_tokens) as completion_tokens,
SUM(total_tokens) as total_tokens,
SUM(cost) as total_cost,
AVG(duration_ms) as avg_duration_ms
FROM llm_requests
WHERE client_id = ?
GROUP BY model, provider
ORDER BY total_cost DESC
"#,
)
.bind(&id)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let breakdown: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
serde_json::json!({
"model": row.get::<String, _>("model"),
"provider": row.get::<String, _>("provider"),
"request_count": row.get::<i64, _>("request_count"),
"prompt_tokens": row.get::<i64, _>("prompt_tokens"),
"completion_tokens": row.get::<i64, _>("completion_tokens"),
"total_tokens": row.get::<i64, _>("total_tokens"),
"total_cost": row.get::<f64, _>("total_cost"),
"avg_duration_ms": row.get::<f64, _>("avg_duration_ms"),
})
})
.collect();
Json(ApiResponse::success(serde_json::json!({
"client_id": id,
"breakdown": breakdown,
})))
}
Err(e) => {
warn!("Failed to fetch client usage: {}", e);
Json(ApiResponse::error(format!("Failed to fetch client usage: {}", e)))
}
}
}
// ── Token management endpoints ──────────────────────────────────────
pub(super) async fn handle_get_client_tokens(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(id): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT id, token, name, is_active, created_at, last_used_at
FROM client_tokens
WHERE client_id = ?
ORDER BY created_at DESC
"#,
)
.bind(&id)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let tokens: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
let token: String = row.get("token");
// Mask all but last 8 chars: sk-••••abcd1234
let masked = if token.len() > 8 {
format!("{}••••{}", &token[..3], &token[token.len() - 8..])
} else {
"••••".to_string()
};
serde_json::json!({
"id": row.get::<i64, _>("id"),
"token_masked": masked,
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "default".to_string()),
"is_active": row.get::<bool, _>("is_active"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
"last_used_at": row.get::<Option<chrono::DateTime<chrono::Utc>>, _>("last_used_at"),
})
})
.collect();
Json(ApiResponse::success(serde_json::json!(tokens)))
}
Err(e) => {
warn!("Failed to fetch client tokens: {}", e);
Json(ApiResponse::error(format!("Failed to fetch client tokens: {}", e)))
}
}
}
#[derive(Deserialize)]
pub(super) struct CreateTokenRequest {
pub(super) name: Option<String>,
}
pub(super) async fn handle_create_client_token(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(id): Path<String>,
Json(payload): Json<CreateTokenRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Verify client exists
let exists: Option<(i64,)> = sqlx::query_as("SELECT 1 as x FROM clients WHERE client_id = ?")
.bind(&id)
.fetch_optional(pool)
.await
.unwrap_or(None);
if exists.is_none() {
return Json(ApiResponse::error(format!("Client '{}' not found", id)));
}
let token = generate_token();
let token_name = payload.name.unwrap_or_else(|| "default".to_string());
let result = sqlx::query(
"INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, ?) RETURNING id, created_at",
)
.bind(&id)
.bind(&token)
.bind(&token_name)
.fetch_one(pool)
.await;
match result {
Ok(row) => Json(ApiResponse::success(serde_json::json!({
"id": row.get::<i64, _>("id"),
"token": token,
"name": token_name,
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
}))),
Err(e) => {
warn!("Failed to create client token: {}", e);
Json(ApiResponse::error(format!("Failed to create token: {}", e)))
}
}
}
pub(super) async fn handle_delete_client_token(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path((client_id, token_id)): Path<(String, i64)>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let result = sqlx::query("DELETE FROM client_tokens WHERE id = ? AND client_id = ?")
.bind(token_id)
.bind(&client_id)
.execute(pool)
.await;
match result {
Ok(r) => {
if r.rows_affected() == 0 {
Json(ApiResponse::error("Token not found".to_string()))
} else {
Json(ApiResponse::success(serde_json::json!({ "message": "Token revoked" })))
}
}
Err(e) => {
warn!("Failed to delete client token: {}", e);
Json(ApiResponse::error(format!("Failed to revoke token: {}", e)))
}
}
}

View File

@@ -1,174 +0,0 @@
// Dashboard module for LLM Proxy Gateway
mod auth;
mod clients;
mod models;
mod providers;
pub mod sessions;
mod system;
mod usage;
mod users;
mod websocket;
use axum::{
extract::{Request, State},
middleware::Next,
response::Response,
Router,
routing::{delete, get, post, put},
};
use axum::http::{header, HeaderValue};
use serde::Serialize;
use tower_http::{
limit::RequestBodyLimitLayer,
set_header::SetResponseHeaderLayer,
};
use crate::state::AppState;
use sessions::SessionManager;
// Dashboard state
#[derive(Clone)]
struct DashboardState {
app_state: AppState,
session_manager: SessionManager,
}
// API Response types
#[derive(Serialize)]
struct ApiResponse<T> {
success: bool,
data: Option<T>,
error: Option<String>,
}
impl<T> ApiResponse<T> {
fn success(data: T) -> Self {
Self {
success: true,
data: Some(data),
error: None,
}
}
fn error(error: String) -> Self {
Self {
success: false,
data: None,
error: Some(error),
}
}
}
/// Rate limiting middleware for dashboard routes
async fn dashboard_rate_limit_middleware(
State(_dashboard_state): State<DashboardState>,
request: Request,
next: Next,
) -> Result<Response, crate::errors::AppError> {
// Bypass rate limiting for dashboard routes to prevent "Failed to load statistics"
// when the UI makes many concurrent requests on load.
// Dashboard endpoints are already secured via auth::require_admin.
Ok(next.run(request).await)
}
// Dashboard routes
pub fn router(state: AppState) -> Router {
let session_manager = SessionManager::new(24); // 24-hour session TTL
let dashboard_state = DashboardState {
app_state: state,
session_manager,
};
// Security headers
let csp_header: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::CONTENT_SECURITY_POLICY,
"default-src 'self'; script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; style-src 'self' 'unsafe-inline' https://cdnjs.cloudflare.com https://fonts.googleapis.com; font-src 'self' https://cdnjs.cloudflare.com https://fonts.gstatic.com; img-src 'self' data:; connect-src 'self' ws:;"
.parse()
.unwrap(),
);
let x_frame_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_FRAME_OPTIONS,
"DENY".parse().unwrap(),
);
let x_content_type_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_CONTENT_TYPE_OPTIONS,
"nosniff".parse().unwrap(),
);
let strict_transport_security: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::STRICT_TRANSPORT_SECURITY,
"max-age=31536000; includeSubDomains".parse().unwrap(),
);
Router::new()
// Static file serving
.fallback_service(tower_http::services::ServeDir::new("static"))
// WebSocket endpoint
.route("/ws", get(websocket::handle_websocket))
// API endpoints
.route("/api/auth/login", post(auth::handle_login))
.route("/api/auth/status", get(auth::handle_auth_status))
.route("/api/auth/logout", post(auth::handle_logout))
.route("/api/auth/change-password", post(auth::handle_change_password))
.route(
"/api/users",
get(users::handle_get_users).post(users::handle_create_user),
)
.route(
"/api/users/{id}",
put(users::handle_update_user).delete(users::handle_delete_user),
)
.route("/api/usage/summary", get(usage::handle_usage_summary))
.route("/api/usage/time-series", get(usage::handle_time_series))
.route("/api/usage/clients", get(usage::handle_clients_usage))
.route("/api/usage/providers", get(usage::handle_providers_usage))
.route("/api/usage/detailed", get(usage::handle_detailed_usage))
.route("/api/analytics/breakdown", get(usage::handle_analytics_breakdown))
.route("/api/models", get(models::handle_get_models))
.route("/api/models/{id}", put(models::handle_update_model))
.route(
"/api/clients",
get(clients::handle_get_clients).post(clients::handle_create_client),
)
.route(
"/api/clients/{id}",
get(clients::handle_get_client)
.put(clients::handle_update_client)
.delete(clients::handle_delete_client),
)
.route("/api/clients/{id}/usage", get(clients::handle_client_usage))
.route(
"/api/clients/{id}/tokens",
get(clients::handle_get_client_tokens).post(clients::handle_create_client_token),
)
.route(
"/api/clients/{id}/tokens/{token_id}",
delete(clients::handle_delete_client_token),
)
.route("/api/providers", get(providers::handle_get_providers))
.route(
"/api/providers/{name}",
get(providers::handle_get_provider).put(providers::handle_update_provider),
)
.route("/api/providers/{name}/test", post(providers::handle_test_provider))
.route("/api/system/health", get(system::handle_system_health))
.route("/api/system/metrics", get(system::handle_system_metrics))
.route("/api/system/logs", get(system::handle_system_logs))
.route("/api/system/backup", post(system::handle_system_backup))
.route(
"/api/system/settings",
get(system::handle_get_settings).post(system::handle_update_settings),
)
// Security layers
.layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit
.layer(csp_header)
.layer(x_frame_options)
.layer(x_content_type_options)
.layer(strict_transport_security)
// Rate limiting middleware
.layer(axum::middleware::from_fn_with_state(
dashboard_state.clone(),
dashboard_rate_limit_middleware,
))
.with_state(dashboard_state)
}

View File

@@ -1,254 +0,0 @@
use axum::{
extract::{Path, Query, State},
response::Json,
};
use serde::Deserialize;
use serde_json;
use sqlx::Row;
use std::collections::HashMap;
use super::{ApiResponse, DashboardState};
use crate::models::registry::{ModelFilter, ModelSortBy, SortOrder};
#[derive(Deserialize)]
pub(super) struct UpdateModelRequest {
pub(super) enabled: bool,
pub(super) prompt_cost: Option<f64>,
pub(super) completion_cost: Option<f64>,
pub(super) mapping: Option<String>,
}
/// Query parameters for `GET /api/models`.
#[derive(Debug, Deserialize, Default)]
pub(super) struct ModelListParams {
/// Filter by provider ID.
pub provider: Option<String>,
/// Text search on model ID or name.
pub search: Option<String>,
/// Filter by input modality (e.g. "image").
pub modality: Option<String>,
/// Only models that support tool calling.
pub tool_call: Option<bool>,
/// Only models that support reasoning.
pub reasoning: Option<bool>,
/// Only models that have pricing data.
pub has_cost: Option<bool>,
/// Only models that have been used in requests.
pub used_only: Option<bool>,
/// Sort field (name, id, provider, context_limit, input_cost, output_cost).
pub sort_by: Option<ModelSortBy>,
/// Sort direction (asc, desc).
pub sort_order: Option<SortOrder>,
}
pub(super) async fn handle_get_models(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Query(params): Query<ModelListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let registry = &state.app_state.model_registry;
let pool = &state.app_state.db_pool;
// Load overrides from database
let db_models_result =
sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
.fetch_all(pool)
.await;
let mut db_models = HashMap::new();
if let Ok(rows) = db_models_result {
for row in rows {
let id: String = row.get("id");
db_models.insert(id, row);
}
}
let mut models_json = Vec::new();
if params.used_only.unwrap_or(false) {
// EXACT USED MODELS LOGIC
let used_pairs_result = sqlx::query(
"SELECT DISTINCT provider, model FROM llm_requests",
)
.fetch_all(pool)
.await;
if let Ok(rows) = used_pairs_result {
for row in rows {
let provider: String = row.get("provider");
let m_key: String = row.get("model");
let provider_name = match provider.as_str() {
"openai" => "OpenAI",
"gemini" => "Google Gemini",
"deepseek" => "DeepSeek",
"grok" => "xAI Grok",
"ollama" => "Ollama",
_ => provider.as_str(),
}.to_string();
let m_meta = registry.find_model(&m_key);
let mut enabled = true;
let mut prompt_cost = m_meta.and_then(|m| m.cost.as_ref().map(|c| c.input)).unwrap_or(0.0);
let mut completion_cost = m_meta.and_then(|m| m.cost.as_ref().map(|c| c.output)).unwrap_or(0.0);
let cache_read_cost = m_meta.and_then(|m| m.cost.as_ref().and_then(|c| c.cache_read));
let cache_write_cost = m_meta.and_then(|m| m.cost.as_ref().and_then(|c| c.cache_write));
let mut mapping = None::<String>;
if let Some(db_row) = db_models.get(&m_key) {
enabled = db_row.get("enabled");
if let Some(p) = db_row.get::<Option<f64>, _>("prompt_cost_per_m") {
prompt_cost = p;
}
if let Some(c) = db_row.get::<Option<f64>, _>("completion_cost_per_m") {
completion_cost = c;
}
mapping = db_row.get("mapping");
}
models_json.push(serde_json::json!({
"id": m_key,
"provider": provider,
"provider_name": provider_name,
"name": m_meta.map(|m| m.name.clone()).unwrap_or_else(|| m_key.clone()),
"enabled": enabled,
"prompt_cost": prompt_cost,
"completion_cost": completion_cost,
"cache_read_cost": cache_read_cost,
"cache_write_cost": cache_write_cost,
"mapping": mapping,
"context_limit": m_meta.and_then(|m| m.limit.as_ref().map(|l| l.context)).unwrap_or(0),
"output_limit": m_meta.and_then(|m| m.limit.as_ref().map(|l| l.output)).unwrap_or(0),
"modalities": m_meta.and_then(|m| m.modalities.as_ref().map(|mo| serde_json::json!({
"input": mo.input,
"output": mo.output,
}))),
"tool_call": m_meta.and_then(|m| m.tool_call),
"reasoning": m_meta.and_then(|m| m.reasoning),
}));
}
}
} else {
// REGISTRY LISTING LOGIC
// Build filter from query params
let filter = ModelFilter {
provider: params.provider,
search: params.search,
modality: params.modality,
tool_call: params.tool_call,
reasoning: params.reasoning,
has_cost: params.has_cost,
};
let sort_by = params.sort_by.unwrap_or_default();
let sort_order = params.sort_order.unwrap_or_default();
// Get filtered and sorted model entries
let entries = registry.list_models(&filter, &sort_by, &sort_order);
for entry in &entries {
let m_key = entry.model_key;
let m_meta = entry.metadata;
let mut enabled = true;
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
let cache_read_cost = m_meta.cost.as_ref().and_then(|c| c.cache_read);
let cache_write_cost = m_meta.cost.as_ref().and_then(|c| c.cache_write);
let mut mapping = None::<String>;
if let Some(row) = db_models.get(m_key) {
enabled = row.get("enabled");
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") {
prompt_cost = p;
}
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") {
completion_cost = c;
}
mapping = row.get("mapping");
}
models_json.push(serde_json::json!({
"id": m_key,
"provider": entry.provider_id,
"provider_name": entry.provider_name,
"name": m_meta.name,
"enabled": enabled,
"prompt_cost": prompt_cost,
"completion_cost": completion_cost,
"cache_read_cost": cache_read_cost,
"cache_write_cost": cache_write_cost,
"mapping": mapping,
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
"output_limit": m_meta.limit.as_ref().map(|l| l.output).unwrap_or(0),
"modalities": m_meta.modalities.as_ref().map(|m| serde_json::json!({
"input": m.input,
"output": m.output,
})),
"tool_call": m_meta.tool_call,
"reasoning": m_meta.reasoning,
}));
}
}
Json(ApiResponse::success(serde_json::json!(models_json)))
}
pub(super) async fn handle_update_model(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(id): Path<String>,
Json(payload): Json<UpdateModelRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Find provider_id for this model in registry
let provider_id = state
.app_state
.model_registry
.providers
.iter()
.find(|(_, p)| p.models.contains_key(&id))
.map(|(id, _)| id.clone())
.unwrap_or_else(|| "unknown".to_string());
let result = sqlx::query(
r#"
INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
enabled = excluded.enabled,
prompt_cost_per_m = excluded.prompt_cost_per_m,
completion_cost_per_m = excluded.completion_cost_per_m,
mapping = excluded.mapping,
updated_at = CURRENT_TIMESTAMP
"#,
)
.bind(&id)
.bind(provider_id)
.bind(payload.enabled)
.bind(payload.prompt_cost)
.bind(payload.completion_cost)
.bind(payload.mapping)
.execute(pool)
.await;
match result {
Ok(_) => {
// Invalidate the in-memory cache so the proxy picks up the change immediately
state.app_state.model_config_cache.invalidate().await;
Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" })))
}
Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))),
}
}

View File

@@ -1,440 +0,0 @@
use axum::{
extract::{Path, State},
response::Json,
};
use serde::Deserialize;
use serde_json;
use sqlx::Row;
use std::collections::HashMap;
use tracing::warn;
use super::{ApiResponse, DashboardState};
use crate::utils::crypto;
#[derive(Deserialize)]
pub(super) struct UpdateProviderRequest {
pub(super) enabled: bool,
pub(super) base_url: Option<String>,
pub(super) api_key: Option<String>,
pub(super) credit_balance: Option<f64>,
pub(super) low_credit_threshold: Option<f64>,
pub(super) billing_mode: Option<String>,
}
pub(super) async fn handle_get_providers(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let registry = &state.app_state.model_registry;
let config = &state.app_state.config;
let pool = &state.app_state.db_pool;
// Load all overrides from database (including billing_mode)
let db_configs_result = sqlx::query(
"SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs",
)
.fetch_all(pool)
.await;
let mut db_configs = HashMap::new();
if let Ok(rows) = db_configs_result {
for row in rows {
let id: String = row.get("id");
let enabled: bool = row.get("enabled");
let base_url: Option<String> = row.get("base_url");
let balance: f64 = row.get("credit_balance");
let threshold: f64 = row.get("low_credit_threshold");
let billing_mode: Option<String> = row.get("billing_mode");
db_configs.insert(id, (enabled, base_url, balance, threshold, billing_mode));
}
}
let mut providers_json = Vec::new();
// Define the list of providers we support
let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
for id in provider_ids {
// Get base config
let (mut enabled, mut base_url, display_name) = match id {
"openai" => (
config.providers.openai.enabled,
config.providers.openai.base_url.clone(),
"OpenAI",
),
"gemini" => (
config.providers.gemini.enabled,
config.providers.gemini.base_url.clone(),
"Google Gemini",
),
"deepseek" => (
config.providers.deepseek.enabled,
config.providers.deepseek.base_url.clone(),
"DeepSeek",
),
"grok" => (
config.providers.grok.enabled,
config.providers.grok.base_url.clone(),
"xAI Grok",
),
"ollama" => (
config.providers.ollama.enabled,
config.providers.ollama.base_url.clone(),
"Ollama",
),
_ => (false, "".to_string(), "Unknown"),
};
let mut balance = 0.0;
let mut threshold = 5.0;
let mut billing_mode: Option<String> = None;
// Apply database overrides
if let Some((db_enabled, db_url, db_balance, db_threshold, db_billing)) = db_configs.get(id) {
enabled = *db_enabled;
if let Some(url) = db_url {
base_url = url.clone();
}
balance = *db_balance;
threshold = *db_threshold;
billing_mode = db_billing.clone();
}
// Find models for this provider in registry
// NOTE: registry provider IDs differ from internal IDs for some providers.
let registry_key = match id {
"gemini" => "google",
"grok" => "xai",
_ => id,
};
let mut models = Vec::new();
if let Some(p_info) = registry.providers.get(registry_key) {
models = p_info.models.keys().cloned().collect();
} else if id == "ollama" {
models = config.providers.ollama.models.clone();
}
// Determine status
let status = if !enabled {
"disabled"
} else {
// Check if it's actually initialized in the provider manager
if state.app_state.provider_manager.get_provider(id).await.is_some() {
// Check circuit breaker
if state
.app_state
.rate_limit_manager
.check_provider_request(id)
.await
.unwrap_or(true)
{
"online"
} else {
"degraded"
}
} else {
"error" // Enabled but failed to initialize (e.g. missing API key)
}
};
providers_json.push(serde_json::json!({
"id": id,
"name": display_name,
"enabled": enabled,
"status": status,
"models": models,
"base_url": base_url,
"credit_balance": balance,
"low_credit_threshold": threshold,
"billing_mode": billing_mode,
"last_used": None::<String>,
}));
}
Json(ApiResponse::success(serde_json::json!(providers_json)))
}
pub(super) async fn handle_get_provider(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(name): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let registry = &state.app_state.model_registry;
let config = &state.app_state.config;
let pool = &state.app_state.db_pool;
// Validate provider name
let (mut enabled, mut base_url, display_name) = match name.as_str() {
"openai" => (
config.providers.openai.enabled,
config.providers.openai.base_url.clone(),
"OpenAI",
),
"gemini" => (
config.providers.gemini.enabled,
config.providers.gemini.base_url.clone(),
"Google Gemini",
),
"deepseek" => (
config.providers.deepseek.enabled,
config.providers.deepseek.base_url.clone(),
"DeepSeek",
),
"grok" => (
config.providers.grok.enabled,
config.providers.grok.base_url.clone(),
"xAI Grok",
),
"ollama" => (
config.providers.ollama.enabled,
config.providers.ollama.base_url.clone(),
"Ollama",
),
_ => return Json(ApiResponse::error(format!("Unknown provider '{}'", name))),
};
let mut balance = 0.0;
let mut threshold = 5.0;
let mut billing_mode: Option<String> = None;
// Apply database overrides
let db_config = sqlx::query(
"SELECT enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs WHERE id = ?",
)
.bind(&name)
.fetch_optional(pool)
.await;
if let Ok(Some(row)) = db_config {
enabled = row.get::<bool, _>("enabled");
if let Some(url) = row.get::<Option<String>, _>("base_url") {
base_url = url;
}
balance = row.get::<f64, _>("credit_balance");
threshold = row.get::<f64, _>("low_credit_threshold");
billing_mode = row.get::<Option<String>, _>("billing_mode");
}
// Find models for this provider
// NOTE: registry provider IDs differ from internal IDs for some providers.
let registry_key = match name.as_str() {
"gemini" => "google",
"grok" => "xai",
_ => name.as_str(),
};
let mut models = Vec::new();
if let Some(p_info) = registry.providers.get(registry_key) {
models = p_info.models.keys().cloned().collect();
} else if name == "ollama" {
models = config.providers.ollama.models.clone();
}
// Determine status
let status = if !enabled {
"disabled"
} else if state.app_state.provider_manager.get_provider(&name).await.is_some() {
if state
.app_state
.rate_limit_manager
.check_provider_request(&name)
.await
.unwrap_or(true)
{
"online"
} else {
"degraded"
}
} else {
"error"
};
Json(ApiResponse::success(serde_json::json!({
"id": name,
"name": display_name,
"enabled": enabled,
"status": status,
"models": models,
"base_url": base_url,
"credit_balance": balance,
"low_credit_threshold": threshold,
"billing_mode": billing_mode,
"last_used": None::<String>,
})))
}
pub(super) async fn handle_update_provider(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(name): Path<String>,
Json(payload): Json<UpdateProviderRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Prepare API key encryption if provided
let (api_key_to_store, api_key_encrypted_flag) = match &payload.api_key {
Some(key) if !key.is_empty() => {
match crypto::encrypt(key) {
Ok(encrypted) => (Some(encrypted), Some(true)),
Err(e) => {
warn!("Failed to encrypt API key for provider {}: {}", name, e);
return Json(ApiResponse::error(format!("Failed to encrypt API key: {}", e)));
}
}
}
Some(_) => {
// Empty string means clear the key
(None, Some(false))
}
None => {
// Keep existing key, we'll rely on COALESCE in SQL
(None, None)
}
};
// Update or insert into database (include billing_mode and api_key_encrypted)
let result = sqlx::query(
r#"
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold, billing_mode)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
enabled = excluded.enabled,
base_url = excluded.base_url,
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
api_key_encrypted = COALESCE(excluded.api_key_encrypted, provider_configs.api_key_encrypted),
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
updated_at = CURRENT_TIMESTAMP
"#,
)
.bind(&name)
.bind(name.to_uppercase())
.bind(payload.enabled)
.bind(&payload.base_url)
.bind(&api_key_to_store)
.bind(api_key_encrypted_flag)
.bind(payload.credit_balance)
.bind(payload.low_credit_threshold)
.bind(payload.billing_mode)
.execute(pool)
.await;
match result {
Ok(_) => {
// Re-initialize provider in manager
if let Err(e) = state
.app_state
.provider_manager
.initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool)
.await
{
warn!("Failed to re-initialize provider {}: {}", name, e);
return Json(ApiResponse::error(format!(
"Provider settings saved but initialization failed: {}",
e
)));
}
Json(ApiResponse::success(
serde_json::json!({ "message": "Provider updated and re-initialized" }),
))
}
Err(e) => {
warn!("Failed to update provider config: {}", e);
Json(ApiResponse::error(format!("Failed to update provider: {}", e)))
}
}
}
pub(super) async fn handle_test_provider(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(name): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let start = std::time::Instant::now();
let provider = match state.app_state.provider_manager.get_provider(&name).await {
Some(p) => p,
None => {
return Json(ApiResponse::error(format!(
"Provider '{}' not found or not enabled",
name
)));
}
};
// Pick a real model for this provider from the registry
// NOTE: registry provider IDs differ from internal IDs for some providers.
let registry_key = match name.as_str() {
"gemini" => "google",
"grok" => "xai",
_ => name.as_str(),
};
let test_model = state
.app_state
.model_registry
.providers
.get(registry_key)
.and_then(|p| p.models.keys().next().cloned())
.unwrap_or_else(|| name.clone());
let test_request = crate::models::UnifiedRequest {
client_id: "system-test".to_string(),
model: test_model,
messages: vec![crate::models::UnifiedMessage {
role: "user".to_string(),
content: vec![crate::models::ContentPart::Text { text: "Hi".to_string() }],
reasoning_content: None,
tool_calls: None,
name: None,
tool_call_id: None,
}],
temperature: None,
top_p: None,
top_k: None,
n: None,
stop: None,
max_tokens: Some(5),
presence_penalty: None,
frequency_penalty: None,
stream: false,
has_images: false,
tools: None,
tool_choice: None,
};
match provider.chat_completion(test_request).await {
Ok(_) => {
let latency = start.elapsed().as_millis();
Json(ApiResponse::success(serde_json::json!({
"success": true,
"latency": latency,
"message": "Connection test successful"
})))
}
Err(e) => Json(ApiResponse::error(format!("Provider test failed: {}", e))),
}
}

View File

@@ -1,311 +0,0 @@
use chrono::{DateTime, Duration, Utc};
use hmac::{Hmac, Mac};
use serde::{Deserialize, Serialize};
use sha2::{Sha256, digest::generic_array::GenericArray};
use std::collections::HashMap;
use std::env;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
const TOKEN_VERSION: &str = "v2";
const REFRESH_WINDOW_MINUTES: i64 = 15; // refresh if token expires within 15 minutes
#[derive(Clone, Debug)]
pub struct Session {
pub username: String,
pub role: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub session_id: String, // unique identifier for the session (UUID)
}
#[derive(Clone)]
pub struct SessionManager {
sessions: Arc<RwLock<HashMap<String, Session>>>, // key = session_id
ttl_hours: i64,
secret: Vec<u8>,
}
#[derive(Debug, Serialize, Deserialize)]
struct SessionPayload {
session_id: String,
username: String,
role: String,
iat: i64, // issued at (Unix timestamp)
exp: i64, // expiry (Unix timestamp)
version: String,
}
impl SessionManager {
pub fn new(ttl_hours: i64) -> Self {
let secret = load_session_secret();
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
ttl_hours,
secret,
}
}
/// Create a new session and return a signed session token.
pub async fn create_session(&self, username: String, role: String) -> String {
let session_id = Uuid::new_v4().to_string();
let now = Utc::now();
let expires_at = now + Duration::hours(self.ttl_hours);
let session = Session {
username: username.clone(),
role: role.clone(),
created_at: now,
expires_at,
session_id: session_id.clone(),
};
// Store session by session_id
self.sessions.write().await.insert(session_id.clone(), session);
// Create signed token
self.create_signed_token(&session_id, &username, &role, now.timestamp(), expires_at.timestamp())
}
/// Validate a session token and return the session if valid and not expired.
/// If the token is within the refresh window, returns a new token as well.
pub async fn validate_session(&self, token: &str) -> Option<Session> {
self.validate_session_with_refresh(token).await.map(|(session, _)| session)
}
/// Validate a session token and return (session, optional new token if refreshed).
pub async fn validate_session_with_refresh(&self, token: &str) -> Option<(Session, Option<String>)> {
// Legacy token format (UUID)
if token.starts_with("session-") {
let sessions = self.sessions.read().await;
return sessions.get(token).and_then(|s| {
if s.expires_at > Utc::now() {
Some((s.clone(), None))
} else {
None
}
});
}
// Signed token format
let payload = match verify_signed_token(token, &self.secret) {
Ok(p) => p,
Err(_) => return None,
};
// Check expiry
let now = Utc::now().timestamp();
if payload.exp <= now {
return None;
}
// Look up session by session_id
let sessions = self.sessions.read().await;
let session = match sessions.get(&payload.session_id) {
Some(s) => s.clone(),
None => return None, // session revoked or not found
};
// Ensure session username/role matches (should always match)
if session.username != payload.username || session.role != payload.role {
return None;
}
// Check if token is within refresh window (last REFRESH_WINDOW_MINUTES of validity)
let refresh_threshold = payload.exp - REFRESH_WINDOW_MINUTES * 60;
let new_token = if now >= refresh_threshold {
// Generate a new token with same session data but updated iat/exp?
// According to activity-based refresh, we should extend the session expiry.
// We'll extend from now by ttl_hours (or keep original expiry?).
// Let's extend from now by ttl_hours (sliding window).
let new_exp = Utc::now() + Duration::hours(self.ttl_hours);
// Update session expiry in store
drop(sessions); // release read lock before acquiring write lock
self.update_session_expiry(&payload.session_id, new_exp).await;
// Create new token with updated iat/exp
let new_token = self.create_signed_token(
&payload.session_id,
&payload.username,
&payload.role,
now,
new_exp.timestamp(),
);
Some(new_token)
} else {
None
};
Some((session, new_token))
}
/// Revoke (delete) a session by token.
/// Supports both legacy tokens (token is key) and signed tokens (extract session_id).
pub async fn revoke_session(&self, token: &str) {
if token.starts_with("session-") {
self.sessions.write().await.remove(token);
return;
}
// For signed token, try to extract session_id
if let Ok(payload) = verify_signed_token(token, &self.secret) {
self.sessions.write().await.remove(&payload.session_id);
}
}
/// Remove all expired sessions from the store.
pub async fn cleanup_expired(&self) {
let now = Utc::now();
self.sessions.write().await.retain(|_, s| s.expires_at > now);
}
// --- Private helpers ---
fn create_signed_token(&self, session_id: &str, username: &str, role: &str, iat: i64, exp: i64) -> String {
let payload = SessionPayload {
session_id: session_id.to_string(),
username: username.to_string(),
role: role.to_string(),
iat,
exp,
version: TOKEN_VERSION.to_string(),
};
sign_token(&payload, &self.secret)
}
async fn update_session_expiry(&self, session_id: &str, new_expires_at: DateTime<Utc>) {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(session_id) {
session.expires_at = new_expires_at;
}
}
}
/// Load session secret from environment variable SESSION_SECRET (hex or base64 encoded).
/// If not set, generates a random 32-byte secret and logs a warning.
fn load_session_secret() -> Vec<u8> {
let secret_str = env::var("SESSION_SECRET").unwrap_or_else(|_| {
// Also check LLM_PROXY__SESSION_SECRET for consistency with config prefix
env::var("LLM_PROXY__SESSION_SECRET").unwrap_or_else(|_| {
// Generate a random secret (32 bytes) and encode as hex
use rand::RngCore;
let mut bytes = [0u8; 32];
rand::rng().fill_bytes(&mut bytes);
let hex_secret = hex::encode(bytes);
tracing::warn!(
"SESSION_SECRET environment variable not set. Using a randomly generated secret. \
This will invalidate all sessions on restart. Set SESSION_SECRET to a fixed hex or base64 encoded 32-byte value."
);
hex_secret
})
});
// Decode hex or base64
hex::decode(&secret_str)
.or_else(|_| URL_SAFE.decode(&secret_str))
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&secret_str))
.unwrap_or_else(|_| {
panic!("SESSION_SECRET must be hex or base64 encoded (32 bytes)");
})
}
/// Sign a session payload and return a token string in format base64_url(payload).base64_url(signature).
fn sign_token(payload: &SessionPayload, secret: &[u8]) -> String {
let json = serde_json::to_vec(payload).expect("Failed to serialize payload");
let payload_b64 = URL_SAFE.encode(&json);
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
mac.update(&json);
let signature = mac.finalize().into_bytes();
let signature_b64 = URL_SAFE.encode(signature);
format!("{}.{}", payload_b64, signature_b64)
}
/// Verify a signed token and return the decoded payload if valid.
fn verify_signed_token(token: &str, secret: &[u8]) -> Result<SessionPayload, TokenError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 2 {
return Err(TokenError::InvalidFormat);
}
let payload_b64 = parts[0];
let signature_b64 = parts[1];
let json = URL_SAFE.decode(payload_b64).map_err(|_| TokenError::InvalidFormat)?;
let signature = URL_SAFE.decode(signature_b64).map_err(|_| TokenError::InvalidFormat)?;
// Verify HMAC
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
mac.update(&json);
// Convert signature slice to GenericArray
let tag = GenericArray::from_slice(&signature);
mac.verify(tag).map_err(|_| TokenError::InvalidSignature)?;
// Deserialize payload
let payload: SessionPayload = serde_json::from_slice(&json).map_err(|_| TokenError::InvalidPayload)?;
Ok(payload)
}
#[derive(Debug)]
enum TokenError {
InvalidFormat,
InvalidSignature,
InvalidPayload,
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[test]
fn test_sign_and_verify_token() {
let secret = b"test-secret-must-be-32-bytes-long!";
let payload = SessionPayload {
session_id: "test-session".to_string(),
username: "testuser".to_string(),
role: "user".to_string(),
iat: 1000,
exp: 2000,
version: TOKEN_VERSION.to_string(),
};
let token = sign_token(&payload, secret);
let verified = verify_signed_token(&token, secret).unwrap();
assert_eq!(verified.session_id, payload.session_id);
assert_eq!(verified.username, payload.username);
assert_eq!(verified.role, payload.role);
assert_eq!(verified.iat, payload.iat);
assert_eq!(verified.exp, payload.exp);
assert_eq!(verified.version, payload.version);
}
#[test]
fn test_tampered_token() {
let secret = b"test-secret-must-be-32-bytes-long!";
let payload = SessionPayload {
session_id: "test-session".to_string(),
username: "testuser".to_string(),
role: "user".to_string(),
iat: 1000,
exp: 2000,
version: TOKEN_VERSION.to_string(),
};
let mut token = sign_token(&payload, secret);
// Tamper with payload part
let mut parts: Vec<&str> = token.split('.').collect();
let mut payload_bytes = URL_SAFE.decode(parts[0]).unwrap();
payload_bytes[0] ^= 0xFF; // flip some bits
let tampered_payload = URL_SAFE.encode(payload_bytes);
parts[0] = &tampered_payload;
token = parts.join(".");
assert!(verify_signed_token(&token, secret).is_err());
}
#[test]
fn test_load_session_secret_from_env() {
unsafe {
env::set_var("SESSION_SECRET", hex::encode([0xAA; 32]));
}
let secret = load_session_secret();
assert_eq!(secret, vec![0xAA; 32]);
unsafe {
env::remove_var("SESSION_SECRET");
}
}
}

View File

@@ -1,405 +0,0 @@
use axum::{extract::State, response::Json};
use chrono;
use serde_json;
use sqlx::Row;
use std::collections::HashMap;
use tracing::warn;
use super::{ApiResponse, DashboardState};
/// Read a value from /proc files, returning None on any failure.
fn read_proc_file(path: &str) -> Option<String> {
std::fs::read_to_string(path).ok()
}
pub(super) async fn handle_system_health(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let mut components = HashMap::new();
components.insert("api_server".to_string(), "online".to_string());
components.insert("database".to_string(), "online".to_string());
// Check provider health via circuit breakers
let provider_ids: Vec<String> = state
.app_state
.provider_manager
.get_all_providers()
.await
.iter()
.map(|p| p.name().to_string())
.collect();
for p_id in provider_ids {
if state
.app_state
.rate_limit_manager
.check_provider_request(&p_id)
.await
.unwrap_or(true)
{
components.insert(p_id, "online".to_string());
} else {
components.insert(p_id, "degraded".to_string());
}
}
// Read real memory usage from /proc/self/status
let memory_mb = read_proc_file("/proc/self/status")
.and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string()))
.and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::<f64>().ok()))
.map(|kb| kb / 1024.0)
.unwrap_or(0.0);
// Get real database pool stats
let db_pool_size = state.app_state.db_pool.size();
let db_pool_idle = state.app_state.db_pool.num_idle();
Json(ApiResponse::success(serde_json::json!({
"status": "healthy",
"timestamp": chrono::Utc::now().to_rfc3339(),
"components": components,
"metrics": {
"memory_usage_mb": (memory_mb * 10.0).round() / 10.0,
"db_connections_active": db_pool_size - db_pool_idle as u32,
"db_connections_idle": db_pool_idle,
}
})))
}
/// Real system metrics from /proc (Linux only).
pub(super) async fn handle_system_metrics(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
// --- CPU usage (aggregate across all cores) ---
// /proc/stat first line: cpu user nice system idle iowait irq softirq steal guest guest_nice
let cpu_percent = read_proc_file("/proc/stat")
.and_then(|s| {
let line = s.lines().find(|l| l.starts_with("cpu "))?.to_string();
let fields: Vec<u64> = line
.split_whitespace()
.skip(1)
.filter_map(|v| v.parse().ok())
.collect();
if fields.len() >= 4 {
let idle = fields[3];
let total: u64 = fields.iter().sum();
if total > 0 {
Some(((total - idle) as f64 / total as f64 * 100.0 * 10.0).round() / 10.0)
} else {
None
}
} else {
None
}
})
.unwrap_or(0.0);
// --- Memory (system-wide from /proc/meminfo) ---
let meminfo = read_proc_file("/proc/meminfo").unwrap_or_default();
let parse_meminfo = |key: &str| -> u64 {
meminfo
.lines()
.find(|l| l.starts_with(key))
.and_then(|l| l.split_whitespace().nth(1))
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(0)
};
let mem_total_kb = parse_meminfo("MemTotal:");
let mem_available_kb = parse_meminfo("MemAvailable:");
let mem_used_kb = mem_total_kb.saturating_sub(mem_available_kb);
let mem_total_mb = mem_total_kb as f64 / 1024.0;
let mem_used_mb = mem_used_kb as f64 / 1024.0;
let mem_percent = if mem_total_kb > 0 {
(mem_used_kb as f64 / mem_total_kb as f64 * 100.0 * 10.0).round() / 10.0
} else {
0.0
};
// --- Process-specific memory (VmRSS) ---
let process_rss_mb = read_proc_file("/proc/self/status")
.and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string()))
.and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::<f64>().ok()))
.map(|kb| (kb / 1024.0 * 10.0).round() / 10.0)
.unwrap_or(0.0);
// --- Disk usage of the data directory ---
let (disk_total_gb, disk_used_gb, disk_percent) = {
// statvfs via libc would be ideal; use df as a simple fallback
std::process::Command::new("df")
.args(["-BM", "--output=size,used,pcent", "."])
.output()
.ok()
.and_then(|o| {
let out = String::from_utf8_lossy(&o.stdout);
let line = out.lines().nth(1)?.to_string();
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 3 {
let total = parts[0].trim_end_matches('M').parse::<f64>().unwrap_or(0.0) / 1024.0;
let used = parts[1].trim_end_matches('M').parse::<f64>().unwrap_or(0.0) / 1024.0;
let pct = parts[2].trim_end_matches('%').parse::<f64>().unwrap_or(0.0);
Some(((total * 10.0).round() / 10.0, (used * 10.0).round() / 10.0, pct))
} else {
None
}
})
.unwrap_or((0.0, 0.0, 0.0))
};
// --- Uptime ---
let uptime_seconds = read_proc_file("/proc/uptime")
.and_then(|s| s.split_whitespace().next().and_then(|v| v.parse::<f64>().ok()))
.unwrap_or(0.0) as u64;
// --- Load average ---
let (load_1, load_5, load_15) = read_proc_file("/proc/loadavg")
.and_then(|s| {
let parts: Vec<&str> = s.split_whitespace().collect();
if parts.len() >= 3 {
Some((
parts[0].parse::<f64>().unwrap_or(0.0),
parts[1].parse::<f64>().unwrap_or(0.0),
parts[2].parse::<f64>().unwrap_or(0.0),
))
} else {
None
}
})
.unwrap_or((0.0, 0.0, 0.0));
// --- Network (from /proc/net/dev, aggregate non-lo interfaces) ---
let (net_rx_bytes, net_tx_bytes) = read_proc_file("/proc/net/dev")
.map(|s| {
s.lines()
.skip(2) // skip header lines
.filter(|l| !l.trim().starts_with("lo:"))
.fold((0u64, 0u64), |(rx, tx), line| {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 10 {
let r = parts[1].parse::<u64>().unwrap_or(0);
let t = parts[9].parse::<u64>().unwrap_or(0);
(rx + r, tx + t)
} else {
(rx, tx)
}
})
})
.unwrap_or((0, 0));
// --- Database pool ---
let db_pool_size = state.app_state.db_pool.size();
let db_pool_idle = state.app_state.db_pool.num_idle();
// --- Active WebSocket listeners ---
let ws_listeners = state.app_state.dashboard_tx.receiver_count();
Json(ApiResponse::success(serde_json::json!({
"cpu": {
"usage_percent": cpu_percent,
"load_average": [load_1, load_5, load_15],
},
"memory": {
"total_mb": (mem_total_mb * 10.0).round() / 10.0,
"used_mb": (mem_used_mb * 10.0).round() / 10.0,
"usage_percent": mem_percent,
"process_rss_mb": process_rss_mb,
},
"disk": {
"total_gb": disk_total_gb,
"used_gb": disk_used_gb,
"usage_percent": disk_percent,
},
"network": {
"rx_bytes": net_rx_bytes,
"tx_bytes": net_tx_bytes,
},
"uptime_seconds": uptime_seconds,
"connections": {
"db_active": db_pool_size - db_pool_idle as u32,
"db_idle": db_pool_idle,
"websocket_listeners": ws_listeners,
},
"timestamp": chrono::Utc::now().to_rfc3339(),
})))
}
pub(super) async fn handle_system_logs(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
id,
timestamp,
client_id,
provider,
model,
prompt_tokens,
completion_tokens,
reasoning_tokens,
total_tokens,
cache_read_tokens,
cache_write_tokens,
cost,
status,
error_message,
duration_ms
FROM llm_requests
ORDER BY timestamp DESC
LIMIT 100
"#,
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let logs: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
serde_json::json!({
"id": row.get::<i64, _>("id"),
"timestamp": row.get::<chrono::DateTime<chrono::Utc>, _>("timestamp"),
"client_id": row.get::<String, _>("client_id"),
"provider": row.get::<String, _>("provider"),
"model": row.get::<String, _>("model"),
"prompt_tokens": row.get::<i64, _>("prompt_tokens"),
"completion_tokens": row.get::<i64, _>("completion_tokens"),
"reasoning_tokens": row.get::<i64, _>("reasoning_tokens"),
"cache_read_tokens": row.get::<i64, _>("cache_read_tokens"),
"cache_write_tokens": row.get::<i64, _>("cache_write_tokens"),
"tokens": row.get::<i64, _>("total_tokens"),
"cost": row.get::<f64, _>("cost"),
"status": row.get::<String, _>("status"),
"error": row.get::<Option<String>, _>("error_message"),
"duration": row.get::<i64, _>("duration_ms"),
})
})
.collect();
Json(ApiResponse::success(serde_json::json!(logs)))
}
Err(e) => {
warn!("Failed to fetch system logs: {}", e);
Json(ApiResponse::error("Failed to fetch system logs".to_string()))
}
}
}
pub(super) async fn handle_system_backup(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let backup_id = format!("backup-{}", chrono::Utc::now().timestamp());
let backup_path = format!("data/{}.db", backup_id);
// Ensure the data directory exists
if let Err(e) = std::fs::create_dir_all("data") {
return Json(ApiResponse::error(format!("Failed to create backup directory: {}", e)));
}
// Use SQLite VACUUM INTO for a consistent backup
let result = sqlx::query(&format!("VACUUM INTO '{}'", backup_path))
.execute(pool)
.await;
match result {
Ok(_) => {
// Get backup file size
let size_bytes = std::fs::metadata(&backup_path).map(|m| m.len()).unwrap_or(0);
Json(ApiResponse::success(serde_json::json!({
"success": true,
"message": "Backup completed successfully",
"backup_id": backup_id,
"backup_path": backup_path,
"size_bytes": size_bytes,
})))
}
Err(e) => {
warn!("Database backup failed: {}", e);
Json(ApiResponse::error(format!("Backup failed: {}", e)))
}
}
}
pub(super) async fn handle_get_settings(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let registry = &state.app_state.model_registry;
let provider_count = registry.providers.len();
let model_count: usize = registry.providers.values().map(|p| p.models.len()).sum();
Json(ApiResponse::success(serde_json::json!({
"server": {
"auth_tokens": state.app_state.auth_tokens.iter().map(|t| mask_token(t)).collect::<Vec<_>>(),
"version": env!("CARGO_PKG_VERSION"),
},
"registry": {
"provider_count": provider_count,
"model_count": model_count,
},
"database": {
"type": "SQLite",
}
})))
}
pub(super) async fn handle_update_settings(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
Json(ApiResponse::error(
"Changing settings at runtime is not yet supported. Please update your config file and restart the server."
.to_string(),
))
}
// Helper functions
fn mask_token(token: &str) -> String {
if token.len() <= 8 {
return "*****".to_string();
}
let masked_len = token.len().min(12);
let visible_len = 4;
let mask_len = masked_len - visible_len;
format!("{}{}", "*".repeat(mask_len), &token[token.len() - visible_len..])
}

View File

@@ -1,526 +0,0 @@
use axum::{
extract::{Query, State},
response::Json,
};
use chrono;
use serde::Deserialize;
use serde_json;
use sqlx::Row;
use tracing::warn;
use super::{ApiResponse, DashboardState};
/// Query parameters for time-based filtering on usage endpoints.
#[derive(Debug, Deserialize, Default)]
pub(super) struct UsagePeriodFilter {
/// Preset period: "today", "24h", "7d", "30d", "all" (default: "all")
pub period: Option<String>,
/// Custom range start (ISO 8601, e.g. "2025-06-01T00:00:00Z")
pub from: Option<String>,
/// Custom range end (ISO 8601)
pub to: Option<String>,
}
impl UsagePeriodFilter {
/// Returns `(sql_fragment, bind_values)` for a WHERE clause.
/// The fragment is either empty (no filter) or " AND timestamp >= ? [AND timestamp <= ?]".
fn to_sql(&self) -> (String, Vec<String>) {
let period = self.period.as_deref().unwrap_or("all");
if period == "custom" {
let mut clause = String::new();
let mut binds = Vec::new();
if let Some(ref from) = self.from {
clause.push_str(" AND timestamp >= ?");
binds.push(from.clone());
}
if let Some(ref to) = self.to {
clause.push_str(" AND timestamp <= ?");
binds.push(to.clone());
}
return (clause, binds);
}
let now = chrono::Utc::now();
let cutoff = match period {
"today" => {
// Start of today (UTC)
let today = now.format("%Y-%m-%dT00:00:00Z").to_string();
Some(today)
}
"24h" => Some((now - chrono::Duration::hours(24)).to_rfc3339()),
"7d" => Some((now - chrono::Duration::days(7)).to_rfc3339()),
"30d" => Some((now - chrono::Duration::days(30)).to_rfc3339()),
_ => None, // "all" or unrecognized
};
match cutoff {
Some(ts) => (" AND timestamp >= ?".to_string(), vec![ts]),
None => (String::new(), vec![]),
}
}
/// Determine the time-series granularity label for grouping.
fn granularity(&self) -> &'static str {
match self.period.as_deref().unwrap_or("all") {
"today" | "24h" => "hour",
_ => "day",
}
}
}
pub(super) async fn handle_usage_summary(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Query(filter): Query<UsagePeriodFilter>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let (period_clause, period_binds) = filter.to_sql();
// Total stats (filtered by period)
let period_sql = format!(
r#"
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(total_tokens), 0) as total_tokens,
COALESCE(SUM(cost), 0.0) as total_cost,
COUNT(DISTINCT client_id) as active_clients,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read,
COALESCE(SUM(cache_write_tokens), 0) as total_cache_write
FROM llm_requests
WHERE 1=1 {}
"#,
period_clause
);
let mut q = sqlx::query(&period_sql);
for b in &period_binds {
q = q.bind(b);
}
let total_stats = q.fetch_one(pool);
// Today's stats
let today = chrono::Utc::now().format("%Y-%m-%d").to_string();
let today_stats = sqlx::query(
r#"
SELECT
COUNT(*) as today_requests,
COALESCE(SUM(total_tokens), 0) as today_tokens,
COALESCE(SUM(cost), 0.0) as today_cost
FROM llm_requests
WHERE strftime('%Y-%m-%d', timestamp) = ?
"#,
)
.bind(today)
.fetch_one(pool);
// Error stats
let error_stats = sqlx::query(
r#"
SELECT
COUNT(*) as total,
SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as errors
FROM llm_requests
"#,
)
.fetch_one(pool);
// Average response time
let avg_response = sqlx::query(
r#"
SELECT COALESCE(AVG(duration_ms), 0.0) as avg_duration
FROM llm_requests
WHERE status = 'success'
"#,
)
.fetch_one(pool);
let results = tokio::join!(total_stats, today_stats, error_stats, avg_response);
match results {
(Ok(t), Ok(d), Ok(e), Ok(a)) => {
let total_requests: i64 = t.get("total_requests");
let total_tokens: i64 = t.get("total_tokens");
let total_cost: f64 = t.get("total_cost");
let active_clients: i64 = t.get("active_clients");
let total_cache_read: i64 = t.get("total_cache_read");
let total_cache_write: i64 = t.get("total_cache_write");
let today_requests: i64 = d.get("today_requests");
let today_cost: f64 = d.get("today_cost");
let total_count: i64 = e.get("total");
let error_count: i64 = e.get("errors");
let error_rate = if total_count > 0 {
(error_count as f64 / total_count as f64) * 100.0
} else {
0.0
};
let avg_response_time: f64 = a.get("avg_duration");
Json(ApiResponse::success(serde_json::json!({
"total_requests": total_requests,
"total_tokens": total_tokens,
"total_cost": total_cost,
"active_clients": active_clients,
"today_requests": today_requests,
"today_cost": today_cost,
"error_rate": error_rate,
"avg_response_time": avg_response_time,
"total_cache_read_tokens": total_cache_read,
"total_cache_write_tokens": total_cache_write,
})))
}
(t_res, d_res, e_res, a_res) => {
if let Err(e) = t_res { warn!("Total stats query failed: {}", e); }
if let Err(e) = d_res { warn!("Today stats query failed: {}", e); }
if let Err(e) = e_res { warn!("Error stats query failed: {}", e); }
if let Err(e) = a_res { warn!("Avg response query failed: {}", e); }
Json(ApiResponse::error("Failed to fetch usage statistics".to_string()))
}
}
}
pub(super) async fn handle_time_series(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Query(filter): Query<UsagePeriodFilter>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let (period_clause, period_binds) = filter.to_sql();
let granularity = filter.granularity();
// Determine the strftime format and default lookback
let (strftime_fmt, _label_key, default_lookback) = match granularity {
"hour" => ("%H:00", "hour", chrono::Duration::hours(24)),
_ => ("%Y-%m-%d", "day", chrono::Duration::days(30)),
};
// If no period filter, apply a sensible default lookback
let (clause, binds) = if period_clause.is_empty() {
let cutoff = (chrono::Utc::now() - default_lookback).to_rfc3339();
(" AND timestamp >= ?".to_string(), vec![cutoff])
} else {
(period_clause, period_binds)
};
let sql = format!(
r#"
SELECT
strftime('{strftime_fmt}', timestamp) as bucket,
COUNT(*) as requests,
COALESCE(SUM(total_tokens), 0) as tokens,
COALESCE(SUM(cost), 0.0) as cost
FROM llm_requests
WHERE 1=1 {clause}
GROUP BY bucket
ORDER BY bucket
"#,
);
let mut q = sqlx::query(&sql);
for b in &binds {
q = q.bind(b);
}
let result = q.fetch_all(pool).await;
match result {
Ok(rows) => {
let mut series = Vec::new();
for row in rows {
let bucket: String = row.get("bucket");
let requests: i64 = row.get("requests");
let tokens: i64 = row.get("tokens");
let cost: f64 = row.get("cost");
series.push(serde_json::json!({
"time": bucket,
"requests": requests,
"tokens": tokens,
"cost": cost,
}));
}
Json(ApiResponse::success(serde_json::json!({
"series": series,
"period": filter.period.as_deref().unwrap_or("all"),
"granularity": granularity,
})))
}
Err(e) => {
warn!("Failed to fetch time series data: {}", e);
Json(ApiResponse::error("Failed to fetch time series data".to_string()))
}
}
}
pub(super) async fn handle_clients_usage(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Query(filter): Query<UsagePeriodFilter>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let (period_clause, period_binds) = filter.to_sql();
let sql = format!(
r#"
SELECT
client_id,
COUNT(*) as requests,
COALESCE(SUM(total_tokens), 0) as tokens,
COALESCE(SUM(cost), 0.0) as cost,
MAX(timestamp) as last_request
FROM llm_requests
WHERE 1=1 {}
GROUP BY client_id
ORDER BY requests DESC
"#,
period_clause
);
let mut q = sqlx::query(&sql);
for b in &period_binds {
q = q.bind(b);
}
let result = q.fetch_all(pool).await;
match result {
Ok(rows) => {
let mut client_usage = Vec::new();
for row in rows {
let client_id: String = row.get("client_id");
let requests: i64 = row.get("requests");
let tokens: i64 = row.get("tokens");
let cost: f64 = row.get("cost");
let last_request: Option<chrono::DateTime<chrono::Utc>> = row.get("last_request");
client_usage.push(serde_json::json!({
"client_id": client_id,
"client_name": client_id,
"requests": requests,
"tokens": tokens,
"cost": cost,
"last_request": last_request,
}));
}
Json(ApiResponse::success(serde_json::json!(client_usage)))
}
Err(e) => {
warn!("Failed to fetch client usage data: {}", e);
Json(ApiResponse::error("Failed to fetch client usage data".to_string()))
}
}
}
pub(super) async fn handle_providers_usage(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Query(filter): Query<UsagePeriodFilter>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let (period_clause, period_binds) = filter.to_sql();
let sql = format!(
r#"
SELECT
provider,
COUNT(*) as requests,
COALESCE(SUM(total_tokens), 0) as tokens,
COALESCE(SUM(cost), 0.0) as cost,
COALESCE(SUM(cache_read_tokens), 0) as cache_read,
COALESCE(SUM(cache_write_tokens), 0) as cache_write
FROM llm_requests
WHERE 1=1 {}
GROUP BY provider
ORDER BY requests DESC
"#,
period_clause
);
let mut q = sqlx::query(&sql);
for b in &period_binds {
q = q.bind(b);
}
let result = q.fetch_all(pool).await;
match result {
Ok(rows) => {
let mut provider_usage = Vec::new();
for row in rows {
let provider: String = row.get("provider");
let requests: i64 = row.get("requests");
let tokens: i64 = row.get("tokens");
let cost: f64 = row.get("cost");
let cache_read: i64 = row.get("cache_read");
let cache_write: i64 = row.get("cache_write");
provider_usage.push(serde_json::json!({
"provider": provider,
"requests": requests,
"tokens": tokens,
"cost": cost,
"cache_read_tokens": cache_read,
"cache_write_tokens": cache_write,
}));
}
Json(ApiResponse::success(serde_json::json!(provider_usage)))
}
Err(e) => {
warn!("Failed to fetch provider usage data: {}", e);
Json(ApiResponse::error("Failed to fetch provider usage data".to_string()))
}
}
}
pub(super) async fn handle_detailed_usage(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Query(filter): Query<UsagePeriodFilter>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let (period_clause, period_binds) = filter.to_sql();
let sql = format!(
r#"
SELECT
strftime('%Y-%m-%d', timestamp) as date,
client_id,
provider,
model,
COUNT(*) as requests,
COALESCE(SUM(total_tokens), 0) as tokens,
COALESCE(SUM(cost), 0.0) as cost,
COALESCE(SUM(cache_read_tokens), 0) as cache_read,
COALESCE(SUM(cache_write_tokens), 0) as cache_write
FROM llm_requests
WHERE 1=1 {}
GROUP BY date, client_id, provider, model
ORDER BY date DESC
LIMIT 200
"#,
period_clause
);
let mut q = sqlx::query(&sql);
for b in &period_binds {
q = q.bind(b);
}
let result = q.fetch_all(pool).await;
match result {
Ok(rows) => {
let usage: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
serde_json::json!({
"date": row.get::<String, _>("date"),
"client": row.get::<String, _>("client_id"),
"provider": row.get::<String, _>("provider"),
"model": row.get::<String, _>("model"),
"requests": row.get::<i64, _>("requests"),
"tokens": row.get::<i64, _>("tokens"),
"cost": row.get::<f64, _>("cost"),
"cache_read_tokens": row.get::<i64, _>("cache_read"),
"cache_write_tokens": row.get::<i64, _>("cache_write"),
})
})
.collect();
Json(ApiResponse::success(serde_json::json!(usage)))
}
Err(e) => {
warn!("Failed to fetch detailed usage: {}", e);
Json(ApiResponse::error("Failed to fetch detailed usage".to_string()))
}
}
}
pub(super) async fn handle_analytics_breakdown(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Query(filter): Query<UsagePeriodFilter>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let (period_clause, period_binds) = filter.to_sql();
// Model breakdown
let model_sql = format!(
"SELECT model as label, COUNT(*) as value FROM llm_requests WHERE 1=1 {} GROUP BY model ORDER BY value DESC",
period_clause
);
let mut mq = sqlx::query(&model_sql);
for b in &period_binds {
mq = mq.bind(b);
}
let models = mq.fetch_all(pool);
// Client breakdown
let client_sql = format!(
"SELECT client_id as label, COUNT(*) as value FROM llm_requests WHERE 1=1 {} GROUP BY client_id ORDER BY value DESC",
period_clause
);
let mut cq = sqlx::query(&client_sql);
for b in &period_binds {
cq = cq.bind(b);
}
let clients = cq.fetch_all(pool);
match tokio::join!(models, clients) {
(Ok(m_rows), Ok(c_rows)) => {
let model_breakdown: Vec<serde_json::Value> = m_rows
.into_iter()
.map(|r| serde_json::json!({ "label": r.get::<String, _>("label"), "value": r.get::<i64, _>("value") }))
.collect();
let client_breakdown: Vec<serde_json::Value> = c_rows
.into_iter()
.map(|r| serde_json::json!({ "label": r.get::<String, _>("label"), "value": r.get::<i64, _>("value") }))
.collect();
Json(ApiResponse::success(serde_json::json!({
"models": model_breakdown,
"clients": client_breakdown
})))
}
_ => Json(ApiResponse::error("Failed to fetch analytics breakdown".to_string())),
}
}

View File

@@ -1,290 +0,0 @@
use axum::{
extract::{Path, State},
response::Json,
};
use serde::Deserialize;
use sqlx::Row;
use tracing::warn;
use super::{ApiResponse, DashboardState, auth};
// ── User management endpoints (admin-only) ──────────────────────────
pub(super) async fn handle_get_users(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let result = sqlx::query(
"SELECT id, username, display_name, role, must_change_password, created_at FROM users ORDER BY created_at ASC",
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let users: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
let username: String = row.get("username");
let display_name: Option<String> = row.get("display_name");
serde_json::json!({
"id": row.get::<i64, _>("id"),
"username": &username,
"display_name": display_name.as_deref().unwrap_or(&username),
"role": row.get::<String, _>("role"),
"must_change_password": row.get::<bool, _>("must_change_password"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
})
})
.collect();
Json(ApiResponse::success(serde_json::json!(users)))
}
Err(e) => {
warn!("Failed to fetch users: {}", e);
Json(ApiResponse::error("Failed to fetch users".to_string()))
}
}
}
#[derive(Deserialize)]
pub(super) struct CreateUserRequest {
pub(super) username: String,
pub(super) password: String,
pub(super) display_name: Option<String>,
pub(super) role: Option<String>,
}
pub(super) async fn handle_create_user(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Json(payload): Json<CreateUserRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Validate role
let role = payload.role.as_deref().unwrap_or("viewer");
if role != "admin" && role != "viewer" {
return Json(ApiResponse::error("Role must be 'admin' or 'viewer'".to_string()));
}
// Validate username
let username = payload.username.trim();
if username.is_empty() || username.len() > 64 {
return Json(ApiResponse::error("Username must be 1-64 characters".to_string()));
}
// Validate password
if payload.password.len() < 4 {
return Json(ApiResponse::error("Password must be at least 4 characters".to_string()));
}
let password_hash = match bcrypt::hash(&payload.password, 12) {
Ok(h) => h,
Err(_) => return Json(ApiResponse::error("Failed to hash password".to_string())),
};
let result = sqlx::query(
r#"
INSERT INTO users (username, password_hash, display_name, role, must_change_password)
VALUES (?, ?, ?, ?, TRUE)
RETURNING id, username, display_name, role, must_change_password, created_at
"#,
)
.bind(username)
.bind(&password_hash)
.bind(&payload.display_name)
.bind(role)
.fetch_one(pool)
.await;
match result {
Ok(row) => {
let uname: String = row.get("username");
let display_name: Option<String> = row.get("display_name");
Json(ApiResponse::success(serde_json::json!({
"id": row.get::<i64, _>("id"),
"username": &uname,
"display_name": display_name.as_deref().unwrap_or(&uname),
"role": row.get::<String, _>("role"),
"must_change_password": row.get::<bool, _>("must_change_password"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
})))
}
Err(e) => {
let msg = if e.to_string().contains("UNIQUE") {
format!("Username '{}' already exists", username)
} else {
format!("Failed to create user: {}", e)
};
warn!("Failed to create user: {}", e);
Json(ApiResponse::error(msg))
}
}
}
#[derive(Deserialize)]
pub(super) struct UpdateUserRequest {
pub(super) display_name: Option<String>,
pub(super) role: Option<String>,
pub(super) password: Option<String>,
pub(super) must_change_password: Option<bool>,
}
pub(super) async fn handle_update_user(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(id): Path<i64>,
Json(payload): Json<UpdateUserRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Validate role if provided
if let Some(ref role) = payload.role {
if role != "admin" && role != "viewer" {
return Json(ApiResponse::error("Role must be 'admin' or 'viewer'".to_string()));
}
}
// Build dynamic update
let mut sets = Vec::new();
let mut string_binds: Vec<String> = Vec::new();
let mut has_password = false;
let mut has_must_change = false;
if let Some(ref display_name) = payload.display_name {
sets.push("display_name = ?");
string_binds.push(display_name.clone());
}
if let Some(ref role) = payload.role {
sets.push("role = ?");
string_binds.push(role.clone());
}
if let Some(ref password) = payload.password {
if password.len() < 4 {
return Json(ApiResponse::error("Password must be at least 4 characters".to_string()));
}
let hash = match bcrypt::hash(password, 12) {
Ok(h) => h,
Err(_) => return Json(ApiResponse::error("Failed to hash password".to_string())),
};
sets.push("password_hash = ?");
string_binds.push(hash);
has_password = true;
}
if let Some(mcp) = payload.must_change_password {
sets.push("must_change_password = ?");
has_must_change = true;
let _ = mcp; // used below in bind
}
if sets.is_empty() {
return Json(ApiResponse::error("No fields to update".to_string()));
}
let sql = format!("UPDATE users SET {} WHERE id = ?", sets.join(", "));
let mut query = sqlx::query(&sql);
for b in &string_binds {
query = query.bind(b);
}
if has_must_change {
query = query.bind(payload.must_change_password.unwrap());
}
let _ = has_password; // consumed above via string_binds
query = query.bind(id);
match query.execute(pool).await {
Ok(result) => {
if result.rows_affected() == 0 {
return Json(ApiResponse::error("User not found".to_string()));
}
// Fetch updated user
let row = sqlx::query(
"SELECT id, username, display_name, role, must_change_password, created_at FROM users WHERE id = ?",
)
.bind(id)
.fetch_one(pool)
.await;
match row {
Ok(row) => {
let uname: String = row.get("username");
let display_name: Option<String> = row.get("display_name");
Json(ApiResponse::success(serde_json::json!({
"id": row.get::<i64, _>("id"),
"username": &uname,
"display_name": display_name.as_deref().unwrap_or(&uname),
"role": row.get::<String, _>("role"),
"must_change_password": row.get::<bool, _>("must_change_password"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
})))
}
Err(_) => Json(ApiResponse::success(serde_json::json!({ "message": "User updated" }))),
}
}
Err(e) => {
warn!("Failed to update user: {}", e);
Json(ApiResponse::error(format!("Failed to update user: {}", e)))
}
}
}
pub(super) async fn handle_delete_user(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(id): Path<i64>,
) -> Json<ApiResponse<serde_json::Value>> {
let (session, _) = match auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Don't allow deleting yourself
let target_username: Option<String> =
sqlx::query_scalar::<_, String>("SELECT username FROM users WHERE id = ?")
.bind(id)
.fetch_optional(pool)
.await
.unwrap_or(None);
match target_username {
None => return Json(ApiResponse::error("User not found".to_string())),
Some(ref uname) if uname == &session.username => {
return Json(ApiResponse::error("Cannot delete your own account".to_string()));
}
_ => {}
}
let result = sqlx::query("DELETE FROM users WHERE id = ?")
.bind(id)
.execute(pool)
.await;
match result {
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "User deleted" }))),
Err(e) => {
warn!("Failed to delete user: {}", e);
Json(ApiResponse::error(format!("Failed to delete user: {}", e)))
}
}
}

View File

@@ -1,75 +0,0 @@
use axum::{
extract::{
State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
response::IntoResponse,
};
use serde_json;
use tracing::info;
use super::DashboardState;
// WebSocket handler
pub(super) async fn handle_websocket(ws: WebSocketUpgrade, State(state): State<DashboardState>) -> impl IntoResponse {
ws.on_upgrade(|socket| handle_websocket_connection(socket, state))
}
pub(super) async fn handle_websocket_connection(mut socket: WebSocket, state: DashboardState) {
info!("WebSocket connection established");
// Subscribe to events from the global bus
let mut rx = state.app_state.dashboard_tx.subscribe();
// Send initial connection message
let _ = socket
.send(Message::Text(
serde_json::json!({
"type": "connected",
"message": "Connected to LLM Proxy Dashboard"
})
.to_string()
.into(),
))
.await;
// Handle incoming messages and broadcast events
loop {
tokio::select! {
// Receive broadcast events
Ok(event) = rx.recv() => {
let Ok(json_str) = serde_json::to_string(&event) else {
continue;
};
let message = Message::Text(json_str.into());
if socket.send(message).await.is_err() {
break;
}
}
// Receive WebSocket messages
result = socket.recv() => {
match result {
Some(Ok(Message::Text(text))) => {
handle_websocket_message(&text, &state).await;
}
_ => break,
}
}
}
}
info!("WebSocket connection closed");
}
pub(super) async fn handle_websocket_message(text: &str, state: &DashboardState) {
// Parse and handle WebSocket messages
if let Ok(data) = serde_json::from_str::<serde_json::Value>(text)
&& data.get("type").and_then(|v| v.as_str()) == Some("ping")
{
let _ = state.app_state.dashboard_tx.send(serde_json::json!({
"type": "pong",
"payload": {}
}));
}
}

View File

@@ -1,275 +0,0 @@
use anyhow::Result;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool};
use std::str::FromStr;
use tracing::info;
use crate::config::DatabaseConfig;
pub type DbPool = SqlitePool;
pub async fn init(config: &DatabaseConfig) -> Result<DbPool> {
// Ensure the database directory exists
if let Some(parent) = config.path.parent()
&& !parent.as_os_str().is_empty()
{
tokio::fs::create_dir_all(parent).await?;
}
let database_path = config.path.to_string_lossy().to_string();
info!("Connecting to database at {}", database_path);
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?
.create_if_missing(true)
.pragma("foreign_keys", "ON");
let pool = SqlitePool::connect_with(options).await?;
// Run migrations
run_migrations(&pool).await?;
info!("Database migrations completed");
Ok(pool)
}
pub async fn run_migrations(pool: &DbPool) -> Result<()> {
// Create clients table if it doesn't exist
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS clients (
id INTEGER PRIMARY KEY AUTOINCREMENT,
client_id TEXT UNIQUE NOT NULL,
name TEXT,
description TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
is_active BOOLEAN DEFAULT TRUE,
rate_limit_per_minute INTEGER DEFAULT 60,
total_requests INTEGER DEFAULT 0,
total_tokens INTEGER DEFAULT 0,
total_cost REAL DEFAULT 0.0
)
"#,
)
.execute(pool)
.await?;
// Create llm_requests table if it doesn't exist
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS llm_requests (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
client_id TEXT,
provider TEXT,
model TEXT,
prompt_tokens INTEGER,
completion_tokens INTEGER,
reasoning_tokens INTEGER DEFAULT 0,
total_tokens INTEGER,
cost REAL,
has_images BOOLEAN DEFAULT FALSE,
status TEXT DEFAULT 'success',
error_message TEXT,
duration_ms INTEGER,
request_body TEXT,
response_body TEXT,
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE SET NULL
)
"#,
)
.execute(pool)
.await?;
// Create provider_configs table
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS provider_configs (
id TEXT PRIMARY KEY,
display_name TEXT NOT NULL,
enabled BOOLEAN DEFAULT TRUE,
base_url TEXT,
api_key TEXT,
credit_balance REAL DEFAULT 0.0,
low_credit_threshold REAL DEFAULT 5.0,
billing_mode TEXT,
api_key_encrypted BOOLEAN DEFAULT FALSE,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
"#,
)
.execute(pool)
.await?;
// Create model_configs table
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS model_configs (
id TEXT PRIMARY KEY,
provider_id TEXT NOT NULL,
display_name TEXT,
enabled BOOLEAN DEFAULT TRUE,
prompt_cost_per_m REAL,
completion_cost_per_m REAL,
cache_read_cost_per_m REAL,
cache_write_cost_per_m REAL,
mapping TEXT,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE
)
"#,
)
.execute(pool)
.await?;
// Create users table for dashboard access
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
display_name TEXT,
role TEXT DEFAULT 'admin',
must_change_password BOOLEAN DEFAULT FALSE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
"#,
)
.execute(pool)
.await?;
// Create client_tokens table for DB-based token auth
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS client_tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
client_id TEXT NOT NULL,
token TEXT NOT NULL UNIQUE,
name TEXT DEFAULT 'default',
is_active BOOLEAN DEFAULT TRUE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_used_at DATETIME,
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE CASCADE
)
"#,
)
.execute(pool)
.await?;
// Add must_change_password column if it doesn't exist (migration for existing DBs)
let _ = sqlx::query("ALTER TABLE users ADD COLUMN must_change_password BOOLEAN DEFAULT FALSE")
.execute(pool)
.await;
// Add display_name column if it doesn't exist (migration for existing DBs)
let _ = sqlx::query("ALTER TABLE users ADD COLUMN display_name TEXT")
.execute(pool)
.await;
// Add cache token columns if they don't exist (migration for existing DBs)
let _ = sqlx::query("ALTER TABLE llm_requests ADD COLUMN cache_read_tokens INTEGER DEFAULT 0")
.execute(pool)
.await;
let _ = sqlx::query("ALTER TABLE llm_requests ADD COLUMN cache_write_tokens INTEGER DEFAULT 0")
.execute(pool)
.await;
let _ = sqlx::query("ALTER TABLE llm_requests ADD COLUMN reasoning_tokens INTEGER DEFAULT 0")
.execute(pool)
.await;
// Add billing_mode column if it doesn't exist (migration for existing DBs)
let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN billing_mode TEXT")
.execute(pool)
.await;
// Add api_key_encrypted column if it doesn't exist (migration for existing DBs)
let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN api_key_encrypted BOOLEAN DEFAULT FALSE")
.execute(pool)
.await;
// Add manual cache cost columns to model_configs if they don't exist
let _ = sqlx::query("ALTER TABLE model_configs ADD COLUMN cache_read_cost_per_m REAL")
.execute(pool)
.await;
let _ = sqlx::query("ALTER TABLE model_configs ADD COLUMN cache_write_cost_per_m REAL")
.execute(pool)
.await;
// Insert default admin user if none exists (default password: admin)
let user_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users").fetch_one(pool).await?;
if user_count.0 == 0 {
// 'admin' hashed with default cost (12)
let default_admin_hash =
bcrypt::hash("admin", 12).map_err(|e| anyhow::anyhow!("Failed to hash default password: {}", e))?;
sqlx::query(
"INSERT INTO users (username, password_hash, role, must_change_password) VALUES ('admin', ?, 'admin', TRUE)"
)
.bind(default_admin_hash)
.execute(pool)
.await?;
info!("Created default admin user with password 'admin' (must change on first login)");
}
// Create indices
sqlx::query("CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)")
.execute(pool)
.await?;
sqlx::query("CREATE UNIQUE INDEX IF NOT EXISTS idx_client_tokens_token ON client_tokens(token)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_client_tokens_client_id ON client_tokens(client_id)")
.execute(pool)
.await?;
// Composite indexes for performance
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id)")
.execute(pool)
.await?;
// Insert default client if none exists
sqlx::query(
r#"
INSERT OR IGNORE INTO clients (client_id, name, description)
VALUES ('default', 'Default Client', 'Default client for anonymous requests')
"#,
)
.execute(pool)
.await?;
Ok(())
}
pub async fn test_connection(pool: &DbPool) -> Result<()> {
sqlx::query("SELECT 1").execute(pool).await?;
Ok(())
}

View File

@@ -1,58 +0,0 @@
use thiserror::Error;
#[derive(Error, Debug, Clone)]
pub enum AppError {
#[error("Authentication failed: {0}")]
AuthError(String),
#[error("Configuration error: {0}")]
ConfigError(String),
#[error("Database error: {0}")]
DatabaseError(String),
#[error("Provider error: {0}")]
ProviderError(String),
#[error("Validation error: {0}")]
ValidationError(String),
#[error("Multimodal processing error: {0}")]
MultimodalError(String),
#[error("Rate limit exceeded: {0}")]
RateLimitError(String),
#[error("Internal server error: {0}")]
InternalError(String),
}
impl From<sqlx::Error> for AppError {
fn from(err: sqlx::Error) -> Self {
AppError::DatabaseError(err.to_string())
}
}
impl From<anyhow::Error> for AppError {
fn from(err: anyhow::Error) -> Self {
AppError::InternalError(err.to_string())
}
}
impl axum::response::IntoResponse for AppError {
fn into_response(self) -> axum::response::Response {
let status = match self {
AppError::AuthError(_) => axum::http::StatusCode::UNAUTHORIZED,
AppError::RateLimitError(_) => axum::http::StatusCode::TOO_MANY_REQUESTS,
AppError::ValidationError(_) => axum::http::StatusCode::BAD_REQUEST,
_ => axum::http::StatusCode::INTERNAL_SERVER_ERROR,
};
let body = axum::Json(serde_json::json!({
"error": self.to_string(),
"type": format!("{:?}", self)
}));
(status, body).into_response()
}
}

View File

@@ -1,328 +0,0 @@
//! LLM Proxy Library
//!
//! This library provides the core functionality for the LLM proxy gateway,
//! including provider integration, token tracking, and API endpoints.
pub mod auth;
pub mod client;
pub mod config;
pub mod dashboard;
pub mod database;
pub mod errors;
pub mod logging;
pub mod models;
pub mod multimodal;
pub mod providers;
pub mod rate_limiting;
pub mod server;
pub mod state;
pub mod utils;
// Re-exports for convenience
pub use auth::{AuthenticatedClient, validate_token};
pub use config::{
AppConfig, DatabaseConfig, DeepSeekConfig, GeminiConfig, GrokConfig, ModelMappingConfig, ModelPricing,
OllamaConfig, OpenAIConfig, PricingConfig, ProviderConfig, ServerConfig,
};
pub use database::{DbPool, init as init_db, test_connection};
pub use errors::AppError;
pub use logging::{LoggingContext, RequestLog, RequestLogger};
pub use models::{
ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage,
ChatStreamChoice, ChatStreamDelta, ContentPart, ContentPartValue, FromOpenAI, ImageUrl, MessageContent,
OpenAIContentPart, OpenAIMessage, OpenAIRequest, ToOpenAI, UnifiedMessage, UnifiedRequest, Usage,
};
pub use providers::{Provider, ProviderManager, ProviderResponse, ProviderStreamChunk};
pub use server::router;
pub use state::AppState;
/// Test utilities for integration testing
#[cfg(test)]
pub mod test_utils {
use std::sync::Arc;
use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState, utils::crypto, database::run_migrations};
use sqlx::sqlite::SqlitePool;
/// Create a test application state
pub async fn create_test_state() -> AppState {
// Create in-memory database
let pool = SqlitePool::connect("sqlite::memory:")
.await
.expect("Failed to create test database");
// Run migrations on the pool
run_migrations(&pool).await.expect("Failed to run migrations");
let rate_limit_manager = RateLimitManager::new(
crate::rate_limiting::RateLimiterConfig::default(),
crate::rate_limiting::CircuitBreakerConfig::default(),
);
let client_manager = Arc::new(ClientManager::new(pool.clone()));
// Create provider manager
let provider_manager = ProviderManager::new();
let model_registry = crate::models::registry::ModelRegistry {
providers: std::collections::HashMap::new(),
};
let (dashboard_tx, _) = tokio::sync::broadcast::channel::<serde_json::Value>(100);
let config = Arc::new(crate::config::AppConfig {
server: crate::config::ServerConfig {
port: 8080,
host: "127.0.0.1".to_string(),
auth_tokens: vec![],
},
database: crate::config::DatabaseConfig {
path: std::path::PathBuf::from(":memory:"),
max_connections: 5,
},
providers: crate::config::ProviderConfig {
openai: crate::config::OpenAIConfig {
api_key_env: "OPENAI_API_KEY".to_string(),
base_url: "".to_string(),
default_model: "".to_string(),
enabled: true,
},
gemini: crate::config::GeminiConfig {
api_key_env: "GEMINI_API_KEY".to_string(),
base_url: "".to_string(),
default_model: "".to_string(),
enabled: true,
},
deepseek: crate::config::DeepSeekConfig {
api_key_env: "DEEPSEEK_API_KEY".to_string(),
base_url: "".to_string(),
default_model: "".to_string(),
enabled: true,
},
grok: crate::config::GrokConfig {
api_key_env: "GROK_API_KEY".to_string(),
base_url: "".to_string(),
default_model: "".to_string(),
enabled: true,
},
ollama: crate::config::OllamaConfig {
base_url: "".to_string(),
enabled: true,
models: vec![],
},
},
model_mapping: crate::config::ModelMappingConfig { patterns: vec![] },
pricing: crate::config::PricingConfig {
openai: vec![],
gemini: vec![],
deepseek: vec![],
grok: vec![],
ollama: vec![],
},
config_path: None,
encryption_key: "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f".to_string(),
});
// Initialize encryption with the test key
crypto::init_with_key(&config.encryption_key).expect("failed to initialize crypto");
AppState::new(
config,
provider_manager,
pool,
rate_limit_manager,
model_registry,
vec![], // auth_tokens
)
}
/// Create a test HTTP client
pub fn create_test_client() -> reqwest::Client {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("Failed to create test HTTP client")
}
}
#[cfg(test)]
mod integration_tests {
use super::test_utils::*;
use crate::{
models::{ChatCompletionRequest, ChatMessage},
server::router,
utils::crypto,
};
use axum::{
body::Body,
http::{Request, StatusCode},
};
use mockito::Server;
use serde_json::json;
use sqlx::Row;
use tower::util::ServiceExt;
#[tokio::test]
async fn test_encrypted_provider_key_integration() {
// Step 1: Setup test database and state
let state = create_test_state().await;
let pool = state.db_pool.clone();
// Step 2: Insert provider with encrypted API key
let test_api_key = "test-openai-key-12345";
let encrypted_key = crypto::encrypt(test_api_key).expect("Failed to encrypt test key");
sqlx::query(
r#"
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind("openai")
.bind("OpenAI")
.bind(true)
.bind("http://localhost:1234") // Mock server URL
.bind(&encrypted_key)
.bind(true) // api_key_encrypted flag
.bind(100.0)
.bind(5.0)
.execute(&pool)
.await
.expect("Failed to update provider URL");
// Re-initialize provider with new URL
state
.provider_manager
.initialize_provider("openai", &state.config, &pool)
.await
.expect("Failed to re-initialize provider");
// Step 4: Mock OpenAI API server
let mut server = Server::new_async().await;
let mock = server
.mock("POST", "/chat/completions")
.match_header("authorization", format!("Bearer {}", test_api_key).as_str())
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
json!({
"id": "chatcmpl-test",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-3.5-turbo",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello, world!"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
}
})
.to_string(),
)
.create_async()
.await;
// Update provider base URL to use mock server
sqlx::query("UPDATE provider_configs SET base_url = ? WHERE id = 'openai'")
.bind(&server.url())
.execute(&pool)
.await
.expect("Failed to update provider URL");
// Re-initialize provider with new URL
state
.provider_manager
.initialize_provider("openai", &state.config, &pool)
.await
.expect("Failed to re-initialize provider");
// Step 5: Create test router and make request
let app = router(state);
let request_body = ChatCompletionRequest {
model: "gpt-3.5-turbo".to_string(),
messages: vec![ChatMessage {
role: "user".to_string(),
content: crate::models::MessageContent::Text {
content: "Hello".to_string(),
},
reasoning_content: None,
tool_calls: None,
name: None,
tool_call_id: None,
}],
temperature: None,
top_p: None,
top_k: None,
n: None,
stop: None,
max_tokens: Some(100),
presence_penalty: None,
frequency_penalty: None,
stream: Some(false),
tools: None,
tool_choice: None,
};
let request = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.header("authorization", "Bearer test-token")
.body(Body::from(serde_json::to_string(&request_body).unwrap()))
.unwrap();
// Step 6: Execute request through proxy
let response = app
.oneshot(request)
.await
.expect("Failed to execute request");
let status = response.status();
println!("Response status: {}", status);
if status != StatusCode::OK {
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
println!("Response body: {}", body_str);
panic!("Response status is not OK: {}", status);
}
assert_eq!(status, StatusCode::OK);
// Verify the mock was called
mock.assert_async().await;
// Give the async logging task time to complete
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Step 7: Verify usage was logged in database
let log_row = sqlx::query("SELECT * FROM llm_requests WHERE client_id = 'client_test-tok' ORDER BY id DESC LIMIT 1")
.fetch_one(&pool)
.await
.expect("Request log not found");
assert_eq!(log_row.get::<String, _>("provider"), "openai");
assert_eq!(log_row.get::<String, _>("model"), "gpt-3.5-turbo");
assert_eq!(log_row.get::<i64, _>("prompt_tokens"), 10);
assert_eq!(log_row.get::<i64, _>("completion_tokens"), 5);
assert_eq!(log_row.get::<i64, _>("total_tokens"), 15);
assert_eq!(log_row.get::<String, _>("status"), "success");
// Verify client usage was updated
let client_row = sqlx::query("SELECT total_requests, total_tokens, total_cost FROM clients WHERE client_id = 'client_test-tok'")
.fetch_one(&pool)
.await
.expect("Client not found");
assert_eq!(client_row.get::<i64, _>("total_requests"), 1);
assert_eq!(client_row.get::<i64, _>("total_tokens"), 15);
}
}

View File

@@ -1,242 +0,0 @@
use chrono::{DateTime, Utc};
use serde::Serialize;
use sqlx::SqlitePool;
use tokio::sync::broadcast;
use tracing::{info, warn};
use crate::errors::AppError;
/// Request log entry for database storage
#[derive(Debug, Clone, Serialize)]
pub struct RequestLog {
pub timestamp: DateTime<Utc>,
pub client_id: String,
pub provider: String,
pub model: String,
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub reasoning_tokens: u32,
pub total_tokens: u32,
pub cache_read_tokens: u32,
pub cache_write_tokens: u32,
pub cost: f64,
pub has_images: bool,
pub status: String, // "success", "error"
pub error_message: Option<String>,
pub duration_ms: u64,
}
/// Database operations for request logging
pub struct RequestLogger {
db_pool: SqlitePool,
dashboard_tx: broadcast::Sender<serde_json::Value>,
}
impl RequestLogger {
pub fn new(db_pool: SqlitePool, dashboard_tx: broadcast::Sender<serde_json::Value>) -> Self {
Self { db_pool, dashboard_tx }
}
/// Log a request to the database (async, spawns a task)
pub fn log_request(&self, log: RequestLog) {
let pool = self.db_pool.clone();
let tx = self.dashboard_tx.clone();
// Spawn async task to log without blocking response
tokio::spawn(async move {
// Broadcast to dashboard
let broadcast_result = tx.send(serde_json::json!({
"type": "request",
"channel": "requests",
"payload": log
}));
match broadcast_result {
Ok(receivers) => info!("Broadcast request log to {} dashboard listeners", receivers),
Err(_) => {} // No active WebSocket clients — expected when dashboard isn't open
}
match Self::insert_log(&pool, log).await {
Ok(()) => info!("Request logged to database successfully"),
Err(e) => warn!("Failed to log request to database: {}", e),
}
});
}
/// Insert a log entry into the database
async fn insert_log(pool: &SqlitePool, log: RequestLog) -> Result<(), sqlx::Error> {
let mut tx = pool.begin().await?;
// Ensure the client row exists (FK constraint requires it)
sqlx::query(
"INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')",
)
.bind(&log.client_id)
.bind(&log.client_id)
.execute(&mut *tx)
.await?;
sqlx::query(
r#"
INSERT INTO llm_requests
(timestamp, client_id, provider, model, prompt_tokens, completion_tokens, reasoning_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status, error_message, duration_ms, request_body, response_body)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(log.timestamp)
.bind(&log.client_id)
.bind(&log.provider)
.bind(&log.model)
.bind(log.prompt_tokens as i64)
.bind(log.completion_tokens as i64)
.bind(log.reasoning_tokens as i64)
.bind(log.total_tokens as i64)
.bind(log.cache_read_tokens as i64)
.bind(log.cache_write_tokens as i64)
.bind(log.cost)
.bind(log.has_images)
.bind(&log.status)
.bind(log.error_message)
.bind(log.duration_ms as i64)
.bind(None::<String>) // request_body - optional, not stored to save disk space
.bind(None::<String>) // response_body - optional, not stored to save disk space
.execute(&mut *tx)
.await?;
// Update client usage statistics
sqlx::query(
r#"
UPDATE clients SET
total_requests = total_requests + 1,
total_tokens = total_tokens + ?,
total_cost = total_cost + ?,
updated_at = CURRENT_TIMESTAMP
WHERE client_id = ?
"#,
)
.bind(log.total_tokens as i64)
.bind(log.cost)
.bind(&log.client_id)
.execute(&mut *tx)
.await?;
// Deduct from provider balance if successful.
// Providers configured with billing_mode = 'postpaid' will not have their
// credit_balance decremented. Use a conditional UPDATE so we don't need
// a prior SELECT and avoid extra round-trips.
if log.cost > 0.0 {
sqlx::query(
"UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')",
)
.bind(log.cost)
.bind(&log.provider)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(())
}
}
// /// Middleware to log LLM API requests
// /// TODO: Implement proper middleware that can extract response body details
// pub async fn request_logging_middleware(
// // Extract the authenticated client (if available)
// auth_result: Result<AuthenticatedClient, AppError>,
// request: Request,
// next: Next,
// ) -> Response {
// let start_time = std::time::Instant::now();
//
// // Extract client_id from auth or use "unknown"
// let client_id = match auth_result {
// Ok(auth) => auth.client_id,
// Err(_) => "unknown".to_string(),
// };
//
// // Try to extract request details
// let (request_parts, request_body) = request.into_parts();
//
// // Clone request parts for logging
// let path = request_parts.uri.path().to_string();
//
// // Check if this is a chat completion request
// let is_chat_completion = path == "/v1/chat/completions";
//
// // Reconstruct request for downstream handlers
// let request = Request::from_parts(request_parts, request_body);
//
// // Process request and get response
// let response = next.run(request).await;
//
// // Calculate duration
// let duration = start_time.elapsed();
// let duration_ms = duration.as_millis() as u64;
//
// // Log basic request info
// info!(
// "Request from {} to {} - Status: {} - Duration: {}ms",
// client_id,
// path,
// response.status().as_u16(),
// duration_ms
// );
//
// // TODO: Extract more details from request/response for logging
// // For now, we'll need to modify the server handler to pass additional context
//
// response
// }
/// Context for request logging that can be passed through extensions
#[derive(Clone)]
pub struct LoggingContext {
pub client_id: String,
pub provider_name: String,
pub model: String,
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub cost: f64,
pub has_images: bool,
pub error: Option<AppError>,
}
impl LoggingContext {
pub fn new(client_id: String, provider_name: String, model: String) -> Self {
Self {
client_id,
provider_name,
model,
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
cost: 0.0,
has_images: false,
error: None,
}
}
pub fn with_token_counts(mut self, prompt_tokens: u32, completion_tokens: u32) -> Self {
self.prompt_tokens = prompt_tokens;
self.completion_tokens = completion_tokens;
self.total_tokens = prompt_tokens + completion_tokens;
self
}
pub fn with_cost(mut self, cost: f64) -> Self {
self.cost = cost;
self
}
pub fn with_images(mut self, has_images: bool) -> Self {
self.has_images = has_images;
self
}
pub fn with_error(mut self, error: AppError) -> Self {
self.error = Some(error);
self
}
}

View File

@@ -1,96 +0,0 @@
use anyhow::Result;
use axum::{Router, routing::get};
use std::net::SocketAddr;
use tracing::{error, info};
use llm_proxy::{
config::AppConfig,
dashboard, database,
providers::ProviderManager,
rate_limiting::{CircuitBreakerConfig, RateLimitManager, RateLimiterConfig},
server,
state::AppState,
utils::crypto,
};
#[tokio::main]
async fn main() -> Result<()> {
// Initialize tracing (logging)
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.with_target(false)
.init();
info!("Starting LLM Proxy Gateway v{}", env!("CARGO_PKG_VERSION"));
// Load configuration
let config = AppConfig::load().await?;
info!("Configuration loaded from {:?}", config.config_path);
// Initialize encryption
crypto::init_with_key(&config.encryption_key)?;
info!("Encryption initialized");
// Initialize database connection pool
let db_pool = database::init(&config.database).await?;
info!("Database initialized at {:?}", config.database.path);
// Initialize provider manager with configured providers
let provider_manager = ProviderManager::new();
// Initialize all supported providers (they handle their own enabled check)
let supported_providers = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
for name in supported_providers {
if let Err(e) = provider_manager.initialize_provider(name, &config, &db_pool).await {
error!("Failed to initialize provider {}: {}", name, e);
}
}
// Create rate limit manager
let rate_limit_manager = RateLimitManager::new(RateLimiterConfig::default(), CircuitBreakerConfig::default());
// Fetch model registry from models.dev
let model_registry = match llm_proxy::utils::registry::fetch_registry().await {
Ok(registry) => registry,
Err(e) => {
error!("Failed to fetch model registry: {}. Using empty registry.", e);
llm_proxy::models::registry::ModelRegistry {
providers: std::collections::HashMap::new(),
}
}
};
// Create application state
let state = AppState::new(
config.clone(),
provider_manager,
db_pool,
rate_limit_manager,
model_registry,
config.server.auth_tokens.clone(),
);
// Initialize model config cache and start background refresh (every 30s)
state.model_config_cache.refresh().await;
state.model_config_cache.clone().start_refresh_task(30);
info!("Model config cache initialized");
// Create application router
let app = Router::new()
.route("/health", get(health_check))
.merge(server::router(state.clone()))
.merge(dashboard::router(state.clone()));
// Start server
let addr = SocketAddr::from(([0, 0, 0, 0], config.server.port));
info!("Server listening on http://{}", addr);
let listener = tokio::net::TcpListener::bind(&addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
async fn health_check() -> &'static str {
"OK"
}

View File

@@ -1,381 +0,0 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
pub mod registry;
// ========== OpenAI-compatible Request/Response Structs ==========
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(default)]
pub temperature: Option<f64>,
#[serde(default)]
pub top_p: Option<f64>,
#[serde(default)]
pub top_k: Option<u32>,
#[serde(default)]
pub n: Option<u32>,
#[serde(default)]
pub stop: Option<Value>, // Can be string or array of strings
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub presence_penalty: Option<f64>,
#[serde(default)]
pub frequency_penalty: Option<f64>,
#[serde(default)]
pub stream: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String, // "system", "user", "assistant", "tool"
#[serde(flatten)]
pub content: MessageContent,
#[serde(alias = "reasoning", alias = "thought", skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text { content: String },
Parts { content: Vec<ContentPartValue> },
None, // Handle cases where content might be null but reasoning is present
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPartValue {
Text { text: String },
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
// ========== Tool-Calling Types ==========
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionDef,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDef {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
Mode(String), // "auto", "none", "required"
Specific(ToolChoiceSpecific),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolChoiceSpecific {
#[serde(rename = "type")]
pub choice_type: String,
pub function: ToolChoiceFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolChoiceFunction {
pub name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallDelta {
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
pub call_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<FunctionCallDelta>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCallDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
// ========== OpenAI-compatible Response Structs ==========
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatChoice>,
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_read_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_write_tokens: Option<u32>,
}
// ========== Streaming Response Structs ==========
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionStreamResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatStreamChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatStreamChoice {
pub index: u32,
pub delta: ChatStreamDelta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatStreamDelta {
pub role: Option<String>,
pub content: Option<String>,
#[serde(alias = "reasoning", alias = "thought", skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallDelta>>,
}
// ========== Unified Request Format (for internal use) ==========
#[derive(Debug, Clone)]
pub struct UnifiedRequest {
pub client_id: String,
pub model: String,
pub messages: Vec<UnifiedMessage>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub top_k: Option<u32>,
pub n: Option<u32>,
pub stop: Option<Vec<String>>,
pub max_tokens: Option<u32>,
pub presence_penalty: Option<f64>,
pub frequency_penalty: Option<f64>,
pub stream: bool,
pub has_images: bool,
pub tools: Option<Vec<Tool>>,
pub tool_choice: Option<ToolChoice>,
}
#[derive(Debug, Clone)]
pub struct UnifiedMessage {
pub role: String,
pub content: Vec<ContentPart>,
pub reasoning_content: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
pub name: Option<String>,
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone)]
pub enum ContentPart {
Text { text: String },
Image(crate::multimodal::ImageInput),
}
// ========== Provider-specific Structs ==========
#[derive(Debug, Clone, Serialize)]
pub struct OpenAIRequest {
pub model: String,
pub messages: Vec<OpenAIMessage>,
pub temperature: Option<f64>,
pub max_tokens: Option<u32>,
pub stream: Option<bool>,
}
#[derive(Debug, Clone, Serialize)]
pub struct OpenAIMessage {
pub role: String,
pub content: Vec<OpenAIContentPart>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum OpenAIContentPart {
Text { text: String },
ImageUrl { image_url: ImageUrl },
}
// Note: ImageUrl struct is defined earlier in the file
// ========== Conversion Traits ==========
pub trait ToOpenAI {
fn to_openai(&self) -> Result<OpenAIRequest, anyhow::Error>;
}
pub trait FromOpenAI {
fn from_openai(request: &OpenAIRequest) -> Result<Self, anyhow::Error>
where
Self: Sized;
}
impl UnifiedRequest {
/// Hydrate all image content by fetching URLs and converting to base64/bytes
pub async fn hydrate_images(&mut self) -> anyhow::Result<()> {
if !self.has_images {
return Ok(());
}
for msg in &mut self.messages {
for part in &mut msg.content {
if let ContentPart::Image(image_input) = part {
// Pre-fetch and validate if it's a URL
if let crate::multimodal::ImageInput::Url(_url) = image_input {
let (base64_data, mime_type) = image_input.to_base64().await?;
*image_input = crate::multimodal::ImageInput::Base64 {
data: base64_data,
mime_type,
};
}
}
}
}
Ok(())
}
}
impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
type Error = anyhow::Error;
fn try_from(req: ChatCompletionRequest) -> Result<Self, Self::Error> {
let mut has_images = false;
// Convert OpenAI-compatible request to unified format
let messages = req
.messages
.into_iter()
.map(|msg| {
let (content, _images_in_message) = match msg.content {
MessageContent::Text { content } => (vec![ContentPart::Text { text: content }], false),
MessageContent::Parts { content } => {
let mut unified_content = Vec::new();
let mut has_images_in_msg = false;
for part in content {
match part {
ContentPartValue::Text { text } => {
unified_content.push(ContentPart::Text { text });
}
ContentPartValue::ImageUrl { image_url } => {
has_images_in_msg = true;
has_images = true;
unified_content.push(ContentPart::Image(crate::multimodal::ImageInput::from_url(
image_url.url,
)));
}
}
}
(unified_content, has_images_in_msg)
}
MessageContent::None => (vec![], false),
};
UnifiedMessage {
role: msg.role,
content,
reasoning_content: msg.reasoning_content,
tool_calls: msg.tool_calls,
name: msg.name,
tool_call_id: msg.tool_call_id,
}
})
.collect();
let stop = match req.stop {
Some(Value::String(s)) => Some(vec![s]),
Some(Value::Array(a)) => Some(
a.into_iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect(),
),
_ => None,
};
Ok(UnifiedRequest {
client_id: String::new(), // Will be populated by auth middleware
model: req.model,
messages,
temperature: req.temperature,
top_p: req.top_p,
top_k: req.top_k,
n: req.n,
stop,
max_tokens: req.max_tokens,
presence_penalty: req.presence_penalty,
frequency_penalty: req.frequency_penalty,
stream: req.stream.unwrap_or(false),
has_images,
tools: req.tools,
tool_choice: req.tool_choice,
})
}
}

View File

@@ -1,219 +0,0 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelRegistry {
#[serde(flatten)]
pub providers: HashMap<String, ProviderInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderInfo {
pub id: String,
pub name: String,
pub models: HashMap<String, ModelMetadata>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetadata {
pub id: String,
pub name: String,
pub cost: Option<ModelCost>,
pub limit: Option<ModelLimit>,
pub modalities: Option<ModelModalities>,
pub tool_call: Option<bool>,
pub reasoning: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCost {
pub input: f64,
pub output: f64,
pub cache_read: Option<f64>,
pub cache_write: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelLimit {
pub context: u32,
pub output: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelModalities {
pub input: Vec<String>,
pub output: Vec<String>,
}
/// A model entry paired with its provider ID, returned by listing/filtering methods.
#[derive(Debug, Clone)]
pub struct ModelEntry<'a> {
pub model_key: &'a str,
pub provider_id: &'a str,
pub provider_name: &'a str,
pub metadata: &'a ModelMetadata,
}
/// Filter criteria for listing models. All fields are optional; `None` means no filter.
#[derive(Debug, Default, Clone, Deserialize)]
pub struct ModelFilter {
/// Filter by provider ID (exact match).
pub provider: Option<String>,
/// Text search on model ID or name (case-insensitive substring).
pub search: Option<String>,
/// Filter by input modality (e.g. "image", "text").
pub modality: Option<String>,
/// Only models that support tool calling.
pub tool_call: Option<bool>,
/// Only models that support reasoning.
pub reasoning: Option<bool>,
/// Only models that have pricing data.
pub has_cost: Option<bool>,
}
/// Sort field for model listings.
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ModelSortBy {
#[default]
Name,
Id,
Provider,
ContextLimit,
InputCost,
OutputCost,
}
/// Sort direction.
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum SortOrder {
#[default]
Asc,
Desc,
}
impl ModelRegistry {
/// Find a model by its ID (searching across all providers)
pub fn find_model(&self, model_id: &str) -> Option<&ModelMetadata> {
// First try exact match if the key in models map matches the ID
for provider in self.providers.values() {
if let Some(model) = provider.models.get(model_id) {
return Some(model);
}
}
// Try searching for the model ID inside the metadata if the key was different
for provider in self.providers.values() {
for model in provider.models.values() {
if model.id == model_id {
return Some(model);
}
}
}
None
}
/// List all models with optional filtering and sorting.
pub fn list_models(
&self,
filter: &ModelFilter,
sort_by: &ModelSortBy,
sort_order: &SortOrder,
) -> Vec<ModelEntry<'_>> {
let mut entries: Vec<ModelEntry<'_>> = Vec::new();
for (p_id, p_info) in &self.providers {
// Provider filter
if let Some(ref prov) = filter.provider {
if p_id != prov {
continue;
}
}
for (m_key, m_meta) in &p_info.models {
// Text search filter
if let Some(ref search) = filter.search {
let search_lower = search.to_lowercase();
if !m_meta.id.to_lowercase().contains(&search_lower)
&& !m_meta.name.to_lowercase().contains(&search_lower)
&& !m_key.to_lowercase().contains(&search_lower)
{
continue;
}
}
// Modality filter
if let Some(ref modality) = filter.modality {
let has_modality = m_meta
.modalities
.as_ref()
.is_some_and(|m| m.input.iter().any(|i| i.eq_ignore_ascii_case(modality)));
if !has_modality {
continue;
}
}
// Tool call filter
if let Some(tc) = filter.tool_call {
if m_meta.tool_call.unwrap_or(false) != tc {
continue;
}
}
// Reasoning filter
if let Some(r) = filter.reasoning {
if m_meta.reasoning.unwrap_or(false) != r {
continue;
}
}
// Has cost filter
if let Some(hc) = filter.has_cost {
if hc != m_meta.cost.is_some() {
continue;
}
}
entries.push(ModelEntry {
model_key: m_key,
provider_id: p_id,
provider_name: &p_info.name,
metadata: m_meta,
});
}
}
// Sort
entries.sort_by(|a, b| {
let cmp = match sort_by {
ModelSortBy::Name => a.metadata.name.to_lowercase().cmp(&b.metadata.name.to_lowercase()),
ModelSortBy::Id => a.model_key.cmp(b.model_key),
ModelSortBy::Provider => a.provider_id.cmp(b.provider_id),
ModelSortBy::ContextLimit => {
let a_ctx = a.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0);
let b_ctx = b.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0);
a_ctx.cmp(&b_ctx)
}
ModelSortBy::InputCost => {
let a_cost = a.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
let b_cost = b.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
}
ModelSortBy::OutputCost => {
let a_cost = a.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
let b_cost = b.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
}
};
match sort_order {
SortOrder::Asc => cmp,
SortOrder::Desc => cmp.reverse(),
}
});
entries
}
}

View File

@@ -1,299 +0,0 @@
//! Multimodal support for image processing and conversion
//!
//! This module handles:
//! 1. Image format detection and conversion
//! 2. Base64 encoding/decoding
//! 3. URL fetching for images
//! 4. Provider-specific image format conversion
use anyhow::{Context, Result};
use base64::{Engine as _, engine::general_purpose};
use std::sync::LazyLock;
use tracing::{info, warn};
/// Shared HTTP client for image fetching — avoids creating a new TCP+TLS
/// connection for every image URL.
static IMAGE_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(30))
.pool_idle_timeout(std::time::Duration::from_secs(60))
.build()
.expect("Failed to build image HTTP client")
});
/// Supported image formats for multimodal input
#[derive(Debug, Clone)]
pub enum ImageInput {
/// Base64-encoded image data with MIME type
Base64 { data: String, mime_type: String },
/// URL to fetch image from
Url(String),
/// Raw bytes with MIME type
Bytes { data: Vec<u8>, mime_type: String },
}
impl ImageInput {
/// Create ImageInput from base64 string
pub fn from_base64(data: String, mime_type: String) -> Self {
Self::Base64 { data, mime_type }
}
/// Create ImageInput from URL
pub fn from_url(url: String) -> Self {
Self::Url(url)
}
/// Create ImageInput from raw bytes
pub fn from_bytes(data: Vec<u8>, mime_type: String) -> Self {
Self::Bytes { data, mime_type }
}
/// Get MIME type if available
pub fn mime_type(&self) -> Option<&str> {
match self {
Self::Base64 { mime_type, .. } => Some(mime_type),
Self::Bytes { mime_type, .. } => Some(mime_type),
Self::Url(_) => None,
}
}
/// Convert to base64 if not already
pub async fn to_base64(&self) -> Result<(String, String)> {
match self {
Self::Base64 { data, mime_type } => Ok((data.clone(), mime_type.clone())),
Self::Bytes { data, mime_type } => {
let base64_data = general_purpose::STANDARD.encode(data);
Ok((base64_data, mime_type.clone()))
}
Self::Url(url) => {
// Fetch image from URL using shared client
info!("Fetching image from URL: {}", url);
let response = IMAGE_CLIENT
.get(url)
.send()
.await
.context("Failed to fetch image from URL")?;
if !response.status().is_success() {
anyhow::bail!("Failed to fetch image: HTTP {}", response.status());
}
let mime_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|h| h.to_str().ok())
.unwrap_or("image/jpeg")
.to_string();
let bytes = response.bytes().await.context("Failed to read image bytes")?;
let base64_data = general_purpose::STANDARD.encode(&bytes);
Ok((base64_data, mime_type))
}
}
}
/// Get image dimensions (width, height)
pub async fn get_dimensions(&self) -> Result<(u32, u32)> {
let bytes = match self {
Self::Base64 { data, .. } => general_purpose::STANDARD
.decode(data)
.context("Failed to decode base64")?,
Self::Bytes { data, .. } => data.clone(),
Self::Url(_) => {
let (base64_data, _) = self.to_base64().await?;
general_purpose::STANDARD
.decode(&base64_data)
.context("Failed to decode base64")?
}
};
let img = image::load_from_memory(&bytes).context("Failed to load image from bytes")?;
Ok((img.width(), img.height()))
}
/// Validate image size and format
pub async fn validate(&self, max_size_mb: f64) -> Result<()> {
let (width, height) = self.get_dimensions().await?;
// Check dimensions
if width > 4096 || height > 4096 {
warn!("Image dimensions too large: {}x{}", width, height);
// Continue anyway, but log warning
}
// Check file size
let size_bytes = match self {
Self::Base64 { data, .. } => {
// Base64 size is ~4/3 of original
(data.len() as f64 * 0.75) as usize
}
Self::Bytes { data, .. } => data.len(),
Self::Url(_) => {
// For URLs, we'd need to fetch to check size
// Skip size check for URLs for now
return Ok(());
}
};
let size_mb = size_bytes as f64 / (1024.0 * 1024.0);
if size_mb > max_size_mb {
anyhow::bail!("Image too large: {:.2}MB > {:.2}MB limit", size_mb, max_size_mb);
}
Ok(())
}
}
/// Provider-specific image format conversion
pub struct ImageConverter;
impl ImageConverter {
/// Convert image to OpenAI-compatible format
pub async fn to_openai_format(image: &ImageInput) -> Result<serde_json::Value> {
let (base64_data, mime_type) = image.to_base64().await?;
// OpenAI expects data URL format: "data:image/jpeg;base64,{data}"
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
Ok(serde_json::json!({
"type": "image_url",
"image_url": {
"url": data_url,
"detail": "auto" // Can be "low", "high", or "auto"
}
}))
}
/// Convert image to Gemini-compatible format
pub async fn to_gemini_format(image: &ImageInput) -> Result<serde_json::Value> {
let (base64_data, mime_type) = image.to_base64().await?;
// Gemini expects inline data format
Ok(serde_json::json!({
"inline_data": {
"mime_type": mime_type,
"data": base64_data
}
}))
}
/// Convert image to DeepSeek-compatible format
pub async fn to_deepseek_format(image: &ImageInput) -> Result<serde_json::Value> {
// DeepSeek uses OpenAI-compatible format for vision models
Self::to_openai_format(image).await
}
/// Detect if a model supports multimodal input
pub fn model_supports_multimodal(model: &str) -> bool {
// OpenAI vision models
if (model.starts_with("gpt-4") && (model.contains("vision") || model.contains("-v") || model.contains("4o")))
|| model.starts_with("o1-")
|| model.starts_with("o3-")
{
return true;
}
// Gemini vision models
if model.starts_with("gemini") {
// Most Gemini models support vision
return true;
}
// DeepSeek vision models
if model.starts_with("deepseek-vl") {
return true;
}
false
}
}
/// Parse OpenAI-compatible multimodal message content
pub fn parse_openai_content(content: &serde_json::Value) -> Result<Vec<(String, Option<ImageInput>)>> {
let mut parts = Vec::new();
if let Some(content_str) = content.as_str() {
// Simple text content
parts.push((content_str.to_string(), None));
} else if let Some(content_array) = content.as_array() {
// Array of content parts (text and/or images)
for part in content_array {
if let Some(part_obj) = part.as_object()
&& let Some(part_type) = part_obj.get("type").and_then(|t| t.as_str())
{
match part_type {
"text" => {
if let Some(text) = part_obj.get("text").and_then(|t| t.as_str()) {
parts.push((text.to_string(), None));
}
}
"image_url" => {
if let Some(image_url_obj) = part_obj.get("image_url").and_then(|o| o.as_object())
&& let Some(url) = image_url_obj.get("url").and_then(|u| u.as_str())
{
if url.starts_with("data:") {
// Parse data URL
if let Some((mime_type, data)) = parse_data_url(url) {
let image_input = ImageInput::from_base64(data, mime_type);
parts.push(("".to_string(), Some(image_input)));
}
} else {
// Regular URL
let image_input = ImageInput::from_url(url.to_string());
parts.push(("".to_string(), Some(image_input)));
}
}
}
_ => {
warn!("Unknown content part type: {}", part_type);
}
}
}
}
}
Ok(parts)
}
/// Parse data URL (data:image/jpeg;base64,{data})
fn parse_data_url(data_url: &str) -> Option<(String, String)> {
if !data_url.starts_with("data:") {
return None;
}
let parts: Vec<&str> = data_url[5..].split(";base64,").collect();
if parts.len() != 2 {
return None;
}
let mime_type = parts[0].to_string();
let data = parts[1].to_string();
Some((mime_type, data))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_data_url() {
let test_url = "data:image/jpeg;base64,SGVsbG8gV29ybGQ="; // "Hello World" in base64
let (mime_type, data) = parse_data_url(test_url).unwrap();
assert_eq!(mime_type, "image/jpeg");
assert_eq!(data, "SGVsbG8gV29ybGQ=");
}
#[tokio::test]
async fn test_model_supports_multimodal() {
assert!(ImageConverter::model_supports_multimodal("gpt-4-vision-preview"));
assert!(ImageConverter::model_supports_multimodal("gpt-4o"));
assert!(ImageConverter::model_supports_multimodal("gemini-pro-vision"));
assert!(ImageConverter::model_supports_multimodal("gemini-pro"));
assert!(!ImageConverter::model_supports_multimodal("gpt-3.5-turbo"));
assert!(!ImageConverter::model_supports_multimodal("claude-3-opus"));
}
}

View File

@@ -1,268 +0,0 @@
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use futures::StreamExt;
use super::helpers;
use super::{ProviderResponse, ProviderStreamChunk};
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
pub struct DeepSeekProvider {
client: reqwest::Client,
config: crate::config::DeepSeekConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>,
}
impl DeepSeekProvider {
pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("deepseek")?;
Self::new_with_key(config, app_config, api_key)
}
pub fn new_with_key(
config: &crate::config::DeepSeekConfig,
app_config: &AppConfig,
api_key: String,
) -> Result<Self> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(300))
.pool_idle_timeout(std::time::Duration::from_secs(90))
.pool_max_idle_per_host(4)
.tcp_keepalive(std::time::Duration::from_secs(30))
.build()?;
Ok(Self {
client,
config: config.clone(),
api_key,
pricing: app_config.pricing.deepseek.clone(),
})
}
}
#[async_trait]
impl super::Provider for DeepSeekProvider {
fn name(&self) -> &str {
"deepseek"
}
fn supports_model(&self, model: &str) -> bool {
model.starts_with("deepseek-") || model.contains("deepseek")
}
fn supports_multimodal(&self) -> bool {
false
}
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let mut body = helpers::build_openai_body(&request, messages_json, false);
// Sanitize and fix for deepseek-reasoner (R1)
if request.model == "deepseek-reasoner" {
if let Some(obj) = body.as_object_mut() {
// Remove unsupported parameters
obj.remove("temperature");
obj.remove("top_p");
obj.remove("presence_penalty");
obj.remove("frequency_penalty");
obj.remove("logit_bias");
obj.remove("logprobs");
obj.remove("top_logprobs");
// ENSURE: EVERY assistant message must have reasoning_content and valid content
if let Some(messages) = obj.get_mut("messages").and_then(|m| m.as_array_mut()) {
for m in messages {
if m["role"].as_str() == Some("assistant") {
// DeepSeek R1 requires reasoning_content for consistency in history.
if m.get("reasoning_content").is_none() || m["reasoning_content"].is_null() {
m["reasoning_content"] = serde_json::json!(" ");
}
// DeepSeek R1 often requires content to be a string, not null/array
if m.get("content").is_none() || m["content"].is_null() || m["content"].is_array() {
m["content"] = serde_json::json!("");
}
}
}
}
}
}
let response = self
.client
.post(format!("{}/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
tracing::error!("DeepSeek API error ({}): {}", status, error_text);
tracing::error!("Offending DeepSeek Request Body: {}", serde_json::to_string(&body).unwrap_or_default());
return Err(AppError::ProviderError(format!("DeepSeek API error ({}): {}", status, error_text)));
}
let resp_json: serde_json::Value = response
.json()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
helpers::parse_openai_response(&resp_json, request.model)
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(
&self,
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
cache_read_tokens: u32,
cache_write_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
) -> f64 {
if let Some(metadata) = registry.find_model(model) {
if metadata.cost.is_some() {
return helpers::calculate_cost_with_registry(
model,
prompt_tokens,
completion_tokens,
cache_read_tokens,
cache_write_tokens,
registry,
&self.pricing,
0.28,
0.42,
);
}
}
// Custom DeepSeek fallback that correctly handles cache hits
let (prompt_rate, completion_rate) = self
.pricing
.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((0.28, 0.42)); // Default to DeepSeek's current API pricing
let cache_hit_rate = prompt_rate / 10.0;
let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens);
(non_cached_prompt as f64 * prompt_rate / 1_000_000.0)
+ (cache_read_tokens as f64 * cache_hit_rate / 1_000_000.0)
+ (completion_tokens as f64 * completion_rate / 1_000_000.0)
}
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
// DeepSeek doesn't support images in streaming, use text-only
let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?;
let mut body = helpers::build_openai_body(&request, messages_json, true);
// Sanitize and fix for deepseek-reasoner (R1)
if request.model == "deepseek-reasoner" {
if let Some(obj) = body.as_object_mut() {
// Keep stream_options if present (DeepSeek supports include_usage)
// Remove unsupported parameters
obj.remove("temperature");
obj.remove("top_p");
obj.remove("presence_penalty");
obj.remove("frequency_penalty");
obj.remove("logit_bias");
obj.remove("logprobs");
obj.remove("top_logprobs");
// ENSURE: EVERY assistant message must have reasoning_content and valid content
if let Some(messages) = obj.get_mut("messages").and_then(|m| m.as_array_mut()) {
for m in messages {
if m["role"].as_str() == Some("assistant") {
// DeepSeek R1 requires reasoning_content for consistency in history.
if m.get("reasoning_content").is_none() || m["reasoning_content"].is_null() {
m["reasoning_content"] = serde_json::json!(" ");
}
// DeepSeek R1 often requires content to be a string, not null/array
if m.get("content").is_none() || m["content"].is_null() || m["content"].is_array() {
m["content"] = serde_json::json!("");
}
}
}
}
}
}
let url = format!("{}/chat/completions", self.config.base_url);
let api_key = self.api_key.clone();
let probe_client = self.client.clone();
let probe_body = body.clone();
let model = request.model.clone();
let es = reqwest_eventsource::EventSource::new(
self.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body),
)
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(reqwest_eventsource::Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: serde_json::Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(p_chunk) = helpers::parse_openai_stream_chunk(&chunk, &model, None) {
yield p_chunk?;
}
}
Ok(_) => continue,
Err(e) => {
// Attempt to probe for the actual error body
let probe_resp = probe_client
.post(&url)
.header("Authorization", format!("Bearer {}", api_key))
.json(&probe_body)
.send()
.await;
match probe_resp {
Ok(r) if !r.status().is_success() => {
let status = r.status();
let error_body = r.text().await.unwrap_or_default();
tracing::error!("DeepSeek Stream Error Probe ({}): {}", status, error_body);
// Log the offending request body at ERROR level so it shows up in standard logs
tracing::error!("Offending DeepSeek Request Body: {}", serde_json::to_string(&probe_body).unwrap_or_default());
Err(AppError::ProviderError(format!("DeepSeek API error ({}): {}", status, error_body)))?;
}
Ok(_) => {
Err(AppError::ProviderError(format!("Stream error (probe returned 200): {}", e)))?;
}
Err(probe_err) => {
tracing::error!("DeepSeek Stream Error Probe failed: {}", probe_err);
Err(AppError::ProviderError(format!("Stream error (probe failed: {}): {}", probe_err, e)))?;
}
}
}
}
}
};
Ok(Box::pin(stream))
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,123 +0,0 @@
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use super::helpers;
use super::{ProviderResponse, ProviderStreamChunk};
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
pub struct GrokProvider {
client: reqwest::Client,
config: crate::config::GrokConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>,
}
impl GrokProvider {
pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("grok")?;
Self::new_with_key(config, app_config, api_key)
}
pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(300))
.pool_idle_timeout(std::time::Duration::from_secs(90))
.pool_max_idle_per_host(4)
.tcp_keepalive(std::time::Duration::from_secs(30))
.build()?;
Ok(Self {
client,
config: config.clone(),
api_key,
pricing: app_config.pricing.grok.clone(),
})
}
}
#[async_trait]
impl super::Provider for GrokProvider {
fn name(&self) -> &str {
"grok"
}
fn supports_model(&self, model: &str) -> bool {
model.starts_with("grok-")
}
fn supports_multimodal(&self) -> bool {
true
}
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let body = helpers::build_openai_body(&request, messages_json, false);
let response = self
.client
.post(format!("{}/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("Grok API error: {}", error_text)));
}
let resp_json: serde_json::Value = response
.json()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
helpers::parse_openai_response(&resp_json, request.model)
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(
&self,
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
cache_read_tokens: u32,
cache_write_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
) -> f64 {
helpers::calculate_cost_with_registry(
model,
prompt_tokens,
completion_tokens,
cache_read_tokens,
cache_write_tokens,
registry,
&self.pricing,
5.0,
15.0,
)
}
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let body = helpers::build_openai_body(&request, messages_json, true);
let es = reqwest_eventsource::EventSource::new(
self.client
.post(format!("{}/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body),
)
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
Ok(helpers::create_openai_stream(es, request.model, None))
}
}

View File

@@ -1,454 +0,0 @@
use super::{ProviderResponse, ProviderStreamChunk, StreamUsage};
use crate::errors::AppError;
use crate::models::{ContentPart, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest};
use futures::stream::{BoxStream, StreamExt};
use serde_json::Value;
/// Convert messages to OpenAI-compatible JSON, resolving images asynchronously.
///
/// This avoids the deadlock caused by `futures::executor::block_on` inside a
/// Tokio async context. All image base64 conversions are awaited properly.
/// Handles tool-calling messages: assistant messages with tool_calls, and
/// tool-role messages with tool_call_id/name.
pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result<Vec<serde_json::Value>, AppError> {
let mut result = Vec::new();
for m in messages {
// Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." }
if m.role == "tool" {
let text_content = m
.content
.first()
.map(|p| match p {
ContentPart::Text { text } => text.clone(),
ContentPart::Image(_) => "[Image]".to_string(),
})
.unwrap_or_default();
let mut msg = serde_json::json!({
"role": "tool",
"content": text_content
});
if let Some(tool_call_id) = &m.tool_call_id {
// OpenAI and others have a 40-char limit for tool_call_id.
// Gemini signatures (56 chars) must be shortened for compatibility.
let id = if tool_call_id.len() > 40 {
&tool_call_id[..40]
} else {
tool_call_id
};
msg["tool_call_id"] = serde_json::json!(id);
}
if let Some(name) = &m.name {
msg["name"] = serde_json::json!(name);
}
result.push(msg);
continue;
}
// Build content parts for non-tool messages
let mut parts = Vec::new();
for p in &m.content {
match p {
ContentPart::Text { text } => {
parts.push(serde_json::json!({ "type": "text", "text": text }));
}
ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input
.to_base64()
.await
.map_err(|e| AppError::MultimodalError(e.to_string()))?;
parts.push(serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
}));
}
}
}
let mut msg = serde_json::json!({ "role": m.role });
// Include reasoning_content if present (DeepSeek R1/reasoner requires this in history)
if let Some(reasoning) = &m.reasoning_content {
msg["reasoning_content"] = serde_json::json!(reasoning);
}
// For assistant messages with tool_calls, content can be empty string
if let Some(tool_calls) = &m.tool_calls {
// Sanitize tool call IDs for OpenAI compatibility (max 40 chars)
let sanitized_calls: Vec<_> = tool_calls.iter().map(|tc| {
let mut sanitized = tc.clone();
if sanitized.id.len() > 40 {
sanitized.id = sanitized.id[..40].to_string();
}
sanitized
}).collect();
if parts.is_empty() {
msg["content"] = serde_json::json!("");
} else {
msg["content"] = serde_json::json!(parts);
}
msg["tool_calls"] = serde_json::json!(sanitized_calls);
} else {
msg["content"] = serde_json::json!(parts);
}
if let Some(name) = &m.name {
msg["name"] = serde_json::json!(name);
}
result.push(msg);
}
Ok(result)
}
/// Convert messages to OpenAI-compatible JSON, but replace images with a
/// text placeholder "[Image]". Useful for providers that don't support
/// multimodal in streaming mode or at all.
///
/// Handles tool-calling messages identically to `messages_to_openai_json`:
/// assistant messages with `tool_calls`, and tool-role messages with
/// `tool_call_id`/`name`.
pub async fn messages_to_openai_json_text_only(
messages: &[UnifiedMessage],
) -> Result<Vec<serde_json::Value>, AppError> {
let mut result = Vec::new();
for m in messages {
// Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." }
if m.role == "tool" {
let text_content = m
.content
.first()
.map(|p| match p {
ContentPart::Text { text } => text.clone(),
ContentPart::Image(_) => "[Image]".to_string(),
})
.unwrap_or_default();
let mut msg = serde_json::json!({
"role": "tool",
"content": text_content
});
if let Some(tool_call_id) = &m.tool_call_id {
// OpenAI and others have a 40-char limit for tool_call_id.
let id = if tool_call_id.len() > 40 {
&tool_call_id[..40]
} else {
tool_call_id
};
msg["tool_call_id"] = serde_json::json!(id);
}
if let Some(name) = &m.name {
msg["name"] = serde_json::json!(name);
}
result.push(msg);
continue;
}
// Build content parts for non-tool messages (images become "[Image]" text)
let mut parts = Vec::new();
for p in &m.content {
match p {
ContentPart::Text { text } => {
parts.push(serde_json::json!({ "type": "text", "text": text }));
}
ContentPart::Image(_) => {
parts.push(serde_json::json!({ "type": "text", "text": "[Image]" }));
}
}
}
let mut msg = serde_json::json!({ "role": m.role });
// Include reasoning_content if present (DeepSeek R1/reasoner requires this in history)
if let Some(reasoning) = &m.reasoning_content {
msg["reasoning_content"] = serde_json::json!(reasoning);
}
// For assistant messages with tool_calls, content can be empty string
if let Some(tool_calls) = &m.tool_calls {
// Sanitize tool call IDs for OpenAI compatibility (max 40 chars)
let sanitized_calls: Vec<_> = tool_calls.iter().map(|tc| {
let mut sanitized = tc.clone();
if sanitized.id.len() > 40 {
sanitized.id = sanitized.id[..40].to_string();
}
sanitized
}).collect();
if parts.is_empty() {
msg["content"] = serde_json::json!("");
} else {
msg["content"] = serde_json::json!(parts);
}
msg["tool_calls"] = serde_json::json!(sanitized_calls);
} else {
msg["content"] = serde_json::json!(parts);
}
if let Some(name) = &m.name {
msg["name"] = serde_json::json!(name);
}
result.push(msg);
}
Ok(result)
}
/// Build an OpenAI-compatible request body from a UnifiedRequest and pre-converted messages.
/// Includes tools and tool_choice when present.
/// When streaming, adds `stream_options.include_usage: true` so providers report
/// token counts in the final SSE chunk.
pub fn build_openai_body(
request: &UnifiedRequest,
messages_json: Vec<serde_json::Value>,
stream: bool,
) -> serde_json::Value {
let mut body = serde_json::json!({
"model": request.model,
"messages": messages_json,
"stream": stream,
});
if stream {
body["stream_options"] = serde_json::json!({ "include_usage": true });
}
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
if let Some(tools) = &request.tools {
body["tools"] = serde_json::json!(tools);
}
if let Some(tool_choice) = &request.tool_choice {
body["tool_choice"] = serde_json::json!(tool_choice);
}
body
}
/// Parse an OpenAI-compatible chat completion response JSON into a ProviderResponse.
/// Extracts tool_calls from the message when present.
/// Extracts cache token counts from:
/// - OpenAI/Grok: `usage.prompt_tokens_details.cached_tokens`
/// - DeepSeek: `usage.prompt_cache_hit_tokens` / `usage.prompt_cache_miss_tokens`
pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<ProviderResponse, AppError> {
let choice = resp_json["choices"]
.get(0)
.ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
let message = &choice["message"];
let content = message["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
// Parse tool_calls from the response message
let tool_calls: Option<Vec<ToolCall>> = message
.get("tool_calls")
.and_then(|tc| serde_json::from_value(tc.clone()).ok());
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
// Extract reasoning tokens
let reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"]
.as_u64()
.unwrap_or(0) as u32;
// Extract cache tokens — try OpenAI/Grok format first, then DeepSeek format
let cache_read_tokens = usage["prompt_tokens_details"]["cached_tokens"]
.as_u64()
// DeepSeek uses a different field name
.or_else(|| usage["prompt_cache_hit_tokens"].as_u64())
.unwrap_or(0) as u32;
// DeepSeek reports prompt_cache_miss_tokens which are just regular non-cached tokens.
// They do not incur a separate cache_write fee, so we don't map them here to avoid double-charging.
let cache_write_tokens = 0;
Ok(ProviderResponse {
content,
reasoning_content,
tool_calls,
prompt_tokens,
completion_tokens,
reasoning_tokens,
total_tokens,
cache_read_tokens,
cache_write_tokens,
model,
})
}
/// Parse a single OpenAI-compatible stream chunk into a ProviderStreamChunk.
/// Returns None if the chunk should be skipped (e.g. promptFeedback).
pub fn parse_openai_stream_chunk(
chunk: &Value,
model: &str,
reasoning_field: Option<&'static str>,
) -> Option<Result<ProviderStreamChunk, AppError>> {
// Parse usage from the final chunk (sent when stream_options.include_usage is true).
// This chunk may have an empty `choices` array.
let stream_usage = chunk.get("usage").and_then(|u| {
if u.is_null() {
return None;
}
let prompt_tokens = u["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = u["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = u["total_tokens"].as_u64().unwrap_or(0) as u32;
let reasoning_tokens = u["completion_tokens_details"]["reasoning_tokens"]
.as_u64()
.unwrap_or(0) as u32;
let cache_read_tokens = u["prompt_tokens_details"]["cached_tokens"]
.as_u64()
.or_else(|| u["prompt_cache_hit_tokens"].as_u64())
.unwrap_or(0) as u32;
let cache_write_tokens = 0;
Some(StreamUsage {
prompt_tokens,
completion_tokens,
reasoning_tokens,
total_tokens,
cache_read_tokens,
cache_write_tokens,
})
});
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"]
.as_str()
.or_else(|| reasoning_field.and_then(|f| delta[f].as_str()))
.map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
// Parse tool_calls deltas from the stream chunk
let tool_calls: Option<Vec<ToolCallDelta>> = delta
.get("tool_calls")
.and_then(|tc| serde_json::from_value(tc.clone()).ok());
Some(Ok(ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
tool_calls,
model: model.to_string(),
usage: stream_usage,
}))
} else if stream_usage.is_some() {
// Final usage-only chunk (empty choices array) — yield it so
// AggregatingStream can capture the real token counts.
Some(Ok(ProviderStreamChunk {
content: String::new(),
reasoning_content: None,
finish_reason: None,
tool_calls: None,
model: model.to_string(),
usage: stream_usage,
}))
} else {
None
}
}
/// Create an SSE stream that parses OpenAI-compatible streaming chunks.
///
/// The optional `reasoning_field` allows overriding the field name for
/// reasoning content (e.g., "thought" for Ollama).
/// Parses tool_calls deltas from streaming chunks when present.
/// When `stream_options.include_usage: true` was sent, the provider sends a
/// final chunk with `usage` data — this is parsed into `StreamUsage` and
/// attached to the yielded `ProviderStreamChunk`.
pub fn create_openai_stream(
es: reqwest_eventsource::EventSource,
model: String,
reasoning_field: Option<&'static str>,
) -> BoxStream<'static, Result<ProviderStreamChunk, AppError>> {
use reqwest_eventsource::Event;
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(p_chunk) = parse_openai_stream_chunk(&chunk, &model, reasoning_field) {
yield p_chunk?;
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Box::pin(stream)
}
/// Calculate cost using the model registry first, then falling back to provider pricing config.
///
/// When the registry provides `cache_read` / `cache_write` rates, the formula is:
/// (prompt_tokens - cache_read_tokens) * input_rate
/// + cache_read_tokens * cache_read_rate
/// + cache_write_tokens * cache_write_rate (if applicable)
/// + completion_tokens * output_rate
///
/// All rates are per-token (the registry stores per-million-token rates).
pub fn calculate_cost_with_registry(
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
cache_read_tokens: u32,
cache_write_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
pricing: &[crate::config::ModelPricing],
default_prompt_rate: f64,
default_completion_rate: f64,
) -> f64 {
if let Some(metadata) = registry.find_model(model)
&& let Some(cost) = &metadata.cost
{
let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens);
let mut total = (non_cached_prompt as f64 * cost.input / 1_000_000.0)
+ (completion_tokens as f64 * cost.output / 1_000_000.0);
if let Some(cache_read_rate) = cost.cache_read {
total += cache_read_tokens as f64 * cache_read_rate / 1_000_000.0;
} else {
// No cache_read rate — charge cached tokens at full input rate
total += cache_read_tokens as f64 * cost.input / 1_000_000.0;
}
if let Some(cache_write_rate) = cost.cache_write {
total += cache_write_tokens as f64 * cache_write_rate / 1_000_000.0;
}
return total;
}
// Fallback: no registry entry — use provider pricing config (no cache awareness)
let (prompt_rate, completion_rate) = pricing
.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((default_prompt_rate, default_completion_rate));
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
}

View File

@@ -1,365 +0,0 @@
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use sqlx::Row;
use std::sync::Arc;
use crate::errors::AppError;
use crate::models::UnifiedRequest;
pub mod deepseek;
pub mod gemini;
pub mod grok;
pub mod helpers;
pub mod ollama;
pub mod openai;
#[async_trait]
pub trait Provider: Send + Sync {
/// Get provider name (e.g., "openai", "gemini")
fn name(&self) -> &str;
/// Check if provider supports a specific model
fn supports_model(&self, model: &str) -> bool;
/// Check if provider supports multimodal (images, etc.)
fn supports_multimodal(&self) -> bool;
/// Process a chat completion request
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError>;
/// Process a chat request using provider-specific "responses" style endpoint
/// Default implementation falls back to `chat_completion` for providers
/// that do not implement a dedicated responses endpoint.
async fn chat_responses(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
self.chat_completion(request).await
}
/// Process a streaming chat completion request
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError>;
/// Process a streaming chat request using provider-specific "responses" style endpoint
/// Default implementation falls back to `chat_completion_stream` for providers
/// that do not implement a dedicated responses endpoint.
async fn chat_responses_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
self.chat_completion_stream(request).await
}
/// Estimate token count for a request (for cost calculation)
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32>;
/// Calculate cost based on token usage and model using the registry.
/// `cache_read_tokens` / `cache_write_tokens` allow cache-aware pricing
/// when the registry provides `cache_read` / `cache_write` rates.
fn calculate_cost(
&self,
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
cache_read_tokens: u32,
cache_write_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
) -> f64;
}
pub struct ProviderResponse {
pub content: String,
pub reasoning_content: Option<String>,
pub tool_calls: Option<Vec<crate::models::ToolCall>>,
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub reasoning_tokens: u32,
pub total_tokens: u32,
pub cache_read_tokens: u32,
pub cache_write_tokens: u32,
pub model: String,
}
/// Usage data from the final streaming chunk (when providers report real token counts).
#[derive(Debug, Clone, Default)]
pub struct StreamUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub reasoning_tokens: u32,
pub total_tokens: u32,
pub cache_read_tokens: u32,
pub cache_write_tokens: u32,
}
#[derive(Debug, Clone)]
pub struct ProviderStreamChunk {
pub content: String,
pub reasoning_content: Option<String>,
pub finish_reason: Option<String>,
pub tool_calls: Option<Vec<crate::models::ToolCallDelta>>,
pub model: String,
/// Populated only on the final chunk when providers report usage (e.g. stream_options.include_usage).
pub usage: Option<StreamUsage>,
}
use tokio::sync::RwLock;
use crate::config::AppConfig;
use crate::providers::{
deepseek::DeepSeekProvider, gemini::GeminiProvider, grok::GrokProvider, ollama::OllamaProvider,
openai::OpenAIProvider,
};
#[derive(Clone)]
pub struct ProviderManager {
providers: Arc<RwLock<Vec<Arc<dyn Provider>>>>,
}
impl Default for ProviderManager {
fn default() -> Self {
Self::new()
}
}
impl ProviderManager {
pub fn new() -> Self {
Self {
providers: Arc::new(RwLock::new(Vec::new())),
}
}
/// Initialize a provider by name using config and database overrides
pub async fn initialize_provider(
&self,
name: &str,
app_config: &AppConfig,
db_pool: &crate::database::DbPool,
) -> Result<()> {
// Load override from database
let db_config = sqlx::query("SELECT enabled, base_url, api_key, api_key_encrypted FROM provider_configs WHERE id = ?")
.bind(name)
.fetch_optional(db_pool)
.await?;
let (enabled, base_url, api_key) = if let Some(row) = db_config {
let enabled = row.get::<bool, _>("enabled");
let base_url = row.get::<Option<String>, _>("base_url");
let api_key_encrypted = row.get::<bool, _>("api_key_encrypted");
let api_key = row.get::<Option<String>, _>("api_key");
// Decrypt API key if encrypted
let api_key = match (api_key, api_key_encrypted) {
(Some(key), true) => {
match crate::utils::crypto::decrypt(&key) {
Ok(decrypted) => Some(decrypted),
Err(e) => {
tracing::error!("Failed to decrypt API key for provider {}: {}", name, e);
None
}
}
}
(Some(key), false) => {
// Plaintext key - optionally encrypt and update database (lazy migration)
// For now, just use plaintext
Some(key)
}
(None, _) => None,
};
(enabled, base_url, api_key)
} else {
// No database override, use defaults from AppConfig
match name {
"openai" => (
app_config.providers.openai.enabled,
Some(app_config.providers.openai.base_url.clone()),
None,
),
"gemini" => (
app_config.providers.gemini.enabled,
Some(app_config.providers.gemini.base_url.clone()),
None,
),
"deepseek" => (
app_config.providers.deepseek.enabled,
Some(app_config.providers.deepseek.base_url.clone()),
None,
),
"grok" => (
app_config.providers.grok.enabled,
Some(app_config.providers.grok.base_url.clone()),
None,
),
"ollama" => (
app_config.providers.ollama.enabled,
Some(app_config.providers.ollama.base_url.clone()),
None,
),
_ => (false, None, None),
}
};
if !enabled {
self.remove_provider(name).await;
return Ok(());
}
// Create provider instance with merged config
let provider: Arc<dyn Provider> = match name {
"openai" => {
let mut cfg = app_config.providers.openai.clone();
if let Some(url) = base_url {
cfg.base_url = url;
}
// Handle API key override if present
let p = if let Some(key) = api_key {
// We need a way to create a provider with an explicit key
// Let's modify the providers to allow this
OpenAIProvider::new_with_key(&cfg, app_config, key)?
} else {
OpenAIProvider::new(&cfg, app_config)?
};
Arc::new(p)
}
"ollama" => {
let mut cfg = app_config.providers.ollama.clone();
if let Some(url) = base_url {
cfg.base_url = url;
}
Arc::new(OllamaProvider::new(&cfg, app_config)?)
}
"gemini" => {
let mut cfg = app_config.providers.gemini.clone();
if let Some(url) = base_url {
cfg.base_url = url;
}
let p = if let Some(key) = api_key {
GeminiProvider::new_with_key(&cfg, app_config, key)?
} else {
GeminiProvider::new(&cfg, app_config)?
};
Arc::new(p)
}
"deepseek" => {
let mut cfg = app_config.providers.deepseek.clone();
if let Some(url) = base_url {
cfg.base_url = url;
}
let p = if let Some(key) = api_key {
DeepSeekProvider::new_with_key(&cfg, app_config, key)?
} else {
DeepSeekProvider::new(&cfg, app_config)?
};
Arc::new(p)
}
"grok" => {
let mut cfg = app_config.providers.grok.clone();
if let Some(url) = base_url {
cfg.base_url = url;
}
let p = if let Some(key) = api_key {
GrokProvider::new_with_key(&cfg, app_config, key)?
} else {
GrokProvider::new(&cfg, app_config)?
};
Arc::new(p)
}
_ => return Err(anyhow::anyhow!("Unknown provider: {}", name)),
};
self.add_provider(provider).await;
Ok(())
}
pub async fn add_provider(&self, provider: Arc<dyn Provider>) {
let mut providers = self.providers.write().await;
// If provider with same name exists, replace it
if let Some(index) = providers.iter().position(|p| p.name() == provider.name()) {
providers[index] = provider;
} else {
providers.push(provider);
}
}
pub async fn remove_provider(&self, name: &str) {
let mut providers = self.providers.write().await;
providers.retain(|p| p.name() != name);
}
pub async fn get_provider_for_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
let providers = self.providers.read().await;
providers.iter().find(|p| p.supports_model(model)).map(Arc::clone)
}
pub async fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
let providers = self.providers.read().await;
providers.iter().find(|p| p.name() == name).map(Arc::clone)
}
pub async fn get_all_providers(&self) -> Vec<Arc<dyn Provider>> {
let providers = self.providers.read().await;
providers.clone()
}
}
// Create placeholder provider implementations
pub mod placeholder {
use super::*;
pub struct PlaceholderProvider {
name: String,
}
impl PlaceholderProvider {
pub fn new(name: &str) -> Self {
Self { name: name.to_string() }
}
}
#[async_trait]
impl Provider for PlaceholderProvider {
fn name(&self) -> &str {
&self.name
}
fn supports_model(&self, _model: &str) -> bool {
false
}
fn supports_multimodal(&self) -> bool {
false
}
async fn chat_completion_stream(
&self,
_request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
Err(AppError::ProviderError(
"Streaming not supported for placeholder provider".to_string(),
))
}
async fn chat_completion(&self, _request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
Err(AppError::ProviderError(format!(
"Provider {} not implemented",
self.name
)))
}
fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result<u32> {
Ok(0)
}
fn calculate_cost(
&self,
_model: &str,
_prompt_tokens: u32,
_completion_tokens: u32,
_cache_read_tokens: u32,
_cache_write_tokens: u32,
_registry: &crate::models::registry::ModelRegistry,
) -> f64 {
0.0
}
}
}

View File

@@ -1,140 +0,0 @@
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use super::helpers;
use super::{ProviderResponse, ProviderStreamChunk};
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
pub struct OllamaProvider {
client: reqwest::Client,
config: crate::config::OllamaConfig,
pricing: Vec<crate::config::ModelPricing>,
}
impl OllamaProvider {
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(300))
.pool_idle_timeout(std::time::Duration::from_secs(90))
.pool_max_idle_per_host(4)
.tcp_keepalive(std::time::Duration::from_secs(30))
.build()?;
Ok(Self {
client,
config: config.clone(),
pricing: app_config.pricing.ollama.clone(),
})
}
}
#[async_trait]
impl super::Provider for OllamaProvider {
fn name(&self) -> &str {
"ollama"
}
fn supports_model(&self, model: &str) -> bool {
self.config.models.iter().any(|m| m == model) || model.starts_with("ollama/")
}
fn supports_multimodal(&self) -> bool {
true
}
async fn chat_completion(&self, mut request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
// Strip "ollama/" prefix if present for the API call
let api_model = request
.model
.strip_prefix("ollama/")
.unwrap_or(&request.model)
.to_string();
let original_model = request.model.clone();
request.model = api_model;
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let body = helpers::build_openai_body(&request, messages_json, false);
let response = self
.client
.post(format!("{}/chat/completions", self.config.base_url))
.json(&body)
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("Ollama API error: {}", error_text)));
}
let resp_json: serde_json::Value = response
.json()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Ollama also supports "thought" as an alias for reasoning_content
let mut result = helpers::parse_openai_response(&resp_json, original_model)?;
if result.reasoning_content.is_none() {
result.reasoning_content = resp_json["choices"]
.get(0)
.and_then(|c| c["message"]["thought"].as_str())
.map(|s| s.to_string());
}
Ok(result)
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(
&self,
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
cache_read_tokens: u32,
cache_write_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
) -> f64 {
helpers::calculate_cost_with_registry(
model,
prompt_tokens,
completion_tokens,
cache_read_tokens,
cache_write_tokens,
registry,
&self.pricing,
0.0,
0.0,
)
}
async fn chat_completion_stream(
&self,
mut request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
let api_model = request
.model
.strip_prefix("ollama/")
.unwrap_or(&request.model)
.to_string();
let original_model = request.model.clone();
request.model = api_model;
let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?;
let body = helpers::build_openai_body(&request, messages_json, true);
let es = reqwest_eventsource::EventSource::new(
self.client
.post(format!("{}/chat/completions", self.config.base_url))
.json(&body),
)
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
// Ollama uses "thought" as an alternative field for reasoning content
Ok(helpers::create_openai_stream(es, original_model, Some("thought")))
}
}

View File

@@ -1,940 +0,0 @@
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use futures::StreamExt;
use super::helpers;
use super::{ProviderResponse, ProviderStreamChunk};
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
pub struct OpenAIProvider {
client: reqwest::Client,
config: crate::config::OpenAIConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>,
}
impl OpenAIProvider {
pub fn new(config: &crate::config::OpenAIConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("openai")?;
Self::new_with_key(config, app_config, api_key)
}
pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(300))
.pool_idle_timeout(std::time::Duration::from_secs(90))
.pool_max_idle_per_host(4)
.tcp_keepalive(std::time::Duration::from_secs(15))
.build()?;
Ok(Self {
client,
config: config.clone(),
api_key,
pricing: app_config.pricing.openai.clone(),
})
}
/// GPT-5.4 models sometimes emit parallel tool calls as a JSON block starting with
/// '{"tool_uses":' inside a text message instead of discrete function_call items.
/// This method attempts to extract and parse such tool calls.
pub fn parse_tool_uses_json(text: &str) -> Vec<crate::models::ToolCall> {
let mut calls = Vec::new();
if let Some(start) = text.find("{\"tool_uses\":") {
// ... (rest of method unchanged)
// Find the end of the JSON block by matching braces
let sub = &text[start..];
let mut brace_count = 0;
let mut end_idx = 0;
let mut found = false;
for (i, c) in sub.char_indices() {
if c == '{' { brace_count += 1; }
else if c == '}' {
brace_count -= 1;
if brace_count == 0 {
end_idx = i + 1;
found = true;
break;
}
}
}
if found {
let json_str = &sub[..end_idx];
if let Ok(val) = serde_json::from_str::<serde_json::Value>(json_str) {
if let Some(uses) = val.get("tool_uses").and_then(|u| u.as_array()) {
for (idx, u) in uses.iter().enumerate() {
let name = u.get("recipient_name")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
// Strip "functions." prefix if present
.replace("functions.", "");
let arguments = u.get("parameters")
.map(|v| v.to_string())
.unwrap_or_else(|| "{}".to_string());
calls.push(crate::models::ToolCall {
id: format!("call_tu_{}_{}", uuid::Uuid::new_v4().to_string()[..8].to_string(), idx),
call_type: "function".to_string(),
function: crate::models::FunctionCall { name, arguments },
});
}
}
}
}
}
calls
}
/// Strips internal metadata prefixes like 'to=multi_tool_use.parallel' from model responses.
pub fn strip_internal_metadata(text: &str) -> String {
let mut result = text.to_string();
// Patterns to strip
let patterns = [
"to=multi_tool_use.parallel",
"to=functions.multi_tool_use",
"ส่งเงินบาทไทยjson", // User reported Thai text preamble
];
for p in patterns {
if let Some(start) = result.find(p) {
// Remove the pattern and any whitespace around it
result.replace_range(start..start + p.len(), "");
}
}
result.trim().to_string()
}
}
#[async_trait]
impl super::Provider for OpenAIProvider {
fn name(&self) -> &str {
"openai"
}
fn supports_model(&self, model: &str) -> bool {
model.starts_with("gpt-") ||
model.starts_with("o1-") ||
model.starts_with("o2-") ||
model.starts_with("o3-") ||
model.starts_with("o4-") ||
model.starts_with("o5-") ||
model.contains("gpt-5")
}
fn supports_multimodal(&self) -> bool {
true
}
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
// Allow proactive routing to Responses API based on heuristic
let model_lc = request.model.to_lowercase();
if model_lc.contains("gpt-5") || model_lc.contains("codex") {
return self.chat_responses(request).await;
}
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let mut body = helpers::build_openai_body(&request, messages_json, false);
// Transition: Newer OpenAI models (o1, o3, gpt-5) require max_completion_tokens
// instead of the legacy max_tokens parameter.
if request.model.starts_with("o1-") || request.model.starts_with("o3-") || request.model.contains("gpt-5") {
if let Some(max_tokens) = body.as_object_mut().and_then(|obj| obj.remove("max_tokens")) {
body["max_completion_tokens"] = max_tokens;
}
}
let response = self
.client
.post(format!("{}/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
// Read error body to diagnose. If the model requires the Responses
// API (v1/responses), retry against that endpoint.
if error_text.to_lowercase().contains("v1/responses") || error_text.to_lowercase().contains("only supported in v1/responses") {
return self.chat_responses(request).await;
}
tracing::error!("OpenAI API error ({}): {}", status, error_text);
return Err(AppError::ProviderError(format!("OpenAI API error ({}): {}", status, error_text)));
}
let resp_json: serde_json::Value = response
.json()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
helpers::parse_openai_response(&resp_json, request.model)
}
async fn chat_responses(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
// Build a structured input for the Responses API.
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let mut input_parts = Vec::new();
for m in &messages_json {
let role = m["role"].as_str().unwrap_or("user");
if role == "tool" {
input_parts.push(serde_json::json!({
"type": "function_call_output",
"call_id": m.get("tool_call_id").and_then(|v| v.as_str()).unwrap_or(""),
"output": m.get("content").and_then(|v| v.as_str()).unwrap_or("")
}));
continue;
}
if role == "assistant" && m.get("tool_calls").is_some() {
// Push message part if it exists
let content_val = m.get("content").cloned().unwrap_or(serde_json::json!(""));
if !content_val.is_null() && (content_val.is_array() && !content_val.as_array().unwrap().is_empty() || content_val.is_string() && !content_val.as_str().unwrap().is_empty()) {
let mut content = content_val.clone();
if let Some(text) = content.as_str() {
content = serde_json::json!([{ "type": "output_text", "text": text }]);
} else if let Some(arr) = content.as_array_mut() {
for part in arr {
if let Some(obj) = part.as_object_mut() {
if obj.get("type").and_then(|v| v.as_str()) == Some("text") {
obj.insert("type".to_string(), serde_json::json!("output_text"));
}
}
}
}
input_parts.push(serde_json::json!({
"type": "message",
"role": "assistant",
"content": content
}));
}
// Push tool calls as separate items
if let Some(tcs) = m.get("tool_calls").and_then(|v| v.as_array()) {
for tc in tcs {
input_parts.push(serde_json::json!({
"type": "function_call",
"call_id": tc["id"],
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}));
}
}
continue;
}
let mut mapped_role = role.to_string();
if mapped_role == "system" {
mapped_role = "developer".to_string();
}
let mut content = m.get("content").cloned().unwrap_or(serde_json::json!([]));
// Map content types based on role for Responses API
if let Some(content_array) = content.as_array_mut() {
for part in content_array {
if let Some(part_obj) = part.as_object_mut() {
if let Some(t) = part_obj.get("type").and_then(|v| v.as_str()) {
match t {
"text" => {
let new_type = if mapped_role == "assistant" { "output_text" } else { "input_text" };
part_obj.insert("type".to_string(), serde_json::json!(new_type));
}
"image_url" => {
let new_type = if mapped_role == "assistant" { "output_image" } else { "input_image" };
part_obj.insert("type".to_string(), serde_json::json!(new_type));
if let Some(img_url) = part_obj.remove("image_url") {
part_obj.insert("image".to_string(), img_url);
}
}
_ => {}
}
}
}
}
} else if let Some(text) = content.as_str() {
// If it's just a string, send it as a string instead of an array of objects
// as it's safer for standard conversational messages.
content = serde_json::json!(text);
}
let mut msg_item = serde_json::json!({
"type": "message",
"role": mapped_role,
"content": content
});
if let Some(name) = m.get("name") {
msg_item["name"] = name.clone();
}
input_parts.push(msg_item);
}
let mut body = serde_json::json!({
"model": request.model,
"input": input_parts,
});
// Add standard parameters
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
// Newer models (gpt-5, o1) in Responses API use max_output_tokens
if let Some(max_tokens) = request.max_tokens {
if request.model.contains("gpt-5") || request.model.starts_with("o1-") || request.model.starts_with("o3-") {
body["max_output_tokens"] = serde_json::json!(max_tokens);
} else {
body["max_tokens"] = serde_json::json!(max_tokens);
}
}
if let Some(tools) = &request.tools {
let flattened: Vec<serde_json::Value> = tools.iter().map(|t| {
let mut obj = serde_json::json!({
"type": t.tool_type,
"name": t.function.name,
});
if let Some(desc) = &t.function.description {
obj["description"] = serde_json::json!(desc);
}
if let Some(params) = &t.function.parameters {
obj["parameters"] = params.clone();
}
obj
}).collect();
body["tools"] = serde_json::json!(flattened);
}
if let Some(tool_choice) = &request.tool_choice {
match tool_choice {
crate::models::ToolChoice::Mode(mode) => {
body["tool_choice"] = serde_json::json!(mode);
}
crate::models::ToolChoice::Specific(specific) => {
body["tool_choice"] = serde_json::json!({
"type": specific.choice_type,
"name": specific.function.name,
});
}
}
}
let resp = self
.client
.post(format!("{}/responses", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
if !resp.status().is_success() {
let err = resp.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("OpenAI Responses API error: {}", err)));
}
let resp_json: serde_json::Value = resp.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Try to normalize: if it's chat-style, use existing parser
if resp_json.get("choices").is_some() {
return helpers::parse_openai_response(&resp_json, request.model);
}
// Normalize Responses API output into ProviderResponse
let mut content_text = String::new();
let mut tool_calls = Vec::new();
if let Some(output) = resp_json.get("output").and_then(|o| o.as_array()) {
for out in output {
let item_type = out.get("type").and_then(|v| v.as_str()).unwrap_or("");
match item_type {
"message" => {
if let Some(contents) = out.get("content").and_then(|c| c.as_array()) {
for item in contents {
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
if !content_text.is_empty() { content_text.push_str("\n"); }
content_text.push_str(text);
}
}
}
}
"function_call" => {
let id = out.get("call_id")
.or_else(|| out.get("item_id"))
.or_else(|| out.get("id"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let name = out.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string();
let arguments = out.get("arguments").and_then(|v| v.as_str()).unwrap_or("").to_string();
tool_calls.push(crate::models::ToolCall {
id,
call_type: "function".to_string(),
function: crate::models::FunctionCall { name, arguments },
});
}
_ => {
// Fallback for older/nested structure
if let Some(contents) = out.get("content").and_then(|c| c.as_array()) {
for item in contents {
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
if !content_text.is_empty() { content_text.push_str("\n"); }
content_text.push_str(text);
}
}
}
}
}
}
}
if content_text.is_empty() {
if let Some(cands) = resp_json.get("candidates").and_then(|c| c.as_array()) {
if let Some(c0) = cands.get(0) {
if let Some(content) = c0.get("content") {
if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) {
for p in parts {
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
if !content_text.is_empty() { content_text.push_str("\n"); }
content_text.push_str(t);
}
}
}
}
}
}
}
let prompt_tokens = resp_json.get("usage").and_then(|u| u.get("prompt_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
let completion_tokens = resp_json.get("usage").and_then(|u| u.get("completion_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
let total_tokens = resp_json.get("usage").and_then(|u| u.get("total_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
// GPT-5.4 parallel tool calls might be embedded in content_text as a JSON block
let embedded_calls = Self::parse_tool_uses_json(&content_text);
if !embedded_calls.is_empty() {
// Strip the JSON part from content_text to keep it clean
if let Some(start) = content_text.find("{\"tool_uses\":") {
content_text = content_text[..start].to_string();
}
tool_calls.extend(embedded_calls);
}
content_text = Self::strip_internal_metadata(&content_text);
Ok(ProviderResponse {
content: content_text,
reasoning_content: None,
tool_calls: if tool_calls.is_empty() { None } else { Some(tool_calls) },
prompt_tokens,
completion_tokens,
reasoning_tokens: 0,
total_tokens,
cache_read_tokens: 0,
cache_write_tokens: 0,
model: request.model,
})
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(
&self,
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
cache_read_tokens: u32,
cache_write_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
) -> f64 {
helpers::calculate_cost_with_registry(
model,
prompt_tokens,
completion_tokens,
cache_read_tokens,
cache_write_tokens,
registry,
&self.pricing,
0.15,
0.60,
)
}
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
// Allow proactive routing to Responses API based on heuristic
let model_lc = request.model.to_lowercase();
if model_lc.contains("gpt-5") || model_lc.contains("codex") {
return self.chat_responses_stream(request).await;
}
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let mut body = helpers::build_openai_body(&request, messages_json, true);
// Standard OpenAI cleanup
if let Some(obj) = body.as_object_mut() {
// stream_options.include_usage is supported by OpenAI for token usage in streaming
// Transition: Newer OpenAI models (o1, o3, gpt-5) require max_completion_tokens
if request.model.starts_with("o1-") || request.model.starts_with("o3-") || request.model.contains("gpt-5") {
if let Some(max_tokens) = obj.remove("max_tokens") {
obj.insert("max_completion_tokens".to_string(), max_tokens);
}
}
}
let url = format!("{}/chat/completions", self.config.base_url);
let api_key = self.api_key.clone();
let probe_client = self.client.clone();
let probe_body = body.clone();
let model = request.model.clone();
let es = reqwest_eventsource::EventSource::new(
self.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body),
)
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(reqwest_eventsource::Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: serde_json::Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(p_chunk) = helpers::parse_openai_stream_chunk(&chunk, &model, None) {
yield p_chunk?;
}
}
Ok(_) => continue,
Err(e) => {
// Attempt to probe for the actual error body
let probe_resp = probe_client
.post(&url)
.header("Authorization", format!("Bearer {}", api_key))
.json(&probe_body)
.send()
.await;
match probe_resp {
Ok(r) if !r.status().is_success() => {
let status = r.status();
let error_body = r.text().await.unwrap_or_default();
tracing::error!("OpenAI Stream Error Probe ({}): {}", status, error_body);
tracing::debug!("Offending OpenAI Request Body: {}", serde_json::to_string(&probe_body).unwrap_or_default());
Err(AppError::ProviderError(format!("OpenAI API error ({}): {}", status, error_body)))?;
}
Ok(_) => {
// Probe returned success? This is unexpected if the original stream failed.
Err(AppError::ProviderError(format!("Stream error (probe returned 200): {}", e)))?;
}
Err(probe_err) => {
// Probe itself failed
tracing::error!("OpenAI Stream Error Probe failed: {}", probe_err);
Err(AppError::ProviderError(format!("Stream error (probe failed: {}): {}", probe_err, e)))?;
}
}
}
}
}
};
Ok(Box::pin(stream))
}
async fn chat_responses_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
// Build a structured input for the Responses API.
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let mut input_parts = Vec::new();
for m in &messages_json {
let role = m["role"].as_str().unwrap_or("user");
if role == "tool" {
input_parts.push(serde_json::json!({
"type": "function_call_output",
"call_id": m.get("tool_call_id").and_then(|v| v.as_str()).unwrap_or(""),
"output": m.get("content").and_then(|v| v.as_str()).unwrap_or("")
}));
continue;
}
if role == "assistant" && m.get("tool_calls").is_some() {
// Push message part if it exists
let content_val = m.get("content").cloned().unwrap_or(serde_json::json!(""));
if !content_val.is_null() && (content_val.is_array() && !content_val.as_array().unwrap().is_empty() || content_val.is_string() && !content_val.as_str().unwrap().is_empty()) {
let mut content = content_val.clone();
if let Some(text) = content.as_str() {
content = serde_json::json!([{ "type": "output_text", "text": text }]);
} else if let Some(arr) = content.as_array_mut() {
for part in arr {
if let Some(obj) = part.as_object_mut() {
if obj.get("type").and_then(|v| v.as_str()) == Some("text") {
obj.insert("type".to_string(), serde_json::json!("output_text"));
}
}
}
}
input_parts.push(serde_json::json!({
"type": "message",
"role": "assistant",
"content": content
}));
}
// Push tool calls as separate items
if let Some(tcs) = m.get("tool_calls").and_then(|v| v.as_array()) {
for tc in tcs {
input_parts.push(serde_json::json!({
"type": "function_call",
"call_id": tc["id"],
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}));
}
}
continue;
}
let mut mapped_role = role.to_string();
if mapped_role == "system" {
mapped_role = "developer".to_string();
}
let mut content = m.get("content").cloned().unwrap_or(serde_json::json!([]));
// Map content types based on role for Responses API
if let Some(content_array) = content.as_array_mut() {
for part in content_array {
if let Some(part_obj) = part.as_object_mut() {
if let Some(t) = part_obj.get("type").and_then(|v| v.as_str()) {
match t {
"text" => {
let new_type = if mapped_role == "assistant" { "output_text" } else { "input_text" };
part_obj.insert("type".to_string(), serde_json::json!(new_type));
}
"image_url" => {
let new_type = if mapped_role == "assistant" { "output_image" } else { "input_image" };
part_obj.insert("type".to_string(), serde_json::json!(new_type));
if let Some(img_url) = part_obj.remove("image_url") {
part_obj.insert("image".to_string(), img_url);
}
}
_ => {}
}
}
}
}
} else if let Some(text) = content.as_str() {
// If it's just a string, send it as a string instead of an array of objects
// as it's safer for standard conversational messages.
content = serde_json::json!(text);
}
let mut msg_item = serde_json::json!({
"type": "message",
"role": mapped_role,
"content": content
});
if let Some(name) = m.get("name") {
msg_item["name"] = name.clone();
}
input_parts.push(msg_item);
}
let mut body = serde_json::json!({
"model": request.model,
"input": input_parts,
});
// Add standard parameters
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
// Newer models (gpt-5, o1) in Responses API use max_output_tokens
if let Some(max_tokens) = request.max_tokens {
if request.model.contains("gpt-5") || request.model.starts_with("o1-") || request.model.starts_with("o3-") {
body["max_output_tokens"] = serde_json::json!(max_tokens);
} else {
body["max_tokens"] = serde_json::json!(max_tokens);
}
}
if let Some(tools) = &request.tools {
let flattened: Vec<serde_json::Value> = tools.iter().map(|t| {
let mut obj = serde_json::json!({
"type": t.tool_type,
"name": t.function.name,
});
if let Some(desc) = &t.function.description {
obj["description"] = serde_json::json!(desc);
}
if let Some(params) = &t.function.parameters {
obj["parameters"] = params.clone();
}
obj
}).collect();
body["tools"] = serde_json::json!(flattened);
}
if let Some(tool_choice) = &request.tool_choice {
match tool_choice {
crate::models::ToolChoice::Mode(mode) => {
body["tool_choice"] = serde_json::json!(mode);
}
crate::models::ToolChoice::Specific(specific) => {
body["tool_choice"] = serde_json::json!({
"type": specific.choice_type,
"name": specific.function.name,
});
}
}
}
body["stream"] = serde_json::json!(true);
let url = format!("{}/responses", self.config.base_url);
let api_key = self.api_key.clone();
let model = request.model.clone();
let probe_client = self.client.clone();
let probe_body = body.clone();
let es = reqwest_eventsource::EventSource::new(
self.client
.post(&url)
.header("Authorization", format!("Bearer {}", api_key))
.header("Accept", "text/event-stream")
.json(&body),
)
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource for Responses API: {}", e)))?;
let stream = async_stream::try_stream! {
let mut es = es;
let mut content_buffer = String::new();
while let Some(event) = es.next().await {
match event {
Ok(reqwest_eventsource::Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: serde_json::Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse Responses stream chunk: {}", e)))?;
// Try standard OpenAI parsing first (choices/usage)
if let Some(p_chunk) = helpers::parse_openai_stream_chunk(&chunk, &model, None) {
yield p_chunk?;
} else {
// Responses API specific parsing for streaming
let mut finish_reason = None;
let mut tool_calls = None;
let event_type = chunk.get("type").and_then(|v| v.as_str()).unwrap_or("");
match event_type {
"response.output_text.delta" => {
if let Some(delta) = chunk.get("delta").and_then(|v| v.as_str()) {
content_buffer.push_str(delta);
}
}
"response.item.delta" => {
if let Some(delta) = chunk.get("delta") {
let t = delta.get("type").and_then(|v| v.as_str()).unwrap_or("");
if t == "function_call" {
let call_id = delta.get("call_id")
.or_else(|| chunk.get("item_id"))
.and_then(|v| v.as_str());
let name = delta.get("name").and_then(|v| v.as_str());
let arguments = delta.get("arguments").and_then(|v| v.as_str());
tool_calls = Some(vec![crate::models::ToolCallDelta {
index: chunk.get("output_index").and_then(|v| v.as_u64()).unwrap_or(0) as u32,
id: call_id.map(|s| s.to_string()),
call_type: Some("function".to_string()),
function: Some(crate::models::FunctionCallDelta {
name: name.map(|s| s.to_string()),
arguments: arguments.map(|s| s.to_string()),
}),
}]);
} else if t == "message" {
if let Some(text) = delta.get("text").and_then(|v| v.as_str()) {
content_buffer.push_str(text);
}
}
}
}
"response.output_text.done" | "response.item.done" | "response.done" => {
finish_reason = Some("stop".to_string());
}
_ => {}
}
// Process content_buffer to extract embedded tool calls or yield text
if !content_buffer.is_empty() {
// If we see the start of a tool call block, we wait for the full block
if content_buffer.contains("{\"tool_uses\":") {
let embedded_calls = Self::parse_tool_uses_json(&content_buffer);
if !embedded_calls.is_empty() {
if let Some(start) = content_buffer.find("{\"tool_uses\":") {
// Yield text before the JSON block
let preamble = content_buffer[..start].to_string();
let stripped_preamble = Self::strip_internal_metadata(&preamble);
if !stripped_preamble.is_empty() {
yield ProviderStreamChunk {
content: stripped_preamble,
reasoning_content: None,
finish_reason: None,
tool_calls: None,
model: model.clone(),
usage: None,
};
}
// Yield the tool calls
// ... (rest of tool call yielding unchanged)
let deltas: Vec<crate::models::ToolCallDelta> = embedded_calls.into_iter().enumerate().map(|(idx, tc)| {
crate::models::ToolCallDelta {
index: idx as u32,
id: Some(tc.id),
call_type: Some("function".to_string()),
function: Some(crate::models::FunctionCallDelta {
name: Some(tc.function.name),
arguments: Some(tc.function.arguments),
}),
}
}).collect();
yield ProviderStreamChunk {
content: String::new(),
reasoning_content: None,
finish_reason: None,
tool_calls: Some(deltas),
model: model.clone(),
usage: None,
};
// Remove the processed part from buffer
// We need to find the end index correctly
let sub = &content_buffer[start..];
let mut brace_count = 0;
let mut end_idx = 0;
for (i, c) in sub.char_indices() {
if c == '{' { brace_count += 1; }
else if c == '}' {
brace_count -= 1;
if brace_count == 0 {
end_idx = start + i + 1;
break;
}
}
}
if end_idx > 0 {
content_buffer = content_buffer[end_idx..].to_string();
} else {
content_buffer.clear();
}
}
}
// If we have "{"tool_uses":" but no full block yet, we just wait (don't yield)
} else if content_buffer.contains("to=multi_tool_use.parallel") {
// Wait for the JSON block that usually follows
} else {
// Standard text, yield and clear buffer
let content = std::mem::take(&mut content_buffer);
let stripped_content = Self::strip_internal_metadata(&content);
if !stripped_content.is_empty() {
yield ProviderStreamChunk {
content: stripped_content,
reasoning_content: None,
finish_reason: None,
tool_calls: None,
model: model.clone(),
usage: None,
};
}
}
}
if finish_reason.is_some() || tool_calls.is_some() {
yield ProviderStreamChunk {
content: String::new(),
reasoning_content: None,
finish_reason,
tool_calls,
model: model.clone(),
usage: None,
};
}
}
}
Ok(_) => continue,
Err(e) => {
// Attempt to probe for the actual error body
let probe_resp = probe_client
.post(&url)
.header("Authorization", format!("Bearer {}", api_key))
.header("Accept", "application/json")
.json(&probe_body)
.send()
.await;
match probe_resp {
Ok(r) => {
let status = r.status();
let body = r.text().await.unwrap_or_default();
if status.is_success() {
let preview = if body.len() > 500 { format!("{}...", &body[..500]) } else { body.clone() };
tracing::warn!("Responses stream ended prematurely but probe returned 200 OK. Body: {}", preview);
Err(AppError::ProviderError(format!("Responses stream ended (server sent 200 OK with body: {})", preview)))?;
} else {
tracing::error!("OpenAI Responses Stream Error Probe ({}): {}", status, body);
Err(AppError::ProviderError(format!("OpenAI Responses API error ({}): {}", status, body)))?;
}
}
Err(probe_err) => {
tracing::error!("OpenAI Responses Stream Error Probe failed: {}", probe_err);
Err(AppError::ProviderError(format!("Responses stream error (probe failed: {}): {}", probe_err, e)))?;
}
}
}
}
}
// Final flush of content_buffer if not empty
if !content_buffer.is_empty() {
let stripped = Self::strip_internal_metadata(&content_buffer);
if !stripped.is_empty() {
yield ProviderStreamChunk {
content: stripped,
reasoning_content: None,
finish_reason: None,
tool_calls: None,
model: model.clone(),
usage: None,
};
}
}
};
Ok(Box::pin(stream))
}
}

View File

@@ -1,353 +0,0 @@
//! Rate limiting and circuit breaking for LLM proxy
//!
//! This module provides:
//! 1. Per-client rate limiting using governor crate
//! 2. Provider circuit breaking to handle API failures
//! 3. Global rate limiting for overall system protection
use anyhow::Result;
use governor::{Quota, RateLimiter, DefaultDirectRateLimiter};
use std::collections::HashMap;
use std::num::NonZeroU32;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{info, warn};
type GovRateLimiter = DefaultDirectRateLimiter;
/// Rate limiter configuration
#[derive(Debug, Clone)]
pub struct RateLimiterConfig {
/// Requests per minute per client
pub requests_per_minute: u32,
/// Burst size (maximum burst capacity)
pub burst_size: u32,
/// Global requests per minute (across all clients)
pub global_requests_per_minute: u32,
}
impl Default for RateLimiterConfig {
fn default() -> Self {
Self {
requests_per_minute: 60, // 1 request per second per client
burst_size: 10, // Allow bursts of up to 10 requests
global_requests_per_minute: 600, // 10 requests per second globally
}
}
}
/// Circuit breaker state
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CircuitState {
Closed, // Normal operation
Open, // Circuit is open, requests fail fast
HalfOpen, // Testing if service has recovered
}
/// Circuit breaker configuration
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
/// Failure threshold to open circuit
pub failure_threshold: u32,
/// Time window for failure counting (seconds)
pub failure_window_secs: u64,
/// Time to wait before trying half-open state (seconds)
pub reset_timeout_secs: u64,
/// Success threshold to close circuit
pub success_threshold: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5, // 5 failures
failure_window_secs: 60, // within 60 seconds
reset_timeout_secs: 30, // wait 30 seconds before half-open
success_threshold: 3, // 3 successes to close circuit
}
}
}
/// Circuit breaker for a provider
#[derive(Debug)]
pub struct ProviderCircuitBreaker {
state: CircuitState,
failure_count: u32,
success_count: u32,
last_failure_time: Option<std::time::Instant>,
last_state_change: std::time::Instant,
config: CircuitBreakerConfig,
}
impl ProviderCircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
state: CircuitState::Closed,
failure_count: 0,
success_count: 0,
last_failure_time: None,
last_state_change: std::time::Instant::now(),
config,
}
}
/// Check if request is allowed
pub fn allow_request(&mut self) -> bool {
match self.state {
CircuitState::Closed => true,
CircuitState::Open => {
// Check if reset timeout has passed
let elapsed = self.last_state_change.elapsed();
if elapsed.as_secs() >= self.config.reset_timeout_secs {
self.state = CircuitState::HalfOpen;
self.last_state_change = std::time::Instant::now();
info!("Circuit breaker transitioning to half-open state");
true
} else {
false
}
}
CircuitState::HalfOpen => true,
}
}
/// Record a successful request
pub fn record_success(&mut self) {
match self.state {
CircuitState::Closed => {
// Reset failure count on success
self.failure_count = 0;
self.last_failure_time = None;
}
CircuitState::HalfOpen => {
self.success_count += 1;
if self.success_count >= self.config.success_threshold {
self.state = CircuitState::Closed;
self.success_count = 0;
self.failure_count = 0;
self.last_state_change = std::time::Instant::now();
info!("Circuit breaker closed after successful requests");
}
}
CircuitState::Open => {
// Should not happen, but handle gracefully
}
}
}
/// Record a failed request
pub fn record_failure(&mut self) {
let now = std::time::Instant::now();
// Check if failure window has expired
if let Some(last_failure) = self.last_failure_time
&& now.duration_since(last_failure).as_secs() > self.config.failure_window_secs
{
// Reset failure count if window expired
self.failure_count = 0;
}
self.failure_count += 1;
self.last_failure_time = Some(now);
if self.failure_count >= self.config.failure_threshold && self.state == CircuitState::Closed {
self.state = CircuitState::Open;
self.last_state_change = now;
warn!("Circuit breaker opened due to {} failures", self.failure_count);
} else if self.state == CircuitState::HalfOpen {
// Failure in half-open state, go back to open
self.state = CircuitState::Open;
self.success_count = 0;
self.last_state_change = now;
warn!("Circuit breaker re-opened after failure in half-open state");
}
}
/// Get current state
pub fn state(&self) -> CircuitState {
self.state
}
}
/// Rate limiting and circuit breaking manager
#[derive(Debug)]
pub struct RateLimitManager {
client_buckets: Arc<RwLock<HashMap<String, GovRateLimiter>>>,
global_bucket: Arc<GovRateLimiter>,
circuit_breakers: Arc<RwLock<HashMap<String, ProviderCircuitBreaker>>>,
config: RateLimiterConfig,
circuit_config: CircuitBreakerConfig,
}
impl RateLimitManager {
pub fn new(config: RateLimiterConfig, circuit_config: CircuitBreakerConfig) -> Self {
// Create global rate limiter quota
// Use a much larger burst size for the global bucket to handle concurrent dashboard load
let global_burst = config.global_requests_per_minute / 6; // e.g., 100 for 600 req/min
let global_quota = Quota::per_minute(
NonZeroU32::new(config.global_requests_per_minute).expect("global_requests_per_minute must be positive")
)
.allow_burst(NonZeroU32::new(global_burst).expect("global_burst must be positive"));
let global_bucket = RateLimiter::direct(global_quota);
Self {
client_buckets: Arc::new(RwLock::new(HashMap::new())),
global_bucket: Arc::new(global_bucket),
circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
config,
circuit_config,
}
}
/// Check if a client request is allowed
pub async fn check_client_request(&self, client_id: &str) -> Result<bool> {
// Check global rate limit first (1 token per request)
if self.global_bucket.check().is_err() {
warn!("Global rate limit exceeded");
return Ok(false);
}
// Check client-specific rate limit
let mut buckets = self.client_buckets.write().await;
let bucket = buckets.entry(client_id.to_string()).or_insert_with(|| {
let quota = Quota::per_minute(
NonZeroU32::new(self.config.requests_per_minute).expect("requests_per_minute must be positive")
)
.allow_burst(NonZeroU32::new(self.config.burst_size).expect("burst_size must be positive"));
RateLimiter::direct(quota)
});
Ok(bucket.check().is_ok())
}
/// Check if provider requests are allowed (circuit breaker)
pub async fn check_provider_request(&self, provider_name: &str) -> Result<bool> {
let mut breakers = self.circuit_breakers.write().await;
let breaker = breakers
.entry(provider_name.to_string())
.or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone()));
Ok(breaker.allow_request())
}
/// Record provider success
pub async fn record_provider_success(&self, provider_name: &str) {
let mut breakers = self.circuit_breakers.write().await;
if let Some(breaker) = breakers.get_mut(provider_name) {
breaker.record_success();
}
}
/// Record provider failure
pub async fn record_provider_failure(&self, provider_name: &str) {
let mut breakers = self.circuit_breakers.write().await;
let breaker = breakers
.entry(provider_name.to_string())
.or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone()));
breaker.record_failure();
}
/// Get provider circuit state
pub async fn get_provider_state(&self, provider_name: &str) -> CircuitState {
let breakers = self.circuit_breakers.read().await;
breakers
.get(provider_name)
.map(|b| b.state())
.unwrap_or(CircuitState::Closed)
}
}
/// Axum middleware for rate limiting
pub mod middleware {
use super::*;
use crate::errors::AppError;
use crate::state::AppState;
use crate::auth::AuthInfo;
use axum::{
extract::{Request, State},
middleware::Next,
response::Response,
};
use sqlx;
/// Rate limiting middleware
pub async fn rate_limit_middleware(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Result<Response, AppError> {
// Extract token synchronously from headers (avoids holding &Request across await)
let token = extract_bearer_token(&request);
// Resolve client_id and populate AuthInfo: DB token lookup, then prefix fallback
let auth_info = resolve_auth_info(token, &state).await;
let client_id = auth_info.client_id.clone();
// Check rate limits
if !state.rate_limit_manager.check_client_request(&client_id).await? {
return Err(AppError::RateLimitError("Rate limit exceeded".to_string()));
}
// Store AuthInfo in request extensions for extractors and downstream handlers
request.extensions_mut().insert(auth_info);
Ok(next.run(request).await)
}
/// Synchronously extract bearer token from request headers
fn extract_bearer_token(request: &Request) -> Option<String> {
request.headers().get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "))
.map(|t| t.to_string())
}
/// Resolve auth info: try DB token first, then fall back to token-prefix derivation
async fn resolve_auth_info(token: Option<String>, state: &AppState) -> AuthInfo {
if let Some(token) = token {
// Try DB token lookup first
match sqlx::query_scalar::<_, String>(
"UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = TRUE RETURNING client_id",
)
.bind(&token)
.fetch_optional(&state.db_pool)
.await
{
Ok(Some(cid)) => {
return AuthInfo {
token,
client_id: cid,
};
}
Err(e) => {
warn!("DB error during token lookup: {}", e);
}
_ => {}
}
// Fallback to token-prefix derivation (env tokens / permissive mode)
let client_id = format!("client_{}", &token[..8.min(token.len())]);
return AuthInfo { token, client_id };
}
// No token — anonymous
AuthInfo {
token: String::new(),
client_id: "anonymous".to_string(),
}
}
/// Circuit breaker middleware for provider requests
pub async fn circuit_breaker_middleware(provider_name: &str, state: &AppState) -> Result<(), AppError> {
if !state.rate_limit_manager.check_provider_request(provider_name).await? {
return Err(AppError::ProviderError(format!(
"Provider {} is currently unavailable (circuit breaker open)",
provider_name
)));
}
Ok(())
}
}

View File

@@ -1,482 +0,0 @@
use axum::{
Json, Router,
extract::State,
response::IntoResponse,
response::sse::{Event, Sse},
routing::{get, post},
};
use axum::http::{header, HeaderValue};
use tower_http::{
limit::RequestBodyLimitLayer,
set_header::SetResponseHeaderLayer,
};
use futures::StreamExt;
use std::sync::Arc;
use uuid::Uuid;
use tracing::{info, warn};
use crate::{
auth::AuthenticatedClient,
errors::AppError,
models::{
ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage,
ChatStreamChoice, ChatStreamDelta, Usage,
},
rate_limiting,
state::AppState,
};
pub fn router(state: AppState) -> Router {
// Security headers
let csp_header: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::CONTENT_SECURITY_POLICY,
"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws:;"
.parse()
.unwrap(),
);
let x_frame_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_FRAME_OPTIONS,
"DENY".parse().unwrap(),
);
let x_content_type_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_CONTENT_TYPE_OPTIONS,
"nosniff".parse().unwrap(),
);
let strict_transport_security: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::STRICT_TRANSPORT_SECURITY,
"max-age=31536000; includeSubDomains".parse().unwrap(),
);
Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/models", get(list_models))
.layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit
.layer(csp_header)
.layer(x_frame_options)
.layer(x_content_type_options)
.layer(strict_transport_security)
.layer(axum::middleware::from_fn_with_state(
state.clone(),
rate_limiting::middleware::rate_limit_middleware,
))
.with_state(state)
}
/// GET /v1/models — OpenAI-compatible model listing.
/// Returns all models from enabled providers so clients like Open WebUI can
/// discover which models are available through the proxy.
async fn list_models(
State(state): State<AppState>,
_auth: AuthenticatedClient,
) -> Result<Json<serde_json::Value>, AppError> {
let registry = &state.model_registry;
let providers = state.provider_manager.get_all_providers().await;
let mut models = Vec::new();
for provider in &providers {
let provider_name = provider.name();
// Map internal provider names to registry provider IDs
let registry_key = match provider_name {
"gemini" => "google",
"grok" => "xai",
_ => provider_name,
};
// Find this provider's models in the registry
if let Some(provider_info) = registry.providers.get(registry_key) {
for (model_id, meta) in &provider_info.models {
// Skip disabled models via the config cache
if let Some(cfg) = state.model_config_cache.get(model_id).await {
if !cfg.enabled {
continue;
}
}
models.push(serde_json::json!({
"id": model_id,
"object": "model",
"created": 0,
"owned_by": provider_name,
"name": meta.name,
}));
}
}
// For Ollama, models are configured in the TOML, not the registry
if provider_name == "ollama" {
for model_id in &state.config.providers.ollama.models {
models.push(serde_json::json!({
"id": model_id,
"object": "model",
"created": 0,
"owned_by": "ollama",
}));
}
}
}
Ok(Json(serde_json::json!({
"object": "list",
"data": models
})))
}
async fn get_model_cost(
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
cache_read_tokens: u32,
cache_write_tokens: u32,
provider: &Arc<dyn crate::providers::Provider>,
state: &AppState,
) -> f64 {
// Check in-memory cache for cost overrides (no SQLite hit)
if let Some(cached) = state.model_config_cache.get(model).await {
if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) {
// Manual overrides logic: if cache rates are provided, use cache-aware formula.
// Formula: (non_cached_prompt * input_rate) + (cache_read * read_rate) + (cache_write * write_rate) + (completion * output_rate)
let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens);
let mut total = (non_cached_prompt as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0);
if let Some(cr) = cached.cache_read_cost_per_m {
total += cache_read_tokens as f64 * cr / 1_000_000.0;
} else {
// No manual cache_read rate — charge cached tokens at full input rate (backwards compatibility)
total += cache_read_tokens as f64 * p / 1_000_000.0;
}
if let Some(cw) = cached.cache_write_cost_per_m {
total += cache_write_tokens as f64 * cw / 1_000_000.0;
}
return total;
}
}
// Fallback to provider's registry-based calculation (cache-aware)
provider.calculate_cost(model, prompt_tokens, completion_tokens, cache_read_tokens, cache_write_tokens, &state.model_registry)
}
async fn chat_completions(
State(state): State<AppState>,
auth: AuthenticatedClient,
Json(mut request): Json<ChatCompletionRequest>,
) -> Result<axum::response::Response, AppError> {
let client_id = auth.client_id.clone();
let token = auth.token.clone();
// Verify token if env tokens are configured
if !state.auth_tokens.is_empty() && !state.auth_tokens.contains(&token) {
// If not in env tokens, check if it was a DB token (client_id wouldn't be client_XXXX prefix)
if client_id.starts_with("client_") {
return Err(AppError::AuthError("Invalid authentication token".to_string()));
}
}
let start_time = std::time::Instant::now();
let model = request.model.clone();
info!("Chat completion request from client {} for model {}", client_id, model);
// Check if model is enabled via in-memory cache (no SQLite hit)
let cached_config = state.model_config_cache.get(&model).await;
let (model_enabled, model_mapping) = match cached_config {
Some(cfg) => (cfg.enabled, cfg.mapping),
None => (true, None),
};
if !model_enabled {
return Err(AppError::ValidationError(format!(
"Model {} is currently disabled",
model
)));
}
// Apply mapping if present
if let Some(target_model) = model_mapping {
info!("Mapping model {} to {}", model, target_model);
request.model = target_model;
}
// Find appropriate provider for the model
let provider = state
.provider_manager
.get_provider_for_model(&request.model)
.await
.ok_or_else(|| AppError::ProviderError(format!("No provider found for model: {}", request.model)))?;
let provider_name = provider.name().to_string();
// Check circuit breaker for this provider
rate_limiting::middleware::circuit_breaker_middleware(&provider_name, &state).await?;
// Convert to unified request format
let mut unified_request =
crate::models::UnifiedRequest::try_from(request).map_err(|e| AppError::ValidationError(e.to_string()))?;
// Set client_id from authentication
unified_request.client_id = client_id.clone();
// Hydrate images if present
if unified_request.has_images {
unified_request
.hydrate_images()
.await
.map_err(|e| AppError::ValidationError(format!("Failed to process images: {}", e)))?;
}
let has_images = unified_request.has_images;
// Measure proxy overhead (time spent before sending to upstream provider)
let proxy_overhead = start_time.elapsed();
// Check if streaming is requested
if unified_request.stream {
// Estimate prompt tokens for logging later
let prompt_tokens = crate::utils::tokens::estimate_request_tokens(&model, &unified_request);
// Handle streaming response
// Allow provider-specific routing for streaming too
let use_responses = provider.name() == "openai"
&& crate::utils::registry::model_prefers_responses(&state.model_registry, &unified_request.model);
let stream_result = if use_responses {
provider.chat_responses_stream(unified_request).await
} else {
provider.chat_completion_stream(unified_request).await
};
match stream_result {
Ok(stream) => {
// Record provider success
state.rate_limit_manager.record_provider_success(&provider_name).await;
info!(
"Streaming started for {} (proxy overhead: {}ms)",
model,
proxy_overhead.as_millis()
);
// Wrap with AggregatingStream for token counting and database logging
let aggregating_stream = crate::utils::streaming::AggregatingStream::new(
stream,
crate::utils::streaming::StreamConfig {
client_id: client_id.clone(),
provider: provider.clone(),
model: model.clone(),
prompt_tokens,
has_images,
logger: state.request_logger.clone(),
model_registry: state.model_registry.clone(),
model_config_cache: state.model_config_cache.clone(),
},
);
// Create SSE stream - simpler approach that works
let stream_id = format!("chatcmpl-{}", Uuid::new_v4());
let stream_created = chrono::Utc::now().timestamp() as u64;
let stream_id_sse = stream_id.clone();
// Build stream that yields events wrapped in Result
let stream = async_stream::stream! {
let mut aggregator = Box::pin(aggregating_stream);
let mut first_chunk = true;
while let Some(chunk_result) = aggregator.next().await {
match chunk_result {
Ok(chunk) => {
let role = if first_chunk {
first_chunk = false;
Some("assistant".to_string())
} else {
None
};
let response = ChatCompletionStreamResponse {
id: stream_id_sse.clone(),
object: "chat.completion.chunk".to_string(),
created: stream_created,
model: chunk.model.clone(),
choices: vec![ChatStreamChoice {
index: 0,
delta: ChatStreamDelta {
role,
content: Some(chunk.content),
reasoning_content: chunk.reasoning_content,
tool_calls: chunk.tool_calls,
},
finish_reason: chunk.finish_reason,
}],
usage: chunk.usage.as_ref().map(|u| crate::models::Usage {
prompt_tokens: u.prompt_tokens,
completion_tokens: u.completion_tokens,
total_tokens: u.total_tokens,
reasoning_tokens: if u.reasoning_tokens > 0 { Some(u.reasoning_tokens) } else { None },
cache_read_tokens: if u.cache_read_tokens > 0 { Some(u.cache_read_tokens) } else { None },
cache_write_tokens: if u.cache_write_tokens > 0 { Some(u.cache_write_tokens) } else { None },
}),
};
// Use axum's Event directly, wrap in Ok
match Event::default().json_data(response) {
Ok(event) => yield Ok::<_, crate::errors::AppError>(event),
Err(e) => {
warn!("Failed to serialize SSE: {}", e);
}
}
}
Err(e) => {
warn!("Stream error: {}", e);
}
}
}
// Yield [DONE] at the end
yield Ok::<_, crate::errors::AppError>(Event::default().data("[DONE]"));
};
Ok(Sse::new(stream).into_response())
}
Err(e) => {
// Record provider failure
state.rate_limit_manager.record_provider_failure(&provider_name).await;
// Log failed request
let duration = start_time.elapsed();
warn!("Streaming request failed after {:?}: {}", duration, e);
Err(e)
}
}
} else {
// Handle non-streaming response
// Allow provider-specific routing: for OpenAI, some models prefer the
// Responses API (/v1/responses). Use the model registry heuristic to
// choose chat_responses vs chat_completion automatically.
let use_responses = provider.name() == "openai"
&& crate::utils::registry::model_prefers_responses(&state.model_registry, &unified_request.model);
let result = if use_responses {
provider.chat_responses(unified_request).await
} else {
provider.chat_completion(unified_request).await
};
match result {
Ok(response) => {
// Record provider success
state.rate_limit_manager.record_provider_success(&provider_name).await;
let duration = start_time.elapsed();
let cost = get_model_cost(
&response.model,
response.prompt_tokens,
response.completion_tokens,
response.cache_read_tokens,
response.cache_write_tokens,
&provider,
&state,
)
.await;
// Log request to database
state.request_logger.log_request(crate::logging::RequestLog {
timestamp: chrono::Utc::now(),
client_id: client_id.clone(),
provider: provider_name.clone(),
model: response.model.clone(),
prompt_tokens: response.prompt_tokens,
completion_tokens: response.completion_tokens,
reasoning_tokens: response.reasoning_tokens,
total_tokens: response.total_tokens,
cache_read_tokens: response.cache_read_tokens,
cache_write_tokens: response.cache_write_tokens,
cost,
has_images,
status: "success".to_string(),
error_message: None,
duration_ms: duration.as_millis() as u64,
});
// Convert ProviderResponse to ChatCompletionResponse
let finish_reason = if response.tool_calls.is_some() {
"tool_calls".to_string()
} else {
"stop".to_string()
};
let chat_response = ChatCompletionResponse {
id: format!("chatcmpl-{}", Uuid::new_v4()),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: response.model,
choices: vec![ChatChoice {
index: 0,
message: ChatMessage {
role: "assistant".to_string(),
content: crate::models::MessageContent::Text {
content: response.content,
},
reasoning_content: response.reasoning_content,
tool_calls: response.tool_calls,
name: None,
tool_call_id: None,
},
finish_reason: Some(finish_reason),
}],
usage: Some(Usage {
prompt_tokens: response.prompt_tokens,
completion_tokens: response.completion_tokens,
total_tokens: response.total_tokens,
reasoning_tokens: if response.reasoning_tokens > 0 { Some(response.reasoning_tokens) } else { None },
cache_read_tokens: if response.cache_read_tokens > 0 { Some(response.cache_read_tokens) } else { None },
cache_write_tokens: if response.cache_write_tokens > 0 { Some(response.cache_write_tokens) } else { None },
}),
};
// Log successful request with proxy overhead breakdown
let upstream_ms = duration.as_millis() as u64 - proxy_overhead.as_millis() as u64;
info!(
"Request completed in {:?} (proxy: {}ms, upstream: {}ms)",
duration,
proxy_overhead.as_millis(),
upstream_ms
);
Ok(Json(chat_response).into_response())
}
Err(e) => {
// Record provider failure
state.rate_limit_manager.record_provider_failure(&provider_name).await;
// Log failed request to database
let duration = start_time.elapsed();
state.request_logger.log_request(crate::logging::RequestLog {
timestamp: chrono::Utc::now(),
client_id: client_id.clone(),
provider: provider_name.clone(),
model: model.clone(),
prompt_tokens: 0,
completion_tokens: 0,
reasoning_tokens: 0,
total_tokens: 0,
cache_read_tokens: 0,
cache_write_tokens: 0,
cost: 0.0,
has_images: false,
status: "error".to_string(),
error_message: Some(e.to_string()),
duration_ms: duration.as_millis() as u64,
});
warn!("Request failed after {:?}: {}", duration, e);
Err(e)
}
}
}
}

View File

@@ -1,133 +0,0 @@
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
use tracing::warn;
use crate::{
client::ClientManager, config::AppConfig, database::DbPool, logging::RequestLogger,
models::registry::ModelRegistry, providers::ProviderManager, rate_limiting::RateLimitManager,
};
/// Cached model configuration entry
#[derive(Debug, Clone)]
pub struct CachedModelConfig {
pub enabled: bool,
pub mapping: Option<String>,
pub prompt_cost_per_m: Option<f64>,
pub completion_cost_per_m: Option<f64>,
pub cache_read_cost_per_m: Option<f64>,
pub cache_write_cost_per_m: Option<f64>,
}
/// In-memory cache for model_configs table.
/// Refreshes periodically to avoid hitting SQLite on every request.
#[derive(Clone)]
pub struct ModelConfigCache {
cache: Arc<RwLock<HashMap<String, CachedModelConfig>>>,
db_pool: DbPool,
}
impl ModelConfigCache {
pub fn new(db_pool: DbPool) -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
db_pool,
}
}
/// Load all model configs from the database into cache
pub async fn refresh(&self) {
match sqlx::query_as::<_, (String, bool, Option<String>, Option<f64>, Option<f64>, Option<f64>, Option<f64>)>(
"SELECT id, enabled, mapping, prompt_cost_per_m, completion_cost_per_m, cache_read_cost_per_m, cache_write_cost_per_m FROM model_configs",
)
.fetch_all(&self.db_pool)
.await
{
Ok(rows) => {
let mut map = HashMap::with_capacity(rows.len());
for (id, enabled, mapping, prompt_cost, completion_cost, cache_read_cost, cache_write_cost) in rows {
map.insert(
id,
CachedModelConfig {
enabled,
mapping,
prompt_cost_per_m: prompt_cost,
completion_cost_per_m: completion_cost,
cache_read_cost_per_m: cache_read_cost,
cache_write_cost_per_m: cache_write_cost,
},
);
}
*self.cache.write().await = map;
}
Err(e) => {
warn!("Failed to refresh model config cache: {}", e);
}
}
}
/// Get a cached model config. Returns None if not in cache (model is unconfigured).
pub async fn get(&self, model: &str) -> Option<CachedModelConfig> {
self.cache.read().await.get(model).cloned()
}
/// Invalidate cache — call this after dashboard writes to model_configs
pub async fn invalidate(&self) {
self.refresh().await;
}
/// Start a background task that refreshes the cache every `interval` seconds
pub fn start_refresh_task(self, interval_secs: u64) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
loop {
interval.tick().await;
self.refresh().await;
}
});
}
}
/// Shared application state
#[derive(Clone)]
pub struct AppState {
pub config: Arc<AppConfig>,
pub provider_manager: ProviderManager,
pub db_pool: DbPool,
pub rate_limit_manager: Arc<RateLimitManager>,
pub client_manager: Arc<ClientManager>,
pub request_logger: Arc<RequestLogger>,
pub model_registry: Arc<ModelRegistry>,
pub model_config_cache: ModelConfigCache,
pub dashboard_tx: broadcast::Sender<serde_json::Value>,
pub auth_tokens: Vec<String>,
}
impl AppState {
pub fn new(
config: Arc<AppConfig>,
provider_manager: ProviderManager,
db_pool: DbPool,
rate_limit_manager: RateLimitManager,
model_registry: ModelRegistry,
auth_tokens: Vec<String>,
) -> Self {
let client_manager = Arc::new(ClientManager::new(db_pool.clone()));
let (dashboard_tx, _) = broadcast::channel(100);
let request_logger = Arc::new(RequestLogger::new(db_pool.clone(), dashboard_tx.clone()));
let model_config_cache = ModelConfigCache::new(db_pool.clone());
Self {
config,
provider_manager,
db_pool,
rate_limit_manager: Arc::new(rate_limit_manager),
client_manager,
request_logger,
model_registry: Arc::new(model_registry),
model_config_cache,
dashboard_tx,
auth_tokens,
}
}
}

View File

@@ -1,171 +0,0 @@
use aes_gcm::{
aead::{Aead, AeadCore, KeyInit, OsRng},
Aes256Gcm, Key, Nonce,
};
use anyhow::{anyhow, Context, Result};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use std::env;
use std::sync::OnceLock;
static RAW_KEY: OnceLock<[u8; 32]> = OnceLock::new();
/// Initialize the encryption key from a hex or base64 encoded string.
/// Must be called before any encryption/decryption operations.
/// Returns error if the key is invalid or already initialized with a different key.
pub fn init_with_key(key_str: &str) -> Result<()> {
let key_bytes = hex::decode(key_str)
.or_else(|_| BASE64.decode(key_str))
.context("Encryption key must be hex or base64 encoded")?;
if key_bytes.len() != 32 {
anyhow::bail!(
"Encryption key must be 32 bytes (256 bits), got {} bytes",
key_bytes.len()
);
}
let key_array: [u8; 32] = key_bytes.try_into().unwrap(); // safe due to length check
// Check if already initialized with same key
if let Some(existing) = RAW_KEY.get() {
if existing == &key_array {
// Same key already initialized, okay
return Ok(());
} else {
anyhow::bail!("Encryption key already initialized with a different key");
}
}
// Store raw key bytes
RAW_KEY
.set(key_array)
.map_err(|_| anyhow::anyhow!("Encryption key already initialized"))?;
Ok(())
}
/// Initialize the encryption key from the environment variable `LLM_PROXY__ENCRYPTION_KEY`.
/// Must be called before any encryption/decryption operations.
/// Panics if the environment variable is missing or invalid.
pub fn init_from_env() -> Result<()> {
let key_str =
env::var("LLM_PROXY__ENCRYPTION_KEY").context("LLM_PROXY__ENCRYPTION_KEY environment variable not set")?;
init_with_key(&key_str)
}
/// Get the encryption key bytes, panicking if not initialized.
fn get_key() -> &'static [u8; 32] {
RAW_KEY
.get()
.expect("Encryption key not initialized. Call crypto::init_with_key() or crypto::init_from_env() first.")
}
/// Encrypt a plaintext string and return a base64-encoded ciphertext (nonce || ciphertext || tag).
pub fn encrypt(plaintext: &str) -> Result<String> {
let key = Key::<Aes256Gcm>::from_slice(get_key());
let cipher = Aes256Gcm::new(key);
let nonce = Aes256Gcm::generate_nonce(&mut OsRng); // 12 bytes
let ciphertext = cipher
.encrypt(&nonce, plaintext.as_bytes())
.map_err(|e| anyhow!("Encryption failed: {}", e))?;
// Combine nonce and ciphertext (ciphertext already includes tag)
let mut combined = Vec::with_capacity(nonce.len() + ciphertext.len());
combined.extend_from_slice(&nonce);
combined.extend_from_slice(&ciphertext);
Ok(BASE64.encode(combined))
}
/// Decrypt a base64-encoded ciphertext (nonce || ciphertext || tag) to a plaintext string.
pub fn decrypt(ciphertext_b64: &str) -> Result<String> {
let key = Key::<Aes256Gcm>::from_slice(get_key());
let cipher = Aes256Gcm::new(key);
let combined = BASE64.decode(ciphertext_b64).context("Invalid base64 ciphertext")?;
if combined.len() < 12 {
anyhow::bail!("Ciphertext too short");
}
let (nonce_bytes, ciphertext_and_tag) = combined.split_at(12);
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext_bytes = cipher
.decrypt(nonce, ciphertext_and_tag)
.map_err(|e| anyhow!("Decryption failed (invalid key or corrupted ciphertext): {}", e))?;
String::from_utf8(plaintext_bytes).context("Decrypted bytes are not valid UTF-8")
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_KEY_HEX: &str = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f";
#[test]
fn test_encrypt_decrypt() {
init_with_key(TEST_KEY_HEX).unwrap();
let plaintext = "super secret api key";
let ciphertext = encrypt(plaintext).unwrap();
assert_ne!(ciphertext, plaintext);
let decrypted = decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_different_inputs_produce_different_ciphertexts() {
init_with_key(TEST_KEY_HEX).unwrap();
let plaintext = "same";
let cipher1 = encrypt(plaintext).unwrap();
let cipher2 = encrypt(plaintext).unwrap();
assert_ne!(cipher1, cipher2, "Nonce should make ciphertexts differ");
assert_eq!(decrypt(&cipher1).unwrap(), plaintext);
assert_eq!(decrypt(&cipher2).unwrap(), plaintext);
}
#[test]
fn test_invalid_key_length() {
let result = init_with_key("tooshort");
assert!(result.is_err());
}
#[test]
fn test_init_from_env() {
unsafe { std::env::set_var("LLM_PROXY__ENCRYPTION_KEY", TEST_KEY_HEX) };
let result = init_from_env();
assert!(result.is_ok());
// Ensure encryption works
let ciphertext = encrypt("test").unwrap();
let decrypted = decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, "test");
}
#[test]
fn test_missing_env_key() {
unsafe { std::env::remove_var("LLM_PROXY__ENCRYPTION_KEY") };
let result = init_from_env();
assert!(result.is_err());
}
#[test]
fn test_key_hex_and_base64() {
// Hex key works
init_with_key(TEST_KEY_HEX).unwrap();
// Base64 key (same bytes encoded as base64)
let base64_key = BASE64.encode(hex::decode(TEST_KEY_HEX).unwrap());
// Re-initialization with same key (different encoding) is allowed
let result = init_with_key(&base64_key);
assert!(result.is_ok());
// Encryption should still work
let ciphertext = encrypt("test").unwrap();
let decrypted = decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, "test");
}
#[test]
#[ignore] // conflicts with global state from other tests
fn test_already_initialized() {
init_with_key(TEST_KEY_HEX).unwrap();
let result = init_with_key(TEST_KEY_HEX);
assert!(result.is_ok()); // same key allowed
}
#[test]
#[ignore] // conflicts with global state from other tests
fn test_already_initialized_different_key() {
init_with_key(TEST_KEY_HEX).unwrap();
let different_key = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e20";
let result = init_with_key(different_key);
assert!(result.is_err());
}
}

View File

@@ -1,4 +0,0 @@
pub mod crypto;
pub mod registry;
pub mod streaming;
pub mod tokens;

View File

@@ -1,49 +0,0 @@
use crate::models::registry::ModelRegistry;
use anyhow::Result;
use tracing::info;
const MODELS_DEV_URL: &str = "https://models.dev/api.json";
pub async fn fetch_registry() -> Result<ModelRegistry> {
info!("Fetching model registry from {}", MODELS_DEV_URL);
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()?;
let response = client.get(MODELS_DEV_URL).send().await?;
if !response.status().is_success() {
return Err(anyhow::anyhow!("Failed to fetch registry: HTTP {}", response.status()));
}
let registry: ModelRegistry = response.json().await?;
info!("Successfully loaded model registry");
Ok(registry)
}
/// Heuristic: decide whether a model should be routed to OpenAI Responses API
/// instead of the legacy chat/completions endpoint.
///
/// Currently this uses simple patterns (codex, gpt-5 series) and also checks
/// the loaded registry metadata name for the substring "codex" as a hint.
pub fn model_prefers_responses(registry: &ModelRegistry, model: &str) -> bool {
let model_lc = model.to_lowercase();
if model_lc.contains("codex") {
return true;
}
if model_lc.starts_with("gpt-5") || model_lc.contains("gpt-5.") {
return true;
}
if let Some(meta) = registry.find_model(model) {
if meta.name.to_lowercase().contains("codex") {
return true;
}
}
false
}

View File

@@ -1,340 +0,0 @@
use crate::errors::AppError;
use crate::logging::{RequestLog, RequestLogger};
use crate::models::ToolCall;
use crate::providers::{Provider, ProviderStreamChunk, StreamUsage};
use crate::state::ModelConfigCache;
use crate::utils::tokens::estimate_completion_tokens;
use futures::stream::Stream;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
/// Configuration for creating an AggregatingStream.
pub struct StreamConfig {
pub client_id: String,
pub provider: Arc<dyn Provider>,
pub model: String,
pub prompt_tokens: u32,
pub has_images: bool,
pub logger: Arc<RequestLogger>,
pub model_registry: Arc<crate::models::registry::ModelRegistry>,
pub model_config_cache: ModelConfigCache,
}
pub struct AggregatingStream<S> {
inner: S,
client_id: String,
provider: Arc<dyn Provider>,
model: String,
prompt_tokens: u32,
has_images: bool,
accumulated_content: String,
accumulated_reasoning: String,
accumulated_tool_calls: Vec<ToolCall>,
/// Real usage data from the provider's final stream chunk (when available).
real_usage: Option<StreamUsage>,
logger: Arc<RequestLogger>,
model_registry: Arc<crate::models::registry::ModelRegistry>,
model_config_cache: ModelConfigCache,
start_time: std::time::Instant,
has_logged: bool,
}
impl<S> AggregatingStream<S>
where
S: Stream<Item = Result<ProviderStreamChunk, AppError>> + Unpin,
{
pub fn new(inner: S, config: StreamConfig) -> Self {
Self {
inner,
client_id: config.client_id,
provider: config.provider,
model: config.model,
prompt_tokens: config.prompt_tokens,
has_images: config.has_images,
accumulated_content: String::new(),
accumulated_reasoning: String::new(),
accumulated_tool_calls: Vec::new(),
real_usage: None,
logger: config.logger,
model_registry: config.model_registry,
model_config_cache: config.model_config_cache,
start_time: std::time::Instant::now(),
has_logged: false,
}
}
fn finalize(&mut self) {
if self.has_logged {
return;
}
self.has_logged = true;
let duration = self.start_time.elapsed();
let client_id = self.client_id.clone();
let provider_name = self.provider.name().to_string();
let model = self.model.clone();
let logger = self.logger.clone();
let provider = self.provider.clone();
let estimated_prompt_tokens = self.prompt_tokens;
let has_images = self.has_images;
let registry = self.model_registry.clone();
let config_cache = self.model_config_cache.clone();
let real_usage = self.real_usage.take();
// Estimate completion tokens (including reasoning if present)
let estimated_content_tokens = estimate_completion_tokens(&self.accumulated_content, &model);
let estimated_reasoning_tokens = if !self.accumulated_reasoning.is_empty() {
estimate_completion_tokens(&self.accumulated_reasoning, &model)
} else {
0
};
let estimated_completion = estimated_content_tokens + estimated_reasoning_tokens;
// Spawn a background task to log the completion
tokio::spawn(async move {
// Use real usage from the provider when available, otherwise fall back to estimates
let (prompt_tokens, completion_tokens, reasoning_tokens, total_tokens, cache_read_tokens, cache_write_tokens) =
if let Some(usage) = &real_usage {
(
usage.prompt_tokens,
usage.completion_tokens,
usage.reasoning_tokens,
usage.total_tokens,
usage.cache_read_tokens,
usage.cache_write_tokens,
)
} else {
(
estimated_prompt_tokens,
estimated_completion,
estimated_reasoning_tokens,
estimated_prompt_tokens + estimated_completion,
0u32,
0u32,
)
};
// Check in-memory cache for cost overrides (no SQLite hit)
let cost = if let Some(cached) = config_cache.get(&model).await {
if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) {
// Manual overrides logic: if cache rates are provided, use cache-aware formula.
let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens);
let mut total = (non_cached_prompt as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0);
if let Some(cr) = cached.cache_read_cost_per_m {
total += cache_read_tokens as f64 * cr / 1_000_000.0;
} else {
// Charge cached tokens at full input rate if no specific rate provided
total += cache_read_tokens as f64 * p / 1_000_000.0;
}
if let Some(cw) = cached.cache_write_cost_per_m {
total += cache_write_tokens as f64 * cw / 1_000_000.0;
}
total
} else {
provider.calculate_cost(
&model,
prompt_tokens,
completion_tokens,
cache_read_tokens,
cache_write_tokens,
&registry,
)
}
} else {
provider.calculate_cost(
&model,
prompt_tokens,
completion_tokens,
cache_read_tokens,
cache_write_tokens,
&registry,
)
};
// Log to database
logger.log_request(RequestLog {
timestamp: chrono::Utc::now(),
client_id: client_id.clone(),
provider: provider_name,
model,
prompt_tokens,
completion_tokens,
reasoning_tokens,
total_tokens,
cache_read_tokens,
cache_write_tokens,
cost,
has_images,
status: "success".to_string(),
error_message: None,
duration_ms: duration.as_millis() as u64,
});
});
}
}
impl<S> Stream for AggregatingStream<S>
where
S: Stream<Item = Result<ProviderStreamChunk, AppError>> + Unpin,
{
type Item = Result<ProviderStreamChunk, AppError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let result = Pin::new(&mut self.inner).poll_next(cx);
match &result {
Poll::Ready(Some(Ok(chunk))) => {
self.accumulated_content.push_str(&chunk.content);
if let Some(reasoning) = &chunk.reasoning_content {
self.accumulated_reasoning.push_str(reasoning);
}
// Capture real usage from the provider when present (typically on the final chunk)
if let Some(usage) = &chunk.usage {
self.real_usage = Some(usage.clone());
}
// Accumulate tool call deltas into complete tool calls
if let Some(deltas) = &chunk.tool_calls {
for delta in deltas {
let idx = delta.index as usize;
// Grow the accumulated_tool_calls vec if needed
while self.accumulated_tool_calls.len() <= idx {
self.accumulated_tool_calls.push(ToolCall {
id: String::new(),
call_type: "function".to_string(),
function: crate::models::FunctionCall {
name: String::new(),
arguments: String::new(),
},
});
}
let tc = &mut self.accumulated_tool_calls[idx];
if let Some(id) = &delta.id {
tc.id.clone_from(id);
}
if let Some(ct) = &delta.call_type {
tc.call_type.clone_from(ct);
}
if let Some(f) = &delta.function {
if let Some(name) = &f.name {
tc.function.name.push_str(name);
}
if let Some(args) = &f.arguments {
tc.function.arguments.push_str(args);
}
}
}
}
}
Poll::Ready(Some(Err(_))) => {
// If there's an error, we might still want to log what we got so far?
// For now, just finalize if we have content
if !self.accumulated_content.is_empty() {
self.finalize();
}
}
Poll::Ready(None) => {
self.finalize();
}
Poll::Pending => {}
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
use futures::stream::{self, StreamExt};
// Simple mock provider for testing
struct MockProvider;
#[async_trait::async_trait]
impl Provider for MockProvider {
fn name(&self) -> &str {
"mock"
}
fn supports_model(&self, _model: &str) -> bool {
true
}
fn supports_multimodal(&self) -> bool {
false
}
async fn chat_completion(
&self,
_req: crate::models::UnifiedRequest,
) -> Result<crate::providers::ProviderResponse, AppError> {
unimplemented!()
}
async fn chat_completion_stream(
&self,
_req: crate::models::UnifiedRequest,
) -> Result<futures::stream::BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
unimplemented!()
}
fn estimate_tokens(&self, _req: &crate::models::UnifiedRequest) -> Result<u32> {
Ok(10)
}
fn calculate_cost(&self, _model: &str, _p: u32, _c: u32, _cr: u32, _cw: u32, _r: &crate::models::registry::ModelRegistry) -> f64 {
0.05
}
}
#[tokio::test]
async fn test_aggregating_stream() {
let chunks = vec![
Ok(ProviderStreamChunk {
content: "Hello".to_string(),
reasoning_content: None,
finish_reason: None,
tool_calls: None,
model: "test".to_string(),
usage: None,
}),
Ok(ProviderStreamChunk {
content: " World".to_string(),
reasoning_content: None,
finish_reason: Some("stop".to_string()),
tool_calls: None,
model: "test".to_string(),
usage: None,
}),
];
let inner_stream = stream::iter(chunks);
let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
let (dashboard_tx, _) = tokio::sync::broadcast::channel(16);
let logger = Arc::new(RequestLogger::new(pool.clone(), dashboard_tx));
let registry = Arc::new(crate::models::registry::ModelRegistry {
providers: std::collections::HashMap::new(),
});
let mut agg_stream = AggregatingStream::new(
inner_stream,
StreamConfig {
client_id: "client_1".to_string(),
provider: Arc::new(MockProvider),
model: "test".to_string(),
prompt_tokens: 10,
has_images: false,
logger,
model_registry: registry,
model_config_cache: ModelConfigCache::new(pool.clone()),
},
);
while let Some(item) = agg_stream.next().await {
assert!(item.is_ok());
}
assert_eq!(agg_stream.accumulated_content, "Hello World");
assert!(agg_stream.has_logged);
}
}

View File

@@ -1,69 +0,0 @@
use crate::models::UnifiedRequest;
use tiktoken_rs::get_bpe_from_model;
/// Count tokens for a given model and text
pub fn count_tokens(model: &str, text: &str) -> u32 {
// If we can't get the bpe for the model, fallback to a safe default (cl100k_base for GPT-4/o1)
let bpe = get_bpe_from_model(model)
.unwrap_or_else(|_| tiktoken_rs::cl100k_base().expect("Failed to get cl100k_base encoding"));
bpe.encode_with_special_tokens(text).len() as u32
}
/// Estimate tokens for a unified request.
/// Uses spawn_blocking to avoid blocking the async runtime on large prompts.
pub fn estimate_request_tokens(model: &str, request: &UnifiedRequest) -> u32 {
let mut total_text = String::new();
let msg_count = request.messages.len();
// Base tokens per message for OpenAI (approximate)
let tokens_per_message: u32 = 3;
for msg in &request.messages {
for part in &msg.content {
match part {
crate::models::ContentPart::Text { text } => {
total_text.push_str(text);
total_text.push('\n');
}
crate::models::ContentPart::Image { .. } => {
// Vision models usually have a fixed cost or calculation based on size
}
}
}
}
// Quick heuristic for small inputs (< 1KB) — avoid spawn_blocking overhead
if total_text.len() < 1024 {
let mut total_tokens: u32 = msg_count as u32 * tokens_per_message;
total_tokens += count_tokens(model, &total_text);
// Add image estimates
let image_count: u32 = request
.messages
.iter()
.flat_map(|m| m.content.iter())
.filter(|p| matches!(p, crate::models::ContentPart::Image { .. }))
.count() as u32;
total_tokens += image_count * 1000;
total_tokens += 3; // assistant reply header
return total_tokens;
}
// For large inputs, use the fast heuristic (chars / 4) to avoid blocking
// the async runtime. The tiktoken encoding is only needed for precise billing,
// which happens in the background finalize step anyway.
let estimated_text_tokens = (total_text.len() as u32) / 4;
let image_count: u32 = request
.messages
.iter()
.flat_map(|m| m.content.iter())
.filter(|p| matches!(p, crate::models::ContentPart::Image { .. }))
.count() as u32;
(msg_count as u32 * tokens_per_message) + estimated_text_tokens + (image_count * 1000) + 3
}
/// Estimate tokens for completion text
pub fn estimate_completion_tokens(text: &str, model: &str) -> u32 {
count_tokens(model, text)
}

View File

@@ -148,22 +148,54 @@ body {
width: 80px; width: 80px;
height: 80px; height: 80px;
margin: 0 auto 1.25rem; margin: 0 auto 1.25rem;
border-radius: 16px;
background: var(--bg2);
display: flex; display: flex;
align-items: center; align-items: center;
justify-content: center; justify-content: center;
color: var(--orange); background: rgba(254, 128, 25, 0.15);
font-size: 2rem; color: var(--primary);
border-radius: 12px;
font-size: 2.5rem;
}
/* GopherGate Logo Icon */
.logo-icon-container {
width: 60px;
height: 60px;
background: var(--blue-light);
border-radius: 12px;
display: flex;
align-items: center;
justify-content: center;
box-shadow: var(--shadow); box-shadow: var(--shadow);
border: 2px solid var(--fg1);
margin: 0 auto;
}
.logo-icon-container.small {
width: 32px;
height: 32px;
border-radius: 6px;
margin: 0;
}
.logo-icon-text {
font-family: 'JetBrains Mono', monospace;
font-weight: 700;
color: var(--bg0);
font-size: 1.8rem;
}
.logo-icon-container.small .logo-icon-text {
font-size: 1rem;
} }
.login-header h1 { .login-header h1 {
font-size: 1.75rem; font-size: 2rem;
font-weight: 800; font-weight: 800;
color: var(--fg0); color: var(--primary-light);
margin-bottom: 0.5rem; margin-bottom: 0.5rem;
letter-spacing: -0.025em; letter-spacing: -0.025em;
text-transform: uppercase;
} }
.login-subtitle { .login-subtitle {
@@ -297,6 +329,25 @@ body {
font-size: 1.125rem; font-size: 1.125rem;
} }
/* Badges */
.badge {
display: inline-block;
padding: 0.25rem 0.5rem;
font-size: 0.75rem;
font-weight: 600;
line-height: 1;
text-align: center;
white-space: nowrap;
vertical-align: baseline;
border-radius: 4px;
}
.badge-success { background-color: rgba(152, 151, 26, 0.15); color: var(--green-light); border: 1px solid var(--green); }
.badge-info { background-color: rgba(69, 133, 136, 0.15); color: var(--blue-light); border: 1px solid var(--blue); }
.badge-warning { background-color: rgba(215, 153, 33, 0.15); color: var(--yellow-light); border: 1px solid var(--yellow); }
.badge-danger { background-color: rgba(204, 36, 29, 0.15); color: var(--red-light); border: 1px solid var(--red); }
.badge-client { background-color: var(--bg2); color: var(--fg1); border: 1px solid var(--bg3); padding: 2px 6px; font-size: 0.7rem; text-transform: uppercase; }
/* Responsive Login */ /* Responsive Login */
@media (max-width: 480px) { @media (max-width: 480px) {
.login-card { .login-card {
@@ -375,11 +426,15 @@ body {
} }
.sidebar.collapsed .logo { .sidebar.collapsed .logo {
display: flex;
}
.sidebar.collapsed .logo span {
display: none; display: none;
} }
.sidebar.collapsed .sidebar-toggle { .sidebar.collapsed .sidebar-toggle {
opacity: 1; margin-left: 0;
} }
.logo { .logo {
@@ -394,6 +449,7 @@ body {
white-space: nowrap; white-space: nowrap;
} }
.sidebar-logo { .sidebar-logo {
width: 32px; width: 32px;
height: 32px; height: 32px;
@@ -588,17 +644,48 @@ body {
/* Main Content Area */ /* Main Content Area */
.main-content { .main-content {
margin-left: 260px; padding-left: 260px;
flex: 1; flex: 1;
min-height: 100vh; min-height: 100vh;
transition: all 0.3s; transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
display: flex; display: flex;
flex-direction: column; flex-direction: column;
background-color: var(--bg-primary); background-color: var(--bg-primary);
} }
.sidebar.collapsed ~ .main-content { .sidebar.collapsed + .main-content {
margin-left: 80px; padding-left: 80px;
}
.top-bar {
height: 70px;
background: var(--bg0);
border-bottom: 1px solid var(--bg2);
display: flex;
align-items: center;
justify-content: space-between;
padding: 0 var(--spacing-xl);
position: sticky;
top: 0;
z-index: 100;
}
.top-bar .page-title h2 {
font-size: 1.25rem;
font-weight: 700;
color: var(--fg0);
}
.top-bar-actions {
display: flex;
align-items: center;
gap: var(--spacing-lg);
}
.content-body {
padding: var(--spacing-xl);
flex: 1;
position: relative;
} }
.top-nav { .top-nav {
@@ -1047,6 +1134,53 @@ body {
gap: 0.75rem; gap: 0.75rem;
} }
/* Connection Status Indicator */
.status-indicator {
display: flex;
align-items: center;
gap: 0.75rem;
padding: 0.5rem 0.875rem;
background: var(--bg1);
border: 1px solid var(--bg3);
border-radius: 6px;
font-size: 0.8rem;
font-weight: 600;
color: var(--fg3);
transition: all 0.2s;
}
.status-dot {
width: 8px;
height: 8px;
border-radius: 50%;
background: var(--fg4);
position: relative;
}
.status-dot.connected {
background: var(--green-light);
box-shadow: 0 0 0 0 rgba(184, 187, 38, 0.4);
animation: status-pulse 2s infinite;
}
.status-dot.disconnected {
background: var(--red-light);
}
.status-dot.connecting {
background: var(--yellow-light);
}
.status-dot.error {
background: var(--red);
}
@keyframes status-pulse {
0% { box-shadow: 0 0 0 0 rgba(184, 187, 38, 0.4); }
70% { box-shadow: 0 0 0 6px rgba(184, 187, 38, 0); }
100% { box-shadow: 0 0 0 0 rgba(184, 187, 38, 0); }
}
/* WebSocket Dot Pulse */ /* WebSocket Dot Pulse */
@keyframes ws-pulse { @keyframes ws-pulse {
0% { box-shadow: 0 0 0 0 rgba(184, 187, 38, 0.4); } 0% { box-shadow: 0 0 0 0 rgba(184, 187, 38, 0.4); }

BIN
static/favicon.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1002 B

View File

@@ -3,49 +3,38 @@
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>LLM Proxy Gateway - Admin Dashboard</title> <title>GopherGate - Admin Dashboard</title>
<link rel="stylesheet" href="/css/dashboard.css?v=11"> <link rel="stylesheet" href="/css/dashboard.css?v=11">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css"> <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
<link rel="icon" href="img/logo-icon.png" type="image/png" sizes="any"> <link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="apple-touch-icon" href="img/logo-icon.png"> <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Fira+Code:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;700&display=swap" rel="stylesheet"> <link href="https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;700&family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
<script src="https://cdn.jsdelivr.net/npm/luxon@3.4.4/build/global/luxon.min.js"></script>
</head> </head>
<body> <body class="gruvbox-dark">
<!-- Login Screen --> <!-- Auth Page -->
<div id="login-screen" class="login-container"> <div id="auth-page" class="login-container">
<div class="login-card"> <div class="login-card">
<div class="login-header"> <div class="login-header">
<i class="fas fa-terminal login-logo-fallback"></i> <div class="logo-icon-container">
<h1>LLM Proxy Gateway</h1> <span class="logo-icon-text">GG</span>
<p class="login-subtitle">Admin Dashboard</p> </div>
<h1>GopherGate</h1>
<p class="login-subtitle">Secure LLM Gateway & Management</p>
</div> </div>
<form id="login-form" class="login-form" onsubmit="event.preventDefault();"> <form id="login-form">
<div class="form-group"> <div class="form-control">
<input type="text" id="username" name="username" placeholder=" " required> <label for="username">Username</label>
<label for="username"> <input type="text" id="username" name="username" required autocomplete="username">
<i class="fas fa-user"></i> Username
</label>
</div> </div>
<div class="form-group"> <div class="form-control">
<input type="password" id="password" name="password" placeholder=" " required> <label for="password">Password</label>
<label for="password"> <input type="password" id="password" name="password" required autocomplete="current-password">
<i class="fas fa-lock"></i> Password
</label>
</div>
<div class="form-group">
<button type="submit" class="login-btn">
<i class="fas fa-sign-in-alt"></i> Sign In
</button>
</div>
<div class="login-footer">
<p>Default: <code>admin</code> / <code>admin</code> (change in Settings &gt; Security)</p>
</div> </div>
<button type="submit" id="login-btn" class="btn btn-primary btn-block">Sign In</button>
</form> </form>
<div id="login-error" class="error-message" style="display: none;"> <div id="login-error" class="error-message" style="display: none;">
<i class="fas fa-exclamation-circle"></i> <i class="fas fa-exclamation-circle"></i>
<span>Invalid credentials. Please try again.</span> <span></span>
</div> </div>
</div> </div>
</div> </div>
@@ -56,9 +45,10 @@
<nav class="sidebar"> <nav class="sidebar">
<div class="sidebar-header"> <div class="sidebar-header">
<div class="logo"> <div class="logo">
<img src="img/logo-icon.png" alt="LLM Proxy" class="sidebar-logo" onerror="this.style.display='none'; this.nextElementSibling.style.display='inline-block';"> <div class="logo-icon-container small">
<i class="fas fa-shield-alt logo-fallback" style="display: none;"></i> <span class="logo-icon-text">GG</span>
<span>LLM Proxy</span> </div>
<span>GopherGate</span>
</div> </div>
<button class="sidebar-toggle" id="sidebar-toggle"> <button class="sidebar-toggle" id="sidebar-toggle">
<i class="fas fa-bars"></i> <i class="fas fa-bars"></i>
@@ -68,68 +58,74 @@
<div class="sidebar-menu"> <div class="sidebar-menu">
<div class="menu-section"> <div class="menu-section">
<h3 class="menu-title">MAIN</h3> <h3 class="menu-title">MAIN</h3>
<a href="#overview" class="menu-item active" data-page="overview" data-tooltip="Dashboard Overview"> <ul class="menu-list">
<i class="fas fa-th-large"></i> <li class="menu-item active" data-page="overview">
<span>Overview</span> <i class="fas fa-th-large"></i>
</a> <span>Overview</span>
<a href="#analytics" class="menu-item" data-page="analytics" data-tooltip="Usage Analytics"> </li>
<i class="fas fa-chart-line"></i> <li class="menu-item" data-page="analytics">
<span>Analytics</span> <i class="fas fa-chart-bar"></i>
</a> <span>Analytics</span>
<a href="#costs" class="menu-item" data-page="costs" data-tooltip="Cost Tracking"> </li>
<i class="fas fa-dollar-sign"></i> <li class="menu-item" data-page="costs">
<span>Cost Management</span> <i class="fas fa-dollar-sign"></i>
</a> <span>Costs & Billing</span>
</li>
</ul>
</div> </div>
<div class="menu-section"> <div class="menu-section">
<h3 class="menu-title">MANAGEMENT</h3> <h3 class="menu-title">MANAGEMENT</h3>
<a href="#clients" class="menu-item" data-page="clients" data-tooltip="API Clients"> <ul class="menu-list">
<i class="fas fa-users"></i> <li class="menu-item" data-page="clients">
<span>Client Management</span> <i class="fas fa-users"></i>
</a> <span>Clients</span>
<a href="#providers" class="menu-item" data-page="providers" data-tooltip="Model Providers"> </li>
<i class="fas fa-server"></i> <li class="menu-item" data-page="providers">
<span>Providers</span> <i class="fas fa-server"></i>
</a> <span>Providers</span>
<a href="#models" class="menu-item" data-page="models" data-tooltip="Manage Models"> </li>
<i class="fas fa-cube"></i> <li class="menu-item" data-page="models">
<span>Models</span> <i class="fas fa-brain"></i>
</a> <span>Models</span>
<a href="#monitoring" class="menu-item" data-page="monitoring" data-tooltip="Live Monitoring"> </li>
<i class="fas fa-heartbeat"></i> </ul>
<span>Real-time Monitoring</span>
</a>
</div> </div>
<div class="menu-section"> <div class="menu-section">
<h3 class="menu-title">SYSTEM</h3> <h3 class="menu-title">SYSTEM</h3>
<a href="#users" class="menu-item admin-only" data-page="users" data-tooltip="User Accounts"> <ul class="menu-list">
<i class="fas fa-user-shield"></i> <li class="menu-item" data-page="monitoring">
<span>User Management</span> <i class="fas fa-activity"></i>
</a> <span>Live Monitoring</span>
<a href="#settings" class="menu-item admin-only" data-page="settings" data-tooltip="System Settings"> </li>
<i class="fas fa-cog"></i> <li class="menu-item" data-page="logs">
<span>Settings</span> <i class="fas fa-list-alt"></i>
</a> <span>Logs</span>
<a href="#logs" class="menu-item" data-page="logs" data-tooltip="System Logs"> </li>
<i class="fas fa-list-alt"></i> <li class="menu-item" data-page="users">
<span>System Logs</span> <i class="fas fa-user-shield"></i>
</a> <span>Admin Users</span>
</li>
<li class="menu-item" data-page="settings">
<i class="fas fa-cog"></i>
<span>Settings</span>
</li>
</ul>
</div> </div>
</div> </div>
<div class="sidebar-footer"> <div class="sidebar-footer">
<div class="user-info"> <div class="user-info">
<div class="user-avatar"> <div class="user-avatar">
<i class="fas fa-user-circle"></i> <i class="fas fa-user"></i>
</div> </div>
<div class="user-details"> <div class="user-details">
<span class="user-name">Loading...</span> <div class="user-name" id="display-username">Admin</div>
<span class="user-role">...</span> <div class="user-role" id="display-role">Administrator</div>
</div> </div>
</div> </div>
<button class="logout-btn" id="logout-btn" title="Logout"> <button id="logout-btn" class="btn-icon" title="Logout">
<i class="fas fa-sign-out-alt"></i> <i class="fas fa-sign-out-alt"></i>
</button> </button>
</div> </div>
@@ -137,43 +133,40 @@
<!-- Main Content --> <!-- Main Content -->
<main class="main-content"> <main class="main-content">
<!-- Top Navigation --> <header class="top-bar">
<header class="top-nav"> <div class="page-title">
<div class="nav-left"> <h2 id="current-page-title">Overview</h2>
<h1 class="page-title" id="page-title">Dashboard Overview</h1>
</div> </div>
<div class="nav-right"> <div class="top-bar-actions">
<div class="nav-item" id="ws-status-nav" title="WebSocket Connection Status"> <div id="connection-status" class="status-indicator">
<div class="ws-dot"></div> <span class="status-dot"></span>
<span class="ws-text">Connecting...</span> <span class="status-text">Disconnected</span>
</div> </div>
<div class="nav-item" title="Refresh Current Page"> <div class="theme-toggle" id="theme-toggle">
<i class="fas fa-sync-alt" id="refresh-btn"></i> <i class="fas fa-moon"></i>
</div>
<div class="nav-item">
<span id="current-time">Loading...</span>
</div> </div>
</div> </div>
</header> </header>
<!-- Page Content --> <div id="page-content" class="content-body">
<div class="page-content" id="page-content"> <!-- Content will be loaded dynamically -->
<!-- Dynamic content container --> <div class="loader-container">
</div> <div class="loader"></div>
</div>
<!-- Global Spinner -->
<div class="spinner-container">
<div class="spinner"></div>
</div> </div>
</main> </main>
</div> </div>
<!-- Scripts (cache-busted with version query params) --> <!-- Scripts -->
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
<script src="https://cdn.jsdelivr.net/npm/luxon@3.3.0/build/global/luxon.min.js"></script>
<script src="/js/api.js?v=7"></script> <script src="/js/api.js?v=7"></script>
<script src="/js/auth.js?v=7"></script> <script src="/js/auth.js?v=7"></script>
<script src="/js/dashboard.js?v=7"></script>
<script src="/js/websocket.js?v=7"></script>
<script src="/js/charts.js?v=7"></script> <script src="/js/charts.js?v=7"></script>
<script src="/js/websocket.js?v=7"></script>
<script src="/js/dashboard.js?v=7"></script>
<!-- Page Modules -->
<script src="/js/pages/overview.js?v=7"></script> <script src="/js/pages/overview.js?v=7"></script>
<script src="/js/pages/analytics.js?v=7"></script> <script src="/js/pages/analytics.js?v=7"></script>
<script src="/js/pages/costs.js?v=7"></script> <script src="/js/pages/costs.js?v=7"></script>
@@ -185,4 +178,4 @@
<script src="/js/pages/logs.js?v=7"></script> <script src="/js/pages/logs.js?v=7"></script>
<script src="/js/pages/users.js?v=7"></script> <script src="/js/pages/users.js?v=7"></script>
</body> </body>
</html> </html>

View File

@@ -1,4 +1,4 @@
// Authentication Module for LLM Proxy Dashboard // Authentication Module for GopherGate Dashboard
class AuthManager { class AuthManager {
constructor() { constructor() {
@@ -58,7 +58,7 @@ class AuthManager {
async login(username, password) { async login(username, password) {
const errorElement = document.getElementById('login-error'); const errorElement = document.getElementById('login-error');
const loginBtn = document.querySelector('.login-btn'); const loginBtn = document.getElementById('login-btn');
try { try {
loginBtn.innerHTML = '<i class="fas fa-spinner fa-spin"></i> Authenticating...'; loginBtn.innerHTML = '<i class="fas fa-spinner fa-spin"></i> Authenticating...';
@@ -124,7 +124,7 @@ class AuthManager {
} }
showLogin() { showLogin() {
const loginScreen = document.getElementById('login-screen'); const loginScreen = document.getElementById('auth-page');
const dashboard = document.getElementById('dashboard'); const dashboard = document.getElementById('dashboard');
if (loginScreen) loginScreen.style.display = 'flex'; if (loginScreen) loginScreen.style.display = 'flex';
@@ -139,7 +139,7 @@ class AuthManager {
if (errorElement) errorElement.style.display = 'none'; if (errorElement) errorElement.style.display = 'none';
// Reset button // Reset button
const loginBtn = document.querySelector('.login-btn'); const loginBtn = document.getElementById('login-btn');
if (loginBtn) { if (loginBtn) {
loginBtn.innerHTML = '<i class="fas fa-sign-in-alt"></i> Sign In'; loginBtn.innerHTML = '<i class="fas fa-sign-in-alt"></i> Sign In';
loginBtn.disabled = false; loginBtn.disabled = false;
@@ -147,7 +147,7 @@ class AuthManager {
} }
showDashboard() { showDashboard() {
const loginScreen = document.getElementById('login-screen'); const loginScreen = document.getElementById('auth-page');
const dashboard = document.getElementById('dashboard'); const dashboard = document.getElementById('dashboard');
if (loginScreen) loginScreen.style.display = 'none'; if (loginScreen) loginScreen.style.display = 'none';
@@ -167,7 +167,7 @@ class AuthManager {
const userRoleElement = document.querySelector('.user-role'); const userRoleElement = document.querySelector('.user-role');
if (userNameElement && this.user) { if (userNameElement && this.user) {
userNameElement.textContent = this.user.name || this.user.username || 'User'; userNameElement.textContent = this.user.display_name || this.user.username || 'User';
} }
if (userRoleElement && this.user) { if (userRoleElement && this.user) {

View File

@@ -492,7 +492,7 @@ class MonitoringPage {
simulateRequest() { simulateRequest() {
const clients = ['client-1', 'client-2', 'client-3', 'client-4', 'client-5']; const clients = ['client-1', 'client-2', 'client-3', 'client-4', 'client-5'];
const providers = ['OpenAI', 'Gemini', 'DeepSeek', 'Grok']; const providers = ['OpenAI', 'Gemini', 'DeepSeek', 'Grok'];
const models = ['gpt-4', 'gpt-3.5-turbo', 'gemini-pro', 'deepseek-chat', 'grok-beta']; const models = ['gpt-4o', 'gpt-4o-mini', 'gemini-2.0-flash', 'deepseek-chat', 'grok-4-1-fast-non-reasoning'];
const statuses = ['success', 'success', 'success', 'error', 'warning']; // Mostly success const statuses = ['success', 'success', 'success', 'error', 'warning']; // Mostly success
const request = { const request = {

View File

@@ -248,21 +248,19 @@ class WebSocketManager {
} }
updateStatus(status) { updateStatus(status) {
const statusElement = document.getElementById('ws-status-nav'); const statusElement = document.getElementById('connection-status');
if (!statusElement) return; if (!statusElement) return;
const dot = statusElement.querySelector('.ws-dot'); const dot = statusElement.querySelector('.status-dot');
const text = statusElement.querySelector('.ws-text'); const text = statusElement.querySelector('.status-text');
if (!dot || !text) return; if (!dot || !text) return;
// Remove all status classes // Remove all status classes
dot.classList.remove('connected', 'disconnected'); dot.classList.remove('connected', 'disconnected', 'error', 'connecting');
statusElement.classList.remove('connected', 'disconnected');
// Add new status class // Add new status class
dot.classList.add(status); dot.classList.add(status);
statusElement.classList.add(status);
// Update text // Update text
const statusText = { const statusText = {

View File

@@ -1,38 +0,0 @@
#!/bin/bash
# Test script for LLM Proxy Dashboard
echo "Building LLM Proxy Gateway..."
cargo build --release
echo ""
echo "Starting server in background..."
./target/release/llm-proxy &
SERVER_PID=$!
# Wait for server to start
sleep 3
echo ""
echo "Testing dashboard endpoints..."
# Test health endpoint
echo "1. Testing health endpoint:"
curl -s http://localhost:8080/health
echo ""
echo "2. Testing dashboard static files:"
curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/
echo ""
echo "3. Testing API endpoints:"
curl -s http://localhost:8080/api/auth/status | jq . 2>/dev/null || echo "JSON response received"
echo ""
echo "Dashboard should be available at: http://localhost:8080"
echo "Default login: admin / admin"
echo ""
echo "Press Ctrl+C to stop the server"
# Keep script running
wait $SERVER_PID

View File

@@ -1,75 +0,0 @@
#!/bin/bash
# Test script for LLM Proxy Gateway
echo "Building LLM Proxy Gateway..."
cargo build --release
if [ $? -ne 0 ]; then
echo "Build failed!"
exit 1
fi
echo "Build successful!"
echo ""
echo "Project Structure Summary:"
echo "=========================="
echo "Core Components:"
echo " - main.rs: Application entry point with server setup"
echo " - config/: Configuration management"
echo " - server/: API route handlers"
echo " - auth/: Bearer token authentication"
echo " - database/: SQLite database setup"
echo " - models/: Data structures (OpenAI-compatible)"
echo " - providers/: LLM provider implementations (OpenAI, Gemini, DeepSeek, Grok)"
echo " - errors/: Custom error types"
echo " - dashboard/: Admin dashboard with WebSocket support"
echo " - logging/: Request logging middleware"
echo " - state/: Shared application state"
echo " - multimodal/: Image processing support (basic structure)"
echo ""
echo "Key Features Implemented:"
echo "=========================="
echo "✓ OpenAI-compatible API endpoint (/v1/chat/completions)"
echo "✓ Bearer token authentication"
echo "✓ SQLite database for request tracking"
echo "✓ Request logging with token/cost calculation"
echo "✓ Provider abstraction layer"
echo "✓ Admin dashboard with real-time monitoring"
echo "✓ WebSocket support for live updates"
echo "✓ Configuration management (config.toml, .env, env vars)"
echo "✓ Multimodal support structure (images)"
echo "✓ Error handling with proper HTTP status codes"
echo ""
echo "Next Steps Needed:"
echo "=================="
echo "1. Add API keys to .env file:"
echo " OPENAI_API_KEY=your_key_here"
echo " GEMINI_API_KEY=your_key_here"
echo " DEEPSEEK_API_KEY=your_key_here"
echo " GROK_API_KEY=your_key_here (optional)"
echo ""
echo "2. Create config.toml for custom configuration (optional)"
echo ""
echo "3. Run the server:"
echo " cargo run"
echo ""
echo "4. Access dashboard at: http://localhost:8080"
echo ""
echo "5. Test API with curl:"
echo " curl -X POST http://localhost:8080/v1/chat/completions \\"
echo " -H 'Authorization: Bearer your_token' \\"
echo " -H 'Content-Type: application/json' \\"
echo " -d '{\"model\": \"gpt-4\", \"messages\": [{\"role\": \"user\", \"content\": \"Hello\"}]}'"
echo ""
echo "Deployment Notes:"
echo "================="
echo "Memory: Designed for 512MB RAM (LXC container)"
echo "Database: SQLite (./data/llm_proxy.db)"
echo "Port: 8080 (configurable)"
echo "Authentication: Single Bearer token (configurable)"
echo "Providers: OpenAI, Gemini, DeepSeek, Grok (disabled by default)"

View File

@@ -1,14 +0,0 @@
gantt
title LLM Proxy Project Timeline
dateFormat YYYY-MM-DD
section Frontend
Standardize Escaping (users.js) :a1, 2026-03-06, 1d
section Backend Cleanup
Remove Unused Imports :b1, 2026-03-06, 1d
section HMAC Migration
Architecture Design :c1, 2026-03-07, 1d
Backend Implementation :c2, after c1, 2d
Session Refresh Logic :c3, after c2, 1d
section Testing
Integration Test (Encrypted Keys) :d1, 2026-03-09, 2d
HMAC Verification Tests :d2, after c3, 1d