Compare commits
128 Commits
6010ec97a8
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 9375448087 | |||
| 5be2f6f7aa | |||
| eebcadcba1 | |||
| 6b2bd13903 | |||
| 5dfda0a10c | |||
| a8a02d9e1c | |||
| bd1d17cc4d | |||
| 9207a7231c | |||
| c6efff9034 | |||
| 27fbd8ed15 | |||
| 348341f304 | |||
| 9380580504 | |||
| 08cf5cc1d9 | |||
| 0f0486d8d4 | |||
| 0ea2a3a985 | |||
| 21e5908c35 | |||
| 6f0a159245 | |||
| 4120a83b67 | |||
| 742cd9e921 | |||
| 593971ecb5 | |||
| 03dca998df | |||
| 0ce5f4f490 | |||
| dec4b927dc | |||
| 3f1e6d3407 | |||
| f02fd6c249 | |||
| f23796f0cc | |||
| 3f76a544e0 | |||
| e474549940 | |||
| b7e37b0399 | |||
| 263c0f0dc9 | |||
| 26d8431998 | |||
| 1f3adceda4 | |||
| 9c64a8fe42 | |||
| b04b794705 | |||
| 0f3c5b6eb4 | |||
| 66a1643bca | |||
| edc6445d70 | |||
| 2d8f1a1fd0 | |||
| cd1a1b45aa | |||
| 246a6d88f0 | |||
| 7d43b2c31b | |||
| 45c2d5e643 | |||
| 1d032c6732 | |||
| 2245cca67a | |||
| c7c244992a | |||
| 4f5b55d40f | |||
| 90874a6721 | |||
| 6b10d4249c | |||
| 57aa0aa70e | |||
| 4de457cc5e | |||
| 66e8b114b9 | |||
| 1cac45502a | |||
| 79dc8fe409 | |||
| 24a898c9a7 | |||
| 7c2a317c01 | |||
| cb619f9286 | |||
| 441270317c | |||
| 2e4318d84b | |||
| d0be16d8e3 | |||
| 83e0ad0240 | |||
| 275ce34d05 | |||
| cb5b921550 | |||
| 649371154f | |||
| 78fff61660 | |||
| b131094dfd | |||
| c3d81c1733 | |||
| e123f542f1 | |||
| 0d28241e39 | |||
| 754ee9cb84 | |||
| 5a9086b883 | |||
| cc5eba1957 | |||
| 3ab00fb188 | |||
| c2595f7a74 | |||
| 0526304398 | |||
| 75e2967727 | |||
| e1bc3b35eb | |||
| 0d32d953d2 | |||
| bd5ca2dd98 | |||
| 6a0aca1a6c | |||
| 4c629e17cb | |||
| fc3bc6968d | |||
| d6280abad9 | |||
| 96486b6318 | |||
| e8955fd36c | |||
| a243a3987d | |||
| 4be23629d8 | |||
| dd54c14ff8 | |||
| 633b69a07b | |||
| 975ae124d1 | |||
| 9b8483e797 | |||
| d32386df3f | |||
| 149a7c3a29 | |||
| d9cfffea62 | |||
| 90ef026c96 | |||
| 5ddf284b8f | |||
| f5677afba0 | |||
| 4ffc6452e0 | |||
| 94162a3dcc | |||
| c26925c253 | |||
| d0d64e2064 | |||
| 6a324c08c7 | |||
| 1ddb5277e9 | |||
| 1067ceaecd | |||
| fc5d3ed636 | |||
| 7411d3dbed | |||
| e3c1b9fa20 | |||
| c2bad90a8f | |||
| c7b67d5840 | |||
| 7efb36029c | |||
| 6440e8cc13 | |||
| 5c5f836eca | |||
| febfcafed4 | |||
| 811885274b | |||
| e307ecf11d | |||
| eac3781079 | |||
| 76bf5b81d4 | |||
| 90a3f5d7f8 | |||
| f7f6768333 | |||
| 5bbd5f77b9 | |||
| 8a33b147f1 | |||
| 154b7b3b77 | |||
| 3d43948dbe | |||
| a75c10bcd8 | |||
| 0dd6212f0a | |||
| f8598060f9 | |||
| 3086a3b6d9 | |||
| fb98f0ebb8 | |||
| 6b7e245827 |
22
.env
22
.env
@@ -1,22 +0,0 @@
|
||||
# LLM Proxy Gateway Environment Variables
|
||||
|
||||
# OpenAI
|
||||
OPENAI_API_KEY=sk-demo-openai-key
|
||||
|
||||
# Google Gemini
|
||||
GEMINI_API_KEY=AIza-demo-gemini-key
|
||||
|
||||
# DeepSeek
|
||||
DEEPSEEK_API_KEY=sk-demo-deepseek-key
|
||||
|
||||
# xAI Grok (not yet available)
|
||||
GROK_API_KEY=gk-demo-grok-key
|
||||
|
||||
# Authentication tokens (comma-separated list)
|
||||
LLM_PROXY__SERVER__AUTH_TOKENS=demo-token-123456,another-token
|
||||
|
||||
# Server port (optional)
|
||||
LLM_PROXY__SERVER__PORT=8080
|
||||
|
||||
# Database path (optional)
|
||||
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
|
||||
61
.env.example
61
.env.example
@@ -1,28 +1,47 @@
|
||||
# LLM Proxy Gateway Environment Variables
|
||||
# Copy to .env and fill in your API keys
|
||||
# GopherGate Configuration Example
|
||||
# Copy this file to .env and fill in your values
|
||||
|
||||
# OpenAI
|
||||
OPENAI_API_KEY=your_openai_api_key_here
|
||||
# ==============================================================================
|
||||
# MANDATORY: Encryption & Security
|
||||
# ==============================================================================
|
||||
# A 32-byte hex or base64 encoded string used for session signing and
|
||||
# database encryption.
|
||||
# Generate one with: openssl rand -hex 32
|
||||
LLM_PROXY__ENCRYPTION_KEY=your_secure_32_byte_key_here
|
||||
|
||||
# Google Gemini
|
||||
GEMINI_API_KEY=your_gemini_api_key_here
|
||||
# ==============================================================================
|
||||
# LLM Provider API Keys
|
||||
# ==============================================================================
|
||||
OPENAI_API_KEY=sk-...
|
||||
GEMINI_API_KEY=AIza...
|
||||
DEEPSEEK_API_KEY=sk-...
|
||||
MOONSHOT_API_KEY=sk-...
|
||||
GROK_API_KEY=xai-...
|
||||
|
||||
# DeepSeek
|
||||
DEEPSEEK_API_KEY=your_deepseek_api_key_here
|
||||
# ==============================================================================
|
||||
# Server Configuration
|
||||
# ==============================================================================
|
||||
LLM_PROXY__SERVER__PORT=8080
|
||||
LLM_PROXY__SERVER__HOST=0.0.0.0
|
||||
|
||||
# xAI Grok (not yet available)
|
||||
GROK_API_KEY=your_grok_api_key_here
|
||||
# Optional: Bearer tokens for client authentication (comma-separated)
|
||||
# If not set, the proxy will look up tokens in the database.
|
||||
# LLM_PROXY__SERVER__AUTH_TOKENS=token1,token2
|
||||
|
||||
# Ollama (local server)
|
||||
# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://your-ollama-host:11434/v1
|
||||
# ==============================================================================
|
||||
# Database Configuration
|
||||
# ==============================================================================
|
||||
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
|
||||
LLM_PROXY__DATABASE__MAX_CONNECTIONS=10
|
||||
|
||||
# ==============================================================================
|
||||
# Provider Overrides (Optional)
|
||||
# ==============================================================================
|
||||
# LLM_PROXY__PROVIDERS__OPENAI__BASE_URL=https://api.openai.com/v1
|
||||
# LLM_PROXY__PROVIDERS__GEMINI__ENABLED=true
|
||||
# LLM_PROXY__PROVIDERS__MOONSHOT__BASE_URL=https://api.moonshot.ai/v1
|
||||
# LLM_PROXY__PROVIDERS__MOONSHOT__ENABLED=true
|
||||
# LLM_PROXY__PROVIDERS__MOONSHOT__DEFAULT_MODEL=kimi-k2.5
|
||||
# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://localhost:11434/v1
|
||||
# LLM_PROXY__PROVIDERS__OLLAMA__ENABLED=true
|
||||
# LLM_PROXY__PROVIDERS__OLLAMA__MODELS=llama3,mistral,llava
|
||||
|
||||
# Authentication tokens (comma-separated list)
|
||||
LLM_PROXY__SERVER__AUTH_TOKENS=your_bearer_token_here,another_token
|
||||
|
||||
# Server port (optional)
|
||||
LLM_PROXY__SERVER__PORT=8080
|
||||
|
||||
# Database path (optional)
|
||||
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
|
||||
62
.github/workflows/ci.yml
vendored
62
.github/workflows/ci.yml
vendored
@@ -6,56 +6,44 @@ on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
RUST_BACKTRACE: 1
|
||||
|
||||
jobs:
|
||||
check:
|
||||
name: Check
|
||||
lint:
|
||||
name: Lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: cargo check --all-targets
|
||||
|
||||
clippy:
|
||||
name: Clippy
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
components: clippy
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: cargo clippy --all-targets -- -D warnings
|
||||
|
||||
fmt:
|
||||
name: Formatting
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v4
|
||||
with:
|
||||
components: rustfmt
|
||||
- run: cargo fmt --all -- --check
|
||||
version: latest
|
||||
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: cargo test --all-targets
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- name: Run Tests
|
||||
run: go test -v ./...
|
||||
|
||||
build-release:
|
||||
name: Release Build
|
||||
build:
|
||||
name: Build
|
||||
runs-on: ubuntu-latest
|
||||
needs: [check, clippy, test]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
- uses: Swatinem/rust-cache@v2
|
||||
- run: cargo build --release
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- name: Build
|
||||
run: go build -v -o gophergate ./cmd/gophergate
|
||||
|
||||
16
.gitignore
vendored
16
.gitignore
vendored
@@ -1,5 +1,13 @@
|
||||
.env
|
||||
.env.*
|
||||
!.env.example
|
||||
/target
|
||||
/.env
|
||||
/*.db
|
||||
/*.db-shm
|
||||
/*.db-wal
|
||||
/llm-proxy
|
||||
/llm-proxy-go
|
||||
/gophergate
|
||||
/data/
|
||||
*.db
|
||||
*.db-shm
|
||||
*.db-wal
|
||||
*.log
|
||||
server.pid
|
||||
|
||||
62
BACKEND_ARCHITECTURE.md
Normal file
62
BACKEND_ARCHITECTURE.md
Normal file
@@ -0,0 +1,62 @@
|
||||
# Backend Architecture (Go)
|
||||
|
||||
The GopherGate backend is implemented in Go, focusing on high performance, clear concurrency patterns, and maintainability.
|
||||
|
||||
## Core Technologies
|
||||
|
||||
- **Runtime:** Go 1.22+
|
||||
- **Web Framework:** [Gin Gonic](https://github.com/gin-gonic/gin) - Fast and lightweight HTTP routing.
|
||||
- **Database:** [sqlx](https://github.com/jmoiron/sqlx) - Lightweight wrapper for standard `database/sql`.
|
||||
- **SQLite Driver:** [modernc.org/sqlite](https://modernc.org/sqlite) - CGO-free SQLite implementation for ease of cross-compilation.
|
||||
- **Config:** [Viper](https://github.com/spf13/viper) - Robust configuration management supporting environment variables and files.
|
||||
- **Metrics:** [gopsutil](https://github.com/shirou/gopsutil) - System-level resource monitoring.
|
||||
|
||||
## Project Structure
|
||||
|
||||
```text
|
||||
├── cmd/
|
||||
│ └── gophergate/ # Entry point (main.go)
|
||||
├── internal/
|
||||
│ ├── config/ # Configuration loading and validation
|
||||
│ ├── db/ # Database schema, migrations, and models
|
||||
│ ├── middleware/ # Auth and logging middleware
|
||||
│ ├── models/ # Unified request/response structs
|
||||
│ ├── providers/ # LLM provider implementations (OpenAI, Gemini, etc.)
|
||||
│ ├── server/ # HTTP server, dashboard handlers, and WebSocket hub
|
||||
│ └── utils/ # Common utilities (registry, pricing, etc.)
|
||||
└── static/ # Frontend assets (served by the backend)
|
||||
```
|
||||
|
||||
## Key Components
|
||||
|
||||
### 1. Provider Interface (`internal/providers/provider.go`)
|
||||
Standardized interface for all LLM backends. Implementations handle mapping between the unified format and provider-specific APIs (OpenAI, Gemini, DeepSeek, Grok).
|
||||
|
||||
### 2. Model Registry & Pricing (`internal/utils/registry.go`)
|
||||
Integrates with `models.dev/api.json` to provide real-time model metadata and pricing.
|
||||
- **Fuzzy Matching:** Supports matching versioned model IDs (e.g., `gpt-4o-2024-08-06`) to base registry entries.
|
||||
- **Automatic Refreshes:** The registry is fetched at startup and refreshed every 24 hours via a background goroutine.
|
||||
|
||||
### 3. Asynchronous Logging (`internal/server/logging.go`)
|
||||
Uses a buffered channel and background worker to log every request to SQLite without blocking the client response. It also broadcasts logs to the WebSocket hub for real-time dashboard updates.
|
||||
|
||||
### 4. Session Management (`internal/server/sessions.go`)
|
||||
Implements HMAC-SHA256 signed tokens for dashboard authentication. Tokens secure the management interface while standard Bearer tokens are used for LLM API access.
|
||||
|
||||
### 5. WebSocket Hub (`internal/server/websocket.go`)
|
||||
A centralized hub for managing WebSocket connections, allowing real-time broadcast of system events, system metrics, and request logs to the dashboard.
|
||||
|
||||
## Concurrency Model
|
||||
|
||||
Go's goroutines and channels are used extensively:
|
||||
- **Streaming:** Each streaming request uses a goroutine to read and parse the provider's response, feeding chunks into a channel for SSE delivery.
|
||||
- **Logging:** A single background worker processes the `logChan` to perform serial database writes.
|
||||
- **WebSocket:** The `Hub` runs in a dedicated goroutine, handling registration and broadcasting.
|
||||
- **Maintenance:** Background tasks handle registry refreshes and status monitoring.
|
||||
|
||||
## Security
|
||||
|
||||
- **Encryption Key:** A mandatory 32-byte key is used for both session signing and encryption of sensitive data.
|
||||
- **Auth Middleware:** Scoped to `/v1` routes to verify client API keys against the database.
|
||||
- **Bcrypt:** Passwords for dashboard users are hashed using Bcrypt with a work factor of 12.
|
||||
- **Database Hardening:** Automatic migrations ensure the schema is always current with the code.
|
||||
3938
Cargo.lock
generated
3938
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
71
Cargo.toml
71
Cargo.toml
@@ -1,71 +0,0 @@
|
||||
[package]
|
||||
name = "llm-proxy"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
rust-version = "1.87"
|
||||
description = "Unified LLM proxy gateway supporting OpenAI, Gemini, DeepSeek, and Grok with token tracking and cost calculation"
|
||||
authors = ["newkirk"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
repository = ""
|
||||
|
||||
[dependencies]
|
||||
# ========== Web Framework & Async Runtime ==========
|
||||
axum = { version = "0.8", features = ["macros", "ws"] }
|
||||
tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "net", "time", "signal", "fs"] }
|
||||
tower = "0.5"
|
||||
tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip", "fs"] }
|
||||
|
||||
# ========== HTTP Clients ==========
|
||||
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
|
||||
tiktoken-rs = "0.9"
|
||||
|
||||
# ========== Database & ORM ==========
|
||||
sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite", "macros", "migrate", "chrono"] }
|
||||
|
||||
# ========== Authentication & Middleware ==========
|
||||
axum-extra = { version = "0.12", features = ["typed-header"] }
|
||||
headers = "0.4"
|
||||
|
||||
# ========== Configuration Management ==========
|
||||
config = "0.13"
|
||||
dotenvy = "0.15"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
toml = "0.8"
|
||||
|
||||
# ========== Logging & Monitoring ==========
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
||||
|
||||
# ========== Multimodal & Image Processing ==========
|
||||
base64 = "0.21"
|
||||
image = { version = "0.25", default-features = false, features = ["jpeg", "png", "webp"] }
|
||||
mime = "0.3"
|
||||
|
||||
# ========== Error Handling & Utilities ==========
|
||||
anyhow = "1.0"
|
||||
thiserror = "1.0"
|
||||
bcrypt = "0.15"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
uuid = { version = "1.0", features = ["v4", "serde"] }
|
||||
futures = "0.3"
|
||||
async-trait = "0.1"
|
||||
async-stream = "0.3"
|
||||
reqwest-eventsource = "0.6"
|
||||
rand = "0.9"
|
||||
hex = "0.4"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = "0.4"
|
||||
mockito = "1.0"
|
||||
tempfile = "3.10"
|
||||
assert_cmd = "2.0"
|
||||
insta = "1.39"
|
||||
anyhow = "1.0"
|
||||
|
||||
[profile.release]
|
||||
opt-level = 3
|
||||
lto = true
|
||||
codegen-units = 1
|
||||
strip = true
|
||||
panic = "abort"
|
||||
@@ -1,220 +0,0 @@
|
||||
# LLM Proxy Gateway - Admin Dashboard
|
||||
|
||||
## Overview
|
||||
|
||||
This is a comprehensive admin dashboard for the LLM Proxy Gateway, providing real-time monitoring, analytics, and management capabilities for the proxy service.
|
||||
|
||||
## Features
|
||||
|
||||
### 1. Dashboard Overview
|
||||
- Real-time request counters and statistics
|
||||
- System health indicators
|
||||
- Provider status monitoring
|
||||
- Recent requests stream
|
||||
|
||||
### 2. Usage Analytics
|
||||
- Time series charts for requests, tokens, and costs
|
||||
- Filter by date range, client, provider, and model
|
||||
- Top clients and models analysis
|
||||
- Export functionality to CSV/JSON
|
||||
|
||||
### 3. Cost Management
|
||||
- Cost breakdown by provider, client, and model
|
||||
- Budget tracking with alerts
|
||||
- Cost projections
|
||||
- Pricing configuration management
|
||||
|
||||
### 4. Client Management
|
||||
- List, create, revoke, and rotate API tokens
|
||||
- Client-specific rate limits
|
||||
- Usage statistics per client
|
||||
- Token management interface
|
||||
|
||||
### 5. Provider Configuration
|
||||
- Enable/disable LLM providers
|
||||
- Configure API keys (masked display)
|
||||
- Test provider connections
|
||||
- Model availability management
|
||||
|
||||
### 6. User Management (RBAC)
|
||||
- **Admin Role:** Full access to all dashboard features, user management, system configuration
|
||||
- **Viewer Role:** Read-only access to usage analytics, costs, and monitoring
|
||||
- Create/manage dashboard users with role assignment
|
||||
- Secure password management
|
||||
|
||||
### 7. Real-time Monitoring
|
||||
- Live request stream via WebSocket
|
||||
- System metrics dashboard
|
||||
- Response time and error rate tracking
|
||||
- Live system logs
|
||||
|
||||
### 7. **System Settings**
|
||||
- General configuration
|
||||
- Database management
|
||||
- Logging settings
|
||||
- Security settings
|
||||
|
||||
## Technology Stack
|
||||
|
||||
### Frontend
|
||||
- **HTML5/CSS3**: Modern, responsive design with CSS Grid/Flexbox
|
||||
- **JavaScript (ES6+)**: Vanilla JavaScript with modular architecture
|
||||
- **Chart.js**: Interactive data visualizations
|
||||
- **Luxon**: Date/time manipulation
|
||||
- **WebSocket API**: Real-time updates
|
||||
|
||||
### Backend (Rust/Axum)
|
||||
- **Axum**: Web framework with WebSocket support
|
||||
- **Tokio**: Async runtime
|
||||
- **Serde**: JSON serialization/deserialization
|
||||
- **Broadcast channels**: Real-time event distribution
|
||||
|
||||
## Installation & Setup
|
||||
|
||||
### 1. Build and Run the Server
|
||||
```bash
|
||||
# Build the project
|
||||
cargo build --release
|
||||
|
||||
# Run the server
|
||||
cargo run --release
|
||||
```
|
||||
|
||||
### 2. Access the Dashboard
|
||||
Once the server is running, access the dashboard at:
|
||||
```
|
||||
http://localhost:8080
|
||||
```
|
||||
|
||||
### 3. Default Login Credentials
|
||||
- **Username**: `admin`
|
||||
- **Password**: `admin123`
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Authentication
|
||||
- `POST /api/auth/login` - Dashboard login
|
||||
- `GET /api/auth/status` - Authentication status
|
||||
|
||||
### Analytics
|
||||
- `GET /api/usage/summary` - Overall usage summary
|
||||
- `GET /api/usage/time-series` - Time series data
|
||||
- `GET /api/usage/clients` - Client breakdown
|
||||
- `GET /api/usage/providers` - Provider breakdown
|
||||
|
||||
### Clients
|
||||
- `GET /api/clients` - List all clients
|
||||
- `POST /api/clients` - Create new client
|
||||
- `PUT /api/clients/{id}` - Update client
|
||||
- `DELETE /api/clients/{id}` - Revoke client
|
||||
- `GET /api/clients/{id}/usage` - Client-specific usage
|
||||
|
||||
### Users (RBAC)
|
||||
- `GET /api/users` - List all dashboard users
|
||||
- `POST /api/users` - Create new user
|
||||
- `PUT /api/users/{id}` - Update user (admin only)
|
||||
- `DELETE /api/users/{id}` - Delete user (admin only)
|
||||
|
||||
### Providers
|
||||
- `GET /api/providers` - List providers and status
|
||||
- `PUT /api/providers/{name}` - Update provider config
|
||||
- `POST /api/providers/{name}/test` - Test provider connection
|
||||
|
||||
### System
|
||||
- `GET /api/system/health` - System health
|
||||
- `GET /api/system/logs` - Recent logs
|
||||
- `POST /api/system/backup` - Trigger backup
|
||||
|
||||
### WebSocket
|
||||
- `GET /ws` - WebSocket endpoint for real-time updates
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
llm-proxy/
|
||||
├── src/
|
||||
│ ├── dashboard/ # Dashboard backend module
|
||||
│ │ └── mod.rs # Dashboard routes and handlers
|
||||
│ ├── server/ # Main proxy server
|
||||
│ ├── providers/ # LLM provider implementations
|
||||
│ └── ... # Other modules
|
||||
├── static/ # Frontend dashboard files
|
||||
│ ├── index.html # Main dashboard HTML
|
||||
│ ├── css/
|
||||
│ │ └── dashboard.css # Dashboard styles
|
||||
│ ├── js/
|
||||
│ │ ├── auth.js # Authentication module
|
||||
│ │ ├── dashboard.js # Main dashboard controller
|
||||
│ │ ├── websocket.js # WebSocket manager
|
||||
│ │ ├── charts.js # Chart.js utilities
|
||||
│ │ └── pages/ # Page-specific modules
|
||||
│ │ ├── overview.js
|
||||
│ │ ├── analytics.js
|
||||
│ │ ├── costs.js
|
||||
│ │ ├── clients.js
|
||||
│ │ ├── providers.js
|
||||
│ │ ├── monitoring.js
|
||||
│ │ ├── settings.js
|
||||
│ │ └── logs.js
|
||||
│ ├── img/ # Images and icons
|
||||
│ └── fonts/ # Font files
|
||||
└── Cargo.toml # Rust dependencies
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
### Adding New Pages
|
||||
1. Create a new JavaScript module in `static/js/pages/`
|
||||
2. Implement the page class with `init()` method
|
||||
3. Register the page in `dashboard.js`
|
||||
4. Add menu item in `index.html`
|
||||
|
||||
### Adding New API Endpoints
|
||||
1. Add route in `src/dashboard/mod.rs`
|
||||
2. Implement handler function
|
||||
3. Update frontend JavaScript to call the endpoint
|
||||
|
||||
### Styling Guidelines
|
||||
- Use CSS custom properties (variables) from `:root`
|
||||
- Follow mobile-first responsive design
|
||||
- Use BEM-like naming convention for CSS classes
|
||||
- Maintain consistent spacing with CSS variables
|
||||
|
||||
## Security Considerations
|
||||
|
||||
1. **Authentication**: Simple password-based auth for demo; replace with proper auth in production
|
||||
2. **API Keys**: Tokens are masked in the UI (only last 4 characters shown)
|
||||
3. **CORS**: Configure appropriate CORS headers for production
|
||||
4. **Rate Limiting**: Implement rate limiting for API endpoints
|
||||
5. **HTTPS**: Always use HTTPS in production
|
||||
|
||||
## Performance Optimizations
|
||||
|
||||
1. **Code Splitting**: JavaScript modules are loaded on-demand
|
||||
2. **Caching**: Static assets are served with cache headers
|
||||
3. **WebSocket**: Real-time updates reduce polling overhead
|
||||
4. **Lazy Loading**: Charts and tables load data as needed
|
||||
5. **Compression**: Enable gzip/brotli compression for static files
|
||||
|
||||
## Browser Support
|
||||
|
||||
- Chrome 60+
|
||||
- Firefox 55+
|
||||
- Safari 11+
|
||||
- Edge 79+
|
||||
|
||||
## License
|
||||
|
||||
MIT License - See LICENSE file for details.
|
||||
|
||||
## Contributing
|
||||
|
||||
1. Fork the repository
|
||||
2. Create a feature branch
|
||||
3. Make your changes
|
||||
4. Add tests if applicable
|
||||
5. Submit a pull request
|
||||
|
||||
## Support
|
||||
|
||||
For issues and feature requests, please use the GitHub issue tracker.
|
||||
43
Dockerfile
43
Dockerfile
@@ -1,35 +1,34 @@
|
||||
# ── Build stage ──────────────────────────────────────────────
|
||||
FROM rust:1-bookworm AS builder
|
||||
# Build stage
|
||||
FROM golang:1.22-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Cache dependency build
|
||||
COPY Cargo.toml Cargo.lock ./
|
||||
RUN mkdir src && echo 'fn main() {}' > src/main.rs && \
|
||||
cargo build --release && \
|
||||
rm -rf src
|
||||
# Copy go mod and sum files
|
||||
COPY go.mod go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
# Build the actual binary
|
||||
COPY src/ src/
|
||||
RUN touch src/main.rs && cargo build --release
|
||||
# Copy the source code
|
||||
COPY . .
|
||||
|
||||
# ── Runtime stage ────────────────────────────────────────────
|
||||
FROM debian:bookworm-slim
|
||||
# Build the application
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o gophergate ./cmd/gophergate
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends ca-certificates && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
# Final stage
|
||||
FROM alpine:latest
|
||||
|
||||
RUN apk --no-cache add ca-certificates tzdata
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=builder /app/target/release/llm-proxy /app/llm-proxy
|
||||
COPY static/ /app/static/
|
||||
# Copy the binary from the builder stage
|
||||
COPY --from=builder /app/gophergate .
|
||||
COPY --from=builder /app/static ./static
|
||||
|
||||
# Default config location
|
||||
VOLUME ["/app/config", "/app/data"]
|
||||
# Create data directory
|
||||
RUN mkdir -p /app/data
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8080
|
||||
|
||||
ENV RUST_LOG=info
|
||||
|
||||
ENTRYPOINT ["/app/llm-proxy"]
|
||||
# Run the application
|
||||
CMD ["./gophergate"]
|
||||
|
||||
232
OPTIMIZATION.md
232
OPTIMIZATION.md
@@ -1,232 +0,0 @@
|
||||
# Optimization for 512MB RAM Environment
|
||||
|
||||
This document provides guidance for optimizing the LLM Proxy Gateway for deployment in resource-constrained environments (512MB RAM).
|
||||
|
||||
## Memory Optimization Strategies
|
||||
|
||||
### 1. Build Optimization
|
||||
|
||||
The project is already configured with optimized build settings in `Cargo.toml`:
|
||||
|
||||
```toml
|
||||
[profile.release]
|
||||
opt-level = 3 # Maximum optimization
|
||||
lto = true # Link-time optimization
|
||||
codegen-units = 1 # Single codegen unit for better optimization
|
||||
strip = true # Strip debug symbols
|
||||
```
|
||||
|
||||
**Additional optimizations you can apply:**
|
||||
|
||||
```bash
|
||||
# Build with specific target for better optimization
|
||||
cargo build --release --target x86_64-unknown-linux-musl
|
||||
|
||||
# Or for ARM (Raspberry Pi, etc.)
|
||||
cargo build --release --target aarch64-unknown-linux-musl
|
||||
```
|
||||
|
||||
### 2. Runtime Memory Management
|
||||
|
||||
#### Database Connection Pool
|
||||
- Default: 10 connections
|
||||
- Recommended for 512MB: 5 connections
|
||||
|
||||
Update `config.toml`:
|
||||
```toml
|
||||
[database]
|
||||
max_connections = 5
|
||||
```
|
||||
|
||||
#### Rate Limiting Memory Usage
|
||||
- Client rate limit buckets: Store in memory
|
||||
- Circuit breakers: Minimal memory usage
|
||||
- Consider reducing burst capacity if memory is critical
|
||||
|
||||
#### Provider Management
|
||||
- Only enable providers you actually use
|
||||
- Disable unused providers in configuration
|
||||
|
||||
### 3. Configuration for Low Memory
|
||||
|
||||
Create a `config-low-memory.toml`:
|
||||
|
||||
```toml
|
||||
[server]
|
||||
port = 8080
|
||||
host = "0.0.0.0"
|
||||
|
||||
[database]
|
||||
path = "./data/llm_proxy.db"
|
||||
max_connections = 3 # Reduced from default 10
|
||||
|
||||
[providers]
|
||||
# Only enable providers you need
|
||||
openai.enabled = true
|
||||
gemini.enabled = false # Disable if not used
|
||||
deepseek.enabled = false # Disable if not used
|
||||
grok.enabled = false # Disable if not used
|
||||
|
||||
[rate_limiting]
|
||||
# Reduce memory usage for rate limiting
|
||||
client_requests_per_minute = 30 # Reduced from 60
|
||||
client_burst_size = 5 # Reduced from 10
|
||||
global_requests_per_minute = 300 # Reduced from 600
|
||||
```
|
||||
|
||||
### 4. System-Level Optimizations
|
||||
|
||||
#### Linux Kernel Parameters
|
||||
Add to `/etc/sysctl.conf`:
|
||||
```bash
|
||||
# Reduce TCP buffer sizes
|
||||
net.ipv4.tcp_rmem = 4096 87380 174760
|
||||
net.ipv4.tcp_wmem = 4096 65536 131072
|
||||
|
||||
# Reduce connection tracking
|
||||
net.netfilter.nf_conntrack_max = 65536
|
||||
net.netfilter.nf_conntrack_tcp_timeout_established = 1200
|
||||
|
||||
# Reduce socket buffer sizes
|
||||
net.core.rmem_max = 131072
|
||||
net.core.wmem_max = 131072
|
||||
net.core.rmem_default = 65536
|
||||
net.core.wmem_default = 65536
|
||||
```
|
||||
|
||||
#### Systemd Service Configuration
|
||||
Create `/etc/systemd/system/llm-proxy.service`:
|
||||
```ini
|
||||
[Unit]
|
||||
Description=LLM Proxy Gateway
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=llmproxy
|
||||
Group=llmproxy
|
||||
WorkingDirectory=/opt/llm-proxy
|
||||
ExecStart=/opt/llm-proxy/llm-proxy
|
||||
Restart=on-failure
|
||||
RestartSec=5
|
||||
|
||||
# Memory limits
|
||||
MemoryMax=400M
|
||||
MemorySwapMax=100M
|
||||
|
||||
# CPU limits
|
||||
CPUQuota=50%
|
||||
|
||||
# Process limits
|
||||
LimitNOFILE=65536
|
||||
LimitNPROC=512
|
||||
|
||||
Environment="RUST_LOG=info"
|
||||
Environment="LLM_PROXY__DATABASE__MAX_CONNECTIONS=3"
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
|
||||
### 5. Application-Specific Optimizations
|
||||
|
||||
#### Disable Unused Features
|
||||
- **Multimodal support**: If not using images, disable image processing dependencies
|
||||
- **Dashboard**: The dashboard uses WebSockets and additional memory. Consider disabling if not needed.
|
||||
- **Detailed logging**: Reduce log verbosity in production
|
||||
|
||||
#### Memory Pool Sizes
|
||||
The application uses several memory pools:
|
||||
1. **Database connection pool**: Configured via `max_connections`
|
||||
2. **HTTP client pool**: Reqwest client pool (defaults to reasonable values)
|
||||
3. **Async runtime**: Tokio worker threads
|
||||
|
||||
Reduce Tokio worker threads for low-core systems:
|
||||
```rust
|
||||
// In main.rs, modify tokio runtime creation
|
||||
#[tokio::main(flavor = "current_thread")] // Single-threaded runtime
|
||||
async fn main() -> Result<()> {
|
||||
// Or for multi-threaded with limited threads:
|
||||
// #[tokio::main(worker_threads = 2)]
|
||||
```
|
||||
|
||||
### 6. Monitoring and Profiling
|
||||
|
||||
#### Memory Usage Monitoring
|
||||
```bash
|
||||
# Install heaptrack for memory profiling
|
||||
cargo install heaptrack
|
||||
|
||||
# Profile memory usage
|
||||
heaptrack ./target/release/llm-proxy
|
||||
|
||||
# Monitor with ps
|
||||
ps aux --sort=-%mem | head -10
|
||||
|
||||
# Monitor with top
|
||||
top -p $(pgrep llm-proxy)
|
||||
```
|
||||
|
||||
#### Performance Benchmarks
|
||||
Test with different configurations:
|
||||
```bash
|
||||
# Test with 100 concurrent connections
|
||||
wrk -t4 -c100 -d30s http://localhost:8080/health
|
||||
|
||||
# Test chat completion endpoint
|
||||
ab -n 1000 -c 10 -p test_request.json -T application/json http://localhost:8080/v1/chat/completions
|
||||
```
|
||||
|
||||
### 7. Deployment Checklist for 512MB RAM
|
||||
|
||||
- [ ] Build with release profile: `cargo build --release`
|
||||
- [ ] Configure database with `max_connections = 3`
|
||||
- [ ] Disable unused providers in configuration
|
||||
- [ ] Set appropriate rate limiting limits
|
||||
- [ ] Configure systemd with memory limits
|
||||
- [ ] Set up log rotation to prevent disk space issues
|
||||
- [ ] Monitor memory usage during initial deployment
|
||||
- [ ] Consider using swap space (512MB-1GB) for safety
|
||||
|
||||
### 8. Troubleshooting High Memory Usage
|
||||
|
||||
#### Common Issues and Solutions:
|
||||
|
||||
1. **Database connection leaks**: Ensure connections are properly closed
|
||||
2. **Memory fragmentation**: Use jemalloc or mimalloc as allocator
|
||||
3. **Unbounded queues**: Check WebSocket message queues
|
||||
4. **Cache growth**: Implement cache limits or TTL
|
||||
|
||||
#### Add to Cargo.toml for alternative allocator:
|
||||
```toml
|
||||
[dependencies]
|
||||
mimalloc = { version = "0.1", default-features = false }
|
||||
|
||||
[features]
|
||||
default = ["mimalloc"]
|
||||
```
|
||||
|
||||
#### In main.rs:
|
||||
```rust
|
||||
#[global_allocator]
|
||||
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||
```
|
||||
|
||||
### 9. Expected Memory Usage
|
||||
|
||||
| Component | Baseline | With 10 clients | With 100 clients |
|
||||
|-----------|----------|-----------------|------------------|
|
||||
| Base executable | 15MB | 15MB | 15MB |
|
||||
| Database connections | 5MB | 8MB | 15MB |
|
||||
| Rate limiting | 2MB | 5MB | 20MB |
|
||||
| HTTP clients | 3MB | 5MB | 10MB |
|
||||
| **Total** | **25MB** | **33MB** | **60MB** |
|
||||
|
||||
**Note**: These are estimates. Actual usage depends on request volume, payload sizes, and configuration.
|
||||
|
||||
### 10. Further Reading
|
||||
|
||||
- [Tokio performance guide](https://tokio.rs/tokio/topics/performance)
|
||||
- [Rust performance book](https://nnethercote.github.io/perf-book/)
|
||||
- [Linux memory management](https://www.kernel.org/doc/html/latest/admin-guide/mm/)
|
||||
- [SQLite performance tips](https://www.sqlite.org/faq.html#q19)
|
||||
170
README.md
170
README.md
@@ -1,127 +1,125 @@
|
||||
# LLM Proxy Gateway
|
||||
# GopherGate
|
||||
|
||||
A unified, high-performance LLM proxy gateway built in Rust. It provides a single OpenAI-compatible API to access multiple providers (OpenAI, Gemini, DeepSeek, Grok, Ollama) with built-in token tracking, real-time cost calculation, multi-user authentication, and a management dashboard.
|
||||
A unified, high-performance LLM proxy gateway built in Go. It provides a single OpenAI-compatible API to access multiple providers (OpenAI, Gemini, DeepSeek, Moonshot, Grok, Ollama) with built-in token tracking, real-time cost calculation, multi-user authentication, and a management dashboard.
|
||||
|
||||
## Features
|
||||
|
||||
- **Unified API:** OpenAI-compatible `/v1/chat/completions` and `/v1/models` endpoints.
|
||||
- **Multi-Provider Support:**
|
||||
- **OpenAI:** GPT-4o, GPT-4o Mini, o1, o3 reasoning models.
|
||||
- **Google Gemini:** Gemini 2.0 Flash, Pro, and vision models.
|
||||
- **DeepSeek:** DeepSeek Chat and Reasoner models.
|
||||
- **xAI Grok:** Grok-beta models.
|
||||
- **Google Gemini:** Gemini 2.0 Flash, Pro, and vision models (with native CoT support).
|
||||
- **DeepSeek:** DeepSeek Chat and Reasoner (R1) models.
|
||||
- **Moonshot:** Kimi K2.5 and other Kimi models.
|
||||
- **xAI Grok:** Grok-4 models.
|
||||
- **Ollama:** Local LLMs running on your network.
|
||||
- **Observability & Tracking:**
|
||||
- **Real-time Costing:** Fetches live pricing and context specs from `models.dev` on startup.
|
||||
- **Token Counting:** Precise estimation using `tiktoken-rs`.
|
||||
- **Database Logging:** Every request logged to SQLite for historical analysis.
|
||||
- **Streaming Support:** Full SSE (Server-Sent Events) with `[DONE]` termination for client compatibility.
|
||||
- **Asynchronous Logging:** Non-blocking request logging to SQLite using background workers.
|
||||
- **Token Counting:** Precise estimation and tracking of prompt, completion, and reasoning tokens.
|
||||
- **Database Persistence:** Every request logged to SQLite for historical analysis and dashboard analytics.
|
||||
- **Streaming Support:** Full SSE (Server-Sent Events) support for all providers.
|
||||
- **Multimodal (Vision):** Image processing (Base64 and remote URLs) across compatible providers.
|
||||
- **Multi-User Access Control:**
|
||||
- **Admin Role:** Full access to all dashboard features, user management, and system configuration.
|
||||
- **Viewer Role:** Read-only access to usage analytics, costs, and monitoring.
|
||||
- **Client API Keys:** Create and manage multiple client tokens for external integrations.
|
||||
- **Reliability:**
|
||||
- **Circuit Breaking:** Automatically protects when providers are down.
|
||||
- **Rate Limiting:** Per-client and global rate limits.
|
||||
- **Cache-Aware Costing:** Tracks cache hit/miss tokens for accurate billing.
|
||||
- **Circuit Breaking:** Automatically protects when providers are down (coming soon).
|
||||
- **Rate Limiting:** Per-client and global rate limits (coming soon).
|
||||
|
||||
## Security
|
||||
|
||||
GopherGate is designed with security in mind:
|
||||
|
||||
- **Signed Session Tokens:** Management dashboard sessions are secured using HMAC-SHA256 signed tokens.
|
||||
- **Encrypted Storage:** Support for encrypted provider API keys in the database.
|
||||
- **Auth Middleware:** Secure client authentication via database-backed API keys.
|
||||
|
||||
**Note:** You must define an `LLM_PROXY__ENCRYPTION_KEY` in your `.env` file for secure session signing and encryption.
|
||||
|
||||
## Tech Stack
|
||||
|
||||
- **Runtime:** Rust with Tokio.
|
||||
- **Web Framework:** Axum.
|
||||
- **Database:** SQLx with SQLite.
|
||||
- **Frontend:** Vanilla JS/CSS with Chart.js for visualizations.
|
||||
- **Runtime:** Go 1.22+
|
||||
- **Web Framework:** Gin Gonic
|
||||
- **Database:** sqlx with SQLite (CGO-free via `modernc.org/sqlite`)
|
||||
- **Frontend:** Vanilla JS/CSS with Chart.js for visualizations
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Rust (1.80+)
|
||||
- SQLite3
|
||||
- Go (1.22+)
|
||||
- SQLite3 (optional, driver is built-in)
|
||||
- Docker (optional, for containerized deployment)
|
||||
|
||||
### Quick Start
|
||||
|
||||
1. Clone and build:
|
||||
```bash
|
||||
git clone ssh://git.dustin.coffee:2222/hobokenchicken/llm-proxy.git
|
||||
cd llm-proxy
|
||||
cargo build --release
|
||||
git clone <repository-url>
|
||||
cd gophergate
|
||||
go build -o gophergate ./cmd/gophergate
|
||||
```
|
||||
|
||||
2. Configure environment:
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Edit .env and add your API keys:
|
||||
# Edit .env and add your configuration:
|
||||
# LLM_PROXY__ENCRYPTION_KEY=... (32-byte hex or base64 string)
|
||||
# OPENAI_API_KEY=sk-...
|
||||
# GEMINI_API_KEY=AIza...
|
||||
# DEEPSEEK_API_KEY=sk-...
|
||||
# GROK_API_KEY=gk-... (optional)
|
||||
# MOONSHOT_API_KEY=...
|
||||
```
|
||||
|
||||
3. Run the proxy:
|
||||
```bash
|
||||
cargo run --release
|
||||
./gophergate
|
||||
```
|
||||
|
||||
The server starts on `http://localhost:8080` by default.
|
||||
The server starts on `http://0.0.0.0:8080` by default.
|
||||
|
||||
### Configuration
|
||||
### Deployment (Docker)
|
||||
|
||||
Edit `config.toml` to customize providers, models, and settings:
|
||||
```bash
|
||||
# Build the container
|
||||
docker build -t gophergate .
|
||||
|
||||
```toml
|
||||
[server]
|
||||
port = 8080
|
||||
host = "0.0.0.0"
|
||||
|
||||
[database]
|
||||
path = "./data/llm_proxy.db"
|
||||
|
||||
[providers.openai]
|
||||
enabled = true
|
||||
default_model = "gpt-4o"
|
||||
|
||||
[providers.gemini]
|
||||
enabled = true
|
||||
default_model = "gemini-2.0-flash"
|
||||
|
||||
[providers.deepseek]
|
||||
enabled = true
|
||||
default_model = "deepseek-reasoner"
|
||||
|
||||
[providers.grok]
|
||||
enabled = false
|
||||
default_model = "grok-beta"
|
||||
|
||||
[providers.ollama]
|
||||
enabled = false
|
||||
base_url = "http://localhost:11434/v1"
|
||||
# Run the container
|
||||
docker run -p 8080:8080 \
|
||||
-e LLM_PROXY__ENCRYPTION_KEY=your-secure-key \
|
||||
-v ./data:/app/data \
|
||||
gophergate
|
||||
```
|
||||
|
||||
## Management Dashboard
|
||||
|
||||
Access the dashboard at `http://localhost:8080`:
|
||||
Access the dashboard at `http://localhost:8080`.
|
||||
|
||||
- **Overview:** Real-time request counters, system health, provider status.
|
||||
- **Analytics:** Time-series charts, filterable by date, client, provider, and model.
|
||||
- **Costs:** Budget tracking, cost breakdown by provider/client/model, projections.
|
||||
- **Clients:** Create, revoke, and rotate API tokens; per-client usage stats.
|
||||
- **Providers:** Enable/disable providers, test connections, configure API keys.
|
||||
- **Monitoring:** Live request stream via WebSocket, response times, error rates.
|
||||
- **Users:** Admin/user management with role-based access control.
|
||||
- **Auth:** Login, session management, and status tracking.
|
||||
- **Usage:** Summary stats, time-series analytics, and provider breakdown.
|
||||
- **Clients:** API key management and per-client usage tracking.
|
||||
- **Providers:** Provider configuration and status monitoring.
|
||||
- **Users:** Admin-only user management for dashboard access.
|
||||
- **Monitoring:** Live request stream via WebSocket.
|
||||
|
||||
### Default Credentials
|
||||
|
||||
- **Username:** `admin`
|
||||
- **Password:** `admin123`
|
||||
- **Password:** `admin123` (You will be prompted to change this on first login)
|
||||
|
||||
Change the admin password in the dashboard after first login!
|
||||
**Forgot Password?**
|
||||
You can reset the admin password to default by running:
|
||||
```bash
|
||||
./gophergate -reset-admin
|
||||
```
|
||||
|
||||
## API Usage
|
||||
|
||||
The proxy is a drop-in replacement for OpenAI. Configure your client:
|
||||
|
||||
Moonshot models are available through the same OpenAI-compatible endpoint. For
|
||||
example, use `kimi-k2.5` as the model name after setting `MOONSHOT_API_KEY` in
|
||||
your environment.
|
||||
|
||||
### Python
|
||||
```python
|
||||
from openai import OpenAI
|
||||
@@ -137,46 +135,6 @@ response = client.chat.completions.create(
|
||||
)
|
||||
```
|
||||
|
||||
### Open WebUI
|
||||
```
|
||||
API Base URL: http://your-server:8080/v1
|
||||
API Key: YOUR_CLIENT_API_KEY
|
||||
```
|
||||
|
||||
### cURL
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer YOUR_CLIENT_API_KEY" \
|
||||
-d '{
|
||||
"model": "gpt-4o",
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
## Model Discovery
|
||||
|
||||
The proxy exposes `/v1/models` for OpenAI-compatible client model discovery:
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/v1/models \
|
||||
-H "Authorization: Bearer YOUR_CLIENT_API_KEY"
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Streaming Issues
|
||||
If clients timeout or show "TransferEncodingError", ensure:
|
||||
1. Proxy buffering is disabled in nginx: `proxy_buffering off;`
|
||||
2. Chunked transfer is enabled: `chunked_transfer_encoding on;`
|
||||
3. Timeouts are sufficient: `proxy_read_timeout 7200s;`
|
||||
|
||||
### Provider Errors
|
||||
- Check API keys are set in `.env`
|
||||
- Test provider in dashboard (Settings → Providers → Test)
|
||||
- Review logs: `journalctl -u llm-proxy -f`
|
||||
|
||||
## License
|
||||
|
||||
MIT OR Apache-2.0
|
||||
MIT
|
||||
|
||||
56
TODO.md
Normal file
56
TODO.md
Normal file
@@ -0,0 +1,56 @@
|
||||
# Migration TODO List
|
||||
|
||||
## Completed Tasks
|
||||
- [x] Initial Go project setup
|
||||
- [x] Database schema & migrations (hardcoded in `db.go`)
|
||||
- [x] Configuration loader (Viper)
|
||||
- [x] Auth Middleware (scoped to `/v1`)
|
||||
- [x] Basic Provider implementations (OpenAI, Gemini, DeepSeek, Grok)
|
||||
- [x] Streaming Support (SSE & Gemini custom streaming)
|
||||
- [x] Archive Rust files to `rust` branch
|
||||
- [x] Clean root and set Go version as `main`
|
||||
- [x] Enhanced `helpers.go` for Multimodal & Tool Calling (OpenAI compatible)
|
||||
- [x] Enhanced `server.go` for robust request conversion
|
||||
- [x] Dashboard Management APIs (Clients, Tokens, Users, Providers)
|
||||
- [x] Dashboard Analytics & Usage Summary (Fixed SQL robustness)
|
||||
- [x] WebSocket for real-time dashboard updates (Hub with client counting)
|
||||
- [x] Asynchronous Request Logging to SQLite
|
||||
- [x] Update documentation (README, deployment, architecture)
|
||||
- [x] Cost Tracking accuracy (Registry integration with `models.dev`)
|
||||
- [x] Model Listing endpoint (`/v1/models`) with provider filtering
|
||||
- [x] System Metrics endpoint (`/api/system/metrics` using `gopsutil`)
|
||||
- [x] Fixed dashboard 404s and 500s
|
||||
|
||||
## Feature Parity Checklist (High Priority)
|
||||
|
||||
### OpenAI Provider
|
||||
- [x] Tool Calling
|
||||
- [x] Multimodal (Images) support
|
||||
- [x] Accurate usage parsing (cached & reasoning tokens)
|
||||
- [ ] Reasoning Content (CoT) support for `o1`, `o3` (need to ensure it's parsed in responses)
|
||||
- [ ] Support for `/v1/responses` API (required for some gpt-5/o1 models)
|
||||
|
||||
### Gemini Provider
|
||||
- [x] Tool Calling (mapping to Gemini format)
|
||||
- [x] Multimodal (Images) support
|
||||
- [x] Reasoning/Thought support
|
||||
- [x] Handle Tool Response role in unified format
|
||||
|
||||
### DeepSeek Provider
|
||||
- [x] Reasoning Content (CoT) support
|
||||
- [x] Parameter sanitization for `deepseek-reasoner`
|
||||
- [x] Tool Calling support
|
||||
- [x] Accurate usage parsing (cache hits & reasoning)
|
||||
|
||||
### Grok Provider
|
||||
- [x] Tool Calling support
|
||||
- [x] Multimodal support
|
||||
- [x] Accurate usage parsing (via OpenAI helper)
|
||||
|
||||
## Infrastructure & Middleware
|
||||
- [ ] Implement Rate Limiting (`golang.org/x/time/rate`)
|
||||
- [ ] Implement Circuit Breaker (`github.com/sony/gobreaker`)
|
||||
|
||||
## Verification
|
||||
- [ ] Unit tests for feature-specific mapping (CoT, Tools, Images)
|
||||
- [ ] Integration tests with live LLM APIs
|
||||
@@ -1 +0,0 @@
|
||||
too-many-arguments-threshold = 8
|
||||
55
cmd/gophergate/main.go
Normal file
55
cmd/gophergate/main.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/db"
|
||||
"gophergate/internal/server"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func main() {
|
||||
resetAdmin := flag.Bool("reset-admin", false, "Reset admin password to admin123")
|
||||
flag.Parse()
|
||||
|
||||
// Load environment variables
|
||||
if err := godotenv.Load(); err != nil {
|
||||
log.Println("No .env file found")
|
||||
}
|
||||
|
||||
// Load configuration
|
||||
cfg, err := config.Load()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load configuration: %v", err)
|
||||
}
|
||||
|
||||
// Initialize database
|
||||
database, err := db.Init(cfg.Database.Path)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize database: %v", err)
|
||||
}
|
||||
|
||||
if *resetAdmin {
|
||||
hash, _ := bcrypt.GenerateFromPassword([]byte("admin123"), 12)
|
||||
_, err = database.Exec("UPDATE users SET password_hash = ?, must_change_password = 1 WHERE username = 'admin'", string(hash))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to reset admin password: %v", err)
|
||||
}
|
||||
log.Println("Admin password has been reset to 'admin123'")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// Initialize server
|
||||
s := server.NewServer(cfg, database)
|
||||
|
||||
// Run server
|
||||
log.Printf("Starting GopherGate on %s:%d", cfg.Server.Host, cfg.Server.Port)
|
||||
if err := s.Run(); err != nil {
|
||||
log.Fatalf("Server failed: %v", err)
|
||||
}
|
||||
}
|
||||
Binary file not shown.
BIN
data/backups/llm_proxy.db.20260303T205057Z
Normal file
BIN
data/backups/llm_proxy.db.20260303T205057Z
Normal file
Binary file not shown.
667
deploy.sh
667
deploy.sh
@@ -1,667 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# LLM Proxy Gateway Deployment Script
|
||||
# This script automates the deployment of the LLM Proxy Gateway on a Linux server
|
||||
|
||||
set -e # Exit on error
|
||||
set -u # Exit on undefined variable
|
||||
|
||||
# Configuration
|
||||
APP_NAME="llm-proxy"
|
||||
APP_USER="llmproxy"
|
||||
APP_GROUP="llmproxy"
|
||||
GIT_REPO="ssh://git.dustin.coffee:2222/hobokenchicken/llm-proxy.git"
|
||||
INSTALL_DIR="/opt/$APP_NAME"
|
||||
CONFIG_DIR="/etc/$APP_NAME"
|
||||
DATA_DIR="/var/lib/$APP_NAME"
|
||||
LOG_DIR="/var/log/$APP_NAME"
|
||||
SERVICE_FILE="/etc/systemd/system/$APP_NAME.service"
|
||||
ENV_FILE="$CONFIG_DIR/.env"
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Logging functions
|
||||
log_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
log_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Check if running as root
|
||||
check_root() {
|
||||
if [[ $EUID -ne 0 ]]; then
|
||||
log_error "This script must be run as root"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Install system dependencies
|
||||
install_dependencies() {
|
||||
log_info "Installing system dependencies..."
|
||||
|
||||
# Detect package manager
|
||||
if command -v apt-get &> /dev/null; then
|
||||
# Debian/Ubuntu
|
||||
apt-get update
|
||||
apt-get install -y \
|
||||
build-essential \
|
||||
pkg-config \
|
||||
libssl-dev \
|
||||
sqlite3 \
|
||||
curl \
|
||||
git
|
||||
elif command -v yum &> /dev/null; then
|
||||
# RHEL/CentOS
|
||||
yum groupinstall -y "Development Tools"
|
||||
yum install -y \
|
||||
openssl-devel \
|
||||
sqlite \
|
||||
curl \
|
||||
git
|
||||
elif command -v dnf &> /dev/null; then
|
||||
# Fedora
|
||||
dnf groupinstall -y "Development Tools"
|
||||
dnf install -y \
|
||||
openssl-devel \
|
||||
sqlite \
|
||||
curl \
|
||||
git
|
||||
elif command -v pacman &> /dev/null; then
|
||||
# Arch Linux
|
||||
pacman -Syu --noconfirm \
|
||||
base-devel \
|
||||
openssl \
|
||||
sqlite \
|
||||
curl \
|
||||
git
|
||||
else
|
||||
log_warn "Could not detect package manager. Please install dependencies manually."
|
||||
fi
|
||||
}
|
||||
|
||||
# Install Rust if not present
|
||||
install_rust() {
|
||||
log_info "Checking for Rust installation..."
|
||||
|
||||
if ! command -v rustc &> /dev/null; then
|
||||
log_info "Installing Rust..."
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
else
|
||||
log_info "Rust is already installed"
|
||||
fi
|
||||
|
||||
# Verify installation
|
||||
rustc --version
|
||||
cargo --version
|
||||
}
|
||||
|
||||
# Create system user and directories
|
||||
setup_directories() {
|
||||
log_info "Creating system user and directories..."
|
||||
|
||||
# Create user and group if they don't exist
|
||||
if ! id "$APP_USER" &>/dev/null; then
|
||||
# Arch uses /usr/bin/nologin, Debian/Ubuntu use /usr/sbin/nologin
|
||||
NOLOGIN=$(command -v nologin 2>/dev/null || echo "/usr/bin/nologin")
|
||||
useradd -r -s "$NOLOGIN" -M "$APP_USER"
|
||||
fi
|
||||
|
||||
# Create directories
|
||||
mkdir -p "$INSTALL_DIR"
|
||||
mkdir -p "$CONFIG_DIR"
|
||||
mkdir -p "$DATA_DIR"
|
||||
mkdir -p "$LOG_DIR"
|
||||
|
||||
# Set permissions
|
||||
chown -R "$APP_USER:$APP_GROUP" "$INSTALL_DIR"
|
||||
chown -R "$APP_USER:$APP_GROUP" "$CONFIG_DIR"
|
||||
chown -R "$APP_USER:$APP_GROUP" "$DATA_DIR"
|
||||
chown -R "$APP_USER:$APP_GROUP" "$LOG_DIR"
|
||||
|
||||
chmod 750 "$INSTALL_DIR"
|
||||
chmod 750 "$CONFIG_DIR"
|
||||
chmod 750 "$DATA_DIR"
|
||||
chmod 750 "$LOG_DIR"
|
||||
}
|
||||
|
||||
# Build the application
|
||||
build_application() {
|
||||
log_info "Building the application..."
|
||||
|
||||
# Clone or update repository
|
||||
if [[ ! -d "$INSTALL_DIR/.git" ]]; then
|
||||
log_info "Cloning repository..."
|
||||
git clone "$GIT_REPO" "$INSTALL_DIR"
|
||||
else
|
||||
log_info "Updating repository..."
|
||||
cd "$INSTALL_DIR"
|
||||
git pull
|
||||
fi
|
||||
|
||||
# Build in release mode
|
||||
cd "$INSTALL_DIR"
|
||||
log_info "Building release binary..."
|
||||
cargo build --release
|
||||
|
||||
# Verify build
|
||||
if [[ -f "target/release/$APP_NAME" ]]; then
|
||||
log_info "Build successful"
|
||||
else
|
||||
log_error "Build failed"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Create configuration files
|
||||
create_configuration() {
|
||||
log_info "Creating configuration files..."
|
||||
|
||||
# Create .env file with API keys
|
||||
cat > "$ENV_FILE" << EOF
|
||||
# LLM Proxy Gateway Environment Variables
|
||||
# Add your API keys here
|
||||
|
||||
# OpenAI API Key
|
||||
# OPENAI_API_KEY=sk-your-key-here
|
||||
|
||||
# Google Gemini API Key
|
||||
# GEMINI_API_KEY=AIza-your-key-here
|
||||
|
||||
# DeepSeek API Key
|
||||
# DEEPSEEK_API_KEY=sk-your-key-here
|
||||
|
||||
# xAI Grok API Key
|
||||
# GROK_API_KEY=gk-your-key-here
|
||||
|
||||
# Authentication tokens (comma-separated)
|
||||
# LLM_PROXY__SERVER__AUTH_TOKENS=token1,token2,token3
|
||||
EOF
|
||||
|
||||
# Create config.toml
|
||||
cat > "$CONFIG_DIR/config.toml" << EOF
|
||||
# LLM Proxy Gateway Configuration
|
||||
|
||||
[server]
|
||||
port = 8080
|
||||
host = "0.0.0.0"
|
||||
# auth_tokens = ["token1", "token2", "token3"] # Uncomment to enable authentication
|
||||
|
||||
[database]
|
||||
path = "$DATA_DIR/llm_proxy.db"
|
||||
max_connections = 5
|
||||
|
||||
[providers.openai]
|
||||
enabled = true
|
||||
api_key_env = "OPENAI_API_KEY"
|
||||
base_url = "https://api.openai.com/v1"
|
||||
default_model = "gpt-4o"
|
||||
|
||||
[providers.gemini]
|
||||
enabled = true
|
||||
api_key_env = "GEMINI_API_KEY"
|
||||
base_url = "https://generativelanguage.googleapis.com/v1"
|
||||
default_model = "gemini-2.0-flash"
|
||||
|
||||
[providers.deepseek]
|
||||
enabled = true
|
||||
api_key_env = "DEEPSEEK_API_KEY"
|
||||
base_url = "https://api.deepseek.com"
|
||||
default_model = "deepseek-reasoner"
|
||||
|
||||
[providers.grok]
|
||||
enabled = false # Disabled by default until API is researched
|
||||
api_key_env = "GROK_API_KEY"
|
||||
base_url = "https://api.x.ai/v1"
|
||||
default_model = "grok-beta"
|
||||
|
||||
[model_mapping]
|
||||
"gpt-*" = "openai"
|
||||
"gemini-*" = "gemini"
|
||||
"deepseek-*" = "deepseek"
|
||||
"grok-*" = "grok"
|
||||
|
||||
[pricing]
|
||||
openai = { input = 0.01, output = 0.03 }
|
||||
gemini = { input = 0.0005, output = 0.0015 }
|
||||
deepseek = { input = 0.00014, output = 0.00028 }
|
||||
grok = { input = 0.001, output = 0.003 }
|
||||
EOF
|
||||
|
||||
# Set permissions
|
||||
chown "$APP_USER:$APP_GROUP" "$ENV_FILE"
|
||||
chown "$APP_USER:$APP_GROUP" "$CONFIG_DIR/config.toml"
|
||||
chmod 640 "$ENV_FILE"
|
||||
chmod 640 "$CONFIG_DIR/config.toml"
|
||||
}
|
||||
|
||||
# Create systemd service
|
||||
create_systemd_service() {
|
||||
log_info "Creating systemd service..."
|
||||
|
||||
cat > "$SERVICE_FILE" << EOF
|
||||
[Unit]
|
||||
Description=LLM Proxy Gateway
|
||||
Documentation=https://git.dustin.coffee/hobokenchicken/llm-proxy
|
||||
After=network.target
|
||||
Wants=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=$APP_USER
|
||||
Group=$APP_GROUP
|
||||
WorkingDirectory=$INSTALL_DIR
|
||||
EnvironmentFile=$ENV_FILE
|
||||
Environment="RUST_LOG=info"
|
||||
Environment="LLM_PROXY__CONFIG_PATH=$CONFIG_DIR/config.toml"
|
||||
Environment="LLM_PROXY__DATABASE__PATH=$DATA_DIR/llm_proxy.db"
|
||||
ExecStart=$INSTALL_DIR/target/release/$APP_NAME
|
||||
Restart=on-failure
|
||||
RestartSec=5
|
||||
|
||||
# Security hardening
|
||||
NoNewPrivileges=true
|
||||
PrivateTmp=true
|
||||
ProtectSystem=strict
|
||||
ProtectHome=true
|
||||
ReadWritePaths=$DATA_DIR $LOG_DIR
|
||||
|
||||
# Resource limits (adjust based on your server)
|
||||
MemoryMax=400M
|
||||
MemorySwapMax=100M
|
||||
CPUQuota=50%
|
||||
LimitNOFILE=65536
|
||||
|
||||
# Logging
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
SyslogIdentifier=$APP_NAME
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
EOF
|
||||
|
||||
# Reload systemd
|
||||
systemctl daemon-reload
|
||||
}
|
||||
|
||||
# Setup nginx reverse proxy (optional)
|
||||
setup_nginx_proxy() {
|
||||
if ! command -v nginx &> /dev/null; then
|
||||
log_warn "nginx not installed. Skipping reverse proxy setup."
|
||||
return
|
||||
fi
|
||||
|
||||
log_info "Setting up nginx reverse proxy..."
|
||||
|
||||
cat > "/etc/nginx/sites-available/$APP_NAME" << EOF
|
||||
server {
|
||||
listen 80;
|
||||
server_name your-domain.com; # Change to your domain
|
||||
|
||||
# Redirect to HTTPS (recommended)
|
||||
return 301 https://\$server_name\$request_uri;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
server_name your-domain.com; # Change to your domain
|
||||
|
||||
# SSL certificates (adjust paths)
|
||||
ssl_certificate /etc/letsencrypt/live/your-domain.com/fullchain.pem;
|
||||
ssl_certificate_key /etc/letsencrypt/live/your-domain.com/privkey.pem;
|
||||
|
||||
# SSL configuration
|
||||
ssl_protocols TLSv1.2 TLSv1.3;
|
||||
ssl_ciphers ECDHE-RSA-AES256-GCM-SHA512:DHE-RSA-AES256-GCM-SHA512:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384;
|
||||
ssl_prefer_server_ciphers off;
|
||||
|
||||
# Proxy to LLM Proxy Gateway
|
||||
location / {
|
||||
proxy_pass http://127.0.0.1:8080;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade \$http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
proxy_set_header Host \$host;
|
||||
proxy_set_header X-Real-IP \$remote_addr;
|
||||
proxy_set_header X-Forwarded-For \$proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto \$scheme;
|
||||
|
||||
# Timeouts
|
||||
proxy_connect_timeout 60s;
|
||||
proxy_send_timeout 60s;
|
||||
proxy_read_timeout 60s;
|
||||
}
|
||||
|
||||
# Health check endpoint
|
||||
location /health {
|
||||
proxy_pass http://127.0.0.1:8080/health;
|
||||
access_log off;
|
||||
}
|
||||
|
||||
# Dashboard
|
||||
location /dashboard {
|
||||
proxy_pass http://127.0.0.1:8080/dashboard;
|
||||
}
|
||||
}
|
||||
EOF
|
||||
|
||||
# Enable site
|
||||
ln -sf "/etc/nginx/sites-available/$APP_NAME" "/etc/nginx/sites-enabled/"
|
||||
|
||||
# Test nginx configuration
|
||||
nginx -t
|
||||
|
||||
log_info "nginx configuration created. Please update the domain and SSL certificate paths."
|
||||
}
|
||||
|
||||
# Setup firewall
|
||||
setup_firewall() {
|
||||
log_info "Configuring firewall..."
|
||||
|
||||
# Check for ufw (Ubuntu)
|
||||
if command -v ufw &> /dev/null; then
|
||||
ufw allow 22/tcp # SSH
|
||||
ufw allow 80/tcp # HTTP
|
||||
ufw allow 443/tcp # HTTPS
|
||||
ufw --force enable
|
||||
log_info "UFW firewall configured"
|
||||
fi
|
||||
|
||||
# Check for firewalld (RHEL/CentOS)
|
||||
if command -v firewall-cmd &> /dev/null; then
|
||||
firewall-cmd --permanent --add-service=ssh
|
||||
firewall-cmd --permanent --add-service=http
|
||||
firewall-cmd --permanent --add-service=https
|
||||
firewall-cmd --reload
|
||||
log_info "Firewalld configured"
|
||||
fi
|
||||
}
|
||||
|
||||
# Initialize database
|
||||
initialize_database() {
|
||||
log_info "Initializing database..."
|
||||
|
||||
# Run the application once to create database
|
||||
sudo -u "$APP_USER" "$INSTALL_DIR/target/release/$APP_NAME" --help &> /dev/null || true
|
||||
|
||||
log_info "Database initialized at $DATA_DIR/llm_proxy.db"
|
||||
}
|
||||
|
||||
# Start and enable service
|
||||
start_service() {
|
||||
log_info "Starting $APP_NAME service..."
|
||||
|
||||
systemctl enable "$APP_NAME"
|
||||
systemctl start "$APP_NAME"
|
||||
|
||||
# Check status
|
||||
sleep 2
|
||||
systemctl status "$APP_NAME" --no-pager
|
||||
}
|
||||
|
||||
# Verify installation
|
||||
verify_installation() {
|
||||
log_info "Verifying installation..."
|
||||
|
||||
# Check if service is running
|
||||
if systemctl is-active --quiet "$APP_NAME"; then
|
||||
log_info "Service is running"
|
||||
else
|
||||
log_error "Service is not running"
|
||||
journalctl -u "$APP_NAME" -n 20 --no-pager
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Test health endpoint
|
||||
if curl -s http://localhost:8080/health | grep -q "OK"; then
|
||||
log_info "Health check passed"
|
||||
else
|
||||
log_error "Health check failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Test dashboard
|
||||
if curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/dashboard | grep -q "200"; then
|
||||
log_info "Dashboard is accessible"
|
||||
else
|
||||
log_warn "Dashboard may not be accessible (this is normal if not configured)"
|
||||
fi
|
||||
|
||||
log_info "Installation verified successfully!"
|
||||
}
|
||||
|
||||
# Print next steps
|
||||
print_next_steps() {
|
||||
cat << EOF
|
||||
|
||||
${GREEN}=== LLM Proxy Gateway Installation Complete ===${NC}
|
||||
|
||||
${YELLOW}Next steps:${NC}
|
||||
|
||||
1. ${GREEN}Configure API keys${NC}
|
||||
Edit: $ENV_FILE
|
||||
Add your API keys for the providers you want to use
|
||||
|
||||
2. ${GREEN}Configure authentication${NC}
|
||||
Edit: $CONFIG_DIR/config.toml
|
||||
Uncomment and set auth_tokens for client authentication
|
||||
|
||||
3. ${GREEN}Configure nginx${NC}
|
||||
Edit: /etc/nginx/sites-available/$APP_NAME
|
||||
Update domain name and SSL certificate paths
|
||||
|
||||
4. ${GREEN}Test the API${NC}
|
||||
curl -X POST http://localhost:8080/v1/chat/completions \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-H "Authorization: Bearer your-token" \\
|
||||
-d '{
|
||||
"model": "gpt-4o",
|
||||
"messages": [{"role": "user", "content": "Hello!"}]
|
||||
}'
|
||||
|
||||
5. ${GREEN}Access the dashboard${NC}
|
||||
Open: http://your-server-ip:8080/dashboard
|
||||
Or: https://your-domain.com/dashboard (if nginx configured)
|
||||
|
||||
${YELLOW}Useful commands:${NC}
|
||||
systemctl status $APP_NAME # Check service status
|
||||
journalctl -u $APP_NAME -f # View logs
|
||||
systemctl restart $APP_NAME # Restart service
|
||||
|
||||
${YELLOW}Configuration files:${NC}
|
||||
Service: $SERVICE_FILE
|
||||
Config: $CONFIG_DIR/config.toml
|
||||
Environment: $ENV_FILE
|
||||
Database: $DATA_DIR/llm_proxy.db
|
||||
Logs: $LOG_DIR/
|
||||
|
||||
${GREEN}For more information, see:${NC}
|
||||
https://git.dustin.coffee/hobokenchicken/llm-proxy
|
||||
$INSTALL_DIR/README.md
|
||||
$INSTALL_DIR/deployment.md
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
# Main deployment function
|
||||
deploy() {
|
||||
log_info "Starting LLM Proxy Gateway deployment..."
|
||||
|
||||
check_root
|
||||
install_dependencies
|
||||
install_rust
|
||||
setup_directories
|
||||
build_application
|
||||
create_configuration
|
||||
create_systemd_service
|
||||
initialize_database
|
||||
start_service
|
||||
verify_installation
|
||||
print_next_steps
|
||||
|
||||
# Optional steps (uncomment if needed)
|
||||
# setup_nginx_proxy
|
||||
# setup_firewall
|
||||
|
||||
log_info "Deployment completed successfully!"
|
||||
}
|
||||
|
||||
# Update function
|
||||
update() {
|
||||
log_info "Updating LLM Proxy Gateway..."
|
||||
|
||||
check_root
|
||||
|
||||
# Pull latest changes (while service keeps running)
|
||||
cd "$INSTALL_DIR"
|
||||
log_info "Pulling latest changes..."
|
||||
git pull
|
||||
|
||||
# Build new binary (service stays up on the old binary)
|
||||
log_info "Building release binary (service still running)..."
|
||||
if ! cargo build --release; then
|
||||
log_error "Build failed — service was NOT interrupted. Fix the error and try again."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify binary exists
|
||||
if [[ ! -f "target/release/$APP_NAME" ]]; then
|
||||
log_error "Binary not found after build — aborting."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Restart service to pick up new binary
|
||||
log_info "Build succeeded. Restarting service..."
|
||||
systemctl restart "$APP_NAME"
|
||||
|
||||
sleep 2
|
||||
if systemctl is-active --quiet "$APP_NAME"; then
|
||||
log_info "Update completed successfully!"
|
||||
systemctl status "$APP_NAME" --no-pager
|
||||
else
|
||||
log_error "Service failed to start after update. Check logs:"
|
||||
journalctl -u "$APP_NAME" -n 20 --no-pager
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Uninstall function
|
||||
uninstall() {
|
||||
log_info "Uninstalling LLM Proxy Gateway..."
|
||||
|
||||
check_root
|
||||
|
||||
# Stop and disable service
|
||||
systemctl stop "$APP_NAME" 2>/dev/null || true
|
||||
systemctl disable "$APP_NAME" 2>/dev/null || true
|
||||
rm -f "$SERVICE_FILE"
|
||||
systemctl daemon-reload
|
||||
|
||||
# Remove application files
|
||||
rm -rf "$INSTALL_DIR"
|
||||
rm -rf "$CONFIG_DIR"
|
||||
|
||||
# Keep data and logs (comment out to remove)
|
||||
log_warn "Data directory $DATA_DIR and logs $LOG_DIR have been preserved"
|
||||
log_warn "Remove manually if desired:"
|
||||
log_warn " rm -rf $DATA_DIR $LOG_DIR"
|
||||
|
||||
# Remove user (optional)
|
||||
read -p "Remove user $APP_USER? [y/N]: " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
||||
userdel "$APP_USER" 2>/dev/null || true
|
||||
groupdel "$APP_GROUP" 2>/dev/null || true
|
||||
fi
|
||||
|
||||
log_info "Uninstallation completed!"
|
||||
}
|
||||
|
||||
# Show usage
|
||||
usage() {
|
||||
cat << EOF
|
||||
LLM Proxy Gateway Deployment Script
|
||||
|
||||
Usage: $0 [command]
|
||||
|
||||
Commands:
|
||||
deploy - Install and configure LLM Proxy Gateway
|
||||
update - Pull latest changes, rebuild, and restart
|
||||
status - Show service status and health check
|
||||
logs - Tail the service logs (Ctrl+C to stop)
|
||||
uninstall - Remove LLM Proxy Gateway
|
||||
help - Show this help message
|
||||
|
||||
Examples:
|
||||
$0 deploy # Full installation
|
||||
$0 update # Update to latest version
|
||||
$0 status # Check if service is healthy
|
||||
$0 logs # Follow live logs
|
||||
|
||||
EOF
|
||||
}
|
||||
|
||||
# Status function
|
||||
status() {
|
||||
echo ""
|
||||
log_info "Service status:"
|
||||
systemctl status "$APP_NAME" --no-pager 2>/dev/null || log_warn "Service not found"
|
||||
echo ""
|
||||
|
||||
# Health check
|
||||
if curl -sf http://localhost:8080/health &>/dev/null; then
|
||||
log_info "Health check: OK"
|
||||
else
|
||||
log_warn "Health check: FAILED (service may not be running or port 8080 not responding)"
|
||||
fi
|
||||
|
||||
# Show current git commit
|
||||
if [[ -d "$INSTALL_DIR/.git" ]]; then
|
||||
echo ""
|
||||
log_info "Installed version:"
|
||||
git -C "$INSTALL_DIR" log -1 --format=" %h %s (%cr)" 2>/dev/null
|
||||
fi
|
||||
}
|
||||
|
||||
# Logs function
|
||||
logs() {
|
||||
log_info "Tailing $APP_NAME logs (Ctrl+C to stop)..."
|
||||
journalctl -u "$APP_NAME" -f
|
||||
}
|
||||
|
||||
# Parse command line arguments
|
||||
case "${1:-}" in
|
||||
deploy)
|
||||
deploy
|
||||
;;
|
||||
update)
|
||||
update
|
||||
;;
|
||||
status)
|
||||
status
|
||||
;;
|
||||
logs)
|
||||
logs
|
||||
;;
|
||||
uninstall)
|
||||
uninstall
|
||||
;;
|
||||
help|--help|-h)
|
||||
usage
|
||||
;;
|
||||
*)
|
||||
usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
328
deployment.md
328
deployment.md
@@ -1,322 +1,52 @@
|
||||
# LLM Proxy Gateway - Deployment Guide
|
||||
# Deployment Guide (Go)
|
||||
|
||||
## Overview
|
||||
A unified LLM proxy gateway supporting OpenAI, Google Gemini, DeepSeek, and xAI Grok with token tracking, cost calculation, and admin dashboard.
|
||||
This guide covers deploying the Go-based GopherGate.
|
||||
|
||||
## System Requirements
|
||||
- **CPU**: 2 cores minimum
|
||||
- **RAM**: 512MB minimum (1GB recommended)
|
||||
- **Storage**: 10GB minimum
|
||||
- **OS**: Linux (tested on Arch Linux, Ubuntu, Debian)
|
||||
- **Runtime**: Rust 1.70+ with Cargo
|
||||
## Environment Setup
|
||||
|
||||
## Deployment Options
|
||||
|
||||
### Option 1: Docker (Recommended)
|
||||
```dockerfile
|
||||
FROM rust:1.70-alpine as builder
|
||||
WORKDIR /app
|
||||
COPY . .
|
||||
RUN cargo build --release
|
||||
|
||||
FROM alpine:latest
|
||||
RUN apk add --no-cache libgcc
|
||||
COPY --from=builder /app/target/release/llm-proxy /usr/local/bin/
|
||||
COPY --from=builder /app/static /app/static
|
||||
WORKDIR /app
|
||||
EXPOSE 8080
|
||||
CMD ["llm-proxy"]
|
||||
```
|
||||
|
||||
### Option 2: Systemd Service (Bare Metal/LXC)
|
||||
```ini
|
||||
# /etc/systemd/system/llm-proxy.service
|
||||
[Unit]
|
||||
Description=LLM Proxy Gateway
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=llmproxy
|
||||
Group=llmproxy
|
||||
WorkingDirectory=/opt/llm-proxy
|
||||
ExecStart=/opt/llm-proxy/llm-proxy
|
||||
Restart=always
|
||||
RestartSec=10
|
||||
Environment="RUST_LOG=info"
|
||||
Environment="LLM_PROXY__SERVER__PORT=8080"
|
||||
Environment="LLM_PROXY__SERVER__AUTH_TOKENS=sk-test-123,sk-test-456"
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
```
|
||||
|
||||
### Option 3: LXC Container (Proxmox)
|
||||
1. Create Alpine Linux LXC container
|
||||
2. Install Rust: `apk add rust cargo`
|
||||
3. Copy application files
|
||||
4. Build: `cargo build --release`
|
||||
5. Run: `./target/release/llm-proxy`
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
1. **Mandatory Configuration:**
|
||||
Create a `.env` file from the example:
|
||||
```bash
|
||||
# Required API Keys
|
||||
OPENAI_API_KEY=sk-...
|
||||
GEMINI_API_KEY=AIza...
|
||||
DEEPSEEK_API_KEY=sk-...
|
||||
GROK_API_KEY=gk-... # Optional
|
||||
|
||||
# Server Configuration (with LLM_PROXY__ prefix)
|
||||
LLM_PROXY__SERVER__PORT=8080
|
||||
LLM_PROXY__SERVER__HOST=0.0.0.0
|
||||
LLM_PROXY__SERVER__AUTH_TOKENS=sk-test-123,sk-test-456
|
||||
|
||||
# Database Configuration
|
||||
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
|
||||
LLM_PROXY__DATABASE__MAX_CONNECTIONS=10
|
||||
|
||||
# Provider Configuration
|
||||
LLM_PROXY__PROVIDERS__OPENAI__ENABLED=true
|
||||
LLM_PROXY__PROVIDERS__GEMINI__ENABLED=true
|
||||
LLM_PROXY__PROVIDERS__DEEPSEEK__ENABLED=true
|
||||
LLM_PROXY__PROVIDERS__GROK__ENABLED=false
|
||||
cp .env.example .env
|
||||
```
|
||||
Ensure `LLM_PROXY__ENCRYPTION_KEY` is set to a secure 32-byte string.
|
||||
|
||||
### Configuration File (config.toml)
|
||||
Create `config.toml` in the application directory:
|
||||
```toml
|
||||
[server]
|
||||
port = 8080
|
||||
host = "0.0.0.0"
|
||||
auth_tokens = ["sk-test-123", "sk-test-456"]
|
||||
2. **Data Directory:**
|
||||
The proxy stores its database in `./data/llm_proxy.db` by default. Ensure this directory exists and is writable.
|
||||
|
||||
[database]
|
||||
path = "./data/llm_proxy.db"
|
||||
max_connections = 10
|
||||
## Binary Deployment
|
||||
|
||||
[providers.openai]
|
||||
enabled = true
|
||||
base_url = "https://api.openai.com/v1"
|
||||
default_model = "gpt-4o"
|
||||
|
||||
[providers.gemini]
|
||||
enabled = true
|
||||
base_url = "https://generativelanguage.googleapis.com/v1"
|
||||
default_model = "gemini-2.0-flash"
|
||||
|
||||
[providers.deepseek]
|
||||
enabled = true
|
||||
base_url = "https://api.deepseek.com"
|
||||
default_model = "deepseek-reasoner"
|
||||
|
||||
[providers.grok]
|
||||
enabled = false
|
||||
base_url = "https://api.x.ai/v1"
|
||||
default_model = "grok-beta"
|
||||
```
|
||||
|
||||
## Nginx Reverse Proxy Configuration
|
||||
|
||||
**Important for SSE/Streaming:** Disable buffering and configure timeouts for proper SSE support.
|
||||
|
||||
```nginx
|
||||
server {
|
||||
listen 80;
|
||||
server_name llm-proxy.yourdomain.com;
|
||||
|
||||
location / {
|
||||
proxy_pass http://localhost:8080;
|
||||
proxy_http_version 1.1;
|
||||
|
||||
# SSE/Streaming support
|
||||
proxy_buffering off;
|
||||
chunked_transfer_encoding on;
|
||||
proxy_set_header Connection '';
|
||||
|
||||
# Timeouts for long-running streams
|
||||
proxy_connect_timeout 7200s;
|
||||
proxy_read_timeout 7200s;
|
||||
proxy_send_timeout 7200s;
|
||||
|
||||
# Disable gzip for streaming
|
||||
gzip off;
|
||||
|
||||
# Headers
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
}
|
||||
|
||||
# SSL configuration (recommended)
|
||||
listen 443 ssl http2;
|
||||
ssl_certificate /etc/letsencrypt/live/llm-proxy.yourdomain.com/fullchain.pem;
|
||||
ssl_certificate_key /etc/letsencrypt/live/llm-proxy.yourdomain.com/privkey.pem;
|
||||
}
|
||||
```
|
||||
|
||||
### NGINX Proxy Manager
|
||||
|
||||
If using NGINX Proxy Manager, add this to **Advanced Settings**:
|
||||
|
||||
```nginx
|
||||
proxy_buffering off;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Connection '';
|
||||
chunked_transfer_encoding on;
|
||||
proxy_connect_timeout 7200s;
|
||||
proxy_read_timeout 7200s;
|
||||
proxy_send_timeout 7200s;
|
||||
gzip off;
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### 1. Authentication
|
||||
- Use strong Bearer tokens
|
||||
- Rotate tokens regularly
|
||||
- Consider implementing JWT for production
|
||||
|
||||
### 2. Rate Limiting
|
||||
- Implement per-client rate limiting
|
||||
- Consider using `governor` crate for advanced rate limiting
|
||||
|
||||
### 3. Network Security
|
||||
- Run behind reverse proxy (nginx)
|
||||
- Enable HTTPS
|
||||
- Restrict access by IP if needed
|
||||
- Use firewall rules
|
||||
|
||||
### 4. Data Security
|
||||
- Database encryption (SQLCipher for SQLite)
|
||||
- Secure API key storage
|
||||
- Regular backups
|
||||
|
||||
## Monitoring & Maintenance
|
||||
|
||||
### Logging
|
||||
- Application logs: `RUST_LOG=info` (or `debug` for troubleshooting)
|
||||
- Access logs via nginx
|
||||
- Database logs for audit trail
|
||||
|
||||
### Health Checks
|
||||
### 1. Build
|
||||
```bash
|
||||
# Health endpoint
|
||||
curl http://localhost:8080/health
|
||||
|
||||
# Database check
|
||||
sqlite3 ./data/llm_proxy.db "SELECT COUNT(*) FROM llm_requests;"
|
||||
go build -o gophergate ./cmd/gophergate
|
||||
```
|
||||
|
||||
### Backup Strategy
|
||||
### 2. Run
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# backup.sh
|
||||
BACKUP_DIR="/backups/llm-proxy"
|
||||
DATE=$(date +%Y%m%d_%H%M%S)
|
||||
|
||||
# Backup database
|
||||
sqlite3 ./data/llm_proxy.db ".backup $BACKUP_DIR/llm_proxy_$DATE.db"
|
||||
|
||||
# Backup configuration
|
||||
cp config.toml $BACKUP_DIR/config_$DATE.toml
|
||||
|
||||
# Rotate old backups (keep 30 days)
|
||||
find $BACKUP_DIR -name "*.db" -mtime +30 -delete
|
||||
find $BACKUP_DIR -name "*.toml" -mtime +30 -delete
|
||||
./gophergate
|
||||
```
|
||||
|
||||
## Performance Tuning
|
||||
## Docker Deployment
|
||||
|
||||
### Database Optimization
|
||||
```sql
|
||||
-- Run these SQL commands periodically
|
||||
VACUUM;
|
||||
ANALYZE;
|
||||
```
|
||||
The project includes a multi-stage `Dockerfile` for minimal image size.
|
||||
|
||||
### Memory Management
|
||||
- Monitor memory usage with `htop` or `ps aux`
|
||||
- Adjust `max_connections` based on load
|
||||
- Consider connection pooling for high traffic
|
||||
|
||||
### Scaling
|
||||
1. **Vertical Scaling**: Increase container resources
|
||||
2. **Horizontal Scaling**: Deploy multiple instances behind load balancer
|
||||
3. **Database**: Migrate to PostgreSQL for high-volume usage
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Port already in use**
|
||||
### 1. Build Image
|
||||
```bash
|
||||
netstat -tulpn | grep :8080
|
||||
kill <PID> # or change port in config
|
||||
docker build -t gophergate .
|
||||
```
|
||||
|
||||
2. **Database permissions**
|
||||
### 2. Run Container
|
||||
```bash
|
||||
chown -R llmproxy:llmproxy /opt/llm-proxy/data
|
||||
chmod 600 /opt/llm-proxy/data/llm_proxy.db
|
||||
docker run -d \
|
||||
--name gophergate \
|
||||
-p 8080:8080 \
|
||||
-v $(pwd)/data:/app/data \
|
||||
--env-file .env \
|
||||
gophergate
|
||||
```
|
||||
|
||||
3. **API key errors**
|
||||
- Verify environment variables are set
|
||||
- Check provider status (dashboard)
|
||||
- Test connectivity: `curl https://api.openai.com/v1/models`
|
||||
## Production Considerations
|
||||
|
||||
4. **High memory usage**
|
||||
- Check for memory leaks
|
||||
- Reduce `max_connections`
|
||||
- Implement connection timeouts
|
||||
|
||||
### Debug Mode
|
||||
```bash
|
||||
# Run with debug logging
|
||||
RUST_LOG=debug ./llm-proxy
|
||||
|
||||
# Check system logs
|
||||
journalctl -u llm-proxy -f
|
||||
```
|
||||
|
||||
## Integration
|
||||
|
||||
### Open-WebUI Compatibility
|
||||
The proxy provides OpenAI-compatible API, so configure Open-WebUI:
|
||||
```
|
||||
API Base URL: http://your-proxy-address:8080
|
||||
API Key: sk-test-123 (or your configured token)
|
||||
```
|
||||
|
||||
### Custom Clients
|
||||
```python
|
||||
import openai
|
||||
|
||||
client = openai.OpenAI(
|
||||
base_url="http://localhost:8080/v1",
|
||||
api_key="sk-test-123"
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
```
|
||||
|
||||
## Updates & Upgrades
|
||||
|
||||
1. **Backup** current configuration and database
|
||||
2. **Stop** the service: `systemctl stop llm-proxy`
|
||||
3. **Update** code: `git pull` or copy new binaries
|
||||
4. **Migrate** database if needed (check migrations/)
|
||||
5. **Restart**: `systemctl start llm-proxy`
|
||||
6. **Verify**: Check logs and test endpoints
|
||||
|
||||
## Support
|
||||
- Check logs in `/var/log/llm-proxy/`
|
||||
- Monitor dashboard at `http://your-server:8080`
|
||||
- Review database metrics in dashboard
|
||||
- Enable debug logging for troubleshooting
|
||||
- **SSL/TLS:** It is recommended to run the proxy behind a reverse proxy like Nginx or Caddy for SSL termination.
|
||||
- **Backups:** Regularly backup the `data/llm_proxy.db` file.
|
||||
- **Monitoring:** Monitor the `/health` endpoint for system status.
|
||||
|
||||
70
go.mod
Normal file
70
go.mod
Normal file
@@ -0,0 +1,70 @@
|
||||
module gophergate
|
||||
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.12.0
|
||||
github.com/go-resty/resty/v2 v2.17.2
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/jmoiron/sqlx v1.4.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/shirou/gopsutil/v3 v3.24.5
|
||||
github.com/spf13/viper v1.21.0
|
||||
golang.org/x/crypto v0.48.0
|
||||
modernc.org/sqlite v1.47.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||
github.com/bytedance/sonic v1.15.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.30.1 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||
github.com/quic-go/qpack v0.6.0 // indirect
|
||||
github.com/quic-go/quic-go v0.59.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.3.1 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/arch v0.22.0 // indirect
|
||||
golang.org/x/net v0.51.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
modernc.org/libc v1.70.0 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
)
|
||||
207
go.sum
Normal file
207
go.sum
Normal file
@@ -0,0 +1,207 @@
|
||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
|
||||
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
|
||||
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
|
||||
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
|
||||
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
|
||||
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
|
||||
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
||||
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
|
||||
github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8=
|
||||
github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w=
|
||||
github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM=
|
||||
github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk=
|
||||
github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
|
||||
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
||||
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
|
||||
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
|
||||
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
|
||||
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
|
||||
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
|
||||
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
|
||||
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=
|
||||
github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk=
|
||||
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
||||
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
|
||||
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
|
||||
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
|
||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
|
||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
||||
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
|
||||
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY=
|
||||
github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
|
||||
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
|
||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
||||
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
|
||||
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
|
||||
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
|
||||
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.32.0 h1:hjG66bI/kqIPX1b2yT6fr/jt+QedtP2fqojG2VrFuVw=
|
||||
modernc.org/ccgo/v4 v4.32.0/go.mod h1:6F08EBCx5uQc38kMGl+0Nm0oWczoo1c7cgpzEry7Uc0=
|
||||
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
|
||||
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
|
||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
|
||||
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||
modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw=
|
||||
modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.47.0 h1:R1XyaNpoW4Et9yly+I2EeX7pBza/w+pmYee/0HJDyKk=
|
||||
modernc.org/sqlite v1.47.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||
207
internal/config/config.go
Normal file
207
internal/config/config.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Providers ProviderConfig `mapstructure:"providers"`
|
||||
EncryptionKey string `mapstructure:"encryption_key"`
|
||||
KeyBytes []byte
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Port int `mapstructure:"port"`
|
||||
Host string `mapstructure:"host"`
|
||||
AuthTokens []string `mapstructure:"auth_tokens"`
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Path string `mapstructure:"path"`
|
||||
MaxConnections int `mapstructure:"max_connections"`
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
OpenAI OpenAIConfig `mapstructure:"openai"`
|
||||
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||
DeepSeek DeepSeekConfig `mapstructure:"deepseek"`
|
||||
Moonshot MoonshotConfig `mapstructure:"moonshot"`
|
||||
Grok GrokConfig `mapstructure:"grok"`
|
||||
Ollama OllamaConfig `mapstructure:"ollama"`
|
||||
}
|
||||
|
||||
type OpenAIConfig struct {
|
||||
APIKeyEnv string `mapstructure:"api_key_env"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
DefaultModel string `mapstructure:"default_model"`
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
}
|
||||
|
||||
type GeminiConfig struct {
|
||||
APIKeyEnv string `mapstructure:"api_key_env"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
DefaultModel string `mapstructure:"default_model"`
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
}
|
||||
|
||||
type DeepSeekConfig struct {
|
||||
APIKeyEnv string `mapstructure:"api_key_env"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
DefaultModel string `mapstructure:"default_model"`
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
}
|
||||
|
||||
type MoonshotConfig struct {
|
||||
APIKeyEnv string `mapstructure:"api_key_env"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
DefaultModel string `mapstructure:"default_model"`
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
}
|
||||
|
||||
type GrokConfig struct {
|
||||
APIKeyEnv string `mapstructure:"api_key_env"`
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
DefaultModel string `mapstructure:"default_model"`
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
}
|
||||
|
||||
type OllamaConfig struct {
|
||||
BaseURL string `mapstructure:"base_url"`
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
DefaultModel string `mapstructure:"default_model"`
|
||||
Models []string `mapstructure:"models"`
|
||||
}
|
||||
|
||||
func Load() (*Config, error) {
|
||||
v := viper.New()
|
||||
|
||||
// Defaults
|
||||
v.SetDefault("server.port", 8080)
|
||||
v.SetDefault("server.host", "0.0.0.0")
|
||||
v.SetDefault("server.auth_tokens", []string{})
|
||||
v.SetDefault("database.path", "./data/llm_proxy.db")
|
||||
v.SetDefault("database.max_connections", 10)
|
||||
|
||||
v.SetDefault("providers.openai.api_key_env", "OPENAI_API_KEY")
|
||||
v.SetDefault("providers.openai.base_url", "https://api.openai.com/v1")
|
||||
v.SetDefault("providers.openai.default_model", "gpt-4o")
|
||||
v.SetDefault("providers.openai.enabled", true)
|
||||
|
||||
v.SetDefault("providers.gemini.api_key_env", "GEMINI_API_KEY")
|
||||
v.SetDefault("providers.gemini.base_url", "https://generativelanguage.googleapis.com/v1")
|
||||
v.SetDefault("providers.gemini.default_model", "gemini-2.0-flash")
|
||||
v.SetDefault("providers.gemini.enabled", true)
|
||||
|
||||
v.SetDefault("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")
|
||||
v.SetDefault("providers.deepseek.base_url", "https://api.deepseek.com")
|
||||
v.SetDefault("providers.deepseek.default_model", "deepseek-reasoner")
|
||||
v.SetDefault("providers.deepseek.enabled", true)
|
||||
|
||||
v.SetDefault("providers.moonshot.api_key_env", "MOONSHOT_API_KEY")
|
||||
v.SetDefault("providers.moonshot.base_url", "https://api.moonshot.ai/v1")
|
||||
v.SetDefault("providers.moonshot.default_model", "kimi-k2.5")
|
||||
v.SetDefault("providers.moonshot.enabled", true)
|
||||
|
||||
v.SetDefault("providers.grok.api_key_env", "GROK_API_KEY")
|
||||
v.SetDefault("providers.grok.base_url", "https://api.x.ai/v1")
|
||||
v.SetDefault("providers.grok.default_model", "grok-4-1-fast-non-reasoning")
|
||||
v.SetDefault("providers.grok.enabled", true)
|
||||
|
||||
v.SetDefault("providers.ollama.base_url", "http://localhost:11434/v1")
|
||||
v.SetDefault("providers.ollama.enabled", false)
|
||||
v.SetDefault("providers.ollama.models", []string{})
|
||||
|
||||
// Environment variables
|
||||
v.SetEnvPrefix("LLM_PROXY")
|
||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "__"))
|
||||
v.AutomaticEnv()
|
||||
|
||||
// Explicitly bind keys that might use double underscores in .env
|
||||
v.BindEnv("encryption_key", "LLM_PROXY__ENCRYPTION_KEY")
|
||||
v.BindEnv("server.port", "LLM_PROXY__SERVER__PORT")
|
||||
v.BindEnv("server.host", "LLM_PROXY__SERVER__HOST")
|
||||
|
||||
// Config file
|
||||
v.SetConfigName("config")
|
||||
v.SetConfigType("toml")
|
||||
v.AddConfigPath(".")
|
||||
if envPath := os.Getenv("LLM_PROXY__CONFIG_PATH"); envPath != "" {
|
||||
v.SetConfigFile(envPath)
|
||||
}
|
||||
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Debug Config: port from viper=%d, host from viper=%s\n", cfg.Server.Port, cfg.Server.Host)
|
||||
fmt.Printf("Debug Env: LLM_PROXY__SERVER__PORT=%s, LLM_PROXY__SERVER__HOST=%s\n", os.Getenv("LLM_PROXY__SERVER__PORT"), os.Getenv("LLM_PROXY__SERVER__HOST"))
|
||||
|
||||
// Manual overrides for nested keys which Viper doesn't always bind correctly with AutomaticEnv + SetEnvPrefix
|
||||
if port := os.Getenv("LLM_PROXY__SERVER__PORT"); port != "" {
|
||||
fmt.Sscanf(port, "%d", &cfg.Server.Port)
|
||||
fmt.Printf("Overriding port to %d from env\n", cfg.Server.Port)
|
||||
}
|
||||
if host := os.Getenv("LLM_PROXY__SERVER__HOST"); host != "" {
|
||||
cfg.Server.Host = host
|
||||
fmt.Printf("Overriding host to %s from env\n", cfg.Server.Host)
|
||||
}
|
||||
|
||||
// Validate encryption key
|
||||
if cfg.EncryptionKey == "" {
|
||||
return nil, fmt.Errorf("encryption key is required (LLM_PROXY__ENCRYPTION_KEY)")
|
||||
}
|
||||
|
||||
keyBytes, err := hex.DecodeString(cfg.EncryptionKey)
|
||||
if err != nil {
|
||||
keyBytes, err = base64.StdEncoding.DecodeString(cfg.EncryptionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encryption key must be hex or base64 encoded")
|
||||
}
|
||||
}
|
||||
|
||||
if len(keyBytes) != 32 {
|
||||
return nil, fmt.Errorf("encryption key must be 32 bytes, got %d", len(keyBytes))
|
||||
}
|
||||
cfg.KeyBytes = keyBytes
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (c *Config) GetAPIKey(provider string) (string, error) {
|
||||
var envVar string
|
||||
switch provider {
|
||||
case "openai":
|
||||
envVar = c.Providers.OpenAI.APIKeyEnv
|
||||
case "gemini":
|
||||
envVar = c.Providers.Gemini.APIKeyEnv
|
||||
case "deepseek":
|
||||
envVar = c.Providers.DeepSeek.APIKeyEnv
|
||||
case "moonshot":
|
||||
envVar = c.Providers.Moonshot.APIKeyEnv
|
||||
case "grok":
|
||||
envVar = c.Providers.Grok.APIKeyEnv
|
||||
default:
|
||||
return "", fmt.Errorf("unknown provider: %s", provider)
|
||||
}
|
||||
|
||||
val := os.Getenv(envVar)
|
||||
if val == "" {
|
||||
return "", fmt.Errorf("environment variable %s not set for %s", envVar, provider)
|
||||
}
|
||||
return strings.TrimSpace(val), nil
|
||||
}
|
||||
264
internal/db/db.go
Normal file
264
internal/db/db.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "modernc.org/sqlite"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
*sqlx.DB
|
||||
}
|
||||
|
||||
func Init(path string) (*DB, error) {
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(path)
|
||||
if dir != "." && dir != "/" {
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create database directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Connect to SQLite
|
||||
dsn := fmt.Sprintf("file:%s?_pragma=foreign_keys(1)", path)
|
||||
db, err := sqlx.Connect("sqlite", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
instance := &DB{db}
|
||||
|
||||
// Run migrations
|
||||
if err := instance.RunMigrations(); err != nil {
|
||||
return nil, fmt.Errorf("failed to run migrations: %w", err)
|
||||
}
|
||||
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
func (db *DB) RunMigrations() error {
|
||||
// Tables creation
|
||||
queries := []string{
|
||||
`CREATE TABLE IF NOT EXISTS clients (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
client_id TEXT UNIQUE NOT NULL,
|
||||
name TEXT,
|
||||
description TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
is_active BOOLEAN DEFAULT TRUE,
|
||||
rate_limit_per_minute INTEGER DEFAULT 60,
|
||||
total_requests INTEGER DEFAULT 0,
|
||||
total_tokens INTEGER DEFAULT 0,
|
||||
total_cost REAL DEFAULT 0.0
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS llm_requests (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
client_id TEXT,
|
||||
provider TEXT,
|
||||
model TEXT,
|
||||
prompt_tokens INTEGER,
|
||||
completion_tokens INTEGER,
|
||||
reasoning_tokens INTEGER DEFAULT 0,
|
||||
total_tokens INTEGER,
|
||||
cost REAL,
|
||||
has_images BOOLEAN DEFAULT FALSE,
|
||||
status TEXT DEFAULT 'success',
|
||||
error_message TEXT,
|
||||
duration_ms INTEGER,
|
||||
request_body TEXT,
|
||||
response_body TEXT,
|
||||
cache_read_tokens INTEGER DEFAULT 0,
|
||||
cache_write_tokens INTEGER DEFAULT 0,
|
||||
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE SET NULL
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS provider_configs (
|
||||
id TEXT PRIMARY KEY,
|
||||
display_name TEXT NOT NULL,
|
||||
enabled BOOLEAN DEFAULT TRUE,
|
||||
base_url TEXT,
|
||||
api_key TEXT,
|
||||
credit_balance REAL DEFAULT 0.0,
|
||||
low_credit_threshold REAL DEFAULT 5.0,
|
||||
billing_mode TEXT,
|
||||
api_key_encrypted BOOLEAN DEFAULT FALSE,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS model_configs (
|
||||
id TEXT PRIMARY KEY,
|
||||
provider_id TEXT NOT NULL,
|
||||
display_name TEXT,
|
||||
enabled BOOLEAN DEFAULT TRUE,
|
||||
prompt_cost_per_m REAL,
|
||||
completion_cost_per_m REAL,
|
||||
cache_read_cost_per_m REAL,
|
||||
cache_write_cost_per_m REAL,
|
||||
mapping TEXT,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT NOT NULL,
|
||||
display_name TEXT,
|
||||
role TEXT DEFAULT 'admin',
|
||||
must_change_password BOOLEAN DEFAULT FALSE,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS client_tokens (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
client_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
name TEXT DEFAULT 'default',
|
||||
is_active BOOLEAN DEFAULT TRUE,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
last_used_at DATETIME,
|
||||
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE CASCADE
|
||||
)`,
|
||||
}
|
||||
|
||||
for _, q := range queries {
|
||||
if _, err := db.Exec(q); err != nil {
|
||||
return fmt.Errorf("migration failed for query [%s]: %w", q, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Add indices
|
||||
indices := []string{
|
||||
"CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_client_tokens_token ON client_tokens(token)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_client_tokens_client_id ON client_tokens(client_id)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id)",
|
||||
}
|
||||
|
||||
for _, idx := range indices {
|
||||
if _, err := db.Exec(idx); err != nil {
|
||||
return fmt.Errorf("failed to create index [%s]: %w", idx, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Default admin user
|
||||
var count int
|
||||
if err := db.Get(&count, "SELECT COUNT(*) FROM users"); err != nil {
|
||||
return fmt.Errorf("failed to count users: %w", err)
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte("admin123"), 12)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash default password: %w", err)
|
||||
}
|
||||
_, err = db.Exec("INSERT INTO users (username, password_hash, role, must_change_password) VALUES ('admin', ?, 'admin', 1)", string(hash))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert default admin: %w", err)
|
||||
}
|
||||
log.Println("Created default admin user with password 'admin123' (must change on first login)")
|
||||
}
|
||||
|
||||
// Default client
|
||||
_, err := db.Exec(`INSERT OR IGNORE INTO clients (client_id, name, description)
|
||||
VALUES ('default', 'Default Client', 'Default client for anonymous requests')`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert default client: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Data models for DB tables
|
||||
|
||||
type Client struct {
|
||||
ID int `db:"id"`
|
||||
ClientID string `db:"client_id"`
|
||||
Name *string `db:"name"`
|
||||
Description *string `db:"description"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
UpdatedAt time.Time `db:"updated_at"`
|
||||
IsActive bool `db:"is_active"`
|
||||
RateLimitPerMinute int `db:"rate_limit_per_minute"`
|
||||
TotalRequests int `db:"total_requests"`
|
||||
TotalTokens int `db:"total_tokens"`
|
||||
TotalCost float64 `db:"total_cost"`
|
||||
}
|
||||
|
||||
type LLMRequest struct {
|
||||
ID int `db:"id"`
|
||||
Timestamp time.Time `db:"timestamp"`
|
||||
ClientID *string `db:"client_id"`
|
||||
Provider *string `db:"provider"`
|
||||
Model *string `db:"model"`
|
||||
PromptTokens *int `db:"prompt_tokens"`
|
||||
CompletionTokens *int `db:"completion_tokens"`
|
||||
ReasoningTokens int `db:"reasoning_tokens"`
|
||||
TotalTokens *int `db:"total_tokens"`
|
||||
Cost *float64 `db:"cost"`
|
||||
HasImages bool `db:"has_images"`
|
||||
Status string `db:"status"`
|
||||
ErrorMessage *string `db:"error_message"`
|
||||
DurationMS *int `db:"duration_ms"`
|
||||
RequestBody *string `db:"request_body"`
|
||||
ResponseBody *string `db:"response_body"`
|
||||
CacheReadTokens int `db:"cache_read_tokens"`
|
||||
CacheWriteTokens int `db:"cache_write_tokens"`
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
ID string `db:"id"`
|
||||
DisplayName string `db:"display_name"`
|
||||
Enabled bool `db:"enabled"`
|
||||
BaseURL *string `db:"base_url"`
|
||||
APIKey *string `db:"api_key"`
|
||||
CreditBalance float64 `db:"credit_balance"`
|
||||
LowCreditThreshold float64 `db:"low_credit_threshold"`
|
||||
BillingMode *string `db:"billing_mode"`
|
||||
APIKeyEncrypted bool `db:"api_key_encrypted"`
|
||||
UpdatedAt time.Time `db:"updated_at"`
|
||||
}
|
||||
|
||||
type ModelConfig struct {
|
||||
ID string `db:"id"`
|
||||
ProviderID string `db:"provider_id"`
|
||||
DisplayName *string `db:"display_name"`
|
||||
Enabled bool `db:"enabled"`
|
||||
PromptCostPerM *float64 `db:"prompt_cost_per_m"`
|
||||
CompletionCostPerM *float64 `db:"completion_cost_per_m"`
|
||||
CacheReadCostPerM *float64 `db:"cache_read_cost_per_m"`
|
||||
CacheWriteCostPerM *float64 `db:"cache_write_cost_per_m"`
|
||||
Mapping *string `db:"mapping"`
|
||||
UpdatedAt time.Time `db:"updated_at"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int `db:"id" json:"id"`
|
||||
Username string `db:"username" json:"username"`
|
||||
PasswordHash string `db:"password_hash" json:"-"`
|
||||
DisplayName *string `db:"display_name" json:"display_name"`
|
||||
Role string `db:"role" json:"role"`
|
||||
MustChangePassword bool `db:"must_change_password" json:"must_change_password"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
type ClientToken struct {
|
||||
ID int `db:"id"`
|
||||
ClientID string `db:"client_id"`
|
||||
Token string `db:"token"`
|
||||
Name string `db:"name"`
|
||||
IsActive bool `db:"is_active"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
LastUsedAt *time.Time `db:"last_used_at"`
|
||||
}
|
||||
52
internal/middleware/auth.go
Normal file
52
internal/middleware/auth.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/db"
|
||||
"gophergate/internal/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func AuthMiddleware(database *db.DB) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if token == authHeader { // No "Bearer " prefix
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Try to resolve client from database
|
||||
var clientID string
|
||||
err := database.Get(&clientID, "UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = 1 RETURNING client_id", token)
|
||||
|
||||
if err == nil {
|
||||
c.Set("auth", models.AuthInfo{
|
||||
Token: token,
|
||||
ClientID: clientID,
|
||||
})
|
||||
} else {
|
||||
// Fallback to token-prefix derivation (matches Rust behavior)
|
||||
prefixLen := len(token)
|
||||
if prefixLen > 8 {
|
||||
prefixLen = 8
|
||||
}
|
||||
clientID = "client_" + token[:prefixLen]
|
||||
c.Set("auth", models.AuthInfo{
|
||||
Token: token,
|
||||
ClientID: clientID,
|
||||
})
|
||||
log.Printf("Token not found in DB, using fallback client ID: %s", clientID)
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
216
internal/models/models.go
Normal file
216
internal/models/models.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
// OpenAI-compatible Request/Response Structs
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *uint32 `json:"top_k,omitempty"`
|
||||
N *uint32 `json:"n,omitempty"`
|
||||
Stop json.RawMessage `json:"stop,omitempty"` // Can be string or array of strings
|
||||
MaxTokens *uint32 `json:"max_tokens,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"` // "system", "user", "assistant", "tool"
|
||||
Content interface{} `json:"content"`
|
||||
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
ToolCallID *string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type ContentPart struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageUrl *ImageUrl `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
type ImageUrl struct {
|
||||
URL string `json:"url"`
|
||||
Detail *string `json:"detail,omitempty"`
|
||||
}
|
||||
|
||||
// Tool-Calling Types
|
||||
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function FunctionDef `json:"function"`
|
||||
}
|
||||
|
||||
type FunctionDef struct {
|
||||
Name string `json:"name"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function FunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
type ToolCallDelta struct {
|
||||
Index uint32 `json:"index"`
|
||||
ID *string `json:"id,omitempty"`
|
||||
Type *string `json:"type,omitempty"`
|
||||
Function *FunctionCallDelta `json:"function,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionCallDelta struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
Arguments *string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAI-compatible Response Structs
|
||||
|
||||
type ChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type ChatChoice struct {
|
||||
Index uint32 `json:"index"`
|
||||
Message ChatMessage `json:"message"`
|
||||
FinishReason *string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens uint32 `json:"prompt_tokens"`
|
||||
CompletionTokens uint32 `json:"completion_tokens"`
|
||||
TotalTokens uint32 `json:"total_tokens"`
|
||||
ReasoningTokens *uint32 `json:"reasoning_tokens,omitempty"`
|
||||
CacheReadTokens *uint32 `json:"cache_read_tokens,omitempty"`
|
||||
CacheWriteTokens *uint32 `json:"cache_write_tokens,omitempty"`
|
||||
}
|
||||
|
||||
// Streaming Response Structs
|
||||
|
||||
type ChatCompletionStreamResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatStreamChoice `json:"choices"`
|
||||
Usage *Usage `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
type ChatStreamChoice struct {
|
||||
Index uint32 `json:"index"`
|
||||
Delta ChatStreamDelta `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type ChatStreamDelta struct {
|
||||
Role *string `json:"role,omitempty"`
|
||||
Content *string `json:"content,omitempty"`
|
||||
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ToolCallDelta `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
type StreamUsage struct {
|
||||
PromptTokens uint32 `json:"prompt_tokens"`
|
||||
CompletionTokens uint32 `json:"completion_tokens"`
|
||||
TotalTokens uint32 `json:"total_tokens"`
|
||||
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||
CacheReadTokens uint32 `json:"cache_read_tokens"`
|
||||
CacheWriteTokens uint32 `json:"cache_write_tokens"`
|
||||
}
|
||||
|
||||
// Unified Request Format (for internal use)
|
||||
|
||||
type UnifiedRequest struct {
|
||||
ClientID string
|
||||
Model string
|
||||
Messages []UnifiedMessage
|
||||
Temperature *float64
|
||||
TopP *float64
|
||||
TopK *uint32
|
||||
N *uint32
|
||||
Stop []string
|
||||
MaxTokens *uint32
|
||||
PresencePenalty *float64
|
||||
FrequencyPenalty *float64
|
||||
Stream bool
|
||||
HasImages bool
|
||||
Tools []Tool
|
||||
ToolChoice json.RawMessage
|
||||
}
|
||||
|
||||
type UnifiedMessage struct {
|
||||
Role string
|
||||
Content []UnifiedContentPart
|
||||
ReasoningContent *string
|
||||
ToolCalls []ToolCall
|
||||
Name *string
|
||||
ToolCallID *string
|
||||
}
|
||||
|
||||
type UnifiedContentPart struct {
|
||||
Type string
|
||||
Text string
|
||||
Image *ImageInput
|
||||
}
|
||||
|
||||
type ImageInput struct {
|
||||
Base64 string `json:"base64,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
MimeType string `json:"mime_type,omitempty"`
|
||||
}
|
||||
|
||||
func (i *ImageInput) ToBase64() (string, string, error) {
|
||||
if i.Base64 != "" {
|
||||
return i.Base64, i.MimeType, nil
|
||||
}
|
||||
|
||||
if i.URL != "" {
|
||||
client := resty.New()
|
||||
resp, err := client.R().Get(i.URL)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to fetch image: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return "", "", fmt.Errorf("failed to fetch image: HTTP %d", resp.StatusCode())
|
||||
}
|
||||
|
||||
mimeType := resp.Header().Get("Content-Type")
|
||||
if mimeType == "" {
|
||||
mimeType = "image/jpeg"
|
||||
}
|
||||
|
||||
encoded := base64.StdEncoding.EncodeToString(resp.Body())
|
||||
return encoded, mimeType, nil
|
||||
}
|
||||
|
||||
return "", "", fmt.Errorf("empty image input")
|
||||
}
|
||||
|
||||
// AuthInfo for context
|
||||
type AuthInfo struct {
|
||||
Token string
|
||||
ClientID string
|
||||
}
|
||||
69
internal/models/registry.go
Normal file
69
internal/models/registry.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package models
|
||||
|
||||
import "strings"
|
||||
|
||||
type ModelRegistry struct {
|
||||
Providers map[string]ProviderInfo `json:"-"`
|
||||
}
|
||||
|
||||
type ProviderInfo struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Models map[string]ModelMetadata `json:"models"`
|
||||
}
|
||||
|
||||
type ModelMetadata struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Cost *ModelCost `json:"cost,omitempty"`
|
||||
Limit *ModelLimit `json:"limit,omitempty"`
|
||||
Modalities *ModelModalities `json:"modalities,omitempty"`
|
||||
ToolCall *bool `json:"tool_call,omitempty"`
|
||||
Reasoning *bool `json:"reasoning,omitempty"`
|
||||
}
|
||||
|
||||
type ModelCost struct {
|
||||
Input float64 `json:"input"`
|
||||
Output float64 `json:"output"`
|
||||
CacheRead *float64 `json:"cache_read,omitempty"`
|
||||
CacheWrite *float64 `json:"cache_write,omitempty"`
|
||||
}
|
||||
|
||||
type ModelLimit struct {
|
||||
Context uint32 `json:"context"`
|
||||
Output uint32 `json:"output"`
|
||||
}
|
||||
|
||||
type ModelModalities struct {
|
||||
Input []string `json:"input"`
|
||||
Output []string `json:"output"`
|
||||
}
|
||||
|
||||
func (r *ModelRegistry) FindModel(modelID string) *ModelMetadata {
|
||||
// First try exact match in models map
|
||||
for _, provider := range r.Providers {
|
||||
if model, ok := provider.Models[modelID]; ok {
|
||||
return &model
|
||||
}
|
||||
}
|
||||
|
||||
// Try searching by ID in metadata
|
||||
for _, provider := range r.Providers {
|
||||
for _, model := range provider.Models {
|
||||
if model.ID == modelID {
|
||||
return &model
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try fuzzy matching (e.g. gpt-4o-2024-05-13 matching gpt-4o)
|
||||
for _, provider := range r.Providers {
|
||||
for id, model := range provider.Models {
|
||||
if strings.HasPrefix(modelID, id) {
|
||||
return &model
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
220
internal/providers/deepseek.go
Normal file
220
internal/providers/deepseek.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type DeepSeekProvider struct {
|
||||
client *resty.Client
|
||||
config config.DeepSeekConfig
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewDeepSeekProvider(cfg config.DeepSeekConfig, apiKey string) *DeepSeekProvider {
|
||||
return &DeepSeekProvider{
|
||||
client: resty.New(),
|
||||
config: cfg,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) Name() string {
|
||||
return "deepseek"
|
||||
}
|
||||
|
||||
type deepSeekUsage struct {
|
||||
PromptTokens uint32 `json:"prompt_tokens"`
|
||||
CompletionTokens uint32 `json:"completion_tokens"`
|
||||
TotalTokens uint32 `json:"total_tokens"`
|
||||
PromptCacheHitTokens uint32 `json:"prompt_cache_hit_tokens"`
|
||||
PromptCacheMissTokens uint32 `json:"prompt_cache_miss_tokens"`
|
||||
CompletionTokensDetails *struct {
|
||||
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||
} `json:"completion_tokens_details"`
|
||||
}
|
||||
|
||||
func (u *deepSeekUsage) ToUnified() *models.Usage {
|
||||
usage := &models.Usage{
|
||||
PromptTokens: u.PromptTokens,
|
||||
CompletionTokens: u.CompletionTokens,
|
||||
TotalTokens: u.TotalTokens,
|
||||
}
|
||||
if u.PromptCacheHitTokens > 0 {
|
||||
usage.CacheReadTokens = &u.PromptCacheHitTokens
|
||||
}
|
||||
if u.PromptCacheMissTokens > 0 {
|
||||
usage.CacheWriteTokens = &u.PromptCacheMissTokens
|
||||
}
|
||||
if u.CompletionTokensDetails != nil && u.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||
usage.ReasoningTokens = &u.CompletionTokensDetails.ReasoningTokens
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||
|
||||
// Sanitize for deepseek-reasoner
|
||||
if req.Model == "deepseek-reasoner" {
|
||||
delete(body, "temperature")
|
||||
delete(body, "top_p")
|
||||
delete(body, "presence_penalty")
|
||||
delete(body, "frequency_penalty")
|
||||
|
||||
if msgs, ok := body["messages"].([]interface{}); ok {
|
||||
for _, m := range msgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
if msg["role"] == "assistant" {
|
||||
if msg["reasoning_content"] == nil {
|
||||
msg["reasoning_content"] = " "
|
||||
}
|
||||
if msg["content"] == nil || msg["content"] == "" {
|
||||
msg["content"] = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
result, err := ParseOpenAIResponse(respJSON, req.Model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fix usage for DeepSeek specifically if details were missing in ParseOpenAIResponse
|
||||
if usageData, ok := respJSON["usage"]; ok {
|
||||
var dUsage deepSeekUsage
|
||||
usageBytes, _ := json.Marshal(usageData)
|
||||
if err := json.Unmarshal(usageBytes, &dUsage); err == nil {
|
||||
result.Usage = dUsage.ToUnified()
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||
|
||||
// Sanitize for deepseek-reasoner
|
||||
if req.Model == "deepseek-reasoner" {
|
||||
delete(body, "temperature")
|
||||
delete(body, "top_p")
|
||||
delete(body, "presence_penalty")
|
||||
delete(body, "frequency_penalty")
|
||||
|
||||
if msgs, ok := body["messages"].([]interface{}); ok {
|
||||
for _, m := range msgs {
|
||||
if msg, ok := m.(map[string]interface{}); ok {
|
||||
if msg["role"] == "assistant" {
|
||||
if msg["reasoning_content"] == nil {
|
||||
msg["reasoning_content"] = " "
|
||||
}
|
||||
if msg["content"] == nil || msg["content"] == "" {
|
||||
msg["content"] = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
// Custom scanner loop to handle DeepSeek specific usage in chunks
|
||||
err := StreamDeepSeek(resp.RawBody(), ch)
|
||||
if err != nil {
|
||||
fmt.Printf("DeepSeek Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func StreamDeepSeek(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse) error {
|
||||
defer ctx.Close()
|
||||
scanner := bufio.NewScanner(ctx)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
var chunk models.ChatCompletionStreamResponse
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Fix DeepSeek specific usage in stream
|
||||
var rawChunk struct {
|
||||
Usage *deepSeekUsage `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil {
|
||||
chunk.Usage = rawChunk.Usage.ToUnified()
|
||||
}
|
||||
|
||||
ch <- &chunk
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
254
internal/providers/gemini.go
Normal file
254
internal/providers/gemini.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type GeminiProvider struct {
|
||||
client *resty.Client
|
||||
config config.GeminiConfig
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewGeminiProvider(cfg config.GeminiConfig, apiKey string) *GeminiProvider {
|
||||
return &GeminiProvider{
|
||||
client: resty.New(),
|
||||
config: cfg,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) Name() string {
|
||||
return "gemini"
|
||||
}
|
||||
|
||||
type GeminiRequest struct {
|
||||
Contents []GeminiContent `json:"contents"`
|
||||
}
|
||||
|
||||
type GeminiContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type GeminiFunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Args json.RawMessage `json:"args"`
|
||||
}
|
||||
|
||||
type GeminiFunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response json.RawMessage `json:"response"`
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
// Gemini mapping
|
||||
var contents []GeminiContent
|
||||
for _, msg := range req.Messages {
|
||||
role := "user"
|
||||
if msg.Role == "assistant" {
|
||||
role = "model"
|
||||
} else if msg.Role == "tool" {
|
||||
role = "user" // Tool results are user-side in Gemini
|
||||
}
|
||||
|
||||
var parts []GeminiPart
|
||||
|
||||
// Handle tool responses
|
||||
if msg.Role == "tool" {
|
||||
text := ""
|
||||
if len(msg.Content) > 0 {
|
||||
text = msg.Content[0].Text
|
||||
}
|
||||
|
||||
// Gemini expects functionResponse to be an object
|
||||
name := "unknown_function"
|
||||
if msg.Name != nil {
|
||||
name = *msg.Name
|
||||
}
|
||||
|
||||
parts = append(parts, GeminiPart{
|
||||
FunctionResponse: &GeminiFunctionResponse{
|
||||
Name: name,
|
||||
Response: json.RawMessage(text),
|
||||
},
|
||||
})
|
||||
} else {
|
||||
for _, cp := range msg.Content {
|
||||
if cp.Type == "text" {
|
||||
parts = append(parts, GeminiPart{Text: cp.Text})
|
||||
} else if cp.Image != nil {
|
||||
base64Data, mimeType, _ := cp.Image.ToBase64()
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: base64Data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Handle assistant tool calls
|
||||
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
parts = append(parts, GeminiPart{
|
||||
FunctionCall: &GeminiFunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Args: json.RawMessage(tc.Function.Arguments),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
contents = append(contents, GeminiContent{
|
||||
Role: role,
|
||||
Parts: parts,
|
||||
})
|
||||
}
|
||||
|
||||
body := GeminiRequest{
|
||||
Contents: contents,
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetBody(body).
|
||||
Post(url)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
// Parse Gemini response and convert to OpenAI format
|
||||
var geminiResp struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata struct {
|
||||
PromptTokenCount uint32 `json:"promptTokenCount"`
|
||||
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
||||
TotalTokenCount uint32 `json:"totalTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(resp.Body(), &geminiResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(geminiResp.Candidates) == 0 {
|
||||
return nil, fmt.Errorf("no candidates in Gemini response")
|
||||
}
|
||||
|
||||
content := ""
|
||||
for _, p := range geminiResp.Candidates[0].Content.Parts {
|
||||
content += p.Text
|
||||
}
|
||||
|
||||
openAIResp := &models.ChatCompletionResponse{
|
||||
ID: "gemini-" + req.Model,
|
||||
Object: "chat.completion",
|
||||
Created: 0, // Should be current timestamp
|
||||
Model: req.Model,
|
||||
Choices: []models.ChatChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Message: models.ChatMessage{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: &geminiResp.Candidates[0].FinishReason,
|
||||
},
|
||||
},
|
||||
Usage: &models.Usage{
|
||||
PromptTokens: geminiResp.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResp.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiResp.UsageMetadata.TotalTokenCount,
|
||||
},
|
||||
}
|
||||
|
||||
return openAIResp, nil
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
// Simplified Gemini mapping
|
||||
var contents []GeminiContent
|
||||
for _, msg := range req.Messages {
|
||||
role := "user"
|
||||
if msg.Role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
|
||||
var parts []GeminiPart
|
||||
for _, p := range msg.Content {
|
||||
parts = append(parts, GeminiPart{Text: p.Text})
|
||||
}
|
||||
|
||||
contents = append(contents, GeminiContent{
|
||||
Role: role,
|
||||
Parts: parts,
|
||||
})
|
||||
}
|
||||
|
||||
body := GeminiRequest{
|
||||
Contents: contents,
|
||||
}
|
||||
|
||||
// Use streamGenerateContent for streaming
|
||||
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(url)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamGemini(resp.RawBody(), ch, req.Model)
|
||||
if err != nil {
|
||||
fmt.Printf("Gemini Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
95
internal/providers/grok.go
Normal file
95
internal/providers/grok.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type GrokProvider struct {
|
||||
client *resty.Client
|
||||
config config.GrokConfig
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewGrokProvider(cfg config.GrokConfig, apiKey string) *GrokProvider {
|
||||
return &GrokProvider{
|
||||
client: resty.New(),
|
||||
config: cfg,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GrokProvider) Name() string {
|
||||
return "grok"
|
||||
}
|
||||
|
||||
func (p *GrokProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return ParseOpenAIResponse(respJSON, req.Model)
|
||||
}
|
||||
|
||||
func (p *GrokProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamOpenAI(resp.RawBody(), ch)
|
||||
if err != nil {
|
||||
fmt.Printf("Grok Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
318
internal/providers/helpers.go
Normal file
318
internal/providers/helpers.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
// MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images.
|
||||
func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) {
|
||||
var result []interface{}
|
||||
for _, m := range messages {
|
||||
if m.Role == "tool" {
|
||||
text := ""
|
||||
if len(m.Content) > 0 {
|
||||
text = m.Content[0].Text
|
||||
}
|
||||
msg := map[string]interface{}{
|
||||
"role": "tool",
|
||||
"content": text,
|
||||
}
|
||||
if m.ToolCallID != nil {
|
||||
id := *m.ToolCallID
|
||||
if len(id) > 40 {
|
||||
id = id[:40]
|
||||
}
|
||||
msg["tool_call_id"] = id
|
||||
}
|
||||
if m.Name != nil {
|
||||
msg["name"] = *m.Name
|
||||
}
|
||||
result = append(result, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
var parts []interface{}
|
||||
for _, p := range m.Content {
|
||||
if p.Type == "text" {
|
||||
parts = append(parts, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": p.Text,
|
||||
})
|
||||
} else if p.Image != nil {
|
||||
base64Data, mimeType, err := p.Image.ToBase64()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert image to base64: %w", err)
|
||||
}
|
||||
parts = append(parts, map[string]interface{}{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]interface{}{
|
||||
"url": fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var finalContent interface{}
|
||||
if len(parts) == 1 {
|
||||
if p, ok := parts[0].(map[string]interface{}); ok && p["type"] == "text" {
|
||||
finalContent = p["text"]
|
||||
} else {
|
||||
finalContent = parts
|
||||
}
|
||||
} else {
|
||||
finalContent = parts
|
||||
}
|
||||
|
||||
msg := map[string]interface{}{
|
||||
"role": m.Role,
|
||||
"content": finalContent,
|
||||
}
|
||||
|
||||
if m.ReasoningContent != nil {
|
||||
msg["reasoning_content"] = *m.ReasoningContent
|
||||
}
|
||||
|
||||
if len(m.ToolCalls) > 0 {
|
||||
sanitizedCalls := make([]models.ToolCall, len(m.ToolCalls))
|
||||
copy(sanitizedCalls, m.ToolCalls)
|
||||
for i := range sanitizedCalls {
|
||||
if len(sanitizedCalls[i].ID) > 40 {
|
||||
sanitizedCalls[i].ID = sanitizedCalls[i].ID[:40]
|
||||
}
|
||||
}
|
||||
msg["tool_calls"] = sanitizedCalls
|
||||
if len(parts) == 0 {
|
||||
msg["content"] = ""
|
||||
}
|
||||
}
|
||||
|
||||
if m.Name != nil {
|
||||
msg["name"] = *m.Name
|
||||
}
|
||||
|
||||
result = append(result, msg)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{}, stream bool) map[string]interface{} {
|
||||
body := map[string]interface{}{
|
||||
"model": request.Model,
|
||||
"messages": messagesJSON,
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
if stream {
|
||||
body["stream_options"] = map[string]interface{}{
|
||||
"include_usage": true,
|
||||
}
|
||||
}
|
||||
|
||||
if request.Temperature != nil {
|
||||
body["temperature"] = *request.Temperature
|
||||
}
|
||||
if request.MaxTokens != nil {
|
||||
body["max_tokens"] = *request.MaxTokens
|
||||
}
|
||||
if len(request.Tools) > 0 {
|
||||
body["tools"] = request.Tools
|
||||
}
|
||||
if request.ToolChoice != nil {
|
||||
var toolChoice interface{}
|
||||
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
|
||||
body["tool_choice"] = toolChoice
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
type openAIUsage struct {
|
||||
PromptTokens uint32 `json:"prompt_tokens"`
|
||||
CompletionTokens uint32 `json:"completion_tokens"`
|
||||
TotalTokens uint32 `json:"total_tokens"`
|
||||
PromptTokensDetails *struct {
|
||||
CachedTokens uint32 `json:"cached_tokens"`
|
||||
} `json:"prompt_tokens_details"`
|
||||
CompletionTokensDetails *struct {
|
||||
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||
} `json:"completion_tokens_details"`
|
||||
}
|
||||
|
||||
func (u *openAIUsage) ToUnified() *models.Usage {
|
||||
usage := &models.Usage{
|
||||
PromptTokens: u.PromptTokens,
|
||||
CompletionTokens: u.CompletionTokens,
|
||||
TotalTokens: u.TotalTokens,
|
||||
}
|
||||
if u.PromptTokensDetails != nil && u.PromptTokensDetails.CachedTokens > 0 {
|
||||
usage.CacheReadTokens = &u.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
if u.CompletionTokensDetails != nil && u.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||
usage.ReasoningTokens = &u.CompletionTokensDetails.ReasoningTokens
|
||||
}
|
||||
return usage
|
||||
}
|
||||
|
||||
func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models.ChatCompletionResponse, error) {
|
||||
data, err := json.Marshal(respJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp models.ChatCompletionResponse
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Manually fix usage because ChatCompletionResponse uses the unified Usage struct
|
||||
// but the provider might have returned more details.
|
||||
if usageData, ok := respJSON["usage"]; ok {
|
||||
var oUsage openAIUsage
|
||||
usageBytes, _ := json.Marshal(usageData)
|
||||
if err := json.Unmarshal(usageBytes, &oUsage); err == nil {
|
||||
resp.Usage = oUsage.ToUnified()
|
||||
}
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// Streaming support
|
||||
|
||||
func ParseOpenAIStreamChunk(line string) (*models.ChatCompletionStreamResponse, bool, error) {
|
||||
if line == "" {
|
||||
return nil, false, nil
|
||||
}
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
if data == "[DONE]" {
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
var chunk models.ChatCompletionStreamResponse
|
||||
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||
return nil, false, fmt.Errorf("failed to unmarshal stream chunk: %w", err)
|
||||
}
|
||||
|
||||
// Handle specialized usage in stream chunks
|
||||
var rawChunk struct {
|
||||
Usage *openAIUsage `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil {
|
||||
chunk.Usage = rawChunk.Usage.ToUnified()
|
||||
}
|
||||
|
||||
return &chunk, false, nil
|
||||
}
|
||||
|
||||
func StreamOpenAI(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse) error {
|
||||
defer ctx.Close()
|
||||
scanner := bufio.NewScanner(ctx)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
chunk, done, err := ParseOpenAIStreamChunk(line)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
if chunk != nil {
|
||||
ch <- chunk
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error {
|
||||
defer ctx.Close()
|
||||
|
||||
dec := json.NewDecoder(ctx)
|
||||
|
||||
t, err := dec.Token()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if delim, ok := t.(json.Delim); ok && delim == '[' {
|
||||
for dec.More() {
|
||||
var geminiChunk struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought string `json:"thought,omitempty"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata struct {
|
||||
PromptTokenCount uint32 `json:"promptTokenCount"`
|
||||
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
||||
TotalTokenCount uint32 `json:"totalTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
if err := dec.Decode(&geminiChunk); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(geminiChunk.Candidates) > 0 || geminiChunk.UsageMetadata.TotalTokenCount > 0 {
|
||||
content := ""
|
||||
var reasoning *string
|
||||
if len(geminiChunk.Candidates) > 0 {
|
||||
for _, p := range geminiChunk.Candidates[0].Content.Parts {
|
||||
if p.Text != "" {
|
||||
content += p.Text
|
||||
}
|
||||
if p.Thought != "" {
|
||||
if reasoning == nil {
|
||||
reasoning = new(string)
|
||||
}
|
||||
*reasoning += p.Thought
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var finishReason *string
|
||||
if len(geminiChunk.Candidates) > 0 {
|
||||
fr := strings.ToLower(geminiChunk.Candidates[0].FinishReason)
|
||||
finishReason = &fr
|
||||
}
|
||||
|
||||
ch <- &models.ChatCompletionStreamResponse{
|
||||
ID: "gemini-stream",
|
||||
Object: "chat.completion.chunk",
|
||||
Created: 0,
|
||||
Model: model,
|
||||
Choices: []models.ChatStreamChoice{
|
||||
{
|
||||
Index: 0,
|
||||
Delta: models.ChatStreamDelta{
|
||||
Content: &content,
|
||||
ReasoningContent: reasoning,
|
||||
},
|
||||
FinishReason: finishReason,
|
||||
},
|
||||
},
|
||||
Usage: &models.Usage{
|
||||
PromptTokens: geminiChunk.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiChunk.UsageMetadata.CandidatesTokenCount,
|
||||
TotalTokens: geminiChunk.UsageMetadata.TotalTokenCount,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
114
internal/providers/moonshot.go
Normal file
114
internal/providers/moonshot.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type MoonshotProvider struct {
|
||||
client *resty.Client
|
||||
config config.MoonshotConfig
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewMoonshotProvider(cfg config.MoonshotConfig, apiKey string) *MoonshotProvider {
|
||||
return &MoonshotProvider{
|
||||
client: resty.New(),
|
||||
config: cfg,
|
||||
apiKey: strings.TrimSpace(apiKey),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *MoonshotProvider) Name() string {
|
||||
return "moonshot"
|
||||
}
|
||||
|
||||
func (p *MoonshotProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||
if strings.Contains(strings.ToLower(req.Model), "kimi-k2.5") {
|
||||
if maxTokens, ok := body["max_tokens"]; ok {
|
||||
delete(body, "max_tokens")
|
||||
body["max_completion_tokens"] = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(p.config.BaseURL, "/")
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("Accept", "application/json").
|
||||
SetBody(body).
|
||||
Post(fmt.Sprintf("%s/chat/completions", baseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return ParseOpenAIResponse(respJSON, req.Model)
|
||||
}
|
||||
|
||||
func (p *MoonshotProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||
if strings.Contains(strings.ToLower(req.Model), "kimi-k2.5") {
|
||||
if maxTokens, ok := body["max_tokens"]; ok {
|
||||
delete(body, "max_tokens")
|
||||
body["max_completion_tokens"] = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := strings.TrimRight(p.config.BaseURL, "/")
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("Accept", "text/event-stream").
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(fmt.Sprintf("%s/chat/completions", baseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
if err := StreamOpenAI(resp.RawBody(), ch); err != nil {
|
||||
fmt.Printf("Moonshot Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
113
internal/providers/openai.go
Normal file
113
internal/providers/openai.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
type OpenAIProvider struct {
|
||||
client *resty.Client
|
||||
config config.OpenAIConfig
|
||||
apiKey string
|
||||
}
|
||||
|
||||
func NewOpenAIProvider(cfg config.OpenAIConfig, apiKey string) *OpenAIProvider {
|
||||
return &OpenAIProvider{
|
||||
client: resty.New(),
|
||||
config: cfg,
|
||||
apiKey: apiKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) Name() string {
|
||||
return "openai"
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||
|
||||
// Transition: Newer models require max_completion_tokens
|
||||
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
|
||||
if maxTokens, ok := body["max_tokens"]; ok {
|
||||
delete(body, "max_tokens")
|
||||
body["max_completion_tokens"] = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
var respJSON map[string]interface{}
|
||||
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return ParseOpenAIResponse(respJSON, req.Model)
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||
}
|
||||
|
||||
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||
|
||||
// Transition: Newer models require max_completion_tokens
|
||||
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
|
||||
if maxTokens, ok := body["max_tokens"]; ok {
|
||||
delete(body, "max_tokens")
|
||||
body["max_completion_tokens"] = maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := p.client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||
SetBody(body).
|
||||
SetDoNotParseResponse(true).
|
||||
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
|
||||
}
|
||||
|
||||
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
err := StreamOpenAI(resp.RawBody(), ch)
|
||||
if err != nil {
|
||||
// In a real app, you might want to send an error chunk or log it
|
||||
fmt.Printf("Stream error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
13
internal/providers/provider.go
Normal file
13
internal/providers/provider.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gophergate/internal/models"
|
||||
)
|
||||
|
||||
type Provider interface {
|
||||
Name() string
|
||||
ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error)
|
||||
ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error)
|
||||
}
|
||||
1393
internal/server/dashboard.go
Normal file
1393
internal/server/dashboard.go
Normal file
File diff suppressed because it is too large
Load Diff
113
internal/server/logging.go
Normal file
113
internal/server/logging.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"gophergate/internal/db"
|
||||
)
|
||||
|
||||
type RequestLog struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
ClientID string `json:"client_id"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
PromptTokens uint32 `json:"prompt_tokens"`
|
||||
CompletionTokens uint32 `json:"completion_tokens"`
|
||||
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||
TotalTokens uint32 `json:"total_tokens"`
|
||||
CacheReadTokens uint32 `json:"cache_read_tokens"`
|
||||
CacheWriteTokens uint32 `json:"cache_write_tokens"`
|
||||
Cost float64 `json:"cost"`
|
||||
HasImages bool `json:"has_images"`
|
||||
Status string `json:"status"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
DurationMS int64 `json:"duration_ms"`
|
||||
}
|
||||
|
||||
type RequestLogger struct {
|
||||
database *db.DB
|
||||
hub *Hub
|
||||
logChan chan RequestLog
|
||||
}
|
||||
|
||||
func NewRequestLogger(database *db.DB, hub *Hub) *RequestLogger {
|
||||
return &RequestLogger{
|
||||
database: database,
|
||||
hub: hub,
|
||||
logChan: make(chan RequestLog, 100),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *RequestLogger) Start() {
|
||||
go func() {
|
||||
for entry := range l.logChan {
|
||||
l.processLog(entry)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (l *RequestLogger) LogRequest(entry RequestLog) {
|
||||
select {
|
||||
case l.logChan <- entry:
|
||||
default:
|
||||
log.Println("Request log channel full, dropping log entry")
|
||||
}
|
||||
}
|
||||
|
||||
func (l *RequestLogger) processLog(entry RequestLog) {
|
||||
// Broadcast to dashboard
|
||||
l.hub.broadcast <- map[string]interface{}{
|
||||
"type": "request",
|
||||
"channel": "requests",
|
||||
"payload": entry,
|
||||
}
|
||||
|
||||
// Insert into DB
|
||||
tx, err := l.database.Begin()
|
||||
if err != nil {
|
||||
log.Printf("Failed to begin transaction for logging: %v", err)
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Ensure client exists
|
||||
_, _ = tx.Exec("INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')",
|
||||
entry.ClientID, entry.ClientID)
|
||||
|
||||
// Insert log
|
||||
_, err = tx.Exec(`
|
||||
INSERT INTO llm_requests
|
||||
(timestamp, client_id, provider, model, prompt_tokens, completion_tokens, reasoning_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status, error_message, duration_ms)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`, entry.Timestamp, entry.ClientID, entry.Provider, entry.Model,
|
||||
entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.TotalTokens,
|
||||
entry.CacheReadTokens, entry.CacheWriteTokens, entry.Cost, entry.HasImages,
|
||||
entry.Status, entry.ErrorMessage, entry.DurationMS)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Failed to insert request log: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update client stats
|
||||
_, _ = tx.Exec(`
|
||||
UPDATE clients SET
|
||||
total_requests = total_requests + 1,
|
||||
total_tokens = total_tokens + ?,
|
||||
total_cost = total_cost + ?,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE client_id = ?
|
||||
`, entry.TotalTokens, entry.Cost, entry.ClientID)
|
||||
|
||||
// Update provider balance
|
||||
if entry.Cost > 0 {
|
||||
_, _ = tx.Exec("UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')",
|
||||
entry.Cost, entry.Provider)
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
log.Printf("Failed to commit logging transaction: %v", err)
|
||||
}
|
||||
}
|
||||
494
internal/server/server.go
Normal file
494
internal/server/server.go
Normal file
@@ -0,0 +1,494 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gophergate/internal/config"
|
||||
"gophergate/internal/db"
|
||||
"gophergate/internal/middleware"
|
||||
"gophergate/internal/models"
|
||||
"gophergate/internal/providers"
|
||||
"gophergate/internal/utils"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
router *gin.Engine
|
||||
cfg *config.Config
|
||||
database *db.DB
|
||||
providers map[string]providers.Provider
|
||||
sessions *SessionManager
|
||||
hub *Hub
|
||||
logger *RequestLogger
|
||||
registry *models.ModelRegistry
|
||||
}
|
||||
|
||||
func NewServer(cfg *config.Config, database *db.DB) *Server {
|
||||
router := gin.Default()
|
||||
hub := NewHub()
|
||||
|
||||
s := &Server{
|
||||
router: router,
|
||||
cfg: cfg,
|
||||
database: database,
|
||||
providers: make(map[string]providers.Provider),
|
||||
sessions: NewSessionManager(cfg.KeyBytes, 24*time.Hour),
|
||||
hub: hub,
|
||||
logger: NewRequestLogger(database, hub),
|
||||
registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)},
|
||||
}
|
||||
|
||||
// Fetch registry in background
|
||||
go func() {
|
||||
registry, err := utils.FetchRegistry()
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Failed to fetch initial model registry: %v\n", err)
|
||||
} else {
|
||||
s.registry = registry
|
||||
}
|
||||
}()
|
||||
|
||||
// Initialize providers from DB and Config
|
||||
if err := s.RefreshProviders(); err != nil {
|
||||
fmt.Printf("Warning: Failed to initial refresh providers: %v\n", err)
|
||||
}
|
||||
|
||||
s.setupRoutes()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Server) RefreshProviders() error {
|
||||
var dbConfigs []db.ProviderConfig
|
||||
err := s.database.Select(&dbConfigs, "SELECT * FROM provider_configs")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch provider configs from db: %w", err)
|
||||
}
|
||||
|
||||
dbMap := make(map[string]db.ProviderConfig)
|
||||
for _, cfg := range dbConfigs {
|
||||
dbMap[cfg.ID] = cfg
|
||||
}
|
||||
|
||||
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok"}
|
||||
for _, id := range providerIDs {
|
||||
// Default values from config
|
||||
enabled := false
|
||||
baseURL := ""
|
||||
apiKey := ""
|
||||
|
||||
switch id {
|
||||
case "openai":
|
||||
enabled = s.cfg.Providers.OpenAI.Enabled
|
||||
baseURL = s.cfg.Providers.OpenAI.BaseURL
|
||||
apiKey, _ = s.cfg.GetAPIKey("openai")
|
||||
case "gemini":
|
||||
enabled = s.cfg.Providers.Gemini.Enabled
|
||||
baseURL = s.cfg.Providers.Gemini.BaseURL
|
||||
apiKey, _ = s.cfg.GetAPIKey("gemini")
|
||||
case "deepseek":
|
||||
enabled = s.cfg.Providers.DeepSeek.Enabled
|
||||
baseURL = s.cfg.Providers.DeepSeek.BaseURL
|
||||
apiKey, _ = s.cfg.GetAPIKey("deepseek")
|
||||
case "moonshot":
|
||||
enabled = s.cfg.Providers.Moonshot.Enabled
|
||||
baseURL = s.cfg.Providers.Moonshot.BaseURL
|
||||
apiKey, _ = s.cfg.GetAPIKey("moonshot")
|
||||
case "grok":
|
||||
enabled = s.cfg.Providers.Grok.Enabled
|
||||
baseURL = s.cfg.Providers.Grok.BaseURL
|
||||
apiKey, _ = s.cfg.GetAPIKey("grok")
|
||||
}
|
||||
|
||||
// Overrides from DB
|
||||
if dbCfg, ok := dbMap[id]; ok {
|
||||
enabled = dbCfg.Enabled
|
||||
if dbCfg.BaseURL != nil && *dbCfg.BaseURL != "" {
|
||||
baseURL = *dbCfg.BaseURL
|
||||
}
|
||||
if dbCfg.APIKey != nil && *dbCfg.APIKey != "" {
|
||||
key := *dbCfg.APIKey
|
||||
if dbCfg.APIKeyEncrypted {
|
||||
decrypted, err := utils.Decrypt(key, s.cfg.KeyBytes)
|
||||
if err == nil {
|
||||
key = decrypted
|
||||
} else {
|
||||
fmt.Printf("Warning: Failed to decrypt API key for %s: %v\n", id, err)
|
||||
}
|
||||
}
|
||||
apiKey = key
|
||||
}
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
delete(s.providers, id)
|
||||
continue
|
||||
}
|
||||
|
||||
// Initialize provider
|
||||
switch id {
|
||||
case "openai":
|
||||
cfg := s.cfg.Providers.OpenAI
|
||||
cfg.BaseURL = baseURL
|
||||
s.providers["openai"] = providers.NewOpenAIProvider(cfg, apiKey)
|
||||
case "gemini":
|
||||
cfg := s.cfg.Providers.Gemini
|
||||
cfg.BaseURL = baseURL
|
||||
s.providers["gemini"] = providers.NewGeminiProvider(cfg, apiKey)
|
||||
case "deepseek":
|
||||
cfg := s.cfg.Providers.DeepSeek
|
||||
cfg.BaseURL = baseURL
|
||||
s.providers["deepseek"] = providers.NewDeepSeekProvider(cfg, apiKey)
|
||||
case "moonshot":
|
||||
cfg := s.cfg.Providers.Moonshot
|
||||
cfg.BaseURL = baseURL
|
||||
s.providers["moonshot"] = providers.NewMoonshotProvider(cfg, apiKey)
|
||||
case "grok":
|
||||
cfg := s.cfg.Providers.Grok
|
||||
cfg.BaseURL = baseURL
|
||||
s.providers["grok"] = providers.NewGrokProvider(cfg, apiKey)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) setupRoutes() {
|
||||
s.router.Use(middleware.AuthMiddleware(s.database))
|
||||
|
||||
// Static files
|
||||
s.router.StaticFile("/", "./static/index.html")
|
||||
s.router.StaticFile("/favicon.ico", "./static/favicon.ico")
|
||||
s.router.Static("/css", "./static/css")
|
||||
s.router.Static("/js", "./static/js")
|
||||
s.router.Static("/img", "./static/img")
|
||||
|
||||
// WebSocket
|
||||
s.router.GET("/ws", s.handleWebSocket)
|
||||
|
||||
// API V1 (External LLM Access) - Secured with AuthMiddleware
|
||||
v1 := s.router.Group("/v1")
|
||||
v1.Use(middleware.AuthMiddleware(s.database))
|
||||
{
|
||||
v1.POST("/chat/completions", s.handleChatCompletions)
|
||||
v1.GET("/models", s.handleListModels)
|
||||
}
|
||||
|
||||
// Dashboard API Group
|
||||
api := s.router.Group("/api")
|
||||
{
|
||||
api.POST("/auth/login", s.handleLogin)
|
||||
api.GET("/auth/status", s.handleAuthStatus)
|
||||
api.POST("/auth/logout", s.handleLogout)
|
||||
api.POST("/auth/change-password", s.handleChangePassword)
|
||||
|
||||
// Protected dashboard routes (need admin session)
|
||||
admin := api.Group("/")
|
||||
admin.Use(s.adminAuthMiddleware())
|
||||
{
|
||||
admin.GET("/usage/summary", s.handleUsageSummary)
|
||||
admin.GET("/usage/time-series", s.handleTimeSeries)
|
||||
admin.GET("/usage/providers", s.handleProvidersUsage)
|
||||
admin.GET("/usage/clients", s.handleClientsUsage)
|
||||
admin.GET("/usage/detailed", s.handleDetailedUsage)
|
||||
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
|
||||
|
||||
admin.GET("/clients", s.handleGetClients)
|
||||
admin.POST("/clients", s.handleCreateClient)
|
||||
admin.GET("/clients/:id", s.handleGetClient)
|
||||
admin.PUT("/clients/:id", s.handleUpdateClient)
|
||||
admin.DELETE("/clients/:id", s.handleDeleteClient)
|
||||
|
||||
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
|
||||
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
|
||||
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
|
||||
|
||||
admin.GET("/providers", s.handleGetProviders)
|
||||
admin.PUT("/providers/:name", s.handleUpdateProvider)
|
||||
admin.POST("/providers/:name/test", s.handleTestProvider)
|
||||
|
||||
admin.GET("/models", s.handleGetModels)
|
||||
admin.PUT("/models/:id", s.handleUpdateModel)
|
||||
|
||||
admin.GET("/users", s.handleGetUsers)
|
||||
admin.POST("/users", s.handleCreateUser)
|
||||
admin.PUT("/users/:id", s.handleUpdateUser)
|
||||
admin.DELETE("/users/:id", s.handleDeleteUser)
|
||||
|
||||
admin.GET("/system/health", s.handleSystemHealth)
|
||||
admin.GET("/system/metrics", s.handleSystemMetrics)
|
||||
admin.GET("/system/settings", s.handleGetSettings)
|
||||
admin.POST("/system/backup", s.handleCreateBackup)
|
||||
admin.GET("/system/logs", s.handleGetLogs)
|
||||
}
|
||||
}
|
||||
|
||||
s.router.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleListModels(c *gin.Context) {
|
||||
type OpenAIModel struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
OwnedBy string `json:"owned_by"`
|
||||
}
|
||||
|
||||
var data []OpenAIModel
|
||||
allowedProviders := map[string]bool{
|
||||
"openai": true,
|
||||
"google": true, // Models from models.dev use 'google' ID for Gemini
|
||||
"deepseek": true,
|
||||
"moonshot": true,
|
||||
"xai": true, // Models from models.dev use 'xai' ID for Grok
|
||||
}
|
||||
|
||||
if s.registry != nil {
|
||||
for pID, pInfo := range s.registry.Providers {
|
||||
if !allowedProviders[pID] {
|
||||
continue
|
||||
}
|
||||
for mID := range pInfo.Models {
|
||||
data = append(data, OpenAIModel{
|
||||
ID: mID,
|
||||
Object: "model",
|
||||
Created: 1700000000,
|
||||
OwnedBy: pID,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"object": "list",
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||
startTime := time.Now()
|
||||
var req models.ChatCompletionRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Select provider based on model name
|
||||
providerName := "openai" // default
|
||||
if strings.Contains(req.Model, "gemini") {
|
||||
providerName = "gemini"
|
||||
} else if strings.Contains(req.Model, "deepseek") {
|
||||
providerName = "deepseek"
|
||||
} else if strings.Contains(req.Model, "kimi") || strings.Contains(req.Model, "moonshot") {
|
||||
providerName = "moonshot"
|
||||
} else if strings.Contains(req.Model, "grok") {
|
||||
providerName = "grok"
|
||||
}
|
||||
|
||||
provider, ok := s.providers[providerName]
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
||||
return
|
||||
}
|
||||
|
||||
// Convert ChatCompletionRequest to UnifiedRequest
|
||||
unifiedReq := &models.UnifiedRequest{
|
||||
Model: req.Model,
|
||||
Messages: []models.UnifiedMessage{},
|
||||
Temperature: req.Temperature,
|
||||
TopP: req.TopP,
|
||||
TopK: req.TopK,
|
||||
N: req.N,
|
||||
MaxTokens: req.MaxTokens,
|
||||
PresencePenalty: req.PresencePenalty,
|
||||
FrequencyPenalty: req.FrequencyPenalty,
|
||||
Stream: req.Stream != nil && *req.Stream,
|
||||
Tools: req.Tools,
|
||||
ToolChoice: req.ToolChoice,
|
||||
}
|
||||
|
||||
// Handle Stop sequences
|
||||
if req.Stop != nil {
|
||||
var stop []string
|
||||
if err := json.Unmarshal(req.Stop, &stop); err == nil {
|
||||
unifiedReq.Stop = stop
|
||||
} else {
|
||||
var singleStop string
|
||||
if err := json.Unmarshal(req.Stop, &singleStop); err == nil {
|
||||
unifiedReq.Stop = []string{singleStop}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert messages
|
||||
for _, msg := range req.Messages {
|
||||
unifiedMsg := models.UnifiedMessage{
|
||||
Role: msg.Role,
|
||||
Content: []models.UnifiedContentPart{},
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
ToolCalls: msg.ToolCalls,
|
||||
Name: msg.Name,
|
||||
ToolCallID: msg.ToolCallID,
|
||||
}
|
||||
|
||||
// Handle multimodal content
|
||||
if strContent, ok := msg.Content.(string); ok {
|
||||
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
||||
Type: "text",
|
||||
Text: strContent,
|
||||
})
|
||||
} else if parts, ok := msg.Content.([]interface{}); ok {
|
||||
for _, part := range parts {
|
||||
if partMap, ok := part.(map[string]interface{}); ok {
|
||||
partType, _ := partMap["type"].(string)
|
||||
if partType == "text" {
|
||||
text, _ := partMap["text"].(string)
|
||||
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
||||
Type: "text",
|
||||
Text: text,
|
||||
})
|
||||
} else if partType == "image_url" {
|
||||
if imgURLMap, ok := partMap["image_url"].(map[string]interface{}); ok {
|
||||
url, _ := imgURLMap["url"].(string)
|
||||
imageInput := &models.ImageInput{}
|
||||
if strings.HasPrefix(url, "data:") {
|
||||
mime, data, err := utils.ParseDataURL(url)
|
||||
if err == nil {
|
||||
imageInput.Base64 = data
|
||||
imageInput.MimeType = mime
|
||||
}
|
||||
} else {
|
||||
imageInput.URL = url
|
||||
}
|
||||
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
||||
Type: "image",
|
||||
Image: imageInput,
|
||||
})
|
||||
unifiedReq.HasImages = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unifiedReq.Messages = append(unifiedReq.Messages, unifiedMsg)
|
||||
}
|
||||
|
||||
clientID := "default"
|
||||
if auth, ok := c.Get("auth"); ok {
|
||||
if authInfo, ok := auth.(models.AuthInfo); ok {
|
||||
unifiedReq.ClientID = authInfo.ClientID
|
||||
clientID = authInfo.ClientID
|
||||
}
|
||||
} else {
|
||||
unifiedReq.ClientID = clientID
|
||||
}
|
||||
|
||||
if unifiedReq.Stream {
|
||||
ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
var lastUsage *models.Usage
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
chunk, ok := <-ch
|
||||
if !ok {
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, lastUsage, nil, unifiedReq.HasImages)
|
||||
return false
|
||||
}
|
||||
if chunk.Usage != nil {
|
||||
lastUsage = chunk.Usage
|
||||
}
|
||||
data, err := json.Marshal(chunk)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq)
|
||||
if err != nil {
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage, nil, unifiedReq.HasImages)
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) logRequest(start time.Time, clientID, provider, model string, usage *models.Usage, err error, hasImages bool) {
|
||||
entry := RequestLog{
|
||||
Timestamp: start,
|
||||
ClientID: clientID,
|
||||
Provider: provider,
|
||||
Model: model,
|
||||
Status: "success",
|
||||
DurationMS: time.Since(start).Milliseconds(),
|
||||
HasImages: hasImages,
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
entry.Status = "error"
|
||||
entry.ErrorMessage = err.Error()
|
||||
}
|
||||
|
||||
if usage != nil {
|
||||
entry.PromptTokens = usage.PromptTokens
|
||||
entry.CompletionTokens = usage.CompletionTokens
|
||||
entry.TotalTokens = usage.TotalTokens
|
||||
if usage.ReasoningTokens != nil {
|
||||
entry.ReasoningTokens = *usage.ReasoningTokens
|
||||
}
|
||||
if usage.CacheReadTokens != nil {
|
||||
entry.CacheReadTokens = *usage.CacheReadTokens
|
||||
}
|
||||
if usage.CacheWriteTokens != nil {
|
||||
entry.CacheWriteTokens = *usage.CacheWriteTokens
|
||||
}
|
||||
|
||||
// Calculate cost using registry
|
||||
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
|
||||
fmt.Printf("[DEBUG] Request logged: model=%s, prompt=%d, completion=%d, reasoning=%d, cache_read=%d, cost=%f\n",
|
||||
model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.Cost)
|
||||
}
|
||||
|
||||
s.logger.LogRequest(entry)
|
||||
}
|
||||
|
||||
func (s *Server) Run() error {
|
||||
go s.hub.Run()
|
||||
s.logger.Start()
|
||||
|
||||
// Start registry refresher
|
||||
go func() {
|
||||
ticker := time.NewTicker(24 * time.Hour)
|
||||
for range ticker.C {
|
||||
newRegistry, err := utils.FetchRegistry()
|
||||
if err == nil {
|
||||
s.registry = newRegistry
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
|
||||
return s.router.Run(addr)
|
||||
}
|
||||
155
internal/server/sessions.go
Normal file
155
internal/server/sessions.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Role string `json:"role"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
SessionID string `json:"session_id"`
|
||||
}
|
||||
|
||||
type SessionManager struct {
|
||||
sessions map[string]Session
|
||||
mu sync.RWMutex
|
||||
secret []byte
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
type sessionPayload struct {
|
||||
SessionID string `json:"session_id"`
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Role string `json:"role"`
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
|
||||
func NewSessionManager(secret []byte, ttl time.Duration) *SessionManager {
|
||||
return &SessionManager{
|
||||
sessions: make(map[string]Session),
|
||||
secret: secret,
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SessionManager) CreateSession(username, displayName, role string) (string, error) {
|
||||
sessionID := uuid.New().String()
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(m.ttl)
|
||||
|
||||
m.mu.Lock()
|
||||
m.sessions[sessionID] = Session{
|
||||
Username: username,
|
||||
DisplayName: displayName,
|
||||
Role: role,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: expiresAt,
|
||||
SessionID: sessionID,
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
return m.createSignedToken(sessionID, username, displayName, role, expiresAt.Unix())
|
||||
}
|
||||
|
||||
func (m *SessionManager) createSignedToken(sessionID, username, displayName, role string, exp int64) (string, error) {
|
||||
payload := sessionPayload{
|
||||
SessionID: sessionID,
|
||||
Username: username,
|
||||
DisplayName: displayName,
|
||||
Role: role,
|
||||
Exp: exp,
|
||||
}
|
||||
|
||||
payloadJSON, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON)
|
||||
|
||||
h := hmac.New(sha256.New, m.secret)
|
||||
h.Write(payloadJSON)
|
||||
signature := h.Sum(nil)
|
||||
signatureB64 := base64.RawURLEncoding.EncodeToString(signature)
|
||||
|
||||
return fmt.Sprintf("%s.%s", payloadB64, signatureB64), nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) ValidateSession(token string) (*Session, string, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 2 {
|
||||
return nil, "", fmt.Errorf("invalid token format")
|
||||
}
|
||||
|
||||
payloadB64 := parts[0]
|
||||
signatureB64 := parts[1]
|
||||
|
||||
payloadJSON, err := base64.RawURLEncoding.DecodeString(payloadB64)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
signature, err := base64.RawURLEncoding.DecodeString(signatureB64)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
h := hmac.New(sha256.New, m.secret)
|
||||
h.Write(payloadJSON)
|
||||
if !hmac.Equal(signature, h.Sum(nil)) {
|
||||
return nil, "", fmt.Errorf("invalid signature")
|
||||
}
|
||||
|
||||
var payload sessionPayload
|
||||
if err := json.Unmarshal(payloadJSON, &payload); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
if time.Now().Unix() > payload.Exp {
|
||||
return nil, "", fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
session, ok := m.sessions[payload.SessionID]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, "", fmt.Errorf("session not found")
|
||||
}
|
||||
|
||||
return &session, "", nil
|
||||
}
|
||||
|
||||
func (m *SessionManager) RevokeSession(token string) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 2 {
|
||||
return
|
||||
}
|
||||
|
||||
payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var payload sessionPayload
|
||||
if err := json.Unmarshal(payloadJSON, &payload); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
delete(m.sessions, payload.SessionID)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
107
internal/server/websocket.go
Normal file
107
internal/server/websocket.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // In production, refine this
|
||||
},
|
||||
}
|
||||
|
||||
type Hub struct {
|
||||
clients map[*websocket.Conn]bool
|
||||
broadcast chan interface{}
|
||||
register chan *websocket.Conn
|
||||
unregister chan *websocket.Conn
|
||||
mu sync.Mutex
|
||||
clientCount int32
|
||||
}
|
||||
|
||||
func NewHub() *Hub {
|
||||
return &Hub{
|
||||
clients: make(map[*websocket.Conn]bool),
|
||||
broadcast: make(chan interface{}),
|
||||
register: make(chan *websocket.Conn),
|
||||
unregister: make(chan *websocket.Conn),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) Run() {
|
||||
for {
|
||||
select {
|
||||
case client := <-h.register:
|
||||
h.mu.Lock()
|
||||
h.clients[client] = true
|
||||
atomic.AddInt32(&h.clientCount, 1)
|
||||
h.mu.Unlock()
|
||||
log.Println("WebSocket client registered")
|
||||
case client := <-h.unregister:
|
||||
h.mu.Lock()
|
||||
if _, ok := h.clients[client]; ok {
|
||||
delete(h.clients, client)
|
||||
client.Close()
|
||||
atomic.AddInt32(&h.clientCount, -1)
|
||||
}
|
||||
h.mu.Unlock()
|
||||
log.Println("WebSocket client unregistered")
|
||||
case message := <-h.broadcast:
|
||||
h.mu.Lock()
|
||||
for client := range h.clients {
|
||||
err := client.WriteJSON(message)
|
||||
if err != nil {
|
||||
log.Printf("WebSocket error: %v", err)
|
||||
client.Close()
|
||||
delete(h.clients, client)
|
||||
atomic.AddInt32(&h.clientCount, -1)
|
||||
}
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) GetClientCount() int {
|
||||
return int(atomic.LoadInt32(&h.clientCount))
|
||||
}
|
||||
|
||||
func (s *Server) handleWebSocket(c *gin.Context) {
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to set websocket upgrade: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
s.hub.register <- conn
|
||||
|
||||
defer func() {
|
||||
s.hub.unregister <- conn
|
||||
}()
|
||||
|
||||
// Initial message
|
||||
conn.WriteJSON(gin.H{
|
||||
"type": "connected",
|
||||
"message": "Connected to GopherGate Dashboard",
|
||||
})
|
||||
|
||||
for {
|
||||
var msg map[string]interface{}
|
||||
err := conn.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if msg["type"] == "ping" {
|
||||
conn.WriteJSON(gin.H{"type": "pong", "payload": gin.H{}})
|
||||
}
|
||||
}
|
||||
}
|
||||
71
internal/utils/crypto.go
Normal file
71
internal/utils/crypto.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Encrypt encrypts plain text using AES-GCM with the given 32-byte key.
|
||||
func Encrypt(plainText string, key []byte) (string, error) {
|
||||
if len(key) != 32 {
|
||||
return "", fmt.Errorf("encryption key must be 32 bytes")
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// The nonce should be prepended to the ciphertext
|
||||
cipherText := gcm.Seal(nonce, nonce, []byte(plainText), nil)
|
||||
return base64.StdEncoding.EncodeToString(cipherText), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts base64-encoded cipher text using AES-GCM with the given 32-byte key.
|
||||
func Decrypt(encodedCipherText string, key []byte) (string, error) {
|
||||
if len(key) != 32 {
|
||||
return "", fmt.Errorf("encryption key must be 32 bytes")
|
||||
}
|
||||
|
||||
cipherText, err := base64.StdEncoding.DecodeString(encodedCipherText)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(cipherText) < nonceSize {
|
||||
return "", fmt.Errorf("cipher text too short")
|
||||
}
|
||||
|
||||
nonce, actualCipherText := cipherText[:nonceSize], cipherText[nonceSize:]
|
||||
plainText, err := gcm.Open(nil, nonce, actualCipherText, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(plainText), nil
|
||||
}
|
||||
69
internal/utils/registry.go
Normal file
69
internal/utils/registry.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"gophergate/internal/models"
|
||||
"github.com/go-resty/resty/v2"
|
||||
)
|
||||
|
||||
const ModelsDevURL = "https://models.dev/api.json"
|
||||
|
||||
func FetchRegistry() (*models.ModelRegistry, error) {
|
||||
log.Printf("Fetching model registry from %s", ModelsDevURL)
|
||||
|
||||
client := resty.New().SetTimeout(10 * time.Second)
|
||||
resp, err := client.R().Get(ModelsDevURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch registry: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccess() {
|
||||
return nil, fmt.Errorf("failed to fetch registry: HTTP %d", resp.StatusCode())
|
||||
}
|
||||
|
||||
var providers map[string]models.ProviderInfo
|
||||
if err := json.Unmarshal(resp.Body(), &providers); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal registry: %w", err)
|
||||
}
|
||||
|
||||
log.Println("Successfully loaded model registry")
|
||||
return &models.ModelRegistry{Providers: providers}, nil
|
||||
}
|
||||
|
||||
func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens, completionTokens, reasoningTokens, cacheRead, cacheWrite uint32) float64 {
|
||||
meta := registry.FindModel(modelID)
|
||||
if meta == nil || meta.Cost == nil {
|
||||
log.Printf("[DEBUG] CalculateCost: model %s not found or has no cost metadata", modelID)
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// promptTokens is usually the TOTAL prompt size.
|
||||
// We subtract cacheRead from it to get the uncached part.
|
||||
uncachedTokens := promptTokens
|
||||
if cacheRead > 0 {
|
||||
if cacheRead > promptTokens {
|
||||
uncachedTokens = 0
|
||||
} else {
|
||||
uncachedTokens = promptTokens - cacheRead
|
||||
}
|
||||
}
|
||||
|
||||
cost := (float64(uncachedTokens) * meta.Cost.Input / 1000000.0) +
|
||||
(float64(completionTokens) * meta.Cost.Output / 1000000.0)
|
||||
|
||||
if meta.Cost.CacheRead != nil {
|
||||
cost += float64(cacheRead) * (*meta.Cost.CacheRead) / 1000000.0
|
||||
}
|
||||
if meta.Cost.CacheWrite != nil {
|
||||
cost += float64(cacheWrite) * (*meta.Cost.CacheWrite) / 1000000.0
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG] CalculateCost: model=%s, uncached=%d, completion=%d, reasoning=%d, cache_read=%d, cache_write=%d, cost=%f (input_rate=%f, output_rate=%f)",
|
||||
modelID, uncachedTokens, completionTokens, reasoningTokens, cacheRead, cacheWrite, cost, meta.Cost.Input, meta.Cost.Output)
|
||||
|
||||
return cost
|
||||
}
|
||||
19
internal/utils/utils.go
Normal file
19
internal/utils/utils.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func ParseDataURL(dataURL string) (string, string, error) {
|
||||
if !strings.HasPrefix(dataURL, "data:") {
|
||||
return "", "", fmt.Errorf("not a data URL")
|
||||
}
|
||||
|
||||
parts := strings.Split(dataURL[5:], ";base64,")
|
||||
if len(parts) != 2 {
|
||||
return "", "", fmt.Errorf("invalid data URL format")
|
||||
}
|
||||
|
||||
return parts[0], parts[1], nil
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
-- Migration: add billing_mode to provider_configs
|
||||
-- Adds a billing_mode TEXT column with default 'prepaid'
|
||||
-- After applying, set Gemini to postpaid with:
|
||||
-- UPDATE provider_configs SET billing_mode = 'postpaid' WHERE id = 'gemini';
|
||||
|
||||
BEGIN TRANSACTION;
|
||||
|
||||
ALTER TABLE provider_configs ADD COLUMN billing_mode TEXT DEFAULT 'prepaid';
|
||||
|
||||
COMMIT;
|
||||
|
||||
-- NOTE: If you use a production SQLite file, run the following to set Gemini to postpaid:
|
||||
-- sqlite3 /path/to/db.sqlite "UPDATE provider_configs SET billing_mode='postpaid' WHERE id='gemini';"
|
||||
@@ -1,2 +0,0 @@
|
||||
max_width = 120
|
||||
use_field_init_shorthand = true
|
||||
@@ -1,38 +0,0 @@
|
||||
use axum::{extract::FromRequestParts, http::request::Parts};
|
||||
use axum_extra::TypedHeader;
|
||||
use axum_extra::headers::Authorization;
|
||||
use headers::authorization::Bearer;
|
||||
|
||||
use crate::errors::AppError;
|
||||
|
||||
pub struct AuthenticatedClient {
|
||||
pub token: String,
|
||||
pub client_id: String,
|
||||
}
|
||||
|
||||
impl<S> FromRequestParts<S> for AuthenticatedClient
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = AppError;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
// Extract bearer token from Authorization header
|
||||
let TypedHeader(Authorization(bearer)) = TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
|
||||
.await
|
||||
.map_err(|_| AppError::AuthError("Missing or invalid bearer token".to_string()))?;
|
||||
|
||||
let token = bearer.token().to_string();
|
||||
|
||||
// Derive client_id from the token prefix
|
||||
let client_id = format!("client_{}", &token[..8.min(token.len())]);
|
||||
|
||||
Ok(AuthenticatedClient { token, client_id })
|
||||
}
|
||||
}
|
||||
|
||||
pub fn validate_token(token: &str, valid_tokens: &[String]) -> bool {
|
||||
// Simple validation against list of tokens
|
||||
// In production, use proper token validation (JWT, database lookup, etc.)
|
||||
valid_tokens.contains(&token.to_string())
|
||||
}
|
||||
@@ -1,304 +0,0 @@
|
||||
//! Client management for LLM proxy
|
||||
//!
|
||||
//! This module handles:
|
||||
//! 1. Client registration and management
|
||||
//! 2. Client usage tracking
|
||||
//! 3. Client rate limit configuration
|
||||
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{Row, SqlitePool};
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Client information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Client {
|
||||
pub id: i64,
|
||||
pub client_id: String,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
pub is_active: bool,
|
||||
pub rate_limit_per_minute: i64,
|
||||
pub total_requests: i64,
|
||||
pub total_tokens: i64,
|
||||
pub total_cost: f64,
|
||||
}
|
||||
|
||||
/// Client creation request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateClientRequest {
|
||||
pub client_id: String,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub rate_limit_per_minute: Option<i64>,
|
||||
}
|
||||
|
||||
/// Client update request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UpdateClientRequest {
|
||||
pub name: Option<String>,
|
||||
pub description: Option<String>,
|
||||
pub is_active: Option<bool>,
|
||||
pub rate_limit_per_minute: Option<i64>,
|
||||
}
|
||||
|
||||
/// Client manager for database operations
|
||||
pub struct ClientManager {
|
||||
db_pool: SqlitePool,
|
||||
}
|
||||
|
||||
impl ClientManager {
|
||||
pub fn new(db_pool: SqlitePool) -> Self {
|
||||
Self { db_pool }
|
||||
}
|
||||
|
||||
/// Create a new client
|
||||
pub async fn create_client(&self, request: CreateClientRequest) -> Result<Client> {
|
||||
let rate_limit = request.rate_limit_per_minute.unwrap_or(60);
|
||||
|
||||
// First insert the client
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO clients (client_id, name, description, rate_limit_per_minute)
|
||||
VALUES (?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(&request.client_id)
|
||||
.bind(&request.name)
|
||||
.bind(&request.description)
|
||||
.bind(rate_limit)
|
||||
.execute(&self.db_pool)
|
||||
.await?;
|
||||
|
||||
// Then fetch the created client
|
||||
let client = self
|
||||
.get_client(&request.client_id)
|
||||
.await?
|
||||
.ok_or_else(|| anyhow::anyhow!("Failed to retrieve created client"))?;
|
||||
|
||||
info!("Created client: {} ({})", client.name, client.client_id);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
/// Get a client by ID
|
||||
pub async fn get_client(&self, client_id: &str) -> Result<Option<Client>> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
id, client_id, name, description,
|
||||
created_at, updated_at, is_active,
|
||||
rate_limit_per_minute, total_requests, total_tokens, total_cost
|
||||
FROM clients
|
||||
WHERE client_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(client_id)
|
||||
.fetch_optional(&self.db_pool)
|
||||
.await?;
|
||||
|
||||
if let Some(row) = row {
|
||||
let client = Client {
|
||||
id: row.get("id"),
|
||||
client_id: row.get("client_id"),
|
||||
name: row.get("name"),
|
||||
description: row.get("description"),
|
||||
created_at: row.get("created_at"),
|
||||
updated_at: row.get("updated_at"),
|
||||
is_active: row.get("is_active"),
|
||||
rate_limit_per_minute: row.get("rate_limit_per_minute"),
|
||||
total_requests: row.get("total_requests"),
|
||||
total_tokens: row.get("total_tokens"),
|
||||
total_cost: row.get("total_cost"),
|
||||
};
|
||||
Ok(Some(client))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Update a client
|
||||
pub async fn update_client(&self, client_id: &str, request: UpdateClientRequest) -> Result<Option<Client>> {
|
||||
// First, get the current client to check if it exists
|
||||
let current_client = self.get_client(client_id).await?;
|
||||
if current_client.is_none() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// Build update query dynamically based on provided fields
|
||||
let mut query_builder = sqlx::QueryBuilder::new("UPDATE clients SET ");
|
||||
let mut has_updates = false;
|
||||
|
||||
if let Some(name) = &request.name {
|
||||
query_builder.push("name = ");
|
||||
query_builder.push_bind(name);
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
if let Some(description) = &request.description {
|
||||
if has_updates {
|
||||
query_builder.push(", ");
|
||||
}
|
||||
query_builder.push("description = ");
|
||||
query_builder.push_bind(description);
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
if let Some(is_active) = request.is_active {
|
||||
if has_updates {
|
||||
query_builder.push(", ");
|
||||
}
|
||||
query_builder.push("is_active = ");
|
||||
query_builder.push_bind(is_active);
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
if let Some(rate_limit) = request.rate_limit_per_minute {
|
||||
if has_updates {
|
||||
query_builder.push(", ");
|
||||
}
|
||||
query_builder.push("rate_limit_per_minute = ");
|
||||
query_builder.push_bind(rate_limit);
|
||||
has_updates = true;
|
||||
}
|
||||
|
||||
// Always update the updated_at timestamp
|
||||
if has_updates {
|
||||
query_builder.push(", ");
|
||||
}
|
||||
query_builder.push("updated_at = CURRENT_TIMESTAMP");
|
||||
|
||||
if !has_updates {
|
||||
// No updates to make
|
||||
return self.get_client(client_id).await;
|
||||
}
|
||||
|
||||
query_builder.push(" WHERE client_id = ");
|
||||
query_builder.push_bind(client_id);
|
||||
|
||||
let query = query_builder.build();
|
||||
query.execute(&self.db_pool).await?;
|
||||
|
||||
// Fetch the updated client
|
||||
let updated_client = self.get_client(client_id).await?;
|
||||
|
||||
if updated_client.is_some() {
|
||||
info!("Updated client: {}", client_id);
|
||||
}
|
||||
|
||||
Ok(updated_client)
|
||||
}
|
||||
|
||||
/// List all clients
|
||||
pub async fn list_clients(&self, limit: Option<i64>, offset: Option<i64>) -> Result<Vec<Client>> {
|
||||
let limit = limit.unwrap_or(100);
|
||||
let offset = offset.unwrap_or(0);
|
||||
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
id, client_id, name, description,
|
||||
created_at, updated_at, is_active,
|
||||
rate_limit_per_minute, total_requests, total_tokens, total_cost
|
||||
FROM clients
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"#,
|
||||
)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.fetch_all(&self.db_pool)
|
||||
.await?;
|
||||
|
||||
let mut clients = Vec::new();
|
||||
for row in rows {
|
||||
let client = Client {
|
||||
id: row.get("id"),
|
||||
client_id: row.get("client_id"),
|
||||
name: row.get("name"),
|
||||
description: row.get("description"),
|
||||
created_at: row.get("created_at"),
|
||||
updated_at: row.get("updated_at"),
|
||||
is_active: row.get("is_active"),
|
||||
rate_limit_per_minute: row.get("rate_limit_per_minute"),
|
||||
total_requests: row.get("total_requests"),
|
||||
total_tokens: row.get("total_tokens"),
|
||||
total_cost: row.get("total_cost"),
|
||||
};
|
||||
clients.push(client);
|
||||
}
|
||||
|
||||
Ok(clients)
|
||||
}
|
||||
|
||||
/// Delete a client
|
||||
pub async fn delete_client(&self, client_id: &str) -> Result<bool> {
|
||||
let result = sqlx::query("DELETE FROM clients WHERE client_id = ?")
|
||||
.bind(client_id)
|
||||
.execute(&self.db_pool)
|
||||
.await?;
|
||||
|
||||
let deleted = result.rows_affected() > 0;
|
||||
|
||||
if deleted {
|
||||
info!("Deleted client: {}", client_id);
|
||||
} else {
|
||||
warn!("Client not found for deletion: {}", client_id);
|
||||
}
|
||||
|
||||
Ok(deleted)
|
||||
}
|
||||
|
||||
/// Update client usage statistics after a request
|
||||
pub async fn update_client_usage(&self, client_id: &str, tokens: i64, cost: f64) -> Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
UPDATE clients
|
||||
SET
|
||||
total_requests = total_requests + 1,
|
||||
total_tokens = total_tokens + ?,
|
||||
total_cost = total_cost + ?,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE client_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(tokens)
|
||||
.bind(cost)
|
||||
.bind(client_id)
|
||||
.execute(&self.db_pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get client usage statistics
|
||||
pub async fn get_client_usage(&self, client_id: &str) -> Result<Option<(i64, i64, f64)>> {
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT total_requests, total_tokens, total_cost
|
||||
FROM clients
|
||||
WHERE client_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(client_id)
|
||||
.fetch_optional(&self.db_pool)
|
||||
.await?;
|
||||
|
||||
if let Some(row) = row {
|
||||
let total_requests: i64 = row.get("total_requests");
|
||||
let total_tokens: i64 = row.get("total_tokens");
|
||||
let total_cost: f64 = row.get("total_cost");
|
||||
Ok(Some((total_requests, total_tokens, total_cost)))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a client exists and is active
|
||||
pub async fn validate_client(&self, client_id: &str) -> Result<bool> {
|
||||
let client = self.get_client(client_id).await?;
|
||||
Ok(client.map(|c| c.is_active).unwrap_or(false))
|
||||
}
|
||||
}
|
||||
@@ -1,242 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use config::{Config, File, FileFormat};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ServerConfig {
|
||||
pub port: u16,
|
||||
pub host: String,
|
||||
#[serde(deserialize_with = "deserialize_vec_or_string")]
|
||||
pub auth_tokens: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DatabaseConfig {
|
||||
pub path: PathBuf,
|
||||
pub max_connections: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProviderConfig {
|
||||
pub openai: OpenAIConfig,
|
||||
pub gemini: GeminiConfig,
|
||||
pub deepseek: DeepSeekConfig,
|
||||
pub grok: GrokConfig,
|
||||
pub ollama: OllamaConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIConfig {
|
||||
pub api_key_env: String,
|
||||
pub base_url: String,
|
||||
pub default_model: String,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GeminiConfig {
|
||||
pub api_key_env: String,
|
||||
pub base_url: String,
|
||||
pub default_model: String,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DeepSeekConfig {
|
||||
pub api_key_env: String,
|
||||
pub base_url: String,
|
||||
pub default_model: String,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GrokConfig {
|
||||
pub api_key_env: String,
|
||||
pub base_url: String,
|
||||
pub default_model: String,
|
||||
pub enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OllamaConfig {
|
||||
pub base_url: String,
|
||||
pub enabled: bool,
|
||||
#[serde(deserialize_with = "deserialize_vec_or_string")]
|
||||
pub models: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelMappingConfig {
|
||||
pub patterns: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PricingConfig {
|
||||
pub openai: Vec<ModelPricing>,
|
||||
pub gemini: Vec<ModelPricing>,
|
||||
pub deepseek: Vec<ModelPricing>,
|
||||
pub grok: Vec<ModelPricing>,
|
||||
pub ollama: Vec<ModelPricing>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelPricing {
|
||||
pub model: String,
|
||||
pub prompt_tokens_per_million: f64,
|
||||
pub completion_tokens_per_million: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AppConfig {
|
||||
pub server: ServerConfig,
|
||||
pub database: DatabaseConfig,
|
||||
pub providers: ProviderConfig,
|
||||
pub model_mapping: ModelMappingConfig,
|
||||
pub pricing: PricingConfig,
|
||||
pub config_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
pub async fn load() -> Result<Arc<Self>> {
|
||||
Self::load_from_path(None).await
|
||||
}
|
||||
|
||||
/// Load configuration from a specific path (for testing)
|
||||
pub async fn load_from_path(config_path: Option<PathBuf>) -> Result<Arc<Self>> {
|
||||
// Load configuration from multiple sources
|
||||
let mut config_builder = Config::builder();
|
||||
|
||||
// Default configuration
|
||||
config_builder = config_builder
|
||||
.set_default("server.port", 8080)?
|
||||
.set_default("server.host", "0.0.0.0")?
|
||||
.set_default("server.auth_tokens", Vec::<String>::new())?
|
||||
.set_default("database.path", "./data/llm_proxy.db")?
|
||||
.set_default("database.max_connections", 10)?
|
||||
.set_default("providers.openai.api_key_env", "OPENAI_API_KEY")?
|
||||
.set_default("providers.openai.base_url", "https://api.openai.com/v1")?
|
||||
.set_default("providers.openai.default_model", "gpt-4o")?
|
||||
.set_default("providers.openai.enabled", true)?
|
||||
.set_default("providers.gemini.api_key_env", "GEMINI_API_KEY")?
|
||||
.set_default(
|
||||
"providers.gemini.base_url",
|
||||
"https://generativelanguage.googleapis.com/v1",
|
||||
)?
|
||||
.set_default("providers.gemini.default_model", "gemini-2.0-flash")?
|
||||
.set_default("providers.gemini.enabled", true)?
|
||||
.set_default("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")?
|
||||
.set_default("providers.deepseek.base_url", "https://api.deepseek.com")?
|
||||
.set_default("providers.deepseek.default_model", "deepseek-reasoner")?
|
||||
.set_default("providers.deepseek.enabled", true)?
|
||||
.set_default("providers.grok.api_key_env", "GROK_API_KEY")?
|
||||
.set_default("providers.grok.base_url", "https://api.x.ai/v1")?
|
||||
.set_default("providers.grok.default_model", "grok-beta")?
|
||||
.set_default("providers.grok.enabled", true)?
|
||||
.set_default("providers.ollama.base_url", "http://localhost:11434/v1")?
|
||||
.set_default("providers.ollama.enabled", false)?
|
||||
.set_default("providers.ollama.models", Vec::<String>::new())?;
|
||||
|
||||
// Load from config file if exists
|
||||
// Priority: explicit path arg > LLM_PROXY__CONFIG_PATH env var > ./config.toml
|
||||
let config_path = config_path
|
||||
.or_else(|| std::env::var("LLM_PROXY__CONFIG_PATH").ok().map(PathBuf::from))
|
||||
.unwrap_or_else(|| {
|
||||
std::env::current_dir()
|
||||
.unwrap_or_else(|_| PathBuf::from("."))
|
||||
.join("config.toml")
|
||||
});
|
||||
if config_path.exists() {
|
||||
config_builder = config_builder.add_source(File::from(config_path.clone()).format(FileFormat::Toml));
|
||||
}
|
||||
|
||||
// Load from .env file
|
||||
dotenvy::dotenv().ok();
|
||||
|
||||
// Load from environment variables (with prefix "LLM_PROXY_")
|
||||
config_builder = config_builder.add_source(
|
||||
config::Environment::with_prefix("LLM_PROXY")
|
||||
.separator("__")
|
||||
.try_parsing(true),
|
||||
);
|
||||
|
||||
let config = config_builder.build()?;
|
||||
|
||||
// Deserialize configuration
|
||||
let server: ServerConfig = config.get("server")?;
|
||||
let database: DatabaseConfig = config.get("database")?;
|
||||
let providers: ProviderConfig = config.get("providers")?;
|
||||
|
||||
// For now, use empty model mapping and pricing (will be populated later)
|
||||
let model_mapping = ModelMappingConfig { patterns: vec![] };
|
||||
let pricing = PricingConfig {
|
||||
openai: vec![],
|
||||
gemini: vec![],
|
||||
deepseek: vec![],
|
||||
grok: vec![],
|
||||
ollama: vec![],
|
||||
};
|
||||
|
||||
Ok(Arc::new(AppConfig {
|
||||
server,
|
||||
database,
|
||||
providers,
|
||||
model_mapping,
|
||||
pricing,
|
||||
config_path: Some(config_path),
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn get_api_key(&self, provider: &str) -> Result<String> {
|
||||
let env_var = match provider {
|
||||
"openai" => &self.providers.openai.api_key_env,
|
||||
"gemini" => &self.providers.gemini.api_key_env,
|
||||
"deepseek" => &self.providers.deepseek.api_key_env,
|
||||
"grok" => &self.providers.grok.api_key_env,
|
||||
_ => return Err(anyhow::anyhow!("Unknown provider: {}", provider)),
|
||||
};
|
||||
|
||||
std::env::var(env_var).map_err(|_| anyhow::anyhow!("Environment variable {} not set for {}", env_var, provider))
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to deserialize a Vec<String> from either a sequence or a comma-separated string
|
||||
fn deserialize_vec_or_string<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct VecOrString;
|
||||
|
||||
impl<'de> serde::de::Visitor<'de> for VecOrString {
|
||||
type Value = Vec<String>;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("a sequence or a comma-separated string")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(value
|
||||
.split(',')
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
|
||||
where
|
||||
S: serde::de::SeqAccess<'de>,
|
||||
{
|
||||
let mut vec = Vec::new();
|
||||
while let Some(element) = seq.next_element()? {
|
||||
vec.push(element);
|
||||
}
|
||||
Ok(vec)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(VecOrString)
|
||||
}
|
||||
@@ -1,212 +0,0 @@
|
||||
use axum::{extract::State, response::Json};
|
||||
use bcrypt;
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
|
||||
// Authentication handlers
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct LoginRequest {
|
||||
pub(super) username: String,
|
||||
pub(super) password: String,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_login(
|
||||
State(state): State<DashboardState>,
|
||||
Json(payload): Json<LoginRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let user_result = sqlx::query(
|
||||
"SELECT username, password_hash, display_name, role, must_change_password FROM users WHERE username = ?",
|
||||
)
|
||||
.bind(&payload.username)
|
||||
.fetch_optional(pool)
|
||||
.await;
|
||||
|
||||
match user_result {
|
||||
Ok(Some(row)) => {
|
||||
let hash = row.get::<String, _>("password_hash");
|
||||
if bcrypt::verify(&payload.password, &hash).unwrap_or(false) {
|
||||
let username = row.get::<String, _>("username");
|
||||
let role = row.get::<String, _>("role");
|
||||
let display_name = row
|
||||
.get::<Option<String>, _>("display_name")
|
||||
.unwrap_or_else(|| username.clone());
|
||||
let must_change_password = row.get::<bool, _>("must_change_password");
|
||||
let token = state
|
||||
.session_manager
|
||||
.create_session(username.clone(), role.clone())
|
||||
.await;
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"token": token,
|
||||
"must_change_password": must_change_password,
|
||||
"user": {
|
||||
"username": username,
|
||||
"name": display_name,
|
||||
"role": role
|
||||
}
|
||||
})))
|
||||
} else {
|
||||
Json(ApiResponse::error("Invalid username or password".to_string()))
|
||||
}
|
||||
}
|
||||
Ok(None) => Json(ApiResponse::error("Invalid username or password".to_string())),
|
||||
Err(e) => {
|
||||
warn!("Database error during login: {}", e);
|
||||
Json(ApiResponse::error("Login failed due to system error".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_auth_status(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let token = headers
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
if let Some(token) = token
|
||||
&& let Some(session) = state.session_manager.validate_session(token).await
|
||||
{
|
||||
// Look up display_name from DB
|
||||
let display_name = sqlx::query_scalar::<_, Option<String>>(
|
||||
"SELECT display_name FROM users WHERE username = ?",
|
||||
)
|
||||
.bind(&session.username)
|
||||
.fetch_optional(&state.app_state.db_pool)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.flatten()
|
||||
.unwrap_or_else(|| session.username.clone());
|
||||
|
||||
return Json(ApiResponse::success(serde_json::json!({
|
||||
"authenticated": true,
|
||||
"user": {
|
||||
"username": session.username,
|
||||
"name": display_name,
|
||||
"role": session.role
|
||||
}
|
||||
})));
|
||||
}
|
||||
|
||||
Json(ApiResponse::error("Not authenticated".to_string()))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct ChangePasswordRequest {
|
||||
pub(super) current_password: String,
|
||||
pub(super) new_password: String,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_change_password(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(payload): Json<ChangePasswordRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Extract the authenticated user from the session token
|
||||
let token = headers
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
let session = match token {
|
||||
Some(t) => state.session_manager.validate_session(t).await,
|
||||
None => None,
|
||||
};
|
||||
|
||||
let username = match session {
|
||||
Some(s) => s.username,
|
||||
None => return Json(ApiResponse::error("Not authenticated".to_string())),
|
||||
};
|
||||
|
||||
let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = ?")
|
||||
.bind(&username)
|
||||
.fetch_one(pool)
|
||||
.await;
|
||||
|
||||
match user_result {
|
||||
Ok(row) => {
|
||||
let hash = row.get::<String, _>("password_hash");
|
||||
if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) {
|
||||
let new_hash = match bcrypt::hash(&payload.new_password, 12) {
|
||||
Ok(h) => h,
|
||||
Err(_) => return Json(ApiResponse::error("Failed to hash new password".to_string())),
|
||||
};
|
||||
|
||||
let update_result = sqlx::query(
|
||||
"UPDATE users SET password_hash = ?, must_change_password = FALSE WHERE username = ?",
|
||||
)
|
||||
.bind(new_hash)
|
||||
.bind(&username)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match update_result {
|
||||
Ok(_) => Json(ApiResponse::success(
|
||||
serde_json::json!({ "message": "Password updated successfully" }),
|
||||
)),
|
||||
Err(e) => Json(ApiResponse::error(format!("Failed to update database: {}", e))),
|
||||
}
|
||||
} else {
|
||||
Json(ApiResponse::error("Current password incorrect".to_string()))
|
||||
}
|
||||
}
|
||||
Err(e) => Json(ApiResponse::error(format!("User not found: {}", e))),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_logout(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let token = headers
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
if let Some(token) = token {
|
||||
state.session_manager.revoke_session(token).await;
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({ "message": "Logged out" })))
|
||||
}
|
||||
|
||||
/// Helper: Extract and validate a session from the Authorization header.
|
||||
/// Returns the Session if valid, or an error response.
|
||||
pub(super) async fn extract_session(
|
||||
state: &DashboardState,
|
||||
headers: &axum::http::HeaderMap,
|
||||
) -> Result<super::sessions::Session, Json<ApiResponse<serde_json::Value>>> {
|
||||
let token = headers
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
match token {
|
||||
Some(t) => match state.session_manager.validate_session(t).await {
|
||||
Some(session) => Ok(session),
|
||||
None => Err(Json(ApiResponse::error("Session expired or invalid".to_string()))),
|
||||
},
|
||||
None => Err(Json(ApiResponse::error("Not authenticated".to_string()))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper: Extract session and require admin role.
|
||||
pub(super) async fn require_admin(
|
||||
state: &DashboardState,
|
||||
headers: &axum::http::HeaderMap,
|
||||
) -> Result<super::sessions::Session, Json<ApiResponse<serde_json::Value>>> {
|
||||
let session = extract_session(state, headers).await?;
|
||||
if session.role != "admin" {
|
||||
return Err(Json(ApiResponse::error("Admin access required".to_string())));
|
||||
}
|
||||
Ok(session)
|
||||
}
|
||||
@@ -1,513 +0,0 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::Json,
|
||||
};
|
||||
use chrono;
|
||||
use rand::Rng;
|
||||
use serde::Deserialize;
|
||||
use serde_json;
|
||||
use sqlx::Row;
|
||||
use tracing::warn;
|
||||
use uuid;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
|
||||
/// Generate a random API token: sk-{48 hex chars}
|
||||
fn generate_token() -> String {
|
||||
let mut rng = rand::rng();
|
||||
let bytes: Vec<u8> = (0..24).map(|_| rng.random::<u8>()).collect();
|
||||
format!("sk-{}", hex::encode(bytes))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct CreateClientRequest {
|
||||
pub(super) name: String,
|
||||
pub(super) client_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct UpdateClientPayload {
|
||||
pub(super) name: Option<String>,
|
||||
pub(super) description: Option<String>,
|
||||
pub(super) is_active: Option<bool>,
|
||||
pub(super) rate_limit_per_minute: Option<i64>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_clients(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
client_id as id,
|
||||
name,
|
||||
description,
|
||||
created_at,
|
||||
total_requests,
|
||||
total_tokens,
|
||||
total_cost,
|
||||
is_active,
|
||||
rate_limit_per_minute
|
||||
FROM clients
|
||||
ORDER BY created_at DESC
|
||||
"#,
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let clients: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
serde_json::json!({
|
||||
"id": row.get::<String, _>("id"),
|
||||
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
|
||||
"description": row.get::<Option<String>, _>("description"),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
"requests_count": row.get::<i64, _>("total_requests"),
|
||||
"total_tokens": row.get::<i64, _>("total_tokens"),
|
||||
"total_cost": row.get::<f64, _>("total_cost"),
|
||||
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
|
||||
"rate_limit_per_minute": row.get::<Option<i64>, _>("rate_limit_per_minute"),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(clients)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch clients: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch clients".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_create_client(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(payload): Json<CreateClientRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let client_id = payload
|
||||
.client_id
|
||||
.unwrap_or_else(|| format!("client-{}", &uuid::Uuid::new_v4().to_string()[..8]));
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO clients (client_id, name, is_active)
|
||||
VALUES (?, ?, TRUE)
|
||||
RETURNING *
|
||||
"#,
|
||||
)
|
||||
.bind(&client_id)
|
||||
.bind(&payload.name)
|
||||
.fetch_one(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(row) => {
|
||||
// Auto-generate a token for the new client
|
||||
let token = generate_token();
|
||||
let token_result = sqlx::query(
|
||||
"INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, 'default')",
|
||||
)
|
||||
.bind(&client_id)
|
||||
.bind(&token)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
if let Err(e) = token_result {
|
||||
warn!("Client created but failed to generate token: {}", e);
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"id": row.get::<String, _>("client_id"),
|
||||
"name": row.get::<Option<String>, _>("name"),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
"status": "active",
|
||||
"token": token,
|
||||
})))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create client: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to create client: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_client(
|
||||
State(state): State<DashboardState>,
|
||||
Path(id): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
c.client_id as id,
|
||||
c.name,
|
||||
c.description,
|
||||
c.is_active,
|
||||
c.rate_limit_per_minute,
|
||||
c.created_at,
|
||||
COALESCE(c.total_tokens, 0) as total_tokens,
|
||||
COALESCE(c.total_cost, 0.0) as total_cost,
|
||||
COUNT(r.id) as total_requests,
|
||||
MAX(r.timestamp) as last_request
|
||||
FROM clients c
|
||||
LEFT JOIN llm_requests r ON c.client_id = r.client_id
|
||||
WHERE c.client_id = ?
|
||||
GROUP BY c.client_id
|
||||
"#,
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_optional(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Some(row)) => Json(ApiResponse::success(serde_json::json!({
|
||||
"id": row.get::<String, _>("id"),
|
||||
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
|
||||
"description": row.get::<Option<String>, _>("description"),
|
||||
"is_active": row.get::<bool, _>("is_active"),
|
||||
"rate_limit_per_minute": row.get::<Option<i64>, _>("rate_limit_per_minute"),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
"total_tokens": row.get::<i64, _>("total_tokens"),
|
||||
"total_cost": row.get::<f64, _>("total_cost"),
|
||||
"total_requests": row.get::<i64, _>("total_requests"),
|
||||
"last_request": row.get::<Option<chrono::DateTime<chrono::Utc>>, _>("last_request"),
|
||||
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
|
||||
}))),
|
||||
Ok(None) => Json(ApiResponse::error(format!("Client '{}' not found", id))),
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch client: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to fetch client: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_update_client(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
Json(payload): Json<UpdateClientPayload>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Build dynamic UPDATE query from provided fields
|
||||
let mut sets = Vec::new();
|
||||
let mut binds: Vec<String> = Vec::new();
|
||||
|
||||
if let Some(ref name) = payload.name {
|
||||
sets.push("name = ?");
|
||||
binds.push(name.clone());
|
||||
}
|
||||
if let Some(ref desc) = payload.description {
|
||||
sets.push("description = ?");
|
||||
binds.push(desc.clone());
|
||||
}
|
||||
if payload.is_active.is_some() {
|
||||
sets.push("is_active = ?");
|
||||
}
|
||||
if payload.rate_limit_per_minute.is_some() {
|
||||
sets.push("rate_limit_per_minute = ?");
|
||||
}
|
||||
|
||||
if sets.is_empty() {
|
||||
return Json(ApiResponse::error("No fields to update".to_string()));
|
||||
}
|
||||
|
||||
// Always update updated_at
|
||||
sets.push("updated_at = CURRENT_TIMESTAMP");
|
||||
|
||||
let sql = format!("UPDATE clients SET {} WHERE client_id = ?", sets.join(", "));
|
||||
let mut query = sqlx::query(&sql);
|
||||
|
||||
// Bind in the same order as sets
|
||||
for b in &binds {
|
||||
query = query.bind(b);
|
||||
}
|
||||
if let Some(active) = payload.is_active {
|
||||
query = query.bind(active);
|
||||
}
|
||||
if let Some(rate) = payload.rate_limit_per_minute {
|
||||
query = query.bind(rate);
|
||||
}
|
||||
query = query.bind(&id);
|
||||
|
||||
match query.execute(pool).await {
|
||||
Ok(result) => {
|
||||
if result.rows_affected() == 0 {
|
||||
return Json(ApiResponse::error(format!("Client '{}' not found", id)));
|
||||
}
|
||||
// Return the updated client
|
||||
let row = sqlx::query(
|
||||
r#"
|
||||
SELECT client_id as id, name, description, is_active, rate_limit_per_minute,
|
||||
created_at, total_requests, total_tokens, total_cost
|
||||
FROM clients WHERE client_id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_one(pool)
|
||||
.await;
|
||||
|
||||
match row {
|
||||
Ok(row) => Json(ApiResponse::success(serde_json::json!({
|
||||
"id": row.get::<String, _>("id"),
|
||||
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
|
||||
"description": row.get::<Option<String>, _>("description"),
|
||||
"is_active": row.get::<bool, _>("is_active"),
|
||||
"rate_limit_per_minute": row.get::<Option<i64>, _>("rate_limit_per_minute"),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
"total_requests": row.get::<i64, _>("total_requests"),
|
||||
"total_tokens": row.get::<i64, _>("total_tokens"),
|
||||
"total_cost": row.get::<f64, _>("total_cost"),
|
||||
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
|
||||
}))),
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch updated client: {}", e);
|
||||
// Update succeeded but fetch failed — still report success
|
||||
Json(ApiResponse::success(serde_json::json!({ "message": "Client updated" })))
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to update client: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to update client: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_delete_client(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Don't allow deleting the default client
|
||||
if id == "default" {
|
||||
return Json(ApiResponse::error("Cannot delete default client".to_string()));
|
||||
}
|
||||
|
||||
let result = sqlx::query("DELETE FROM clients WHERE client_id = ?")
|
||||
.bind(id)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Client deleted" }))),
|
||||
Err(e) => Json(ApiResponse::error(format!("Failed to delete client: {}", e))),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_client_usage(
|
||||
State(state): State<DashboardState>,
|
||||
Path(id): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Get per-model breakdown for this client
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
model,
|
||||
provider,
|
||||
COUNT(*) as request_count,
|
||||
SUM(prompt_tokens) as prompt_tokens,
|
||||
SUM(completion_tokens) as completion_tokens,
|
||||
SUM(total_tokens) as total_tokens,
|
||||
SUM(cost) as total_cost,
|
||||
AVG(duration_ms) as avg_duration_ms
|
||||
FROM llm_requests
|
||||
WHERE client_id = ?
|
||||
GROUP BY model, provider
|
||||
ORDER BY total_cost DESC
|
||||
"#,
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let breakdown: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
serde_json::json!({
|
||||
"model": row.get::<String, _>("model"),
|
||||
"provider": row.get::<String, _>("provider"),
|
||||
"request_count": row.get::<i64, _>("request_count"),
|
||||
"prompt_tokens": row.get::<i64, _>("prompt_tokens"),
|
||||
"completion_tokens": row.get::<i64, _>("completion_tokens"),
|
||||
"total_tokens": row.get::<i64, _>("total_tokens"),
|
||||
"total_cost": row.get::<f64, _>("total_cost"),
|
||||
"avg_duration_ms": row.get::<f64, _>("avg_duration_ms"),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"client_id": id,
|
||||
"breakdown": breakdown,
|
||||
})))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch client usage: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to fetch client usage: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Token management endpoints ──────────────────────────────────────
|
||||
|
||||
pub(super) async fn handle_get_client_tokens(
|
||||
State(state): State<DashboardState>,
|
||||
Path(id): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
SELECT id, token, name, is_active, created_at, last_used_at
|
||||
FROM client_tokens
|
||||
WHERE client_id = ?
|
||||
ORDER BY created_at DESC
|
||||
"#,
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let tokens: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
let token: String = row.get("token");
|
||||
// Mask all but last 8 chars: sk-••••abcd1234
|
||||
let masked = if token.len() > 8 {
|
||||
format!("{}••••{}", &token[..3], &token[token.len() - 8..])
|
||||
} else {
|
||||
"••••".to_string()
|
||||
};
|
||||
serde_json::json!({
|
||||
"id": row.get::<i64, _>("id"),
|
||||
"token_masked": masked,
|
||||
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "default".to_string()),
|
||||
"is_active": row.get::<bool, _>("is_active"),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
"last_used_at": row.get::<Option<chrono::DateTime<chrono::Utc>>, _>("last_used_at"),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(tokens)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch client tokens: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to fetch client tokens: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct CreateTokenRequest {
|
||||
pub(super) name: Option<String>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_create_client_token(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
Json(payload): Json<CreateTokenRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Verify client exists
|
||||
let exists: Option<(i64,)> = sqlx::query_as("SELECT 1 as x FROM clients WHERE client_id = ?")
|
||||
.bind(&id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
if exists.is_none() {
|
||||
return Json(ApiResponse::error(format!("Client '{}' not found", id)));
|
||||
}
|
||||
|
||||
let token = generate_token();
|
||||
let token_name = payload.name.unwrap_or_else(|| "default".to_string());
|
||||
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, ?) RETURNING id, created_at",
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(&token)
|
||||
.bind(&token_name)
|
||||
.fetch_one(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(row) => Json(ApiResponse::success(serde_json::json!({
|
||||
"id": row.get::<i64, _>("id"),
|
||||
"token": token,
|
||||
"name": token_name,
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
}))),
|
||||
Err(e) => {
|
||||
warn!("Failed to create client token: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to create token: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_delete_client_token(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path((client_id, token_id)): Path<(String, i64)>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query("DELETE FROM client_tokens WHERE id = ? AND client_id = ?")
|
||||
.bind(token_id)
|
||||
.bind(&client_id)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) => {
|
||||
if r.rows_affected() == 0 {
|
||||
Json(ApiResponse::error("Token not found".to_string()))
|
||||
} else {
|
||||
Json(ApiResponse::success(serde_json::json!({ "message": "Token revoked" })))
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to delete client token: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to revoke token: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
// Dashboard module for LLM Proxy Gateway
|
||||
|
||||
mod auth;
|
||||
mod clients;
|
||||
mod models;
|
||||
mod providers;
|
||||
pub mod sessions;
|
||||
mod system;
|
||||
mod usage;
|
||||
mod users;
|
||||
mod websocket;
|
||||
|
||||
use axum::{
|
||||
Router,
|
||||
routing::{delete, get, post, put},
|
||||
};
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::state::AppState;
|
||||
use sessions::SessionManager;
|
||||
|
||||
// Dashboard state
|
||||
#[derive(Clone)]
|
||||
struct DashboardState {
|
||||
app_state: AppState,
|
||||
session_manager: SessionManager,
|
||||
}
|
||||
|
||||
// API Response types
|
||||
#[derive(Serialize)]
|
||||
struct ApiResponse<T> {
|
||||
success: bool,
|
||||
data: Option<T>,
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
impl<T> ApiResponse<T> {
|
||||
fn success(data: T) -> Self {
|
||||
Self {
|
||||
success: true,
|
||||
data: Some(data),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn error(error: String) -> Self {
|
||||
Self {
|
||||
success: false,
|
||||
data: None,
|
||||
error: Some(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Dashboard routes
|
||||
pub fn router(state: AppState) -> Router {
|
||||
let session_manager = SessionManager::new(24); // 24-hour session TTL
|
||||
let dashboard_state = DashboardState {
|
||||
app_state: state,
|
||||
session_manager,
|
||||
};
|
||||
|
||||
Router::new()
|
||||
// Static file serving
|
||||
.fallback_service(tower_http::services::ServeDir::new("static"))
|
||||
// WebSocket endpoint
|
||||
.route("/ws", get(websocket::handle_websocket))
|
||||
// API endpoints
|
||||
.route("/api/auth/login", post(auth::handle_login))
|
||||
.route("/api/auth/status", get(auth::handle_auth_status))
|
||||
.route("/api/auth/logout", post(auth::handle_logout))
|
||||
.route("/api/auth/change-password", post(auth::handle_change_password))
|
||||
.route(
|
||||
"/api/users",
|
||||
get(users::handle_get_users).post(users::handle_create_user),
|
||||
)
|
||||
.route(
|
||||
"/api/users/{id}",
|
||||
put(users::handle_update_user).delete(users::handle_delete_user),
|
||||
)
|
||||
.route("/api/usage/summary", get(usage::handle_usage_summary))
|
||||
.route("/api/usage/time-series", get(usage::handle_time_series))
|
||||
.route("/api/usage/clients", get(usage::handle_clients_usage))
|
||||
.route("/api/usage/providers", get(usage::handle_providers_usage))
|
||||
.route("/api/usage/detailed", get(usage::handle_detailed_usage))
|
||||
.route("/api/analytics/breakdown", get(usage::handle_analytics_breakdown))
|
||||
.route("/api/models", get(models::handle_get_models))
|
||||
.route("/api/models/{id}", put(models::handle_update_model))
|
||||
.route(
|
||||
"/api/clients",
|
||||
get(clients::handle_get_clients).post(clients::handle_create_client),
|
||||
)
|
||||
.route(
|
||||
"/api/clients/{id}",
|
||||
get(clients::handle_get_client)
|
||||
.put(clients::handle_update_client)
|
||||
.delete(clients::handle_delete_client),
|
||||
)
|
||||
.route("/api/clients/{id}/usage", get(clients::handle_client_usage))
|
||||
.route(
|
||||
"/api/clients/{id}/tokens",
|
||||
get(clients::handle_get_client_tokens).post(clients::handle_create_client_token),
|
||||
)
|
||||
.route(
|
||||
"/api/clients/{id}/tokens/{token_id}",
|
||||
delete(clients::handle_delete_client_token),
|
||||
)
|
||||
.route("/api/providers", get(providers::handle_get_providers))
|
||||
.route(
|
||||
"/api/providers/{name}",
|
||||
get(providers::handle_get_provider).put(providers::handle_update_provider),
|
||||
)
|
||||
.route("/api/providers/{name}/test", post(providers::handle_test_provider))
|
||||
.route("/api/system/health", get(system::handle_system_health))
|
||||
.route("/api/system/metrics", get(system::handle_system_metrics))
|
||||
.route("/api/system/logs", get(system::handle_system_logs))
|
||||
.route("/api/system/backup", post(system::handle_system_backup))
|
||||
.route(
|
||||
"/api/system/settings",
|
||||
get(system::handle_get_settings).post(system::handle_update_settings),
|
||||
)
|
||||
.with_state(dashboard_state)
|
||||
}
|
||||
@@ -1,204 +0,0 @@
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
response::Json,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json;
|
||||
use sqlx::Row;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
use crate::models::registry::{ModelFilter, ModelSortBy, SortOrder};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct UpdateModelRequest {
|
||||
pub(super) enabled: bool,
|
||||
pub(super) prompt_cost: Option<f64>,
|
||||
pub(super) completion_cost: Option<f64>,
|
||||
pub(super) mapping: Option<String>,
|
||||
}
|
||||
|
||||
/// Query parameters for `GET /api/models`.
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
pub(super) struct ModelListParams {
|
||||
/// Filter by provider ID.
|
||||
pub provider: Option<String>,
|
||||
/// Text search on model ID or name.
|
||||
pub search: Option<String>,
|
||||
/// Filter by input modality (e.g. "image").
|
||||
pub modality: Option<String>,
|
||||
/// Only models that support tool calling.
|
||||
pub tool_call: Option<bool>,
|
||||
/// Only models that support reasoning.
|
||||
pub reasoning: Option<bool>,
|
||||
/// Only models that have pricing data.
|
||||
pub has_cost: Option<bool>,
|
||||
/// Only models that have been used in requests.
|
||||
pub used_only: Option<bool>,
|
||||
/// Sort field (name, id, provider, context_limit, input_cost, output_cost).
|
||||
pub sort_by: Option<ModelSortBy>,
|
||||
/// Sort direction (asc, desc).
|
||||
pub sort_order: Option<SortOrder>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_models(
|
||||
State(state): State<DashboardState>,
|
||||
Query(params): Query<ModelListParams>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let registry = &state.app_state.model_registry;
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// If used_only, fetch the set of models that appear in llm_requests
|
||||
let used_models: Option<std::collections::HashSet<String>> =
|
||||
if params.used_only.unwrap_or(false) {
|
||||
match sqlx::query_scalar::<_, String>(
|
||||
"SELECT DISTINCT model FROM llm_requests",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
{
|
||||
Ok(models) => Some(models.into_iter().collect()),
|
||||
Err(_) => Some(std::collections::HashSet::new()),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Build filter from query params
|
||||
let filter = ModelFilter {
|
||||
provider: params.provider,
|
||||
search: params.search,
|
||||
modality: params.modality,
|
||||
tool_call: params.tool_call,
|
||||
reasoning: params.reasoning,
|
||||
has_cost: params.has_cost,
|
||||
};
|
||||
let sort_by = params.sort_by.unwrap_or_default();
|
||||
let sort_order = params.sort_order.unwrap_or_default();
|
||||
|
||||
// Get filtered and sorted model entries
|
||||
let entries = registry.list_models(&filter, &sort_by, &sort_order);
|
||||
|
||||
// Load overrides from database
|
||||
let db_models_result =
|
||||
sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
let mut db_models = HashMap::new();
|
||||
if let Ok(rows) = db_models_result {
|
||||
for row in rows {
|
||||
let id: String = row.get("id");
|
||||
db_models.insert(id, row);
|
||||
}
|
||||
}
|
||||
|
||||
let mut models_json = Vec::new();
|
||||
|
||||
for entry in &entries {
|
||||
let m_key = entry.model_key;
|
||||
|
||||
// Skip models not in the used set (when used_only is active)
|
||||
if let Some(ref used) = used_models {
|
||||
if !used.contains(m_key) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let m_meta = entry.metadata;
|
||||
|
||||
let mut enabled = true;
|
||||
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
||||
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
||||
let cache_read_cost = m_meta.cost.as_ref().and_then(|c| c.cache_read);
|
||||
let cache_write_cost = m_meta.cost.as_ref().and_then(|c| c.cache_write);
|
||||
let mut mapping = None::<String>;
|
||||
|
||||
if let Some(row) = db_models.get(m_key) {
|
||||
enabled = row.get("enabled");
|
||||
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") {
|
||||
prompt_cost = p;
|
||||
}
|
||||
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") {
|
||||
completion_cost = c;
|
||||
}
|
||||
mapping = row.get("mapping");
|
||||
}
|
||||
|
||||
models_json.push(serde_json::json!({
|
||||
"id": m_key,
|
||||
"provider": entry.provider_id,
|
||||
"provider_name": entry.provider_name,
|
||||
"name": m_meta.name,
|
||||
"enabled": enabled,
|
||||
"prompt_cost": prompt_cost,
|
||||
"completion_cost": completion_cost,
|
||||
"cache_read_cost": cache_read_cost,
|
||||
"cache_write_cost": cache_write_cost,
|
||||
"mapping": mapping,
|
||||
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
|
||||
"output_limit": m_meta.limit.as_ref().map(|l| l.output).unwrap_or(0),
|
||||
"modalities": m_meta.modalities.as_ref().map(|m| serde_json::json!({
|
||||
"input": m.input,
|
||||
"output": m.output,
|
||||
})),
|
||||
"tool_call": m_meta.tool_call,
|
||||
"reasoning": m_meta.reasoning,
|
||||
}));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(models_json)))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_update_model(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<String>,
|
||||
Json(payload): Json<UpdateModelRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Find provider_id for this model in registry
|
||||
let provider_id = state
|
||||
.app_state
|
||||
.model_registry
|
||||
.providers
|
||||
.iter()
|
||||
.find(|(_, p)| p.models.contains_key(&id))
|
||||
.map(|(id, _)| id.clone())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
enabled = excluded.enabled,
|
||||
prompt_cost_per_m = excluded.prompt_cost_per_m,
|
||||
completion_cost_per_m = excluded.completion_cost_per_m,
|
||||
mapping = excluded.mapping,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"#,
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(provider_id)
|
||||
.bind(payload.enabled)
|
||||
.bind(payload.prompt_cost)
|
||||
.bind(payload.completion_cost)
|
||||
.bind(payload.mapping)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
// Invalidate the in-memory cache so the proxy picks up the change immediately
|
||||
state.app_state.model_config_cache.invalidate().await;
|
||||
Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" })))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))),
|
||||
}
|
||||
}
|
||||
@@ -1,388 +0,0 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::Json,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json;
|
||||
use sqlx::Row;
|
||||
use std::collections::HashMap;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct UpdateProviderRequest {
|
||||
pub(super) enabled: bool,
|
||||
pub(super) base_url: Option<String>,
|
||||
pub(super) api_key: Option<String>,
|
||||
pub(super) credit_balance: Option<f64>,
|
||||
pub(super) low_credit_threshold: Option<f64>,
|
||||
pub(super) billing_mode: Option<String>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_providers(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let registry = &state.app_state.model_registry;
|
||||
let config = &state.app_state.config;
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Load all overrides from database (including billing_mode)
|
||||
let db_configs_result = sqlx::query(
|
||||
"SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
let mut db_configs = HashMap::new();
|
||||
if let Ok(rows) = db_configs_result {
|
||||
for row in rows {
|
||||
let id: String = row.get("id");
|
||||
let enabled: bool = row.get("enabled");
|
||||
let base_url: Option<String> = row.get("base_url");
|
||||
let balance: f64 = row.get("credit_balance");
|
||||
let threshold: f64 = row.get("low_credit_threshold");
|
||||
let billing_mode: Option<String> = row.get("billing_mode");
|
||||
db_configs.insert(id, (enabled, base_url, balance, threshold, billing_mode));
|
||||
}
|
||||
}
|
||||
|
||||
let mut providers_json = Vec::new();
|
||||
|
||||
// Define the list of providers we support
|
||||
let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
|
||||
|
||||
for id in provider_ids {
|
||||
// Get base config
|
||||
let (mut enabled, mut base_url, display_name) = match id {
|
||||
"openai" => (
|
||||
config.providers.openai.enabled,
|
||||
config.providers.openai.base_url.clone(),
|
||||
"OpenAI",
|
||||
),
|
||||
"gemini" => (
|
||||
config.providers.gemini.enabled,
|
||||
config.providers.gemini.base_url.clone(),
|
||||
"Google Gemini",
|
||||
),
|
||||
"deepseek" => (
|
||||
config.providers.deepseek.enabled,
|
||||
config.providers.deepseek.base_url.clone(),
|
||||
"DeepSeek",
|
||||
),
|
||||
"grok" => (
|
||||
config.providers.grok.enabled,
|
||||
config.providers.grok.base_url.clone(),
|
||||
"xAI Grok",
|
||||
),
|
||||
"ollama" => (
|
||||
config.providers.ollama.enabled,
|
||||
config.providers.ollama.base_url.clone(),
|
||||
"Ollama",
|
||||
),
|
||||
_ => (false, "".to_string(), "Unknown"),
|
||||
};
|
||||
|
||||
let mut balance = 0.0;
|
||||
let mut threshold = 5.0;
|
||||
let mut billing_mode: Option<String> = None;
|
||||
|
||||
// Apply database overrides
|
||||
if let Some((db_enabled, db_url, db_balance, db_threshold, db_billing)) = db_configs.get(id) {
|
||||
enabled = *db_enabled;
|
||||
if let Some(url) = db_url {
|
||||
base_url = url.clone();
|
||||
}
|
||||
balance = *db_balance;
|
||||
threshold = *db_threshold;
|
||||
billing_mode = db_billing.clone();
|
||||
}
|
||||
|
||||
// Find models for this provider in registry
|
||||
// NOTE: registry provider IDs differ from internal IDs for some providers.
|
||||
let registry_key = match id {
|
||||
"gemini" => "google",
|
||||
"grok" => "xai",
|
||||
_ => id,
|
||||
};
|
||||
|
||||
let mut models = Vec::new();
|
||||
if let Some(p_info) = registry.providers.get(registry_key) {
|
||||
models = p_info.models.keys().cloned().collect();
|
||||
} else if id == "ollama" {
|
||||
models = config.providers.ollama.models.clone();
|
||||
}
|
||||
|
||||
// Determine status
|
||||
let status = if !enabled {
|
||||
"disabled"
|
||||
} else {
|
||||
// Check if it's actually initialized in the provider manager
|
||||
if state.app_state.provider_manager.get_provider(id).await.is_some() {
|
||||
// Check circuit breaker
|
||||
if state
|
||||
.app_state
|
||||
.rate_limit_manager
|
||||
.check_provider_request(id)
|
||||
.await
|
||||
.unwrap_or(true)
|
||||
{
|
||||
"online"
|
||||
} else {
|
||||
"degraded"
|
||||
}
|
||||
} else {
|
||||
"error" // Enabled but failed to initialize (e.g. missing API key)
|
||||
}
|
||||
};
|
||||
|
||||
providers_json.push(serde_json::json!({
|
||||
"id": id,
|
||||
"name": display_name,
|
||||
"enabled": enabled,
|
||||
"status": status,
|
||||
"models": models,
|
||||
"base_url": base_url,
|
||||
"credit_balance": balance,
|
||||
"low_credit_threshold": threshold,
|
||||
"billing_mode": billing_mode,
|
||||
"last_used": None::<String>,
|
||||
}));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(providers_json)))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_provider(
|
||||
State(state): State<DashboardState>,
|
||||
Path(name): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let registry = &state.app_state.model_registry;
|
||||
let config = &state.app_state.config;
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Validate provider name
|
||||
let (mut enabled, mut base_url, display_name) = match name.as_str() {
|
||||
"openai" => (
|
||||
config.providers.openai.enabled,
|
||||
config.providers.openai.base_url.clone(),
|
||||
"OpenAI",
|
||||
),
|
||||
"gemini" => (
|
||||
config.providers.gemini.enabled,
|
||||
config.providers.gemini.base_url.clone(),
|
||||
"Google Gemini",
|
||||
),
|
||||
"deepseek" => (
|
||||
config.providers.deepseek.enabled,
|
||||
config.providers.deepseek.base_url.clone(),
|
||||
"DeepSeek",
|
||||
),
|
||||
"grok" => (
|
||||
config.providers.grok.enabled,
|
||||
config.providers.grok.base_url.clone(),
|
||||
"xAI Grok",
|
||||
),
|
||||
"ollama" => (
|
||||
config.providers.ollama.enabled,
|
||||
config.providers.ollama.base_url.clone(),
|
||||
"Ollama",
|
||||
),
|
||||
_ => return Json(ApiResponse::error(format!("Unknown provider '{}'", name))),
|
||||
};
|
||||
|
||||
let mut balance = 0.0;
|
||||
let mut threshold = 5.0;
|
||||
let mut billing_mode: Option<String> = None;
|
||||
|
||||
// Apply database overrides
|
||||
let db_config = sqlx::query(
|
||||
"SELECT enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs WHERE id = ?",
|
||||
)
|
||||
.bind(&name)
|
||||
.fetch_optional(pool)
|
||||
.await;
|
||||
|
||||
if let Ok(Some(row)) = db_config {
|
||||
enabled = row.get::<bool, _>("enabled");
|
||||
if let Some(url) = row.get::<Option<String>, _>("base_url") {
|
||||
base_url = url;
|
||||
}
|
||||
balance = row.get::<f64, _>("credit_balance");
|
||||
threshold = row.get::<f64, _>("low_credit_threshold");
|
||||
billing_mode = row.get::<Option<String>, _>("billing_mode");
|
||||
}
|
||||
|
||||
// Find models for this provider
|
||||
// NOTE: registry provider IDs differ from internal IDs for some providers.
|
||||
let registry_key = match name.as_str() {
|
||||
"gemini" => "google",
|
||||
"grok" => "xai",
|
||||
_ => name.as_str(),
|
||||
};
|
||||
|
||||
let mut models = Vec::new();
|
||||
if let Some(p_info) = registry.providers.get(registry_key) {
|
||||
models = p_info.models.keys().cloned().collect();
|
||||
} else if name == "ollama" {
|
||||
models = config.providers.ollama.models.clone();
|
||||
}
|
||||
|
||||
// Determine status
|
||||
let status = if !enabled {
|
||||
"disabled"
|
||||
} else if state.app_state.provider_manager.get_provider(&name).await.is_some() {
|
||||
if state
|
||||
.app_state
|
||||
.rate_limit_manager
|
||||
.check_provider_request(&name)
|
||||
.await
|
||||
.unwrap_or(true)
|
||||
{
|
||||
"online"
|
||||
} else {
|
||||
"degraded"
|
||||
}
|
||||
} else {
|
||||
"error"
|
||||
};
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"id": name,
|
||||
"name": display_name,
|
||||
"enabled": enabled,
|
||||
"status": status,
|
||||
"models": models,
|
||||
"base_url": base_url,
|
||||
"credit_balance": balance,
|
||||
"low_credit_threshold": threshold,
|
||||
"billing_mode": billing_mode,
|
||||
"last_used": None::<String>,
|
||||
})))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_update_provider(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(name): Path<String>,
|
||||
Json(payload): Json<UpdateProviderRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Update or insert into database (include billing_mode)
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
enabled = excluded.enabled,
|
||||
base_url = excluded.base_url,
|
||||
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
|
||||
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
|
||||
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
|
||||
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"#,
|
||||
)
|
||||
.bind(&name)
|
||||
.bind(name.to_uppercase())
|
||||
.bind(payload.enabled)
|
||||
.bind(&payload.base_url)
|
||||
.bind(&payload.api_key)
|
||||
.bind(payload.credit_balance)
|
||||
.bind(payload.low_credit_threshold)
|
||||
.bind(payload.billing_mode)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
// Re-initialize provider in manager
|
||||
if let Err(e) = state
|
||||
.app_state
|
||||
.provider_manager
|
||||
.initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool)
|
||||
.await
|
||||
{
|
||||
warn!("Failed to re-initialize provider {}: {}", name, e);
|
||||
return Json(ApiResponse::error(format!(
|
||||
"Provider settings saved but initialization failed: {}",
|
||||
e
|
||||
)));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(
|
||||
serde_json::json!({ "message": "Provider updated and re-initialized" }),
|
||||
))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to update provider config: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to update provider: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_test_provider(
|
||||
State(state): State<DashboardState>,
|
||||
Path(name): Path<String>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let provider = match state.app_state.provider_manager.get_provider(&name).await {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
return Json(ApiResponse::error(format!(
|
||||
"Provider '{}' not found or not enabled",
|
||||
name
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// Pick a real model for this provider from the registry
|
||||
// NOTE: registry provider IDs differ from internal IDs for some providers.
|
||||
let registry_key = match name.as_str() {
|
||||
"gemini" => "google",
|
||||
"grok" => "xai",
|
||||
_ => name.as_str(),
|
||||
};
|
||||
|
||||
let test_model = state
|
||||
.app_state
|
||||
.model_registry
|
||||
.providers
|
||||
.get(registry_key)
|
||||
.and_then(|p| p.models.keys().next().cloned())
|
||||
.unwrap_or_else(|| name.clone());
|
||||
|
||||
let test_request = crate::models::UnifiedRequest {
|
||||
client_id: "system-test".to_string(),
|
||||
model: test_model,
|
||||
messages: vec![crate::models::UnifiedMessage {
|
||||
role: "user".to_string(),
|
||||
content: vec![crate::models::ContentPart::Text { text: "Hi".to_string() }],
|
||||
tool_calls: None,
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
temperature: None,
|
||||
max_tokens: Some(5),
|
||||
stream: false,
|
||||
has_images: false,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
};
|
||||
|
||||
match provider.chat_completion(test_request).await {
|
||||
Ok(_) => {
|
||||
let latency = start.elapsed().as_millis();
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"success": true,
|
||||
"latency": latency,
|
||||
"message": "Connection test successful"
|
||||
})))
|
||||
}
|
||||
Err(e) => Json(ApiResponse::error(format!("Provider test failed: {}", e))),
|
||||
}
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Session {
|
||||
pub username: String,
|
||||
pub role: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SessionManager {
|
||||
sessions: Arc<RwLock<HashMap<String, Session>>>,
|
||||
ttl_hours: i64,
|
||||
}
|
||||
|
||||
impl SessionManager {
|
||||
pub fn new(ttl_hours: i64) -> Self {
|
||||
Self {
|
||||
sessions: Arc::new(RwLock::new(HashMap::new())),
|
||||
ttl_hours,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new session and return the session token.
|
||||
pub async fn create_session(&self, username: String, role: String) -> String {
|
||||
let token = format!("session-{}", uuid::Uuid::new_v4());
|
||||
let now = Utc::now();
|
||||
let session = Session {
|
||||
username,
|
||||
role,
|
||||
created_at: now,
|
||||
expires_at: now + Duration::hours(self.ttl_hours),
|
||||
};
|
||||
self.sessions.write().await.insert(token.clone(), session);
|
||||
token
|
||||
}
|
||||
|
||||
/// Validate a session token and return the session if valid and not expired.
|
||||
pub async fn validate_session(&self, token: &str) -> Option<Session> {
|
||||
let sessions = self.sessions.read().await;
|
||||
sessions.get(token).and_then(|s| {
|
||||
if s.expires_at > Utc::now() {
|
||||
Some(s.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Revoke (delete) a session by token.
|
||||
pub async fn revoke_session(&self, token: &str) {
|
||||
self.sessions.write().await.remove(token);
|
||||
}
|
||||
|
||||
/// Remove all expired sessions from the store.
|
||||
pub async fn cleanup_expired(&self) {
|
||||
let now = Utc::now();
|
||||
self.sessions.write().await.retain(|_, s| s.expires_at > now);
|
||||
}
|
||||
}
|
||||
@@ -1,365 +0,0 @@
|
||||
use axum::{extract::State, response::Json};
|
||||
use chrono;
|
||||
use serde_json;
|
||||
use sqlx::Row;
|
||||
use std::collections::HashMap;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
|
||||
/// Read a value from /proc files, returning None on any failure.
|
||||
fn read_proc_file(path: &str) -> Option<String> {
|
||||
std::fs::read_to_string(path).ok()
|
||||
}
|
||||
|
||||
pub(super) async fn handle_system_health(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let mut components = HashMap::new();
|
||||
components.insert("api_server".to_string(), "online".to_string());
|
||||
components.insert("database".to_string(), "online".to_string());
|
||||
|
||||
// Check provider health via circuit breakers
|
||||
let provider_ids: Vec<String> = state
|
||||
.app_state
|
||||
.provider_manager
|
||||
.get_all_providers()
|
||||
.await
|
||||
.iter()
|
||||
.map(|p| p.name().to_string())
|
||||
.collect();
|
||||
|
||||
for p_id in provider_ids {
|
||||
if state
|
||||
.app_state
|
||||
.rate_limit_manager
|
||||
.check_provider_request(&p_id)
|
||||
.await
|
||||
.unwrap_or(true)
|
||||
{
|
||||
components.insert(p_id, "online".to_string());
|
||||
} else {
|
||||
components.insert(p_id, "degraded".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Read real memory usage from /proc/self/status
|
||||
let memory_mb = read_proc_file("/proc/self/status")
|
||||
.and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string()))
|
||||
.and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::<f64>().ok()))
|
||||
.map(|kb| kb / 1024.0)
|
||||
.unwrap_or(0.0);
|
||||
|
||||
// Get real database pool stats
|
||||
let db_pool_size = state.app_state.db_pool.size();
|
||||
let db_pool_idle = state.app_state.db_pool.num_idle();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"status": "healthy",
|
||||
"timestamp": chrono::Utc::now().to_rfc3339(),
|
||||
"components": components,
|
||||
"metrics": {
|
||||
"memory_usage_mb": (memory_mb * 10.0).round() / 10.0,
|
||||
"db_connections_active": db_pool_size - db_pool_idle as u32,
|
||||
"db_connections_idle": db_pool_idle,
|
||||
}
|
||||
})))
|
||||
}
|
||||
|
||||
/// Real system metrics from /proc (Linux only).
|
||||
pub(super) async fn handle_system_metrics(
|
||||
State(state): State<DashboardState>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
// --- CPU usage (aggregate across all cores) ---
|
||||
// /proc/stat first line: cpu user nice system idle iowait irq softirq steal guest guest_nice
|
||||
let cpu_percent = read_proc_file("/proc/stat")
|
||||
.and_then(|s| {
|
||||
let line = s.lines().find(|l| l.starts_with("cpu "))?.to_string();
|
||||
let fields: Vec<u64> = line
|
||||
.split_whitespace()
|
||||
.skip(1)
|
||||
.filter_map(|v| v.parse().ok())
|
||||
.collect();
|
||||
if fields.len() >= 4 {
|
||||
let idle = fields[3];
|
||||
let total: u64 = fields.iter().sum();
|
||||
if total > 0 {
|
||||
Some(((total - idle) as f64 / total as f64 * 100.0 * 10.0).round() / 10.0)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or(0.0);
|
||||
|
||||
// --- Memory (system-wide from /proc/meminfo) ---
|
||||
let meminfo = read_proc_file("/proc/meminfo").unwrap_or_default();
|
||||
let parse_meminfo = |key: &str| -> u64 {
|
||||
meminfo
|
||||
.lines()
|
||||
.find(|l| l.starts_with(key))
|
||||
.and_then(|l| l.split_whitespace().nth(1))
|
||||
.and_then(|v| v.parse::<u64>().ok())
|
||||
.unwrap_or(0)
|
||||
};
|
||||
let mem_total_kb = parse_meminfo("MemTotal:");
|
||||
let mem_available_kb = parse_meminfo("MemAvailable:");
|
||||
let mem_used_kb = mem_total_kb.saturating_sub(mem_available_kb);
|
||||
let mem_total_mb = mem_total_kb as f64 / 1024.0;
|
||||
let mem_used_mb = mem_used_kb as f64 / 1024.0;
|
||||
let mem_percent = if mem_total_kb > 0 {
|
||||
(mem_used_kb as f64 / mem_total_kb as f64 * 100.0 * 10.0).round() / 10.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// --- Process-specific memory (VmRSS) ---
|
||||
let process_rss_mb = read_proc_file("/proc/self/status")
|
||||
.and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string()))
|
||||
.and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::<f64>().ok()))
|
||||
.map(|kb| (kb / 1024.0 * 10.0).round() / 10.0)
|
||||
.unwrap_or(0.0);
|
||||
|
||||
// --- Disk usage of the data directory ---
|
||||
let (disk_total_gb, disk_used_gb, disk_percent) = {
|
||||
// statvfs via libc would be ideal; use df as a simple fallback
|
||||
std::process::Command::new("df")
|
||||
.args(["-BM", "--output=size,used,pcent", "."])
|
||||
.output()
|
||||
.ok()
|
||||
.and_then(|o| {
|
||||
let out = String::from_utf8_lossy(&o.stdout);
|
||||
let line = out.lines().nth(1)?.to_string();
|
||||
let parts: Vec<&str> = line.split_whitespace().collect();
|
||||
if parts.len() >= 3 {
|
||||
let total = parts[0].trim_end_matches('M').parse::<f64>().unwrap_or(0.0) / 1024.0;
|
||||
let used = parts[1].trim_end_matches('M').parse::<f64>().unwrap_or(0.0) / 1024.0;
|
||||
let pct = parts[2].trim_end_matches('%').parse::<f64>().unwrap_or(0.0);
|
||||
Some(((total * 10.0).round() / 10.0, (used * 10.0).round() / 10.0, pct))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or((0.0, 0.0, 0.0))
|
||||
};
|
||||
|
||||
// --- Uptime ---
|
||||
let uptime_seconds = read_proc_file("/proc/uptime")
|
||||
.and_then(|s| s.split_whitespace().next().and_then(|v| v.parse::<f64>().ok()))
|
||||
.unwrap_or(0.0) as u64;
|
||||
|
||||
// --- Load average ---
|
||||
let (load_1, load_5, load_15) = read_proc_file("/proc/loadavg")
|
||||
.and_then(|s| {
|
||||
let parts: Vec<&str> = s.split_whitespace().collect();
|
||||
if parts.len() >= 3 {
|
||||
Some((
|
||||
parts[0].parse::<f64>().unwrap_or(0.0),
|
||||
parts[1].parse::<f64>().unwrap_or(0.0),
|
||||
parts[2].parse::<f64>().unwrap_or(0.0),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or((0.0, 0.0, 0.0));
|
||||
|
||||
// --- Network (from /proc/net/dev, aggregate non-lo interfaces) ---
|
||||
let (net_rx_bytes, net_tx_bytes) = read_proc_file("/proc/net/dev")
|
||||
.map(|s| {
|
||||
s.lines()
|
||||
.skip(2) // skip header lines
|
||||
.filter(|l| !l.trim().starts_with("lo:"))
|
||||
.fold((0u64, 0u64), |(rx, tx), line| {
|
||||
let parts: Vec<&str> = line.split_whitespace().collect();
|
||||
if parts.len() >= 10 {
|
||||
let r = parts[1].parse::<u64>().unwrap_or(0);
|
||||
let t = parts[9].parse::<u64>().unwrap_or(0);
|
||||
(rx + r, tx + t)
|
||||
} else {
|
||||
(rx, tx)
|
||||
}
|
||||
})
|
||||
})
|
||||
.unwrap_or((0, 0));
|
||||
|
||||
// --- Database pool ---
|
||||
let db_pool_size = state.app_state.db_pool.size();
|
||||
let db_pool_idle = state.app_state.db_pool.num_idle();
|
||||
|
||||
// --- Active WebSocket listeners ---
|
||||
let ws_listeners = state.app_state.dashboard_tx.receiver_count();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"cpu": {
|
||||
"usage_percent": cpu_percent,
|
||||
"load_average": [load_1, load_5, load_15],
|
||||
},
|
||||
"memory": {
|
||||
"total_mb": (mem_total_mb * 10.0).round() / 10.0,
|
||||
"used_mb": (mem_used_mb * 10.0).round() / 10.0,
|
||||
"usage_percent": mem_percent,
|
||||
"process_rss_mb": process_rss_mb,
|
||||
},
|
||||
"disk": {
|
||||
"total_gb": disk_total_gb,
|
||||
"used_gb": disk_used_gb,
|
||||
"usage_percent": disk_percent,
|
||||
},
|
||||
"network": {
|
||||
"rx_bytes": net_rx_bytes,
|
||||
"tx_bytes": net_tx_bytes,
|
||||
},
|
||||
"uptime_seconds": uptime_seconds,
|
||||
"connections": {
|
||||
"db_active": db_pool_size - db_pool_idle as u32,
|
||||
"db_idle": db_pool_idle,
|
||||
"websocket_listeners": ws_listeners,
|
||||
},
|
||||
"timestamp": chrono::Utc::now().to_rfc3339(),
|
||||
})))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_system_logs(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
id,
|
||||
timestamp,
|
||||
client_id,
|
||||
provider,
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cost,
|
||||
status,
|
||||
error_message,
|
||||
duration_ms
|
||||
FROM llm_requests
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 100
|
||||
"#,
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let logs: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
serde_json::json!({
|
||||
"id": row.get::<i64, _>("id"),
|
||||
"timestamp": row.get::<chrono::DateTime<chrono::Utc>, _>("timestamp"),
|
||||
"client_id": row.get::<String, _>("client_id"),
|
||||
"provider": row.get::<String, _>("provider"),
|
||||
"model": row.get::<String, _>("model"),
|
||||
"tokens": row.get::<i64, _>("total_tokens"),
|
||||
"cost": row.get::<f64, _>("cost"),
|
||||
"status": row.get::<String, _>("status"),
|
||||
"error": row.get::<Option<String>, _>("error_message"),
|
||||
"duration": row.get::<i64, _>("duration_ms"),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(logs)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch system logs: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch system logs".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_system_backup(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
let backup_id = format!("backup-{}", chrono::Utc::now().timestamp());
|
||||
let backup_path = format!("data/{}.db", backup_id);
|
||||
|
||||
// Ensure the data directory exists
|
||||
if let Err(e) = std::fs::create_dir_all("data") {
|
||||
return Json(ApiResponse::error(format!("Failed to create backup directory: {}", e)));
|
||||
}
|
||||
|
||||
// Use SQLite VACUUM INTO for a consistent backup
|
||||
let result = sqlx::query(&format!("VACUUM INTO '{}'", backup_path))
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
// Get backup file size
|
||||
let size_bytes = std::fs::metadata(&backup_path).map(|m| m.len()).unwrap_or(0);
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"success": true,
|
||||
"message": "Backup completed successfully",
|
||||
"backup_id": backup_id,
|
||||
"backup_path": backup_path,
|
||||
"size_bytes": size_bytes,
|
||||
})))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Database backup failed: {}", e);
|
||||
Json(ApiResponse::error(format!("Backup failed: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_get_settings(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let registry = &state.app_state.model_registry;
|
||||
let provider_count = registry.providers.len();
|
||||
let model_count: usize = registry.providers.values().map(|p| p.models.len()).sum();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"server": {
|
||||
"auth_tokens": state.app_state.auth_tokens.iter().map(|t| mask_token(t)).collect::<Vec<_>>(),
|
||||
"version": env!("CARGO_PKG_VERSION"),
|
||||
},
|
||||
"registry": {
|
||||
"provider_count": provider_count,
|
||||
"model_count": model_count,
|
||||
},
|
||||
"database": {
|
||||
"type": "SQLite",
|
||||
}
|
||||
})))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_update_settings(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = super::auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
|
||||
Json(ApiResponse::error(
|
||||
"Changing settings at runtime is not yet supported. Please update your config file and restart the server."
|
||||
.to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
fn mask_token(token: &str) -> String {
|
||||
if token.len() <= 8 {
|
||||
return "*****".to_string();
|
||||
}
|
||||
|
||||
let masked_len = token.len().min(12);
|
||||
let visible_len = 4;
|
||||
let mask_len = masked_len - visible_len;
|
||||
|
||||
format!("{}{}", "*".repeat(mask_len), &token[token.len() - visible_len..])
|
||||
}
|
||||
@@ -1,482 +0,0 @@
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
response::Json,
|
||||
};
|
||||
use chrono;
|
||||
use serde::Deserialize;
|
||||
use serde_json;
|
||||
use sqlx::Row;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{ApiResponse, DashboardState};
|
||||
|
||||
/// Query parameters for time-based filtering on usage endpoints.
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
pub(super) struct UsagePeriodFilter {
|
||||
/// Preset period: "today", "24h", "7d", "30d", "all" (default: "all")
|
||||
pub period: Option<String>,
|
||||
/// Custom range start (ISO 8601, e.g. "2025-06-01T00:00:00Z")
|
||||
pub from: Option<String>,
|
||||
/// Custom range end (ISO 8601)
|
||||
pub to: Option<String>,
|
||||
}
|
||||
|
||||
impl UsagePeriodFilter {
|
||||
/// Returns `(sql_fragment, bind_values)` for a WHERE clause.
|
||||
/// The fragment is either empty (no filter) or " AND timestamp >= ? [AND timestamp <= ?]".
|
||||
fn to_sql(&self) -> (String, Vec<String>) {
|
||||
let period = self.period.as_deref().unwrap_or("all");
|
||||
|
||||
if period == "custom" {
|
||||
let mut clause = String::new();
|
||||
let mut binds = Vec::new();
|
||||
if let Some(ref from) = self.from {
|
||||
clause.push_str(" AND timestamp >= ?");
|
||||
binds.push(from.clone());
|
||||
}
|
||||
if let Some(ref to) = self.to {
|
||||
clause.push_str(" AND timestamp <= ?");
|
||||
binds.push(to.clone());
|
||||
}
|
||||
return (clause, binds);
|
||||
}
|
||||
|
||||
let now = chrono::Utc::now();
|
||||
let cutoff = match period {
|
||||
"today" => {
|
||||
// Start of today (UTC)
|
||||
let today = now.format("%Y-%m-%dT00:00:00Z").to_string();
|
||||
Some(today)
|
||||
}
|
||||
"24h" => Some((now - chrono::Duration::hours(24)).to_rfc3339()),
|
||||
"7d" => Some((now - chrono::Duration::days(7)).to_rfc3339()),
|
||||
"30d" => Some((now - chrono::Duration::days(30)).to_rfc3339()),
|
||||
_ => None, // "all" or unrecognized
|
||||
};
|
||||
|
||||
match cutoff {
|
||||
Some(ts) => (" AND timestamp >= ?".to_string(), vec![ts]),
|
||||
None => (String::new(), vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine the time-series granularity label for grouping.
|
||||
fn granularity(&self) -> &'static str {
|
||||
match self.period.as_deref().unwrap_or("all") {
|
||||
"today" | "24h" => "hour",
|
||||
_ => "day",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_usage_summary(
|
||||
State(state): State<DashboardState>,
|
||||
Query(filter): Query<UsagePeriodFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
let (period_clause, period_binds) = filter.to_sql();
|
||||
|
||||
// Total stats (filtered by period)
|
||||
let period_sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(total_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(cost), 0.0) as total_cost,
|
||||
COUNT(DISTINCT client_id) as active_clients,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read,
|
||||
COALESCE(SUM(cache_write_tokens), 0) as total_cache_write
|
||||
FROM llm_requests
|
||||
WHERE 1=1 {}
|
||||
"#,
|
||||
period_clause
|
||||
);
|
||||
let mut q = sqlx::query(&period_sql);
|
||||
for b in &period_binds {
|
||||
q = q.bind(b);
|
||||
}
|
||||
let total_stats = q.fetch_one(pool);
|
||||
|
||||
// Today's stats
|
||||
let today = chrono::Utc::now().format("%Y-%m-%d").to_string();
|
||||
let today_stats = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
COUNT(*) as today_requests,
|
||||
COALESCE(SUM(total_tokens), 0) as today_tokens,
|
||||
COALESCE(SUM(cost), 0.0) as today_cost
|
||||
FROM llm_requests
|
||||
WHERE strftime('%Y-%m-%d', timestamp) = ?
|
||||
"#,
|
||||
)
|
||||
.bind(today)
|
||||
.fetch_one(pool);
|
||||
|
||||
// Error stats
|
||||
let error_stats = sqlx::query(
|
||||
r#"
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as errors
|
||||
FROM llm_requests
|
||||
"#,
|
||||
)
|
||||
.fetch_one(pool);
|
||||
|
||||
// Average response time
|
||||
let avg_response = sqlx::query(
|
||||
r#"
|
||||
SELECT COALESCE(AVG(duration_ms), 0.0) as avg_duration
|
||||
FROM llm_requests
|
||||
WHERE status = 'success'
|
||||
"#,
|
||||
)
|
||||
.fetch_one(pool);
|
||||
|
||||
match tokio::join!(total_stats, today_stats, error_stats, avg_response) {
|
||||
(Ok(t), Ok(d), Ok(e), Ok(a)) => {
|
||||
let total_requests: i64 = t.get("total_requests");
|
||||
let total_tokens: i64 = t.get("total_tokens");
|
||||
let total_cost: f64 = t.get("total_cost");
|
||||
let active_clients: i64 = t.get("active_clients");
|
||||
let total_cache_read: i64 = t.get("total_cache_read");
|
||||
let total_cache_write: i64 = t.get("total_cache_write");
|
||||
|
||||
let today_requests: i64 = d.get("today_requests");
|
||||
let today_cost: f64 = d.get("today_cost");
|
||||
|
||||
let total_count: i64 = e.get("total");
|
||||
let error_count: i64 = e.get("errors");
|
||||
let error_rate = if total_count > 0 {
|
||||
(error_count as f64 / total_count as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let avg_response_time: f64 = a.get("avg_duration");
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"total_requests": total_requests,
|
||||
"total_tokens": total_tokens,
|
||||
"total_cost": total_cost,
|
||||
"active_clients": active_clients,
|
||||
"today_requests": today_requests,
|
||||
"today_cost": today_cost,
|
||||
"error_rate": error_rate,
|
||||
"avg_response_time": avg_response_time,
|
||||
"total_cache_read_tokens": total_cache_read,
|
||||
"total_cache_write_tokens": total_cache_write,
|
||||
})))
|
||||
}
|
||||
_ => Json(ApiResponse::error("Failed to fetch usage statistics".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_time_series(
|
||||
State(state): State<DashboardState>,
|
||||
Query(filter): Query<UsagePeriodFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
let (period_clause, period_binds) = filter.to_sql();
|
||||
let granularity = filter.granularity();
|
||||
|
||||
// Determine the strftime format and default lookback
|
||||
let (strftime_fmt, _label_key, default_lookback) = match granularity {
|
||||
"hour" => ("%H:00", "hour", chrono::Duration::hours(24)),
|
||||
_ => ("%Y-%m-%d", "day", chrono::Duration::days(30)),
|
||||
};
|
||||
|
||||
// If no period filter, apply a sensible default lookback
|
||||
let (clause, binds) = if period_clause.is_empty() {
|
||||
let cutoff = (chrono::Utc::now() - default_lookback).to_rfc3339();
|
||||
(" AND timestamp >= ?".to_string(), vec![cutoff])
|
||||
} else {
|
||||
(period_clause, period_binds)
|
||||
};
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
strftime('{strftime_fmt}', timestamp) as bucket,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(total_tokens), 0) as tokens,
|
||||
COALESCE(SUM(cost), 0.0) as cost
|
||||
FROM llm_requests
|
||||
WHERE 1=1 {clause}
|
||||
GROUP BY bucket
|
||||
ORDER BY bucket
|
||||
"#,
|
||||
);
|
||||
|
||||
let mut q = sqlx::query(&sql);
|
||||
for b in &binds {
|
||||
q = q.bind(b);
|
||||
}
|
||||
|
||||
let result = q.fetch_all(pool).await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let mut series = Vec::new();
|
||||
|
||||
for row in rows {
|
||||
let bucket: String = row.get("bucket");
|
||||
let requests: i64 = row.get("requests");
|
||||
let tokens: i64 = row.get("tokens");
|
||||
let cost: f64 = row.get("cost");
|
||||
|
||||
series.push(serde_json::json!({
|
||||
"time": bucket,
|
||||
"requests": requests,
|
||||
"tokens": tokens,
|
||||
"cost": cost,
|
||||
}));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"series": series,
|
||||
"period": filter.period.as_deref().unwrap_or("all"),
|
||||
"granularity": granularity,
|
||||
})))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch time series data: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch time series data".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_clients_usage(
|
||||
State(state): State<DashboardState>,
|
||||
Query(filter): Query<UsagePeriodFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
let (period_clause, period_binds) = filter.to_sql();
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
client_id,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(total_tokens), 0) as tokens,
|
||||
COALESCE(SUM(cost), 0.0) as cost,
|
||||
MAX(timestamp) as last_request
|
||||
FROM llm_requests
|
||||
WHERE 1=1 {}
|
||||
GROUP BY client_id
|
||||
ORDER BY requests DESC
|
||||
"#,
|
||||
period_clause
|
||||
);
|
||||
|
||||
let mut q = sqlx::query(&sql);
|
||||
for b in &period_binds {
|
||||
q = q.bind(b);
|
||||
}
|
||||
|
||||
let result = q.fetch_all(pool).await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let mut client_usage = Vec::new();
|
||||
|
||||
for row in rows {
|
||||
let client_id: String = row.get("client_id");
|
||||
let requests: i64 = row.get("requests");
|
||||
let tokens: i64 = row.get("tokens");
|
||||
let cost: f64 = row.get("cost");
|
||||
let last_request: Option<chrono::DateTime<chrono::Utc>> = row.get("last_request");
|
||||
|
||||
client_usage.push(serde_json::json!({
|
||||
"client_id": client_id,
|
||||
"client_name": client_id,
|
||||
"requests": requests,
|
||||
"tokens": tokens,
|
||||
"cost": cost,
|
||||
"last_request": last_request,
|
||||
}));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(client_usage)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch client usage data: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch client usage data".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_providers_usage(
|
||||
State(state): State<DashboardState>,
|
||||
Query(filter): Query<UsagePeriodFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
let (period_clause, period_binds) = filter.to_sql();
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
provider,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(total_tokens), 0) as tokens,
|
||||
COALESCE(SUM(cost), 0.0) as cost,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read,
|
||||
COALESCE(SUM(cache_write_tokens), 0) as cache_write
|
||||
FROM llm_requests
|
||||
WHERE 1=1 {}
|
||||
GROUP BY provider
|
||||
ORDER BY requests DESC
|
||||
"#,
|
||||
period_clause
|
||||
);
|
||||
|
||||
let mut q = sqlx::query(&sql);
|
||||
for b in &period_binds {
|
||||
q = q.bind(b);
|
||||
}
|
||||
|
||||
let result = q.fetch_all(pool).await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let mut provider_usage = Vec::new();
|
||||
|
||||
for row in rows {
|
||||
let provider: String = row.get("provider");
|
||||
let requests: i64 = row.get("requests");
|
||||
let tokens: i64 = row.get("tokens");
|
||||
let cost: f64 = row.get("cost");
|
||||
let cache_read: i64 = row.get("cache_read");
|
||||
let cache_write: i64 = row.get("cache_write");
|
||||
|
||||
provider_usage.push(serde_json::json!({
|
||||
"provider": provider,
|
||||
"requests": requests,
|
||||
"tokens": tokens,
|
||||
"cost": cost,
|
||||
"cache_read_tokens": cache_read,
|
||||
"cache_write_tokens": cache_write,
|
||||
}));
|
||||
}
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(provider_usage)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch provider usage data: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch provider usage data".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_detailed_usage(
|
||||
State(state): State<DashboardState>,
|
||||
Query(filter): Query<UsagePeriodFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
let (period_clause, period_binds) = filter.to_sql();
|
||||
|
||||
let sql = format!(
|
||||
r#"
|
||||
SELECT
|
||||
strftime('%Y-%m-%d', timestamp) as date,
|
||||
client_id,
|
||||
provider,
|
||||
model,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(total_tokens), 0) as tokens,
|
||||
COALESCE(SUM(cost), 0.0) as cost,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read,
|
||||
COALESCE(SUM(cache_write_tokens), 0) as cache_write
|
||||
FROM llm_requests
|
||||
WHERE 1=1 {}
|
||||
GROUP BY date, client_id, provider, model
|
||||
ORDER BY date DESC
|
||||
LIMIT 200
|
||||
"#,
|
||||
period_clause
|
||||
);
|
||||
|
||||
let mut q = sqlx::query(&sql);
|
||||
for b in &period_binds {
|
||||
q = q.bind(b);
|
||||
}
|
||||
|
||||
let result = q.fetch_all(pool).await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let usage: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
serde_json::json!({
|
||||
"date": row.get::<String, _>("date"),
|
||||
"client": row.get::<String, _>("client_id"),
|
||||
"provider": row.get::<String, _>("provider"),
|
||||
"model": row.get::<String, _>("model"),
|
||||
"requests": row.get::<i64, _>("requests"),
|
||||
"tokens": row.get::<i64, _>("tokens"),
|
||||
"cost": row.get::<f64, _>("cost"),
|
||||
"cache_read_tokens": row.get::<i64, _>("cache_read"),
|
||||
"cache_write_tokens": row.get::<i64, _>("cache_write"),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(usage)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch detailed usage: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch detailed usage".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_analytics_breakdown(
|
||||
State(state): State<DashboardState>,
|
||||
Query(filter): Query<UsagePeriodFilter>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let pool = &state.app_state.db_pool;
|
||||
let (period_clause, period_binds) = filter.to_sql();
|
||||
|
||||
// Model breakdown
|
||||
let model_sql = format!(
|
||||
"SELECT model as label, COUNT(*) as value FROM llm_requests WHERE 1=1 {} GROUP BY model ORDER BY value DESC",
|
||||
period_clause
|
||||
);
|
||||
let mut mq = sqlx::query(&model_sql);
|
||||
for b in &period_binds {
|
||||
mq = mq.bind(b);
|
||||
}
|
||||
let models = mq.fetch_all(pool);
|
||||
|
||||
// Client breakdown
|
||||
let client_sql = format!(
|
||||
"SELECT client_id as label, COUNT(*) as value FROM llm_requests WHERE 1=1 {} GROUP BY client_id ORDER BY value DESC",
|
||||
period_clause
|
||||
);
|
||||
let mut cq = sqlx::query(&client_sql);
|
||||
for b in &period_binds {
|
||||
cq = cq.bind(b);
|
||||
}
|
||||
let clients = cq.fetch_all(pool);
|
||||
|
||||
match tokio::join!(models, clients) {
|
||||
(Ok(m_rows), Ok(c_rows)) => {
|
||||
let model_breakdown: Vec<serde_json::Value> = m_rows
|
||||
.into_iter()
|
||||
.map(|r| serde_json::json!({ "label": r.get::<String, _>("label"), "value": r.get::<i64, _>("value") }))
|
||||
.collect();
|
||||
|
||||
let client_breakdown: Vec<serde_json::Value> = c_rows
|
||||
.into_iter()
|
||||
.map(|r| serde_json::json!({ "label": r.get::<String, _>("label"), "value": r.get::<i64, _>("value") }))
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"models": model_breakdown,
|
||||
"clients": client_breakdown
|
||||
})))
|
||||
}
|
||||
_ => Json(ApiResponse::error("Failed to fetch analytics breakdown".to_string())),
|
||||
}
|
||||
}
|
||||
@@ -1,287 +0,0 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::Json,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use sqlx::Row;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{ApiResponse, DashboardState, auth};
|
||||
|
||||
// ── User management endpoints (admin-only) ──────────────────────────
|
||||
|
||||
pub(super) async fn handle_get_users(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
let result = sqlx::query(
|
||||
"SELECT id, username, display_name, role, must_change_password, created_at FROM users ORDER BY created_at ASC",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(rows) => {
|
||||
let users: Vec<serde_json::Value> = rows
|
||||
.into_iter()
|
||||
.map(|row| {
|
||||
let username: String = row.get("username");
|
||||
let display_name: Option<String> = row.get("display_name");
|
||||
serde_json::json!({
|
||||
"id": row.get::<i64, _>("id"),
|
||||
"username": &username,
|
||||
"display_name": display_name.as_deref().unwrap_or(&username),
|
||||
"role": row.get::<String, _>("role"),
|
||||
"must_change_password": row.get::<bool, _>("must_change_password"),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Json(ApiResponse::success(serde_json::json!(users)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch users: {}", e);
|
||||
Json(ApiResponse::error("Failed to fetch users".to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct CreateUserRequest {
|
||||
pub(super) username: String,
|
||||
pub(super) password: String,
|
||||
pub(super) display_name: Option<String>,
|
||||
pub(super) role: Option<String>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_create_user(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Json(payload): Json<CreateUserRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Validate role
|
||||
let role = payload.role.as_deref().unwrap_or("viewer");
|
||||
if role != "admin" && role != "viewer" {
|
||||
return Json(ApiResponse::error("Role must be 'admin' or 'viewer'".to_string()));
|
||||
}
|
||||
|
||||
// Validate username
|
||||
let username = payload.username.trim();
|
||||
if username.is_empty() || username.len() > 64 {
|
||||
return Json(ApiResponse::error("Username must be 1-64 characters".to_string()));
|
||||
}
|
||||
|
||||
// Validate password
|
||||
if payload.password.len() < 4 {
|
||||
return Json(ApiResponse::error("Password must be at least 4 characters".to_string()));
|
||||
}
|
||||
|
||||
let password_hash = match bcrypt::hash(&payload.password, 12) {
|
||||
Ok(h) => h,
|
||||
Err(_) => return Json(ApiResponse::error("Failed to hash password".to_string())),
|
||||
};
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
INSERT INTO users (username, password_hash, display_name, role, must_change_password)
|
||||
VALUES (?, ?, ?, ?, TRUE)
|
||||
RETURNING id, username, display_name, role, must_change_password, created_at
|
||||
"#,
|
||||
)
|
||||
.bind(username)
|
||||
.bind(&password_hash)
|
||||
.bind(&payload.display_name)
|
||||
.bind(role)
|
||||
.fetch_one(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(row) => {
|
||||
let uname: String = row.get("username");
|
||||
let display_name: Option<String> = row.get("display_name");
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"id": row.get::<i64, _>("id"),
|
||||
"username": &uname,
|
||||
"display_name": display_name.as_deref().unwrap_or(&uname),
|
||||
"role": row.get::<String, _>("role"),
|
||||
"must_change_password": row.get::<bool, _>("must_change_password"),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
})))
|
||||
}
|
||||
Err(e) => {
|
||||
let msg = if e.to_string().contains("UNIQUE") {
|
||||
format!("Username '{}' already exists", username)
|
||||
} else {
|
||||
format!("Failed to create user: {}", e)
|
||||
};
|
||||
warn!("Failed to create user: {}", e);
|
||||
Json(ApiResponse::error(msg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct UpdateUserRequest {
|
||||
pub(super) display_name: Option<String>,
|
||||
pub(super) role: Option<String>,
|
||||
pub(super) password: Option<String>,
|
||||
pub(super) must_change_password: Option<bool>,
|
||||
}
|
||||
|
||||
pub(super) async fn handle_update_user(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<i64>,
|
||||
Json(payload): Json<UpdateUserRequest>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
if let Err(e) = auth::require_admin(&state, &headers).await {
|
||||
return e;
|
||||
}
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Validate role if provided
|
||||
if let Some(ref role) = payload.role {
|
||||
if role != "admin" && role != "viewer" {
|
||||
return Json(ApiResponse::error("Role must be 'admin' or 'viewer'".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
// Build dynamic update
|
||||
let mut sets = Vec::new();
|
||||
let mut string_binds: Vec<String> = Vec::new();
|
||||
let mut has_password = false;
|
||||
let mut has_must_change = false;
|
||||
|
||||
if let Some(ref display_name) = payload.display_name {
|
||||
sets.push("display_name = ?");
|
||||
string_binds.push(display_name.clone());
|
||||
}
|
||||
if let Some(ref role) = payload.role {
|
||||
sets.push("role = ?");
|
||||
string_binds.push(role.clone());
|
||||
}
|
||||
if let Some(ref password) = payload.password {
|
||||
if password.len() < 4 {
|
||||
return Json(ApiResponse::error("Password must be at least 4 characters".to_string()));
|
||||
}
|
||||
let hash = match bcrypt::hash(password, 12) {
|
||||
Ok(h) => h,
|
||||
Err(_) => return Json(ApiResponse::error("Failed to hash password".to_string())),
|
||||
};
|
||||
sets.push("password_hash = ?");
|
||||
string_binds.push(hash);
|
||||
has_password = true;
|
||||
}
|
||||
if let Some(mcp) = payload.must_change_password {
|
||||
sets.push("must_change_password = ?");
|
||||
has_must_change = true;
|
||||
let _ = mcp; // used below in bind
|
||||
}
|
||||
|
||||
if sets.is_empty() {
|
||||
return Json(ApiResponse::error("No fields to update".to_string()));
|
||||
}
|
||||
|
||||
let sql = format!("UPDATE users SET {} WHERE id = ?", sets.join(", "));
|
||||
let mut query = sqlx::query(&sql);
|
||||
|
||||
for b in &string_binds {
|
||||
query = query.bind(b);
|
||||
}
|
||||
if has_must_change {
|
||||
query = query.bind(payload.must_change_password.unwrap());
|
||||
}
|
||||
let _ = has_password; // consumed above via string_binds
|
||||
query = query.bind(id);
|
||||
|
||||
match query.execute(pool).await {
|
||||
Ok(result) => {
|
||||
if result.rows_affected() == 0 {
|
||||
return Json(ApiResponse::error("User not found".to_string()));
|
||||
}
|
||||
// Fetch updated user
|
||||
let row = sqlx::query(
|
||||
"SELECT id, username, display_name, role, must_change_password, created_at FROM users WHERE id = ?",
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_one(pool)
|
||||
.await;
|
||||
|
||||
match row {
|
||||
Ok(row) => {
|
||||
let uname: String = row.get("username");
|
||||
let display_name: Option<String> = row.get("display_name");
|
||||
Json(ApiResponse::success(serde_json::json!({
|
||||
"id": row.get::<i64, _>("id"),
|
||||
"username": &uname,
|
||||
"display_name": display_name.as_deref().unwrap_or(&uname),
|
||||
"role": row.get::<String, _>("role"),
|
||||
"must_change_password": row.get::<bool, _>("must_change_password"),
|
||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
||||
})))
|
||||
}
|
||||
Err(_) => Json(ApiResponse::success(serde_json::json!({ "message": "User updated" }))),
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to update user: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to update user: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_delete_user(
|
||||
State(state): State<DashboardState>,
|
||||
headers: axum::http::HeaderMap,
|
||||
Path(id): Path<i64>,
|
||||
) -> Json<ApiResponse<serde_json::Value>> {
|
||||
let session = match auth::require_admin(&state, &headers).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => return e,
|
||||
};
|
||||
|
||||
let pool = &state.app_state.db_pool;
|
||||
|
||||
// Don't allow deleting yourself
|
||||
let target_username: Option<String> =
|
||||
sqlx::query_scalar::<_, String>("SELECT username FROM users WHERE id = ?")
|
||||
.bind(id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
match target_username {
|
||||
None => return Json(ApiResponse::error("User not found".to_string())),
|
||||
Some(ref uname) if uname == &session.username => {
|
||||
return Json(ApiResponse::error("Cannot delete your own account".to_string()));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let result = sqlx::query("DELETE FROM users WHERE id = ?")
|
||||
.bind(id)
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "User deleted" }))),
|
||||
Err(e) => {
|
||||
warn!("Failed to delete user: {}", e);
|
||||
Json(ApiResponse::error(format!("Failed to delete user: {}", e)))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
use axum::{
|
||||
extract::{
|
||||
State,
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use serde_json;
|
||||
use tracing::info;
|
||||
|
||||
use super::DashboardState;
|
||||
|
||||
// WebSocket handler
|
||||
pub(super) async fn handle_websocket(ws: WebSocketUpgrade, State(state): State<DashboardState>) -> impl IntoResponse {
|
||||
ws.on_upgrade(|socket| handle_websocket_connection(socket, state))
|
||||
}
|
||||
|
||||
pub(super) async fn handle_websocket_connection(mut socket: WebSocket, state: DashboardState) {
|
||||
info!("WebSocket connection established");
|
||||
|
||||
// Subscribe to events from the global bus
|
||||
let mut rx = state.app_state.dashboard_tx.subscribe();
|
||||
|
||||
// Send initial connection message
|
||||
let _ = socket
|
||||
.send(Message::Text(
|
||||
serde_json::json!({
|
||||
"type": "connected",
|
||||
"message": "Connected to LLM Proxy Dashboard"
|
||||
})
|
||||
.to_string()
|
||||
.into(),
|
||||
))
|
||||
.await;
|
||||
|
||||
// Handle incoming messages and broadcast events
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Receive broadcast events
|
||||
Ok(event) = rx.recv() => {
|
||||
let Ok(json_str) = serde_json::to_string(&event) else {
|
||||
continue;
|
||||
};
|
||||
let message = Message::Text(json_str.into());
|
||||
if socket.send(message).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Receive WebSocket messages
|
||||
result = socket.recv() => {
|
||||
match result {
|
||||
Some(Ok(Message::Text(text))) => {
|
||||
handle_websocket_message(&text, &state).await;
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("WebSocket connection closed");
|
||||
}
|
||||
|
||||
pub(super) async fn handle_websocket_message(text: &str, state: &DashboardState) {
|
||||
// Parse and handle WebSocket messages
|
||||
if let Ok(data) = serde_json::from_str::<serde_json::Value>(text)
|
||||
&& data.get("type").and_then(|v| v.as_str()) == Some("ping")
|
||||
{
|
||||
let _ = state.app_state.dashboard_tx.send(serde_json::json!({
|
||||
"type": "pong",
|
||||
"payload": {}
|
||||
}));
|
||||
}
|
||||
}
|
||||
@@ -1,235 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool};
|
||||
use std::str::FromStr;
|
||||
use tracing::info;
|
||||
|
||||
use crate::config::DatabaseConfig;
|
||||
|
||||
pub type DbPool = SqlitePool;
|
||||
|
||||
pub async fn init(config: &DatabaseConfig) -> Result<DbPool> {
|
||||
// Ensure the database directory exists
|
||||
if let Some(parent) = config.path.parent()
|
||||
&& !parent.as_os_str().is_empty()
|
||||
{
|
||||
tokio::fs::create_dir_all(parent).await?;
|
||||
}
|
||||
|
||||
let database_path = config.path.to_string_lossy().to_string();
|
||||
info!("Connecting to database at {}", database_path);
|
||||
|
||||
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?.create_if_missing(true);
|
||||
|
||||
let pool = SqlitePool::connect_with(options).await?;
|
||||
|
||||
// Run migrations
|
||||
run_migrations(&pool).await?;
|
||||
info!("Database migrations completed");
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
async fn run_migrations(pool: &DbPool) -> Result<()> {
|
||||
// Create clients table if it doesn't exist
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS clients (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
client_id TEXT UNIQUE NOT NULL,
|
||||
name TEXT,
|
||||
description TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
is_active BOOLEAN DEFAULT TRUE,
|
||||
rate_limit_per_minute INTEGER DEFAULT 60,
|
||||
total_requests INTEGER DEFAULT 0,
|
||||
total_tokens INTEGER DEFAULT 0,
|
||||
total_cost REAL DEFAULT 0.0
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Create llm_requests table if it doesn't exist
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS llm_requests (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
client_id TEXT,
|
||||
provider TEXT,
|
||||
model TEXT,
|
||||
prompt_tokens INTEGER,
|
||||
completion_tokens INTEGER,
|
||||
total_tokens INTEGER,
|
||||
cost REAL,
|
||||
has_images BOOLEAN DEFAULT FALSE,
|
||||
status TEXT DEFAULT 'success',
|
||||
error_message TEXT,
|
||||
duration_ms INTEGER,
|
||||
request_body TEXT,
|
||||
response_body TEXT,
|
||||
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE SET NULL
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Create provider_configs table
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS provider_configs (
|
||||
id TEXT PRIMARY KEY,
|
||||
display_name TEXT NOT NULL,
|
||||
enabled BOOLEAN DEFAULT TRUE,
|
||||
base_url TEXT,
|
||||
api_key TEXT,
|
||||
credit_balance REAL DEFAULT 0.0,
|
||||
low_credit_threshold REAL DEFAULT 5.0,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Create model_configs table
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS model_configs (
|
||||
id TEXT PRIMARY KEY,
|
||||
provider_id TEXT NOT NULL,
|
||||
display_name TEXT,
|
||||
enabled BOOLEAN DEFAULT TRUE,
|
||||
prompt_cost_per_m REAL,
|
||||
completion_cost_per_m REAL,
|
||||
mapping TEXT,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Create users table for dashboard access
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT NOT NULL,
|
||||
display_name TEXT,
|
||||
role TEXT DEFAULT 'admin',
|
||||
must_change_password BOOLEAN DEFAULT FALSE,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Create client_tokens table for DB-based token auth
|
||||
sqlx::query(
|
||||
r#"
|
||||
CREATE TABLE IF NOT EXISTS client_tokens (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
client_id TEXT NOT NULL,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
name TEXT DEFAULT 'default',
|
||||
is_active BOOLEAN DEFAULT TRUE,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
last_used_at DATETIME,
|
||||
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE CASCADE
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Add must_change_password column if it doesn't exist (migration for existing DBs)
|
||||
let _ = sqlx::query("ALTER TABLE users ADD COLUMN must_change_password BOOLEAN DEFAULT FALSE")
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
// Add display_name column if it doesn't exist (migration for existing DBs)
|
||||
let _ = sqlx::query("ALTER TABLE users ADD COLUMN display_name TEXT")
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
// Add cache token columns if they don't exist (migration for existing DBs)
|
||||
let _ = sqlx::query("ALTER TABLE llm_requests ADD COLUMN cache_read_tokens INTEGER DEFAULT 0")
|
||||
.execute(pool)
|
||||
.await;
|
||||
let _ = sqlx::query("ALTER TABLE llm_requests ADD COLUMN cache_write_tokens INTEGER DEFAULT 0")
|
||||
.execute(pool)
|
||||
.await;
|
||||
|
||||
// Insert default admin user if none exists (default password: admin)
|
||||
let user_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users").fetch_one(pool).await?;
|
||||
|
||||
if user_count.0 == 0 {
|
||||
// 'admin' hashed with default cost (12)
|
||||
let default_admin_hash =
|
||||
bcrypt::hash("admin", 12).map_err(|e| anyhow::anyhow!("Failed to hash default password: {}", e))?;
|
||||
sqlx::query(
|
||||
"INSERT INTO users (username, password_hash, role, must_change_password) VALUES ('admin', ?, 'admin', TRUE)"
|
||||
)
|
||||
.bind(default_admin_hash)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
info!("Created default admin user with password 'admin' (must change on first login)");
|
||||
}
|
||||
|
||||
// Create indices
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query("CREATE UNIQUE INDEX IF NOT EXISTS idx_client_tokens_token ON client_tokens(token)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_client_tokens_client_id ON client_tokens(client_id)")
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// Insert default client if none exists
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT OR IGNORE INTO clients (client_id, name, description)
|
||||
VALUES ('default', 'Default Client', 'Default client for anonymous requests')
|
||||
"#,
|
||||
)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn test_connection(pool: &DbPool) -> Result<()> {
|
||||
sqlx::query("SELECT 1").execute(pool).await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
pub enum AppError {
|
||||
#[error("Authentication failed: {0}")]
|
||||
AuthError(String),
|
||||
|
||||
#[error("Configuration error: {0}")]
|
||||
ConfigError(String),
|
||||
|
||||
#[error("Database error: {0}")]
|
||||
DatabaseError(String),
|
||||
|
||||
#[error("Provider error: {0}")]
|
||||
ProviderError(String),
|
||||
|
||||
#[error("Validation error: {0}")]
|
||||
ValidationError(String),
|
||||
|
||||
#[error("Multimodal processing error: {0}")]
|
||||
MultimodalError(String),
|
||||
|
||||
#[error("Rate limit exceeded: {0}")]
|
||||
RateLimitError(String),
|
||||
|
||||
#[error("Internal server error: {0}")]
|
||||
InternalError(String),
|
||||
}
|
||||
|
||||
impl From<sqlx::Error> for AppError {
|
||||
fn from(err: sqlx::Error) -> Self {
|
||||
AppError::DatabaseError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for AppError {
|
||||
fn from(err: anyhow::Error) -> Self {
|
||||
AppError::InternalError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl axum::response::IntoResponse for AppError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
let status = match self {
|
||||
AppError::AuthError(_) => axum::http::StatusCode::UNAUTHORIZED,
|
||||
AppError::RateLimitError(_) => axum::http::StatusCode::TOO_MANY_REQUESTS,
|
||||
AppError::ValidationError(_) => axum::http::StatusCode::BAD_REQUEST,
|
||||
_ => axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
};
|
||||
|
||||
let body = axum::Json(serde_json::json!({
|
||||
"error": self.to_string(),
|
||||
"type": format!("{:?}", self)
|
||||
}));
|
||||
|
||||
(status, body).into_response()
|
||||
}
|
||||
}
|
||||
151
src/lib.rs
151
src/lib.rs
@@ -1,151 +0,0 @@
|
||||
//! LLM Proxy Library
|
||||
//!
|
||||
//! This library provides the core functionality for the LLM proxy gateway,
|
||||
//! including provider integration, token tracking, and API endpoints.
|
||||
|
||||
pub mod auth;
|
||||
pub mod client;
|
||||
pub mod config;
|
||||
pub mod dashboard;
|
||||
pub mod database;
|
||||
pub mod errors;
|
||||
pub mod logging;
|
||||
pub mod models;
|
||||
pub mod multimodal;
|
||||
pub mod providers;
|
||||
pub mod rate_limiting;
|
||||
pub mod server;
|
||||
pub mod state;
|
||||
pub mod utils;
|
||||
|
||||
// Re-exports for convenience
|
||||
pub use auth::{AuthenticatedClient, validate_token};
|
||||
pub use config::{
|
||||
AppConfig, DatabaseConfig, DeepSeekConfig, GeminiConfig, GrokConfig, ModelMappingConfig, ModelPricing,
|
||||
OllamaConfig, OpenAIConfig, PricingConfig, ProviderConfig, ServerConfig,
|
||||
};
|
||||
pub use database::{DbPool, init as init_db, test_connection};
|
||||
pub use errors::AppError;
|
||||
pub use logging::{LoggingContext, RequestLog, RequestLogger};
|
||||
pub use models::{
|
||||
ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage,
|
||||
ChatStreamChoice, ChatStreamDelta, ContentPart, ContentPartValue, FromOpenAI, ImageUrl, MessageContent,
|
||||
OpenAIContentPart, OpenAIMessage, OpenAIRequest, ToOpenAI, UnifiedMessage, UnifiedRequest, Usage,
|
||||
};
|
||||
pub use providers::{Provider, ProviderManager, ProviderResponse, ProviderStreamChunk};
|
||||
pub use server::router;
|
||||
pub use state::AppState;
|
||||
|
||||
/// Test utilities for integration testing
|
||||
#[cfg(test)]
|
||||
pub mod test_utils {
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState};
|
||||
use sqlx::sqlite::SqlitePool;
|
||||
|
||||
/// Create a test application state
|
||||
pub async fn create_test_state() -> Arc<AppState> {
|
||||
// Create in-memory database
|
||||
let pool = SqlitePool::connect("sqlite::memory:")
|
||||
.await
|
||||
.expect("Failed to create test database");
|
||||
|
||||
// Run migrations
|
||||
crate::database::init(&crate::config::DatabaseConfig {
|
||||
path: std::path::PathBuf::from(":memory:"),
|
||||
max_connections: 5,
|
||||
})
|
||||
.await
|
||||
.expect("Failed to initialize test database");
|
||||
|
||||
let rate_limit_manager = RateLimitManager::new(
|
||||
crate::rate_limiting::RateLimiterConfig::default(),
|
||||
crate::rate_limiting::CircuitBreakerConfig::default(),
|
||||
);
|
||||
|
||||
let client_manager = Arc::new(ClientManager::new(pool.clone()));
|
||||
|
||||
// Create provider manager
|
||||
let provider_manager = ProviderManager::new();
|
||||
|
||||
let model_registry = crate::models::registry::ModelRegistry {
|
||||
providers: std::collections::HashMap::new(),
|
||||
};
|
||||
|
||||
let (dashboard_tx, _) = tokio::sync::broadcast::channel(100);
|
||||
|
||||
let config = Arc::new(crate::config::AppConfig {
|
||||
server: crate::config::ServerConfig {
|
||||
port: 8080,
|
||||
host: "127.0.0.1".to_string(),
|
||||
auth_tokens: vec![],
|
||||
},
|
||||
database: crate::config::DatabaseConfig {
|
||||
path: std::path::PathBuf::from(":memory:"),
|
||||
max_connections: 5,
|
||||
},
|
||||
providers: crate::config::ProviderConfig {
|
||||
openai: crate::config::OpenAIConfig {
|
||||
api_key_env: "OPENAI_API_KEY".to_string(),
|
||||
base_url: "".to_string(),
|
||||
default_model: "".to_string(),
|
||||
enabled: true,
|
||||
},
|
||||
gemini: crate::config::GeminiConfig {
|
||||
api_key_env: "GEMINI_API_KEY".to_string(),
|
||||
base_url: "".to_string(),
|
||||
default_model: "".to_string(),
|
||||
enabled: true,
|
||||
},
|
||||
deepseek: crate::config::DeepSeekConfig {
|
||||
api_key_env: "DEEPSEEK_API_KEY".to_string(),
|
||||
base_url: "".to_string(),
|
||||
default_model: "".to_string(),
|
||||
enabled: true,
|
||||
},
|
||||
grok: crate::config::GrokConfig {
|
||||
api_key_env: "GROK_API_KEY".to_string(),
|
||||
base_url: "".to_string(),
|
||||
default_model: "".to_string(),
|
||||
enabled: true,
|
||||
},
|
||||
ollama: crate::config::OllamaConfig {
|
||||
base_url: "".to_string(),
|
||||
enabled: true,
|
||||
models: vec![],
|
||||
},
|
||||
},
|
||||
model_mapping: crate::config::ModelMappingConfig { patterns: vec![] },
|
||||
pricing: crate::config::PricingConfig {
|
||||
openai: vec![],
|
||||
gemini: vec![],
|
||||
deepseek: vec![],
|
||||
grok: vec![],
|
||||
ollama: vec![],
|
||||
},
|
||||
config_path: None,
|
||||
});
|
||||
|
||||
Arc::new(AppState {
|
||||
config,
|
||||
provider_manager,
|
||||
db_pool: pool.clone(),
|
||||
rate_limit_manager: Arc::new(rate_limit_manager),
|
||||
client_manager,
|
||||
request_logger: Arc::new(crate::logging::RequestLogger::new(pool.clone(), dashboard_tx.clone())),
|
||||
model_registry: Arc::new(model_registry),
|
||||
model_config_cache: crate::state::ModelConfigCache::new(pool.clone()),
|
||||
dashboard_tx,
|
||||
auth_tokens: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a test HTTP client
|
||||
pub fn create_test_client() -> reqwest::Client {
|
||||
reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("Failed to create test HTTP client")
|
||||
}
|
||||
}
|
||||
@@ -1,223 +0,0 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::Serialize;
|
||||
use sqlx::SqlitePool;
|
||||
use tokio::sync::broadcast;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::errors::AppError;
|
||||
|
||||
/// Request log entry for database storage
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct RequestLog {
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub client_id: String,
|
||||
pub provider: String,
|
||||
pub model: String,
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
pub cache_read_tokens: u32,
|
||||
pub cache_write_tokens: u32,
|
||||
pub cost: f64,
|
||||
pub has_images: bool,
|
||||
pub status: String, // "success", "error"
|
||||
pub error_message: Option<String>,
|
||||
pub duration_ms: u64,
|
||||
}
|
||||
|
||||
/// Database operations for request logging
|
||||
pub struct RequestLogger {
|
||||
db_pool: SqlitePool,
|
||||
dashboard_tx: broadcast::Sender<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl RequestLogger {
|
||||
pub fn new(db_pool: SqlitePool, dashboard_tx: broadcast::Sender<serde_json::Value>) -> Self {
|
||||
Self { db_pool, dashboard_tx }
|
||||
}
|
||||
|
||||
/// Log a request to the database (async, spawns a task)
|
||||
pub fn log_request(&self, log: RequestLog) {
|
||||
let pool = self.db_pool.clone();
|
||||
let tx = self.dashboard_tx.clone();
|
||||
|
||||
// Spawn async task to log without blocking response
|
||||
tokio::spawn(async move {
|
||||
// Broadcast to dashboard
|
||||
let broadcast_result = tx.send(serde_json::json!({
|
||||
"type": "request",
|
||||
"channel": "requests",
|
||||
"payload": log
|
||||
}));
|
||||
match broadcast_result {
|
||||
Ok(receivers) => info!("Broadcast request log to {} dashboard listeners", receivers),
|
||||
Err(_) => {} // No active WebSocket clients — expected when dashboard isn't open
|
||||
}
|
||||
|
||||
match Self::insert_log(&pool, log).await {
|
||||
Ok(()) => info!("Request logged to database successfully"),
|
||||
Err(e) => warn!("Failed to log request to database: {}", e),
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Insert a log entry into the database
|
||||
async fn insert_log(pool: &SqlitePool, log: RequestLog) -> Result<(), sqlx::Error> {
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
// Ensure the client row exists (FK constraint requires it)
|
||||
sqlx::query(
|
||||
"INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')",
|
||||
)
|
||||
.bind(&log.client_id)
|
||||
.bind(&log.client_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO llm_requests
|
||||
(timestamp, client_id, provider, model, prompt_tokens, completion_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status, error_message, duration_ms, request_body, response_body)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(log.timestamp)
|
||||
.bind(log.client_id)
|
||||
.bind(&log.provider)
|
||||
.bind(log.model)
|
||||
.bind(log.prompt_tokens as i64)
|
||||
.bind(log.completion_tokens as i64)
|
||||
.bind(log.total_tokens as i64)
|
||||
.bind(log.cache_read_tokens as i64)
|
||||
.bind(log.cache_write_tokens as i64)
|
||||
.bind(log.cost)
|
||||
.bind(log.has_images)
|
||||
.bind(log.status)
|
||||
.bind(log.error_message)
|
||||
.bind(log.duration_ms as i64)
|
||||
.bind(None::<String>) // request_body - optional, not stored to save disk space
|
||||
.bind(None::<String>) // response_body - optional, not stored to save disk space
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// Deduct from provider balance if successful.
|
||||
// Providers configured with billing_mode = 'postpaid' will not have their
|
||||
// credit_balance decremented. Use a conditional UPDATE so we don't need
|
||||
// a prior SELECT and avoid extra round-trips.
|
||||
if log.cost > 0.0 {
|
||||
sqlx::query(
|
||||
"UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')",
|
||||
)
|
||||
.bind(log.cost)
|
||||
.bind(&log.provider)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
tx.commit().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// /// Middleware to log LLM API requests
|
||||
// /// TODO: Implement proper middleware that can extract response body details
|
||||
// pub async fn request_logging_middleware(
|
||||
// // Extract the authenticated client (if available)
|
||||
// auth_result: Result<AuthenticatedClient, AppError>,
|
||||
// request: Request,
|
||||
// next: Next,
|
||||
// ) -> Response {
|
||||
// let start_time = std::time::Instant::now();
|
||||
//
|
||||
// // Extract client_id from auth or use "unknown"
|
||||
// let client_id = match auth_result {
|
||||
// Ok(auth) => auth.client_id,
|
||||
// Err(_) => "unknown".to_string(),
|
||||
// };
|
||||
//
|
||||
// // Try to extract request details
|
||||
// let (request_parts, request_body) = request.into_parts();
|
||||
//
|
||||
// // Clone request parts for logging
|
||||
// let path = request_parts.uri.path().to_string();
|
||||
//
|
||||
// // Check if this is a chat completion request
|
||||
// let is_chat_completion = path == "/v1/chat/completions";
|
||||
//
|
||||
// // Reconstruct request for downstream handlers
|
||||
// let request = Request::from_parts(request_parts, request_body);
|
||||
//
|
||||
// // Process request and get response
|
||||
// let response = next.run(request).await;
|
||||
//
|
||||
// // Calculate duration
|
||||
// let duration = start_time.elapsed();
|
||||
// let duration_ms = duration.as_millis() as u64;
|
||||
//
|
||||
// // Log basic request info
|
||||
// info!(
|
||||
// "Request from {} to {} - Status: {} - Duration: {}ms",
|
||||
// client_id,
|
||||
// path,
|
||||
// response.status().as_u16(),
|
||||
// duration_ms
|
||||
// );
|
||||
//
|
||||
// // TODO: Extract more details from request/response for logging
|
||||
// // For now, we'll need to modify the server handler to pass additional context
|
||||
//
|
||||
// response
|
||||
// }
|
||||
|
||||
/// Context for request logging that can be passed through extensions
|
||||
#[derive(Clone)]
|
||||
pub struct LoggingContext {
|
||||
pub client_id: String,
|
||||
pub provider_name: String,
|
||||
pub model: String,
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
pub cost: f64,
|
||||
pub has_images: bool,
|
||||
pub error: Option<AppError>,
|
||||
}
|
||||
|
||||
impl LoggingContext {
|
||||
pub fn new(client_id: String, provider_name: String, model: String) -> Self {
|
||||
Self {
|
||||
client_id,
|
||||
provider_name,
|
||||
model,
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
cost: 0.0,
|
||||
has_images: false,
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_token_counts(mut self, prompt_tokens: u32, completion_tokens: u32) -> Self {
|
||||
self.prompt_tokens = prompt_tokens;
|
||||
self.completion_tokens = completion_tokens;
|
||||
self.total_tokens = prompt_tokens + completion_tokens;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_cost(mut self, cost: f64) -> Self {
|
||||
self.cost = cost;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_images(mut self, has_images: bool) -> Self {
|
||||
self.has_images = has_images;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_error(mut self, error: AppError) -> Self {
|
||||
self.error = Some(error);
|
||||
self
|
||||
}
|
||||
}
|
||||
91
src/main.rs
91
src/main.rs
@@ -1,91 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use axum::{Router, routing::get};
|
||||
use std::net::SocketAddr;
|
||||
use tracing::{error, info};
|
||||
|
||||
use llm_proxy::{
|
||||
config::AppConfig,
|
||||
dashboard, database,
|
||||
providers::ProviderManager,
|
||||
rate_limiting::{CircuitBreakerConfig, RateLimitManager, RateLimiterConfig},
|
||||
server,
|
||||
state::AppState,
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Initialize tracing (logging)
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(tracing::Level::INFO)
|
||||
.with_target(false)
|
||||
.init();
|
||||
|
||||
info!("Starting LLM Proxy Gateway v{}", env!("CARGO_PKG_VERSION"));
|
||||
|
||||
// Load configuration
|
||||
let config = AppConfig::load().await?;
|
||||
info!("Configuration loaded from {:?}", config.config_path);
|
||||
|
||||
// Initialize database connection pool
|
||||
let db_pool = database::init(&config.database).await?;
|
||||
info!("Database initialized at {:?}", config.database.path);
|
||||
|
||||
// Initialize provider manager with configured providers
|
||||
let provider_manager = ProviderManager::new();
|
||||
|
||||
// Initialize all supported providers (they handle their own enabled check)
|
||||
let supported_providers = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
|
||||
for name in supported_providers {
|
||||
if let Err(e) = provider_manager.initialize_provider(name, &config, &db_pool).await {
|
||||
error!("Failed to initialize provider {}: {}", name, e);
|
||||
}
|
||||
}
|
||||
|
||||
// Create rate limit manager
|
||||
let rate_limit_manager = RateLimitManager::new(RateLimiterConfig::default(), CircuitBreakerConfig::default());
|
||||
|
||||
// Fetch model registry from models.dev
|
||||
let model_registry = match llm_proxy::utils::registry::fetch_registry().await {
|
||||
Ok(registry) => registry,
|
||||
Err(e) => {
|
||||
error!("Failed to fetch model registry: {}. Using empty registry.", e);
|
||||
llm_proxy::models::registry::ModelRegistry {
|
||||
providers: std::collections::HashMap::new(),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Create application state
|
||||
let state = AppState::new(
|
||||
config.clone(),
|
||||
provider_manager,
|
||||
db_pool,
|
||||
rate_limit_manager,
|
||||
model_registry,
|
||||
config.server.auth_tokens.clone(),
|
||||
);
|
||||
|
||||
// Initialize model config cache and start background refresh (every 30s)
|
||||
state.model_config_cache.refresh().await;
|
||||
state.model_config_cache.clone().start_refresh_task(30);
|
||||
info!("Model config cache initialized");
|
||||
|
||||
// Create application router
|
||||
let app = Router::new()
|
||||
.route("/health", get(health_check))
|
||||
.merge(server::router(state.clone()))
|
||||
.merge(dashboard::router(state.clone()));
|
||||
|
||||
// Start server
|
||||
let addr = SocketAddr::from(([0, 0, 0, 0], config.server.port));
|
||||
info!("Server listening on http://{}", addr);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn health_check() -> &'static str {
|
||||
"OK"
|
||||
}
|
||||
@@ -1,341 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
pub mod registry;
|
||||
|
||||
// ========== OpenAI-compatible Request/Response Structs ==========
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
#[serde(default)]
|
||||
pub temperature: Option<f64>,
|
||||
#[serde(default)]
|
||||
pub max_tokens: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub stream: Option<bool>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String, // "system", "user", "assistant", "tool"
|
||||
#[serde(flatten)]
|
||||
pub content: MessageContent,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum MessageContent {
|
||||
Text { content: String },
|
||||
Parts { content: Vec<ContentPartValue> },
|
||||
None, // Handle cases where content might be null but reasoning is present
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ContentPartValue {
|
||||
Text { text: String },
|
||||
ImageUrl { image_url: ImageUrl },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ImageUrl {
|
||||
pub url: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub detail: Option<String>,
|
||||
}
|
||||
|
||||
// ========== Tool-Calling Types ==========
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String,
|
||||
pub function: FunctionDef,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionDef {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parameters: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ToolChoice {
|
||||
Mode(String), // "auto", "none", "required"
|
||||
Specific(ToolChoiceSpecific),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolChoiceSpecific {
|
||||
#[serde(rename = "type")]
|
||||
pub choice_type: String,
|
||||
pub function: ToolChoiceFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolChoiceFunction {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub call_type: String,
|
||||
pub function: FunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionCall {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCallDelta {
|
||||
pub index: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||
pub call_type: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function: Option<FunctionCallDelta>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionCallDelta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
// ========== OpenAI-compatible Response Structs ==========
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatChoice>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatChoice {
|
||||
pub index: u32,
|
||||
pub message: ChatMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cache_read_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cache_write_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
// ========== Streaming Response Structs ==========
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionStreamResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatStreamChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatStreamChoice {
|
||||
pub index: u32,
|
||||
pub delta: ChatStreamDelta,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatStreamDelta {
|
||||
pub role: Option<String>,
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
||||
}
|
||||
|
||||
// ========== Unified Request Format (for internal use) ==========
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UnifiedRequest {
|
||||
pub client_id: String,
|
||||
pub model: String,
|
||||
pub messages: Vec<UnifiedMessage>,
|
||||
pub temperature: Option<f64>,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub stream: bool,
|
||||
pub has_images: bool,
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UnifiedMessage {
|
||||
pub role: String,
|
||||
pub content: Vec<ContentPart>,
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
pub name: Option<String>,
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ContentPart {
|
||||
Text { text: String },
|
||||
Image(crate::multimodal::ImageInput),
|
||||
}
|
||||
|
||||
// ========== Provider-specific Structs ==========
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct OpenAIRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<OpenAIMessage>,
|
||||
pub temperature: Option<f64>,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub stream: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct OpenAIMessage {
|
||||
pub role: String,
|
||||
pub content: Vec<OpenAIContentPart>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum OpenAIContentPart {
|
||||
Text { text: String },
|
||||
ImageUrl { image_url: ImageUrl },
|
||||
}
|
||||
|
||||
// Note: ImageUrl struct is defined earlier in the file
|
||||
|
||||
// ========== Conversion Traits ==========
|
||||
|
||||
pub trait ToOpenAI {
|
||||
fn to_openai(&self) -> Result<OpenAIRequest, anyhow::Error>;
|
||||
}
|
||||
|
||||
pub trait FromOpenAI {
|
||||
fn from_openai(request: &OpenAIRequest) -> Result<Self, anyhow::Error>
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
impl UnifiedRequest {
|
||||
/// Hydrate all image content by fetching URLs and converting to base64/bytes
|
||||
pub async fn hydrate_images(&mut self) -> anyhow::Result<()> {
|
||||
if !self.has_images {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for msg in &mut self.messages {
|
||||
for part in &mut msg.content {
|
||||
if let ContentPart::Image(image_input) = part {
|
||||
// Pre-fetch and validate if it's a URL
|
||||
if let crate::multimodal::ImageInput::Url(_url) = image_input {
|
||||
let (base64_data, mime_type) = image_input.to_base64().await?;
|
||||
*image_input = crate::multimodal::ImageInput::Base64 {
|
||||
data: base64_data,
|
||||
mime_type,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(req: ChatCompletionRequest) -> Result<Self, Self::Error> {
|
||||
let mut has_images = false;
|
||||
|
||||
// Convert OpenAI-compatible request to unified format
|
||||
let messages = req
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| {
|
||||
let (content, _images_in_message) = match msg.content {
|
||||
MessageContent::Text { content } => (vec![ContentPart::Text { text: content }], false),
|
||||
MessageContent::Parts { content } => {
|
||||
let mut unified_content = Vec::new();
|
||||
let mut has_images_in_msg = false;
|
||||
|
||||
for part in content {
|
||||
match part {
|
||||
ContentPartValue::Text { text } => {
|
||||
unified_content.push(ContentPart::Text { text });
|
||||
}
|
||||
ContentPartValue::ImageUrl { image_url } => {
|
||||
has_images_in_msg = true;
|
||||
has_images = true;
|
||||
unified_content.push(ContentPart::Image(crate::multimodal::ImageInput::from_url(
|
||||
image_url.url,
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(unified_content, has_images_in_msg)
|
||||
}
|
||||
MessageContent::None => (vec![], false),
|
||||
};
|
||||
|
||||
UnifiedMessage {
|
||||
role: msg.role,
|
||||
content,
|
||||
tool_calls: msg.tool_calls,
|
||||
name: msg.name,
|
||||
tool_call_id: msg.tool_call_id,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(UnifiedRequest {
|
||||
client_id: String::new(), // Will be populated by auth middleware
|
||||
model: req.model,
|
||||
messages,
|
||||
temperature: req.temperature,
|
||||
max_tokens: req.max_tokens,
|
||||
stream: req.stream.unwrap_or(false),
|
||||
has_images,
|
||||
tools: req.tools,
|
||||
tool_choice: req.tool_choice,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,219 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelRegistry {
|
||||
#[serde(flatten)]
|
||||
pub providers: HashMap<String, ProviderInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProviderInfo {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub models: HashMap<String, ModelMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelMetadata {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub cost: Option<ModelCost>,
|
||||
pub limit: Option<ModelLimit>,
|
||||
pub modalities: Option<ModelModalities>,
|
||||
pub tool_call: Option<bool>,
|
||||
pub reasoning: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelCost {
|
||||
pub input: f64,
|
||||
pub output: f64,
|
||||
pub cache_read: Option<f64>,
|
||||
pub cache_write: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelLimit {
|
||||
pub context: u32,
|
||||
pub output: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelModalities {
|
||||
pub input: Vec<String>,
|
||||
pub output: Vec<String>,
|
||||
}
|
||||
|
||||
/// A model entry paired with its provider ID, returned by listing/filtering methods.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelEntry<'a> {
|
||||
pub model_key: &'a str,
|
||||
pub provider_id: &'a str,
|
||||
pub provider_name: &'a str,
|
||||
pub metadata: &'a ModelMetadata,
|
||||
}
|
||||
|
||||
/// Filter criteria for listing models. All fields are optional; `None` means no filter.
|
||||
#[derive(Debug, Default, Clone, Deserialize)]
|
||||
pub struct ModelFilter {
|
||||
/// Filter by provider ID (exact match).
|
||||
pub provider: Option<String>,
|
||||
/// Text search on model ID or name (case-insensitive substring).
|
||||
pub search: Option<String>,
|
||||
/// Filter by input modality (e.g. "image", "text").
|
||||
pub modality: Option<String>,
|
||||
/// Only models that support tool calling.
|
||||
pub tool_call: Option<bool>,
|
||||
/// Only models that support reasoning.
|
||||
pub reasoning: Option<bool>,
|
||||
/// Only models that have pricing data.
|
||||
pub has_cost: Option<bool>,
|
||||
}
|
||||
|
||||
/// Sort field for model listings.
|
||||
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ModelSortBy {
|
||||
#[default]
|
||||
Name,
|
||||
Id,
|
||||
Provider,
|
||||
ContextLimit,
|
||||
InputCost,
|
||||
OutputCost,
|
||||
}
|
||||
|
||||
/// Sort direction.
|
||||
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SortOrder {
|
||||
#[default]
|
||||
Asc,
|
||||
Desc,
|
||||
}
|
||||
|
||||
impl ModelRegistry {
|
||||
/// Find a model by its ID (searching across all providers)
|
||||
pub fn find_model(&self, model_id: &str) -> Option<&ModelMetadata> {
|
||||
// First try exact match if the key in models map matches the ID
|
||||
for provider in self.providers.values() {
|
||||
if let Some(model) = provider.models.get(model_id) {
|
||||
return Some(model);
|
||||
}
|
||||
}
|
||||
|
||||
// Try searching for the model ID inside the metadata if the key was different
|
||||
for provider in self.providers.values() {
|
||||
for model in provider.models.values() {
|
||||
if model.id == model_id {
|
||||
return Some(model);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// List all models with optional filtering and sorting.
|
||||
pub fn list_models(
|
||||
&self,
|
||||
filter: &ModelFilter,
|
||||
sort_by: &ModelSortBy,
|
||||
sort_order: &SortOrder,
|
||||
) -> Vec<ModelEntry<'_>> {
|
||||
let mut entries: Vec<ModelEntry<'_>> = Vec::new();
|
||||
|
||||
for (p_id, p_info) in &self.providers {
|
||||
// Provider filter
|
||||
if let Some(ref prov) = filter.provider {
|
||||
if p_id != prov {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for (m_key, m_meta) in &p_info.models {
|
||||
// Text search filter
|
||||
if let Some(ref search) = filter.search {
|
||||
let search_lower = search.to_lowercase();
|
||||
if !m_meta.id.to_lowercase().contains(&search_lower)
|
||||
&& !m_meta.name.to_lowercase().contains(&search_lower)
|
||||
&& !m_key.to_lowercase().contains(&search_lower)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Modality filter
|
||||
if let Some(ref modality) = filter.modality {
|
||||
let has_modality = m_meta
|
||||
.modalities
|
||||
.as_ref()
|
||||
.is_some_and(|m| m.input.iter().any(|i| i.eq_ignore_ascii_case(modality)));
|
||||
if !has_modality {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Tool call filter
|
||||
if let Some(tc) = filter.tool_call {
|
||||
if m_meta.tool_call.unwrap_or(false) != tc {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Reasoning filter
|
||||
if let Some(r) = filter.reasoning {
|
||||
if m_meta.reasoning.unwrap_or(false) != r {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Has cost filter
|
||||
if let Some(hc) = filter.has_cost {
|
||||
if hc != m_meta.cost.is_some() {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
entries.push(ModelEntry {
|
||||
model_key: m_key,
|
||||
provider_id: p_id,
|
||||
provider_name: &p_info.name,
|
||||
metadata: m_meta,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Sort
|
||||
entries.sort_by(|a, b| {
|
||||
let cmp = match sort_by {
|
||||
ModelSortBy::Name => a.metadata.name.to_lowercase().cmp(&b.metadata.name.to_lowercase()),
|
||||
ModelSortBy::Id => a.model_key.cmp(b.model_key),
|
||||
ModelSortBy::Provider => a.provider_id.cmp(b.provider_id),
|
||||
ModelSortBy::ContextLimit => {
|
||||
let a_ctx = a.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0);
|
||||
let b_ctx = b.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0);
|
||||
a_ctx.cmp(&b_ctx)
|
||||
}
|
||||
ModelSortBy::InputCost => {
|
||||
let a_cost = a.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
||||
let b_cost = b.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
||||
a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
|
||||
}
|
||||
ModelSortBy::OutputCost => {
|
||||
let a_cost = a.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
||||
let b_cost = b.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
||||
a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
|
||||
}
|
||||
};
|
||||
|
||||
match sort_order {
|
||||
SortOrder::Asc => cmp,
|
||||
SortOrder::Desc => cmp.reverse(),
|
||||
}
|
||||
});
|
||||
|
||||
entries
|
||||
}
|
||||
}
|
||||
@@ -1,299 +0,0 @@
|
||||
//! Multimodal support for image processing and conversion
|
||||
//!
|
||||
//! This module handles:
|
||||
//! 1. Image format detection and conversion
|
||||
//! 2. Base64 encoding/decoding
|
||||
//! 3. URL fetching for images
|
||||
//! 4. Provider-specific image format conversion
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use base64::{Engine as _, engine::general_purpose};
|
||||
use std::sync::LazyLock;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Shared HTTP client for image fetching — avoids creating a new TCP+TLS
|
||||
/// connection for every image URL.
|
||||
static IMAGE_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
||||
reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(60))
|
||||
.build()
|
||||
.expect("Failed to build image HTTP client")
|
||||
});
|
||||
|
||||
/// Supported image formats for multimodal input
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ImageInput {
|
||||
/// Base64-encoded image data with MIME type
|
||||
Base64 { data: String, mime_type: String },
|
||||
/// URL to fetch image from
|
||||
Url(String),
|
||||
/// Raw bytes with MIME type
|
||||
Bytes { data: Vec<u8>, mime_type: String },
|
||||
}
|
||||
|
||||
impl ImageInput {
|
||||
/// Create ImageInput from base64 string
|
||||
pub fn from_base64(data: String, mime_type: String) -> Self {
|
||||
Self::Base64 { data, mime_type }
|
||||
}
|
||||
|
||||
/// Create ImageInput from URL
|
||||
pub fn from_url(url: String) -> Self {
|
||||
Self::Url(url)
|
||||
}
|
||||
|
||||
/// Create ImageInput from raw bytes
|
||||
pub fn from_bytes(data: Vec<u8>, mime_type: String) -> Self {
|
||||
Self::Bytes { data, mime_type }
|
||||
}
|
||||
|
||||
/// Get MIME type if available
|
||||
pub fn mime_type(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::Base64 { mime_type, .. } => Some(mime_type),
|
||||
Self::Bytes { mime_type, .. } => Some(mime_type),
|
||||
Self::Url(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to base64 if not already
|
||||
pub async fn to_base64(&self) -> Result<(String, String)> {
|
||||
match self {
|
||||
Self::Base64 { data, mime_type } => Ok((data.clone(), mime_type.clone())),
|
||||
Self::Bytes { data, mime_type } => {
|
||||
let base64_data = general_purpose::STANDARD.encode(data);
|
||||
Ok((base64_data, mime_type.clone()))
|
||||
}
|
||||
Self::Url(url) => {
|
||||
// Fetch image from URL using shared client
|
||||
info!("Fetching image from URL: {}", url);
|
||||
let response = IMAGE_CLIENT
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to fetch image from URL")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!("Failed to fetch image: HTTP {}", response.status());
|
||||
}
|
||||
|
||||
let mime_type = response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.unwrap_or("image/jpeg")
|
||||
.to_string();
|
||||
|
||||
let bytes = response.bytes().await.context("Failed to read image bytes")?;
|
||||
|
||||
let base64_data = general_purpose::STANDARD.encode(&bytes);
|
||||
Ok((base64_data, mime_type))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get image dimensions (width, height)
|
||||
pub async fn get_dimensions(&self) -> Result<(u32, u32)> {
|
||||
let bytes = match self {
|
||||
Self::Base64 { data, .. } => general_purpose::STANDARD
|
||||
.decode(data)
|
||||
.context("Failed to decode base64")?,
|
||||
Self::Bytes { data, .. } => data.clone(),
|
||||
Self::Url(_) => {
|
||||
let (base64_data, _) = self.to_base64().await?;
|
||||
general_purpose::STANDARD
|
||||
.decode(&base64_data)
|
||||
.context("Failed to decode base64")?
|
||||
}
|
||||
};
|
||||
|
||||
let img = image::load_from_memory(&bytes).context("Failed to load image from bytes")?;
|
||||
Ok((img.width(), img.height()))
|
||||
}
|
||||
|
||||
/// Validate image size and format
|
||||
pub async fn validate(&self, max_size_mb: f64) -> Result<()> {
|
||||
let (width, height) = self.get_dimensions().await?;
|
||||
|
||||
// Check dimensions
|
||||
if width > 4096 || height > 4096 {
|
||||
warn!("Image dimensions too large: {}x{}", width, height);
|
||||
// Continue anyway, but log warning
|
||||
}
|
||||
|
||||
// Check file size
|
||||
let size_bytes = match self {
|
||||
Self::Base64 { data, .. } => {
|
||||
// Base64 size is ~4/3 of original
|
||||
(data.len() as f64 * 0.75) as usize
|
||||
}
|
||||
Self::Bytes { data, .. } => data.len(),
|
||||
Self::Url(_) => {
|
||||
// For URLs, we'd need to fetch to check size
|
||||
// Skip size check for URLs for now
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let size_mb = size_bytes as f64 / (1024.0 * 1024.0);
|
||||
if size_mb > max_size_mb {
|
||||
anyhow::bail!("Image too large: {:.2}MB > {:.2}MB limit", size_mb, max_size_mb);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Provider-specific image format conversion
|
||||
pub struct ImageConverter;
|
||||
|
||||
impl ImageConverter {
|
||||
/// Convert image to OpenAI-compatible format
|
||||
pub async fn to_openai_format(image: &ImageInput) -> Result<serde_json::Value> {
|
||||
let (base64_data, mime_type) = image.to_base64().await?;
|
||||
|
||||
// OpenAI expects data URL format: "data:image/jpeg;base64,{data}"
|
||||
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": data_url,
|
||||
"detail": "auto" // Can be "low", "high", or "auto"
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Convert image to Gemini-compatible format
|
||||
pub async fn to_gemini_format(image: &ImageInput) -> Result<serde_json::Value> {
|
||||
let (base64_data, mime_type) = image.to_base64().await?;
|
||||
|
||||
// Gemini expects inline data format
|
||||
Ok(serde_json::json!({
|
||||
"inline_data": {
|
||||
"mime_type": mime_type,
|
||||
"data": base64_data
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// Convert image to DeepSeek-compatible format
|
||||
pub async fn to_deepseek_format(image: &ImageInput) -> Result<serde_json::Value> {
|
||||
// DeepSeek uses OpenAI-compatible format for vision models
|
||||
Self::to_openai_format(image).await
|
||||
}
|
||||
|
||||
/// Detect if a model supports multimodal input
|
||||
pub fn model_supports_multimodal(model: &str) -> bool {
|
||||
// OpenAI vision models
|
||||
if (model.starts_with("gpt-4") && (model.contains("vision") || model.contains("-v") || model.contains("4o")))
|
||||
|| model.starts_with("o1-")
|
||||
|| model.starts_with("o3-")
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// Gemini vision models
|
||||
if model.starts_with("gemini") {
|
||||
// Most Gemini models support vision
|
||||
return true;
|
||||
}
|
||||
|
||||
// DeepSeek vision models
|
||||
if model.starts_with("deepseek-vl") {
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse OpenAI-compatible multimodal message content
|
||||
pub fn parse_openai_content(content: &serde_json::Value) -> Result<Vec<(String, Option<ImageInput>)>> {
|
||||
let mut parts = Vec::new();
|
||||
|
||||
if let Some(content_str) = content.as_str() {
|
||||
// Simple text content
|
||||
parts.push((content_str.to_string(), None));
|
||||
} else if let Some(content_array) = content.as_array() {
|
||||
// Array of content parts (text and/or images)
|
||||
for part in content_array {
|
||||
if let Some(part_obj) = part.as_object()
|
||||
&& let Some(part_type) = part_obj.get("type").and_then(|t| t.as_str())
|
||||
{
|
||||
match part_type {
|
||||
"text" => {
|
||||
if let Some(text) = part_obj.get("text").and_then(|t| t.as_str()) {
|
||||
parts.push((text.to_string(), None));
|
||||
}
|
||||
}
|
||||
"image_url" => {
|
||||
if let Some(image_url_obj) = part_obj.get("image_url").and_then(|o| o.as_object())
|
||||
&& let Some(url) = image_url_obj.get("url").and_then(|u| u.as_str())
|
||||
{
|
||||
if url.starts_with("data:") {
|
||||
// Parse data URL
|
||||
if let Some((mime_type, data)) = parse_data_url(url) {
|
||||
let image_input = ImageInput::from_base64(data, mime_type);
|
||||
parts.push(("".to_string(), Some(image_input)));
|
||||
}
|
||||
} else {
|
||||
// Regular URL
|
||||
let image_input = ImageInput::from_url(url.to_string());
|
||||
parts.push(("".to_string(), Some(image_input)));
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("Unknown content part type: {}", part_type);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(parts)
|
||||
}
|
||||
|
||||
/// Parse data URL (data:image/jpeg;base64,{data})
|
||||
fn parse_data_url(data_url: &str) -> Option<(String, String)> {
|
||||
if !data_url.starts_with("data:") {
|
||||
return None;
|
||||
}
|
||||
|
||||
let parts: Vec<&str> = data_url[5..].split(";base64,").collect();
|
||||
if parts.len() != 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mime_type = parts[0].to_string();
|
||||
let data = parts[1].to_string();
|
||||
|
||||
Some((mime_type, data))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_data_url() {
|
||||
let test_url = "data:image/jpeg;base64,SGVsbG8gV29ybGQ="; // "Hello World" in base64
|
||||
let (mime_type, data) = parse_data_url(test_url).unwrap();
|
||||
|
||||
assert_eq!(mime_type, "image/jpeg");
|
||||
assert_eq!(data, "SGVsbG8gV29ybGQ=");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_model_supports_multimodal() {
|
||||
assert!(ImageConverter::model_supports_multimodal("gpt-4-vision-preview"));
|
||||
assert!(ImageConverter::model_supports_multimodal("gpt-4o"));
|
||||
assert!(ImageConverter::model_supports_multimodal("gemini-pro-vision"));
|
||||
assert!(ImageConverter::model_supports_multimodal("gemini-pro"));
|
||||
assert!(!ImageConverter::model_supports_multimodal("gpt-3.5-turbo"));
|
||||
assert!(!ImageConverter::model_supports_multimodal("claude-3-opus"));
|
||||
}
|
||||
}
|
||||
@@ -1,129 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
|
||||
use super::helpers;
|
||||
use super::{ProviderResponse, ProviderStreamChunk};
|
||||
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
||||
|
||||
pub struct DeepSeekProvider {
|
||||
client: reqwest::Client,
|
||||
config: crate::config::DeepSeekConfig,
|
||||
api_key: String,
|
||||
pricing: Vec<crate::config::ModelPricing>,
|
||||
}
|
||||
|
||||
impl DeepSeekProvider {
|
||||
pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result<Self> {
|
||||
let api_key = app_config.get_api_key("deepseek")?;
|
||||
Self::new_with_key(config, app_config, api_key)
|
||||
}
|
||||
|
||||
pub fn new_with_key(
|
||||
config: &crate::config::DeepSeekConfig,
|
||||
app_config: &AppConfig,
|
||||
api_key: String,
|
||||
) -> Result<Self> {
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||
.pool_max_idle_per_host(4)
|
||||
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
config: config.clone(),
|
||||
api_key,
|
||||
pricing: app_config.pricing.deepseek.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Provider for DeepSeekProvider {
|
||||
fn name(&self) -> &str {
|
||||
"deepseek"
|
||||
}
|
||||
|
||||
fn supports_model(&self, model: &str) -> bool {
|
||||
model.starts_with("deepseek-") || model.contains("deepseek")
|
||||
}
|
||||
|
||||
fn supports_multimodal(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, false);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(AppError::ProviderError(format!("DeepSeek API error: {}", error_text)));
|
||||
}
|
||||
|
||||
let resp_json: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
helpers::parse_openai_response(&resp_json, request.model)
|
||||
}
|
||||
|
||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||
}
|
||||
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
cache_read_tokens: u32,
|
||||
cache_write_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
helpers::calculate_cost_with_registry(
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
registry,
|
||||
&self.pricing,
|
||||
0.14,
|
||||
0.28,
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_completion_stream(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
// DeepSeek doesn't support images in streaming, use text-only
|
||||
let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?;
|
||||
let mut body = helpers::build_openai_body(&request, messages_json, true);
|
||||
body.as_object_mut().expect("body is object").remove("stream_options");
|
||||
|
||||
let es = reqwest_eventsource::EventSource::new(
|
||||
self.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body),
|
||||
)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
Ok(helpers::create_openai_stream(es, request.model, None))
|
||||
}
|
||||
}
|
||||
@@ -1,836 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{ProviderResponse, ProviderStreamChunk};
|
||||
use crate::{
|
||||
config::AppConfig,
|
||||
errors::AppError,
|
||||
models::{ContentPart, FunctionCall, FunctionCallDelta, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest},
|
||||
};
|
||||
|
||||
// ========== Gemini Request Structs ==========
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiRequest {
|
||||
contents: Vec<GeminiContent>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
system_instruction: Option<GeminiContent>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
generation_config: Option<GeminiGenerationConfig>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<GeminiTool>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_config: Option<GeminiToolConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct GeminiContent {
|
||||
parts: Vec<GeminiPart>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
role: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiPart {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
text: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
inline_data: Option<GeminiInlineData>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
function_call: Option<GeminiFunctionCall>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
function_response: Option<GeminiFunctionResponse>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct GeminiInlineData {
|
||||
mime_type: String,
|
||||
data: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct GeminiFunctionCall {
|
||||
name: String,
|
||||
args: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct GeminiFunctionResponse {
|
||||
name: String,
|
||||
response: Value,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiGenerationConfig {
|
||||
temperature: Option<f64>,
|
||||
max_output_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
// ========== Gemini Tool Structs ==========
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiTool {
|
||||
function_declarations: Vec<GeminiFunctionDeclaration>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct GeminiFunctionDeclaration {
|
||||
name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
parameters: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiToolConfig {
|
||||
function_calling_config: GeminiFunctionCallingConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct GeminiFunctionCallingConfig {
|
||||
mode: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename = "allowedFunctionNames")]
|
||||
allowed_function_names: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
// ========== Gemini Response Structs ==========
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiCandidate {
|
||||
content: GeminiContent,
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)]
|
||||
finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiUsageMetadata {
|
||||
#[serde(default)]
|
||||
prompt_token_count: u32,
|
||||
#[serde(default)]
|
||||
candidates_token_count: u32,
|
||||
#[serde(default)]
|
||||
total_token_count: u32,
|
||||
#[serde(default)]
|
||||
cached_content_token_count: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiResponse {
|
||||
candidates: Vec<GeminiCandidate>,
|
||||
usage_metadata: Option<GeminiUsageMetadata>,
|
||||
}
|
||||
|
||||
// Streaming responses from Gemini may include messages without `candidates` (e.g. promptFeedback).
|
||||
// Use a more permissive struct for streaming to avoid aborting the SSE prematurely.
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiStreamResponse {
|
||||
#[serde(default)]
|
||||
candidates: Vec<GeminiStreamCandidate>,
|
||||
#[serde(default)]
|
||||
usage_metadata: Option<GeminiUsageMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct GeminiStreamCandidate {
|
||||
#[serde(default)]
|
||||
content: Option<GeminiContent>,
|
||||
#[serde(default)]
|
||||
finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
// ========== Provider Implementation ==========
|
||||
|
||||
pub struct GeminiProvider {
|
||||
client: reqwest::Client,
|
||||
config: crate::config::GeminiConfig,
|
||||
api_key: String,
|
||||
pricing: Vec<crate::config::ModelPricing>,
|
||||
}
|
||||
|
||||
impl GeminiProvider {
|
||||
pub fn new(config: &crate::config::GeminiConfig, app_config: &AppConfig) -> Result<Self> {
|
||||
let api_key = app_config.get_api_key("gemini")?;
|
||||
Self::new_with_key(config, app_config, api_key)
|
||||
}
|
||||
|
||||
pub fn new_with_key(config: &crate::config::GeminiConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||
.pool_max_idle_per_host(4)
|
||||
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
config: config.clone(),
|
||||
api_key,
|
||||
pricing: app_config.pricing.gemini.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert unified messages to Gemini content format.
|
||||
/// Handles text, images, tool calls (assistant), and tool results.
|
||||
/// Returns (contents, system_instruction)
|
||||
async fn convert_messages(
|
||||
messages: Vec<UnifiedMessage>,
|
||||
) -> Result<(Vec<GeminiContent>, Option<GeminiContent>), AppError> {
|
||||
let mut contents: Vec<GeminiContent> = Vec::new();
|
||||
let mut system_parts = Vec::new();
|
||||
|
||||
for msg in messages {
|
||||
if msg.role == "system" {
|
||||
for part in msg.content {
|
||||
if let ContentPart::Text { text } = part {
|
||||
if !text.trim().is_empty() {
|
||||
system_parts.push(GeminiPart {
|
||||
text: Some(text),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let role = match msg.role.as_str() {
|
||||
"assistant" => "model".to_string(),
|
||||
"tool" => "user".to_string(), // Tool results are technically from the user side in Gemini
|
||||
_ => "user".to_string(),
|
||||
};
|
||||
|
||||
let mut parts = Vec::new();
|
||||
|
||||
// Handle tool results (role "tool")
|
||||
if msg.role == "tool" {
|
||||
let text_content = msg
|
||||
.content
|
||||
.first()
|
||||
.map(|p| match p {
|
||||
ContentPart::Text { text } => text.clone(),
|
||||
ContentPart::Image(_) => "[Image]".to_string(),
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
// Gemini function response MUST have a name. Fallback to tool_call_id if name is missing.
|
||||
let name = msg.name.clone().or_else(|| msg.tool_call_id.clone()).unwrap_or_else(|| "unknown_function".to_string());
|
||||
let response_value = serde_json::from_str::<Value>(&text_content)
|
||||
.unwrap_or_else(|_| serde_json::json!({ "result": text_content }));
|
||||
|
||||
parts.push(GeminiPart {
|
||||
text: None,
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: Some(GeminiFunctionResponse {
|
||||
name,
|
||||
response: response_value,
|
||||
}),
|
||||
});
|
||||
} else if msg.role == "assistant" && msg.tool_calls.is_some() {
|
||||
// Assistant messages with tool_calls
|
||||
if let Some(tool_calls) = &msg.tool_calls {
|
||||
// Include text content if present
|
||||
for p in &msg.content {
|
||||
if let ContentPart::Text { text } = p {
|
||||
if !text.trim().is_empty() {
|
||||
parts.push(GeminiPart {
|
||||
text: Some(text.clone()),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for tc in tool_calls {
|
||||
let args = serde_json::from_str::<Value>(&tc.function.arguments)
|
||||
.unwrap_or_else(|_| serde_json::json!({}));
|
||||
parts.push(GeminiPart {
|
||||
text: None,
|
||||
inline_data: None,
|
||||
function_call: Some(GeminiFunctionCall {
|
||||
name: tc.function.name.clone(),
|
||||
args,
|
||||
}),
|
||||
function_response: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Regular text/image messages
|
||||
for part in msg.content {
|
||||
match part {
|
||||
ContentPart::Text { text } => {
|
||||
if !text.trim().is_empty() {
|
||||
parts.push(GeminiPart {
|
||||
text: Some(text),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
ContentPart::Image(image_input) => {
|
||||
let (base64_data, mime_type) = image_input
|
||||
.to_base64()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to convert image: {}", e)))?;
|
||||
|
||||
parts.push(GeminiPart {
|
||||
text: None,
|
||||
inline_data: Some(GeminiInlineData {
|
||||
mime_type,
|
||||
data: base64_data,
|
||||
}),
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if parts.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Merge with previous message if role matches
|
||||
if let Some(last_content) = contents.last_mut() {
|
||||
if last_content.role.as_ref() == Some(&role) {
|
||||
last_content.parts.extend(parts);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
contents.push(GeminiContent {
|
||||
parts,
|
||||
role: Some(role),
|
||||
});
|
||||
}
|
||||
|
||||
// Gemini requires the first message to be from "user".
|
||||
// If it starts with "model", we prepend a placeholder user message.
|
||||
if let Some(first) = contents.first() {
|
||||
if first.role.as_deref() == Some("model") {
|
||||
contents.insert(0, GeminiContent {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![GeminiPart {
|
||||
text: Some("Continue conversation.".to_string()),
|
||||
inline_data: None,
|
||||
function_call: None,
|
||||
function_response: None,
|
||||
}],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let system_instruction = if !system_parts.is_empty() {
|
||||
Some(GeminiContent {
|
||||
parts: system_parts,
|
||||
role: None,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok((contents, system_instruction))
|
||||
}
|
||||
|
||||
/// Convert OpenAI tools to Gemini function declarations.
|
||||
fn convert_tools(request: &UnifiedRequest) -> Option<Vec<GeminiTool>> {
|
||||
request.tools.as_ref().map(|tools| {
|
||||
let declarations: Vec<GeminiFunctionDeclaration> = tools
|
||||
.iter()
|
||||
.map(|t| GeminiFunctionDeclaration {
|
||||
name: t.function.name.clone(),
|
||||
description: t.function.description.clone(),
|
||||
parameters: t.function.parameters.clone(),
|
||||
})
|
||||
.collect();
|
||||
vec![GeminiTool {
|
||||
function_declarations: declarations,
|
||||
}]
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert OpenAI tool_choice to Gemini tool_config.
|
||||
fn convert_tool_config(request: &UnifiedRequest) -> Option<GeminiToolConfig> {
|
||||
request.tool_choice.as_ref().map(|tc| {
|
||||
let (mode, allowed_names) = match tc {
|
||||
crate::models::ToolChoice::Mode(mode) => {
|
||||
let gemini_mode = match mode.as_str() {
|
||||
"auto" => "AUTO",
|
||||
"none" => "NONE",
|
||||
"required" => "ANY",
|
||||
_ => "AUTO",
|
||||
};
|
||||
(gemini_mode.to_string(), None)
|
||||
}
|
||||
crate::models::ToolChoice::Specific(specific) => {
|
||||
("ANY".to_string(), Some(vec![specific.function.name.clone()]))
|
||||
}
|
||||
};
|
||||
GeminiToolConfig {
|
||||
function_calling_config: GeminiFunctionCallingConfig {
|
||||
mode,
|
||||
allowed_function_names: allowed_names,
|
||||
},
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract tool calls from Gemini response parts into OpenAI-format ToolCalls.
|
||||
fn extract_tool_calls(parts: &[GeminiPart]) -> Option<Vec<ToolCall>> {
|
||||
let calls: Vec<ToolCall> = parts
|
||||
.iter()
|
||||
.filter_map(|p| p.function_call.as_ref())
|
||||
.map(|fc| ToolCall {
|
||||
id: format!("call_{}", Uuid::new_v4().simple()),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: fc.name.clone(),
|
||||
arguments: serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string()),
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
||||
if calls.is_empty() { None } else { Some(calls) }
|
||||
}
|
||||
|
||||
/// Extract tool call deltas from Gemini response parts for streaming.
|
||||
fn extract_tool_call_deltas(parts: &[GeminiPart]) -> Option<Vec<ToolCallDelta>> {
|
||||
let deltas: Vec<ToolCallDelta> = parts
|
||||
.iter()
|
||||
.filter_map(|p| p.function_call.as_ref())
|
||||
.enumerate()
|
||||
.map(|(i, fc)| ToolCallDelta {
|
||||
index: i as u32,
|
||||
id: Some(format!("call_{}", Uuid::new_v4().simple())),
|
||||
call_type: Some("function".to_string()),
|
||||
function: Some(FunctionCallDelta {
|
||||
name: Some(fc.name.clone()),
|
||||
arguments: Some(serde_json::to_string(&fc.args).unwrap_or_else(|_| "{}".to_string())),
|
||||
}),
|
||||
})
|
||||
.collect();
|
||||
|
||||
if deltas.is_empty() { None } else { Some(deltas) }
|
||||
}
|
||||
|
||||
/// Determine the appropriate base URL for the model.
|
||||
/// "preview" models often require the v1beta endpoint, but newer promoted ones may be on v1.
|
||||
fn get_base_url(&self, model: &str) -> String {
|
||||
// Only use v1beta for older preview models or specific "thinking" experimental models.
|
||||
// Gemini 3.0+ models are typically released on v1 even in preview.
|
||||
if (model.contains("preview") && !model.contains("gemini-3")) || model.contains("thinking") {
|
||||
self.config.base_url.replace("/v1", "/v1beta")
|
||||
} else {
|
||||
self.config.base_url.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Provider for GeminiProvider {
|
||||
fn name(&self) -> &str {
|
||||
"gemini"
|
||||
}
|
||||
|
||||
fn supports_model(&self, model: &str) -> bool {
|
||||
model.starts_with("gemini-")
|
||||
}
|
||||
|
||||
fn supports_multimodal(&self) -> bool {
|
||||
true // Gemini supports vision
|
||||
}
|
||||
|
||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
let mut model = request.model.clone();
|
||||
|
||||
// Normalize model name: If it's a known Gemini model version, use it;
|
||||
// otherwise, if it starts with gemini- but is an unknown legacy version,
|
||||
// fallback to the default model to avoid 400 errors.
|
||||
// We now allow gemini-3+ as valid versions.
|
||||
let is_known_version = model.starts_with("gemini-1.5") ||
|
||||
model.starts_with("gemini-2.0") ||
|
||||
model.starts_with("gemini-2.5") ||
|
||||
model.starts_with("gemini-3");
|
||||
|
||||
if !is_known_version && model.starts_with("gemini-") {
|
||||
tracing::info!("Mapping unknown Gemini model {} to default {}", model, self.config.default_model);
|
||||
model = self.config.default_model.clone();
|
||||
}
|
||||
|
||||
let tools = Self::convert_tools(&request);
|
||||
let tool_config = Self::convert_tool_config(&request);
|
||||
let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?;
|
||||
|
||||
if contents.is_empty() {
|
||||
return Err(AppError::ProviderError("No valid messages to send".to_string()));
|
||||
}
|
||||
|
||||
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
|
||||
// Some Gemini models (especially 1.5) have lower max_output_tokens limits (e.g. 8192)
|
||||
// than what clients like opencode might request. Clamp to a safe maximum.
|
||||
// Note: Gemini 2.0+ supports much higher limits, but 8192 is a safe universal floor.
|
||||
let max_tokens = request.max_tokens.map(|t| t.min(8192));
|
||||
|
||||
Some(GeminiGenerationConfig {
|
||||
temperature: request.temperature,
|
||||
max_output_tokens: max_tokens,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let gemini_request = GeminiRequest {
|
||||
contents,
|
||||
system_instruction,
|
||||
generation_config,
|
||||
tools,
|
||||
tool_config,
|
||||
};
|
||||
|
||||
let base_url = self.get_base_url(&model);
|
||||
let url = format!("{}/models/{}:generateContent", base_url, model);
|
||||
tracing::debug!("Calling Gemini API: {}", url);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("x-goog-api-key", &self.api_key)
|
||||
.json(&gemini_request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(format!("HTTP request failed: {}", e)))?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(AppError::ProviderError(format!(
|
||||
"Gemini API error ({}): {}",
|
||||
status, error_text
|
||||
)));
|
||||
}
|
||||
|
||||
let gemini_response: GeminiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse response: {}", e)))?;
|
||||
|
||||
let candidate = gemini_response.candidates.first();
|
||||
|
||||
// Extract text content (may be absent if only function calls)
|
||||
let content = candidate
|
||||
.and_then(|c| c.content.parts.iter().find_map(|p| p.text.clone()))
|
||||
.unwrap_or_default();
|
||||
|
||||
// Extract function calls → OpenAI tool_calls
|
||||
let tool_calls = candidate.and_then(|c| Self::extract_tool_calls(&c.content.parts));
|
||||
|
||||
let prompt_tokens = gemini_response
|
||||
.usage_metadata
|
||||
.as_ref()
|
||||
.map(|u| u.prompt_token_count)
|
||||
.unwrap_or(0);
|
||||
let completion_tokens = gemini_response
|
||||
.usage_metadata
|
||||
.as_ref()
|
||||
.map(|u| u.candidates_token_count)
|
||||
.unwrap_or(0);
|
||||
let total_tokens = gemini_response
|
||||
.usage_metadata
|
||||
.as_ref()
|
||||
.map(|u| u.total_token_count)
|
||||
.unwrap_or(0);
|
||||
let cache_read_tokens = gemini_response
|
||||
.usage_metadata
|
||||
.as_ref()
|
||||
.map(|u| u.cached_content_token_count)
|
||||
.unwrap_or(0);
|
||||
|
||||
Ok(ProviderResponse {
|
||||
content,
|
||||
reasoning_content: None,
|
||||
tool_calls,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens: 0, // Gemini doesn't report cache writes separately
|
||||
model,
|
||||
})
|
||||
}
|
||||
|
||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||
}
|
||||
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
cache_read_tokens: u32,
|
||||
cache_write_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
super::helpers::calculate_cost_with_registry(
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
registry,
|
||||
&self.pricing,
|
||||
0.075,
|
||||
0.30,
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_completion_stream(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
let mut model = request.model.clone();
|
||||
|
||||
// Normalize model name: fallback to default if unknown Gemini model is requested
|
||||
let is_known_version = model.starts_with("gemini-1.5") ||
|
||||
model.starts_with("gemini-2.0") ||
|
||||
model.starts_with("gemini-2.5") ||
|
||||
model.starts_with("gemini-3");
|
||||
|
||||
if !is_known_version && model.starts_with("gemini-") {
|
||||
tracing::info!("Mapping unknown Gemini model {} to default {}", model, self.config.default_model);
|
||||
model = self.config.default_model.clone();
|
||||
}
|
||||
|
||||
let tools = Self::convert_tools(&request);
|
||||
let tool_config = Self::convert_tool_config(&request);
|
||||
let (contents, system_instruction) = Self::convert_messages(request.messages.clone()).await?;
|
||||
|
||||
if contents.is_empty() {
|
||||
return Err(AppError::ProviderError("No valid messages to send".to_string()));
|
||||
}
|
||||
|
||||
let generation_config = if request.temperature.is_some() || request.max_tokens.is_some() {
|
||||
// Some Gemini models (especially 1.5) have lower max_output_tokens limits (e.g. 8192)
|
||||
// than what clients like opencode might request. Clamp to a safe maximum.
|
||||
let max_tokens = request.max_tokens.map(|t| t.min(8192));
|
||||
|
||||
Some(GeminiGenerationConfig {
|
||||
temperature: request.temperature,
|
||||
max_output_tokens: max_tokens,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let gemini_request = GeminiRequest {
|
||||
contents,
|
||||
system_instruction,
|
||||
generation_config,
|
||||
tools,
|
||||
tool_config,
|
||||
};
|
||||
|
||||
let base_url = self.get_base_url(&model);
|
||||
let url = format!(
|
||||
"{}/models/{}:streamGenerateContent?alt=sse",
|
||||
base_url, model,
|
||||
);
|
||||
tracing::debug!("Calling Gemini Stream API: {}", url);
|
||||
|
||||
// (no fallback_request needed here)
|
||||
|
||||
use futures::StreamExt;
|
||||
use reqwest_eventsource::Event;
|
||||
|
||||
// Try to create an SSE event source for streaming. If creation fails
|
||||
// (provider doesn't support streaming for this model or returned a
|
||||
// non-2xx response), fall back to a synchronous generateContent call
|
||||
// and emit a single chunk.
|
||||
// Prepare clones for HTTP fallback usage inside non-streaming paths.
|
||||
let http_client = self.client.clone();
|
||||
let http_api_key = self.api_key.clone();
|
||||
let http_base = base_url.clone();
|
||||
let gemini_request_clone = gemini_request.clone();
|
||||
|
||||
let es_result = reqwest_eventsource::EventSource::new(
|
||||
self.client
|
||||
.post(&url)
|
||||
.header("x-goog-api-key", &self.api_key)
|
||||
.json(&gemini_request),
|
||||
);
|
||||
|
||||
if let Err(_e) = es_result {
|
||||
// Fallback: call non-streaming generateContent via HTTP and convert to a single-stream chunk
|
||||
let resp_http = http_client
|
||||
.post(format!("{}/models/{}:generateContent", http_base, model))
|
||||
.header("x-goog-api-key", &http_api_key)
|
||||
.json(&gemini_request_clone)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e2| AppError::ProviderError(format!("Failed to call generateContent fallback: {}", e2)))?;
|
||||
|
||||
if !resp_http.status().is_success() {
|
||||
let status = resp_http.status();
|
||||
let err = resp_http.text().await.unwrap_or_default();
|
||||
return Err(AppError::ProviderError(format!("Gemini API error ({}): {}", status, err)));
|
||||
}
|
||||
|
||||
let gemini_response: GeminiResponse = resp_http
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e2| AppError::ProviderError(format!("Failed to parse generateContent response: {}", e2)))?;
|
||||
|
||||
let candidate = gemini_response.candidates.first();
|
||||
let content = candidate
|
||||
.and_then(|c| c.content.parts.iter().find_map(|p| p.text.clone()))
|
||||
.unwrap_or_default();
|
||||
|
||||
let prompt_tokens = gemini_response
|
||||
.usage_metadata
|
||||
.as_ref()
|
||||
.map(|u| u.prompt_token_count)
|
||||
.unwrap_or(0);
|
||||
let completion_tokens = gemini_response
|
||||
.usage_metadata
|
||||
.as_ref()
|
||||
.map(|u| u.candidates_token_count)
|
||||
.unwrap_or(0);
|
||||
let total_tokens = gemini_response
|
||||
.usage_metadata
|
||||
.as_ref()
|
||||
.map(|u| u.total_token_count)
|
||||
.unwrap_or(0);
|
||||
|
||||
let single_stream = async_stream::try_stream! {
|
||||
let chunk = ProviderStreamChunk {
|
||||
content,
|
||||
reasoning_content: None,
|
||||
finish_reason: Some("stop".to_string()),
|
||||
tool_calls: None,
|
||||
model: model.clone(),
|
||||
usage: Some(super::StreamUsage {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cache_read_tokens: gemini_response.usage_metadata.as_ref().map(|u| u.cached_content_token_count).unwrap_or(0),
|
||||
cache_write_tokens: 0,
|
||||
}),
|
||||
};
|
||||
|
||||
yield chunk;
|
||||
};
|
||||
|
||||
return Ok(Box::pin(single_stream));
|
||||
}
|
||||
|
||||
let es = es_result.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
let stream = async_stream::try_stream! {
|
||||
let mut es = es;
|
||||
while let Some(event) = es.next().await {
|
||||
match event {
|
||||
Ok(Event::Message(msg)) => {
|
||||
let gemini_response: GeminiStreamResponse = serde_json::from_str(&msg.data)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||
|
||||
// Extract usage from usageMetadata if present (reported on every/last chunk)
|
||||
let stream_usage = gemini_response.usage_metadata.as_ref().map(|u| {
|
||||
super::StreamUsage {
|
||||
prompt_tokens: u.prompt_token_count,
|
||||
completion_tokens: u.candidates_token_count,
|
||||
total_tokens: u.total_token_count,
|
||||
cache_read_tokens: u.cached_content_token_count,
|
||||
cache_write_tokens: 0,
|
||||
}
|
||||
});
|
||||
|
||||
// Some streaming events may not contain candidates (e.g. promptFeedback).
|
||||
// Only emit chunks when we have candidate content or tool calls.
|
||||
if let Some(candidate) = gemini_response.candidates.first() {
|
||||
if let Some(content_obj) = &candidate.content {
|
||||
let content = content_obj
|
||||
.parts
|
||||
.iter()
|
||||
.find_map(|p| p.text.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
let tool_calls = Self::extract_tool_call_deltas(&content_obj.parts);
|
||||
|
||||
// Determine finish_reason
|
||||
let finish_reason = candidate.finish_reason.as_ref().map(|fr| {
|
||||
match fr.as_str() {
|
||||
"STOP" => "stop".to_string(),
|
||||
_ => fr.to_lowercase(),
|
||||
}
|
||||
});
|
||||
|
||||
// Avoid emitting completely empty chunks unless they carry usage.
|
||||
if !content.is_empty() || tool_calls.is_some() || stream_usage.is_some() {
|
||||
yield ProviderStreamChunk {
|
||||
content,
|
||||
reasoning_content: None,
|
||||
finish_reason,
|
||||
tool_calls,
|
||||
model: model.clone(),
|
||||
usage: stream_usage,
|
||||
};
|
||||
}
|
||||
} else if stream_usage.is_some() {
|
||||
// Usage-only update
|
||||
yield ProviderStreamChunk {
|
||||
content: String::new(),
|
||||
reasoning_content: None,
|
||||
finish_reason: None,
|
||||
tool_calls: None,
|
||||
model: model.clone(),
|
||||
usage: stream_usage,
|
||||
};
|
||||
}
|
||||
} else if stream_usage.is_some() {
|
||||
// No candidates but usage present
|
||||
yield ProviderStreamChunk {
|
||||
content: String::new(),
|
||||
reasoning_content: None,
|
||||
finish_reason: None,
|
||||
tool_calls: None,
|
||||
model: model.clone(),
|
||||
usage: stream_usage,
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok(_) => continue,
|
||||
Err(e) => {
|
||||
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
|
||||
use super::helpers;
|
||||
use super::{ProviderResponse, ProviderStreamChunk};
|
||||
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
||||
|
||||
pub struct GrokProvider {
|
||||
client: reqwest::Client,
|
||||
config: crate::config::GrokConfig,
|
||||
api_key: String,
|
||||
pricing: Vec<crate::config::ModelPricing>,
|
||||
}
|
||||
|
||||
impl GrokProvider {
|
||||
pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result<Self> {
|
||||
let api_key = app_config.get_api_key("grok")?;
|
||||
Self::new_with_key(config, app_config, api_key)
|
||||
}
|
||||
|
||||
pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||
.pool_max_idle_per_host(4)
|
||||
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
config: config.clone(),
|
||||
api_key,
|
||||
pricing: app_config.pricing.grok.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Provider for GrokProvider {
|
||||
fn name(&self) -> &str {
|
||||
"grok"
|
||||
}
|
||||
|
||||
fn supports_model(&self, model: &str) -> bool {
|
||||
model.starts_with("grok-")
|
||||
}
|
||||
|
||||
fn supports_multimodal(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, false);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(AppError::ProviderError(format!("Grok API error: {}", error_text)));
|
||||
}
|
||||
|
||||
let resp_json: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
helpers::parse_openai_response(&resp_json, request.model)
|
||||
}
|
||||
|
||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||
}
|
||||
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
cache_read_tokens: u32,
|
||||
cache_write_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
helpers::calculate_cost_with_registry(
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
registry,
|
||||
&self.pricing,
|
||||
5.0,
|
||||
15.0,
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_completion_stream(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, true);
|
||||
|
||||
let es = reqwest_eventsource::EventSource::new(
|
||||
self.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body),
|
||||
)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
Ok(helpers::create_openai_stream(es, request.model, None))
|
||||
}
|
||||
}
|
||||
@@ -1,390 +0,0 @@
|
||||
use super::{ProviderResponse, ProviderStreamChunk, StreamUsage};
|
||||
use crate::errors::AppError;
|
||||
use crate::models::{ContentPart, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest};
|
||||
use futures::stream::{BoxStream, StreamExt};
|
||||
use serde_json::Value;
|
||||
|
||||
/// Convert messages to OpenAI-compatible JSON, resolving images asynchronously.
|
||||
///
|
||||
/// This avoids the deadlock caused by `futures::executor::block_on` inside a
|
||||
/// Tokio async context. All image base64 conversions are awaited properly.
|
||||
/// Handles tool-calling messages: assistant messages with tool_calls, and
|
||||
/// tool-role messages with tool_call_id/name.
|
||||
pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result<Vec<serde_json::Value>, AppError> {
|
||||
let mut result = Vec::new();
|
||||
for m in messages {
|
||||
// Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." }
|
||||
if m.role == "tool" {
|
||||
let text_content = m
|
||||
.content
|
||||
.first()
|
||||
.map(|p| match p {
|
||||
ContentPart::Text { text } => text.clone(),
|
||||
ContentPart::Image(_) => "[Image]".to_string(),
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut msg = serde_json::json!({
|
||||
"role": "tool",
|
||||
"content": text_content
|
||||
});
|
||||
if let Some(tool_call_id) = &m.tool_call_id {
|
||||
msg["tool_call_id"] = serde_json::json!(tool_call_id);
|
||||
}
|
||||
if let Some(name) = &m.name {
|
||||
msg["name"] = serde_json::json!(name);
|
||||
}
|
||||
result.push(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Build content parts for non-tool messages
|
||||
let mut parts = Vec::new();
|
||||
for p in &m.content {
|
||||
match p {
|
||||
ContentPart::Text { text } => {
|
||||
parts.push(serde_json::json!({ "type": "text", "text": text }));
|
||||
}
|
||||
ContentPart::Image(image_input) => {
|
||||
let (base64_data, mime_type) = image_input
|
||||
.to_base64()
|
||||
.await
|
||||
.map_err(|e| AppError::MultimodalError(e.to_string()))?;
|
||||
parts.push(serde_json::json!({
|
||||
"type": "image_url",
|
||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut msg = serde_json::json!({ "role": m.role });
|
||||
|
||||
// For assistant messages with tool_calls, content can be null
|
||||
if let Some(tool_calls) = &m.tool_calls {
|
||||
if parts.is_empty() {
|
||||
msg["content"] = serde_json::Value::Null;
|
||||
} else {
|
||||
msg["content"] = serde_json::json!(parts);
|
||||
}
|
||||
msg["tool_calls"] = serde_json::json!(tool_calls);
|
||||
} else {
|
||||
msg["content"] = serde_json::json!(parts);
|
||||
}
|
||||
|
||||
if let Some(name) = &m.name {
|
||||
msg["name"] = serde_json::json!(name);
|
||||
}
|
||||
|
||||
result.push(msg);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Convert messages to OpenAI-compatible JSON, but replace images with a
|
||||
/// text placeholder "[Image]". Useful for providers that don't support
|
||||
/// multimodal in streaming mode or at all.
|
||||
///
|
||||
/// Handles tool-calling messages identically to `messages_to_openai_json`:
|
||||
/// assistant messages with `tool_calls`, and tool-role messages with
|
||||
/// `tool_call_id`/`name`.
|
||||
pub async fn messages_to_openai_json_text_only(
|
||||
messages: &[UnifiedMessage],
|
||||
) -> Result<Vec<serde_json::Value>, AppError> {
|
||||
let mut result = Vec::new();
|
||||
for m in messages {
|
||||
// Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." }
|
||||
if m.role == "tool" {
|
||||
let text_content = m
|
||||
.content
|
||||
.first()
|
||||
.map(|p| match p {
|
||||
ContentPart::Text { text } => text.clone(),
|
||||
ContentPart::Image(_) => "[Image]".to_string(),
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut msg = serde_json::json!({
|
||||
"role": "tool",
|
||||
"content": text_content
|
||||
});
|
||||
if let Some(tool_call_id) = &m.tool_call_id {
|
||||
msg["tool_call_id"] = serde_json::json!(tool_call_id);
|
||||
}
|
||||
if let Some(name) = &m.name {
|
||||
msg["name"] = serde_json::json!(name);
|
||||
}
|
||||
result.push(msg);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Build content parts for non-tool messages (images become "[Image]" text)
|
||||
let mut parts = Vec::new();
|
||||
for p in &m.content {
|
||||
match p {
|
||||
ContentPart::Text { text } => {
|
||||
parts.push(serde_json::json!({ "type": "text", "text": text }));
|
||||
}
|
||||
ContentPart::Image(_) => {
|
||||
parts.push(serde_json::json!({ "type": "text", "text": "[Image]" }));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut msg = serde_json::json!({ "role": m.role });
|
||||
|
||||
// For assistant messages with tool_calls, content can be null
|
||||
if let Some(tool_calls) = &m.tool_calls {
|
||||
if parts.is_empty() {
|
||||
msg["content"] = serde_json::Value::Null;
|
||||
} else {
|
||||
msg["content"] = serde_json::json!(parts);
|
||||
}
|
||||
msg["tool_calls"] = serde_json::json!(tool_calls);
|
||||
} else {
|
||||
msg["content"] = serde_json::json!(parts);
|
||||
}
|
||||
|
||||
if let Some(name) = &m.name {
|
||||
msg["name"] = serde_json::json!(name);
|
||||
}
|
||||
|
||||
result.push(msg);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Build an OpenAI-compatible request body from a UnifiedRequest and pre-converted messages.
|
||||
/// Includes tools and tool_choice when present.
|
||||
/// When streaming, adds `stream_options.include_usage: true` so providers report
|
||||
/// token counts in the final SSE chunk.
|
||||
pub fn build_openai_body(
|
||||
request: &UnifiedRequest,
|
||||
messages_json: Vec<serde_json::Value>,
|
||||
stream: bool,
|
||||
) -> serde_json::Value {
|
||||
let mut body = serde_json::json!({
|
||||
"model": request.model,
|
||||
"messages": messages_json,
|
||||
"stream": stream,
|
||||
});
|
||||
|
||||
if stream {
|
||||
body["stream_options"] = serde_json::json!({ "include_usage": true });
|
||||
}
|
||||
|
||||
if let Some(temp) = request.temperature {
|
||||
body["temperature"] = serde_json::json!(temp);
|
||||
}
|
||||
if let Some(max_tokens) = request.max_tokens {
|
||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
||||
}
|
||||
if let Some(tools) = &request.tools {
|
||||
body["tools"] = serde_json::json!(tools);
|
||||
}
|
||||
if let Some(tool_choice) = &request.tool_choice {
|
||||
body["tool_choice"] = serde_json::json!(tool_choice);
|
||||
}
|
||||
|
||||
body
|
||||
}
|
||||
|
||||
/// Parse an OpenAI-compatible chat completion response JSON into a ProviderResponse.
|
||||
/// Extracts tool_calls from the message when present.
|
||||
/// Extracts cache token counts from:
|
||||
/// - OpenAI/Grok: `usage.prompt_tokens_details.cached_tokens`
|
||||
/// - DeepSeek: `usage.prompt_cache_hit_tokens` / `usage.prompt_cache_miss_tokens`
|
||||
pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<ProviderResponse, AppError> {
|
||||
let choice = resp_json["choices"]
|
||||
.get(0)
|
||||
.ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
||||
let message = &choice["message"];
|
||||
|
||||
let content = message["content"].as_str().unwrap_or_default().to_string();
|
||||
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
|
||||
|
||||
// Parse tool_calls from the response message
|
||||
let tool_calls: Option<Vec<ToolCall>> = message
|
||||
.get("tool_calls")
|
||||
.and_then(|tc| serde_json::from_value(tc.clone()).ok());
|
||||
|
||||
let usage = &resp_json["usage"];
|
||||
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
|
||||
// Extract cache tokens — try OpenAI/Grok format first, then DeepSeek format
|
||||
let cache_read_tokens = usage["prompt_tokens_details"]["cached_tokens"]
|
||||
.as_u64()
|
||||
// DeepSeek uses a different field name
|
||||
.or_else(|| usage["prompt_cache_hit_tokens"].as_u64())
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
// DeepSeek reports cache_write as prompt_cache_miss_tokens (tokens written to cache for future use).
|
||||
// OpenAI doesn't report cache_write in this location, but may in the future.
|
||||
let cache_write_tokens = usage["prompt_cache_miss_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
|
||||
Ok(ProviderResponse {
|
||||
content,
|
||||
reasoning_content,
|
||||
tool_calls,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
model,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create an SSE stream that parses OpenAI-compatible streaming chunks.
|
||||
///
|
||||
/// The optional `reasoning_field` allows overriding the field name for
|
||||
/// reasoning content (e.g., "thought" for Ollama).
|
||||
/// Parses tool_calls deltas from streaming chunks when present.
|
||||
/// When `stream_options.include_usage: true` was sent, the provider sends a
|
||||
/// final chunk with `usage` data — this is parsed into `StreamUsage` and
|
||||
/// attached to the yielded `ProviderStreamChunk`.
|
||||
pub fn create_openai_stream(
|
||||
es: reqwest_eventsource::EventSource,
|
||||
model: String,
|
||||
reasoning_field: Option<&'static str>,
|
||||
) -> BoxStream<'static, Result<ProviderStreamChunk, AppError>> {
|
||||
use reqwest_eventsource::Event;
|
||||
|
||||
let stream = async_stream::try_stream! {
|
||||
let mut es = es;
|
||||
while let Some(event) = es.next().await {
|
||||
match event {
|
||||
Ok(Event::Message(msg)) => {
|
||||
if msg.data == "[DONE]" {
|
||||
break;
|
||||
}
|
||||
|
||||
let chunk: Value = serde_json::from_str(&msg.data)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
||||
|
||||
// Parse usage from the final chunk (sent when stream_options.include_usage is true).
|
||||
// This chunk may have an empty `choices` array.
|
||||
let stream_usage = chunk.get("usage").and_then(|u| {
|
||||
if u.is_null() {
|
||||
return None;
|
||||
}
|
||||
let prompt_tokens = u["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let completion_tokens = u["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let total_tokens = u["total_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
|
||||
let cache_read_tokens = u["prompt_tokens_details"]["cached_tokens"]
|
||||
.as_u64()
|
||||
.or_else(|| u["prompt_cache_hit_tokens"].as_u64())
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
let cache_write_tokens = u["prompt_cache_miss_tokens"]
|
||||
.as_u64()
|
||||
.unwrap_or(0) as u32;
|
||||
|
||||
Some(StreamUsage {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
})
|
||||
});
|
||||
|
||||
if let Some(choice) = chunk["choices"].get(0) {
|
||||
let delta = &choice["delta"];
|
||||
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
||||
let reasoning_content = delta["reasoning_content"]
|
||||
.as_str()
|
||||
.or_else(|| reasoning_field.and_then(|f| delta[f].as_str()))
|
||||
.map(|s| s.to_string());
|
||||
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
||||
|
||||
// Parse tool_calls deltas from the stream chunk
|
||||
let tool_calls: Option<Vec<ToolCallDelta>> = delta
|
||||
.get("tool_calls")
|
||||
.and_then(|tc| serde_json::from_value(tc.clone()).ok());
|
||||
|
||||
yield ProviderStreamChunk {
|
||||
content,
|
||||
reasoning_content,
|
||||
finish_reason,
|
||||
tool_calls,
|
||||
model: model.clone(),
|
||||
usage: stream_usage,
|
||||
};
|
||||
} else if stream_usage.is_some() {
|
||||
// Final usage-only chunk (empty choices array) — yield it so
|
||||
// AggregatingStream can capture the real token counts.
|
||||
yield ProviderStreamChunk {
|
||||
content: String::new(),
|
||||
reasoning_content: None,
|
||||
finish_reason: None,
|
||||
tool_calls: None,
|
||||
model: model.clone(),
|
||||
usage: stream_usage,
|
||||
};
|
||||
}
|
||||
}
|
||||
Ok(_) => continue,
|
||||
Err(e) => {
|
||||
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Box::pin(stream)
|
||||
}
|
||||
|
||||
/// Calculate cost using the model registry first, then falling back to provider pricing config.
|
||||
///
|
||||
/// When the registry provides `cache_read` / `cache_write` rates, the formula is:
|
||||
/// (prompt_tokens - cache_read_tokens) * input_rate
|
||||
/// + cache_read_tokens * cache_read_rate
|
||||
/// + cache_write_tokens * cache_write_rate (if applicable)
|
||||
/// + completion_tokens * output_rate
|
||||
///
|
||||
/// All rates are per-token (the registry stores per-million-token rates).
|
||||
pub fn calculate_cost_with_registry(
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
cache_read_tokens: u32,
|
||||
cache_write_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
pricing: &[crate::config::ModelPricing],
|
||||
default_prompt_rate: f64,
|
||||
default_completion_rate: f64,
|
||||
) -> f64 {
|
||||
if let Some(metadata) = registry.find_model(model)
|
||||
&& let Some(cost) = &metadata.cost
|
||||
{
|
||||
let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens);
|
||||
let mut total = (non_cached_prompt as f64 * cost.input / 1_000_000.0)
|
||||
+ (completion_tokens as f64 * cost.output / 1_000_000.0);
|
||||
|
||||
if let Some(cache_read_rate) = cost.cache_read {
|
||||
total += cache_read_tokens as f64 * cache_read_rate / 1_000_000.0;
|
||||
} else {
|
||||
// No cache_read rate — charge cached tokens at full input rate
|
||||
total += cache_read_tokens as f64 * cost.input / 1_000_000.0;
|
||||
}
|
||||
|
||||
if let Some(cache_write_rate) = cost.cache_write {
|
||||
total += cache_write_tokens as f64 * cache_write_rate / 1_000_000.0;
|
||||
}
|
||||
|
||||
return total;
|
||||
}
|
||||
|
||||
// Fallback: no registry entry — use provider pricing config (no cache awareness)
|
||||
let (prompt_rate, completion_rate) = pricing
|
||||
.iter()
|
||||
.find(|p| model.contains(&p.model))
|
||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
||||
.unwrap_or((default_prompt_rate, default_completion_rate));
|
||||
|
||||
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
||||
}
|
||||
@@ -1,334 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use sqlx::Row;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::errors::AppError;
|
||||
use crate::models::UnifiedRequest;
|
||||
|
||||
pub mod deepseek;
|
||||
pub mod gemini;
|
||||
pub mod grok;
|
||||
pub mod helpers;
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
/// Get provider name (e.g., "openai", "gemini")
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Check if provider supports a specific model
|
||||
fn supports_model(&self, model: &str) -> bool;
|
||||
|
||||
/// Check if provider supports multimodal (images, etc.)
|
||||
fn supports_multimodal(&self) -> bool;
|
||||
|
||||
/// Process a chat completion request
|
||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError>;
|
||||
|
||||
/// Process a chat request using provider-specific "responses" style endpoint
|
||||
/// Default implementation falls back to `chat_completion` for providers
|
||||
/// that do not implement a dedicated responses endpoint.
|
||||
async fn chat_responses(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
self.chat_completion(request).await
|
||||
}
|
||||
|
||||
/// Process a streaming chat completion request
|
||||
async fn chat_completion_stream(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError>;
|
||||
|
||||
/// Estimate token count for a request (for cost calculation)
|
||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32>;
|
||||
|
||||
/// Calculate cost based on token usage and model using the registry.
|
||||
/// `cache_read_tokens` / `cache_write_tokens` allow cache-aware pricing
|
||||
/// when the registry provides `cache_read` / `cache_write` rates.
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
cache_read_tokens: u32,
|
||||
cache_write_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64;
|
||||
}
|
||||
|
||||
pub struct ProviderResponse {
|
||||
pub content: String,
|
||||
pub reasoning_content: Option<String>,
|
||||
pub tool_calls: Option<Vec<crate::models::ToolCall>>,
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
pub cache_read_tokens: u32,
|
||||
pub cache_write_tokens: u32,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
/// Usage data from the final streaming chunk (when providers report real token counts).
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct StreamUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
pub cache_read_tokens: u32,
|
||||
pub cache_write_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProviderStreamChunk {
|
||||
pub content: String,
|
||||
pub reasoning_content: Option<String>,
|
||||
pub finish_reason: Option<String>,
|
||||
pub tool_calls: Option<Vec<crate::models::ToolCallDelta>>,
|
||||
pub model: String,
|
||||
/// Populated only on the final chunk when providers report usage (e.g. stream_options.include_usage).
|
||||
pub usage: Option<StreamUsage>,
|
||||
}
|
||||
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::config::AppConfig;
|
||||
use crate::providers::{
|
||||
deepseek::DeepSeekProvider, gemini::GeminiProvider, grok::GrokProvider, ollama::OllamaProvider,
|
||||
openai::OpenAIProvider,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ProviderManager {
|
||||
providers: Arc<RwLock<Vec<Arc<dyn Provider>>>>,
|
||||
}
|
||||
|
||||
impl Default for ProviderManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ProviderManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
providers: Arc::new(RwLock::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize a provider by name using config and database overrides
|
||||
pub async fn initialize_provider(
|
||||
&self,
|
||||
name: &str,
|
||||
app_config: &AppConfig,
|
||||
db_pool: &crate::database::DbPool,
|
||||
) -> Result<()> {
|
||||
// Load override from database
|
||||
let db_config = sqlx::query("SELECT enabled, base_url, api_key FROM provider_configs WHERE id = ?")
|
||||
.bind(name)
|
||||
.fetch_optional(db_pool)
|
||||
.await?;
|
||||
|
||||
let (enabled, base_url, api_key) = if let Some(row) = db_config {
|
||||
(
|
||||
row.get::<bool, _>("enabled"),
|
||||
row.get::<Option<String>, _>("base_url"),
|
||||
row.get::<Option<String>, _>("api_key"),
|
||||
)
|
||||
} else {
|
||||
// No database override, use defaults from AppConfig
|
||||
match name {
|
||||
"openai" => (
|
||||
app_config.providers.openai.enabled,
|
||||
Some(app_config.providers.openai.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
"gemini" => (
|
||||
app_config.providers.gemini.enabled,
|
||||
Some(app_config.providers.gemini.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
"deepseek" => (
|
||||
app_config.providers.deepseek.enabled,
|
||||
Some(app_config.providers.deepseek.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
"grok" => (
|
||||
app_config.providers.grok.enabled,
|
||||
Some(app_config.providers.grok.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
"ollama" => (
|
||||
app_config.providers.ollama.enabled,
|
||||
Some(app_config.providers.ollama.base_url.clone()),
|
||||
None,
|
||||
),
|
||||
_ => (false, None, None),
|
||||
}
|
||||
};
|
||||
|
||||
if !enabled {
|
||||
self.remove_provider(name).await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Create provider instance with merged config
|
||||
let provider: Arc<dyn Provider> = match name {
|
||||
"openai" => {
|
||||
let mut cfg = app_config.providers.openai.clone();
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
// Handle API key override if present
|
||||
let p = if let Some(key) = api_key {
|
||||
// We need a way to create a provider with an explicit key
|
||||
// Let's modify the providers to allow this
|
||||
OpenAIProvider::new_with_key(&cfg, app_config, key)?
|
||||
} else {
|
||||
OpenAIProvider::new(&cfg, app_config)?
|
||||
};
|
||||
Arc::new(p)
|
||||
}
|
||||
"ollama" => {
|
||||
let mut cfg = app_config.providers.ollama.clone();
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
Arc::new(OllamaProvider::new(&cfg, app_config)?)
|
||||
}
|
||||
"gemini" => {
|
||||
let mut cfg = app_config.providers.gemini.clone();
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
let p = if let Some(key) = api_key {
|
||||
GeminiProvider::new_with_key(&cfg, app_config, key)?
|
||||
} else {
|
||||
GeminiProvider::new(&cfg, app_config)?
|
||||
};
|
||||
Arc::new(p)
|
||||
}
|
||||
"deepseek" => {
|
||||
let mut cfg = app_config.providers.deepseek.clone();
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
let p = if let Some(key) = api_key {
|
||||
DeepSeekProvider::new_with_key(&cfg, app_config, key)?
|
||||
} else {
|
||||
DeepSeekProvider::new(&cfg, app_config)?
|
||||
};
|
||||
Arc::new(p)
|
||||
}
|
||||
"grok" => {
|
||||
let mut cfg = app_config.providers.grok.clone();
|
||||
if let Some(url) = base_url {
|
||||
cfg.base_url = url;
|
||||
}
|
||||
let p = if let Some(key) = api_key {
|
||||
GrokProvider::new_with_key(&cfg, app_config, key)?
|
||||
} else {
|
||||
GrokProvider::new(&cfg, app_config)?
|
||||
};
|
||||
Arc::new(p)
|
||||
}
|
||||
_ => return Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
||||
};
|
||||
|
||||
self.add_provider(provider).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn add_provider(&self, provider: Arc<dyn Provider>) {
|
||||
let mut providers = self.providers.write().await;
|
||||
// If provider with same name exists, replace it
|
||||
if let Some(index) = providers.iter().position(|p| p.name() == provider.name()) {
|
||||
providers[index] = provider;
|
||||
} else {
|
||||
providers.push(provider);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn remove_provider(&self, name: &str) {
|
||||
let mut providers = self.providers.write().await;
|
||||
providers.retain(|p| p.name() != name);
|
||||
}
|
||||
|
||||
pub async fn get_provider_for_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
|
||||
let providers = self.providers.read().await;
|
||||
providers.iter().find(|p| p.supports_model(model)).map(Arc::clone)
|
||||
}
|
||||
|
||||
pub async fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
|
||||
let providers = self.providers.read().await;
|
||||
providers.iter().find(|p| p.name() == name).map(Arc::clone)
|
||||
}
|
||||
|
||||
pub async fn get_all_providers(&self) -> Vec<Arc<dyn Provider>> {
|
||||
let providers = self.providers.read().await;
|
||||
providers.clone()
|
||||
}
|
||||
}
|
||||
|
||||
// Create placeholder provider implementations
|
||||
pub mod placeholder {
|
||||
use super::*;
|
||||
|
||||
pub struct PlaceholderProvider {
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl PlaceholderProvider {
|
||||
pub fn new(name: &str) -> Self {
|
||||
Self { name: name.to_string() }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for PlaceholderProvider {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn supports_model(&self, _model: &str) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn supports_multimodal(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
async fn chat_completion_stream(
|
||||
&self,
|
||||
_request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
Err(AppError::ProviderError(
|
||||
"Streaming not supported for placeholder provider".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn chat_completion(&self, _request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
Err(AppError::ProviderError(format!(
|
||||
"Provider {} not implemented",
|
||||
self.name
|
||||
)))
|
||||
}
|
||||
|
||||
fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result<u32> {
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
_model: &str,
|
||||
_prompt_tokens: u32,
|
||||
_completion_tokens: u32,
|
||||
_cache_read_tokens: u32,
|
||||
_cache_write_tokens: u32,
|
||||
_registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,140 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
|
||||
use super::helpers;
|
||||
use super::{ProviderResponse, ProviderStreamChunk};
|
||||
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
||||
|
||||
pub struct OllamaProvider {
|
||||
client: reqwest::Client,
|
||||
config: crate::config::OllamaConfig,
|
||||
pricing: Vec<crate::config::ModelPricing>,
|
||||
}
|
||||
|
||||
impl OllamaProvider {
|
||||
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||
.pool_max_idle_per_host(4)
|
||||
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
config: config.clone(),
|
||||
pricing: app_config.pricing.ollama.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Provider for OllamaProvider {
|
||||
fn name(&self) -> &str {
|
||||
"ollama"
|
||||
}
|
||||
|
||||
fn supports_model(&self, model: &str) -> bool {
|
||||
self.config.models.iter().any(|m| m == model) || model.starts_with("ollama/")
|
||||
}
|
||||
|
||||
fn supports_multimodal(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn chat_completion(&self, mut request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
// Strip "ollama/" prefix if present for the API call
|
||||
let api_model = request
|
||||
.model
|
||||
.strip_prefix("ollama/")
|
||||
.unwrap_or(&request.model)
|
||||
.to_string();
|
||||
let original_model = request.model.clone();
|
||||
request.model = api_model;
|
||||
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, false);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
return Err(AppError::ProviderError(format!("Ollama API error: {}", error_text)));
|
||||
}
|
||||
|
||||
let resp_json: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
// Ollama also supports "thought" as an alias for reasoning_content
|
||||
let mut result = helpers::parse_openai_response(&resp_json, original_model)?;
|
||||
if result.reasoning_content.is_none() {
|
||||
result.reasoning_content = resp_json["choices"]
|
||||
.get(0)
|
||||
.and_then(|c| c["message"]["thought"].as_str())
|
||||
.map(|s| s.to_string());
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||
}
|
||||
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
cache_read_tokens: u32,
|
||||
cache_write_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
helpers::calculate_cost_with_registry(
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
registry,
|
||||
&self.pricing,
|
||||
0.0,
|
||||
0.0,
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_completion_stream(
|
||||
&self,
|
||||
mut request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
let api_model = request
|
||||
.model
|
||||
.strip_prefix("ollama/")
|
||||
.unwrap_or(&request.model)
|
||||
.to_string();
|
||||
let original_model = request.model.clone();
|
||||
request.model = api_model;
|
||||
|
||||
let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, true);
|
||||
|
||||
let es = reqwest_eventsource::EventSource::new(
|
||||
self.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.json(&body),
|
||||
)
|
||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
// Ollama uses "thought" as an alternative field for reasoning content
|
||||
Ok(helpers::create_openai_stream(es, original_model, Some("thought")))
|
||||
}
|
||||
}
|
||||
@@ -1,342 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
|
||||
use super::helpers;
|
||||
use super::{ProviderResponse, ProviderStreamChunk};
|
||||
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
||||
|
||||
pub struct OpenAIProvider {
|
||||
client: reqwest::Client,
|
||||
config: crate::config::OpenAIConfig,
|
||||
api_key: String,
|
||||
pricing: Vec<crate::config::ModelPricing>,
|
||||
}
|
||||
|
||||
impl OpenAIProvider {
|
||||
pub fn new(config: &crate::config::OpenAIConfig, app_config: &AppConfig) -> Result<Self> {
|
||||
let api_key = app_config.get_api_key("openai")?;
|
||||
Self::new_with_key(config, app_config, api_key)
|
||||
}
|
||||
|
||||
pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
||||
let client = reqwest::Client::builder()
|
||||
.connect_timeout(std::time::Duration::from_secs(5))
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
||||
.pool_max_idle_per_host(4)
|
||||
.tcp_keepalive(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
config: config.clone(),
|
||||
api_key,
|
||||
pricing: app_config.pricing.openai.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::Provider for OpenAIProvider {
|
||||
fn name(&self) -> &str {
|
||||
"openai"
|
||||
}
|
||||
|
||||
fn supports_model(&self, model: &str) -> bool {
|
||||
model.starts_with("gpt-") || model.starts_with("o1-") || model.starts_with("o3-") || model.starts_with("o4-")
|
||||
}
|
||||
|
||||
fn supports_multimodal(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, false);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
// Read error body to diagnose. If the model requires the Responses
|
||||
// API (v1/responses), retry against that endpoint.
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
if error_text.to_lowercase().contains("v1/responses") || error_text.to_lowercase().contains("only supported in v1/responses") {
|
||||
// Build a simple `input` string by concatenating message parts.
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let mut inputs: Vec<String> = Vec::new();
|
||||
for m in &messages_json {
|
||||
let role = m["role"].as_str().unwrap_or("");
|
||||
let parts = m.get("content").and_then(|c| c.as_array()).cloned().unwrap_or_default();
|
||||
let mut text_parts = Vec::new();
|
||||
for p in parts {
|
||||
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
|
||||
text_parts.push(t.to_string());
|
||||
}
|
||||
}
|
||||
inputs.push(format!("{}: {}", role, text_parts.join("")));
|
||||
}
|
||||
let input_text = inputs.join("\n");
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(format!("{}/responses", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&serde_json::json!({ "model": request.model, "input": input_text }))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let err = resp.text().await.unwrap_or_default();
|
||||
return Err(AppError::ProviderError(format!("OpenAI Responses API error: {}", err)));
|
||||
}
|
||||
|
||||
let resp_json: serde_json::Value = resp.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
// Try to normalize: if it's chat-style, use existing parser
|
||||
if resp_json.get("choices").is_some() {
|
||||
return helpers::parse_openai_response(&resp_json, request.model);
|
||||
}
|
||||
|
||||
// Responses API: try to extract text from `output` or `candidates`
|
||||
// output -> [{"content": [{"type":..., "text": "..."}, ...]}]
|
||||
let mut content_text = String::new();
|
||||
if let Some(output) = resp_json.get("output").and_then(|o| o.as_array()) {
|
||||
if let Some(first) = output.get(0) {
|
||||
if let Some(contents) = first.get("content").and_then(|c| c.as_array()) {
|
||||
for item in contents {
|
||||
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
|
||||
if !content_text.is_empty() {
|
||||
content_text.push_str("\n");
|
||||
}
|
||||
content_text.push_str(text);
|
||||
} else if let Some(parts) = item.get("parts").and_then(|p| p.as_array()) {
|
||||
for p in parts {
|
||||
if let Some(t) = p.as_str() {
|
||||
if !content_text.is_empty() { content_text.push_str("\n"); }
|
||||
content_text.push_str(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: check `candidates` -> candidate.content.parts.text
|
||||
if content_text.is_empty() {
|
||||
if let Some(cands) = resp_json.get("candidates").and_then(|c| c.as_array()) {
|
||||
if let Some(c0) = cands.get(0) {
|
||||
if let Some(content) = c0.get("content") {
|
||||
if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) {
|
||||
for p in parts {
|
||||
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
|
||||
if !content_text.is_empty() { content_text.push_str("\n"); }
|
||||
content_text.push_str(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract simple usage if present
|
||||
let prompt_tokens = resp_json.get("usage").and_then(|u| u.get("prompt_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
||||
let completion_tokens = resp_json.get("usage").and_then(|u| u.get("completion_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
||||
let total_tokens = resp_json.get("usage").and_then(|u| u.get("total_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
||||
|
||||
return Ok(ProviderResponse {
|
||||
content: content_text,
|
||||
reasoning_content: None,
|
||||
tool_calls: None,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cache_read_tokens: 0,
|
||||
cache_write_tokens: 0,
|
||||
model: request.model,
|
||||
});
|
||||
}
|
||||
|
||||
return Err(AppError::ProviderError(format!("OpenAI API error: {}", error_text)));
|
||||
}
|
||||
|
||||
let resp_json: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
helpers::parse_openai_response(&resp_json, request.model)
|
||||
}
|
||||
|
||||
async fn chat_responses(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
||||
// Build a simple `input` string by concatenating message parts.
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let mut inputs: Vec<String> = Vec::new();
|
||||
for m in &messages_json {
|
||||
let role = m["role"].as_str().unwrap_or("");
|
||||
let parts = m.get("content").and_then(|c| c.as_array()).cloned().unwrap_or_default();
|
||||
let mut text_parts = Vec::new();
|
||||
for p in parts {
|
||||
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
|
||||
text_parts.push(t.to_string());
|
||||
}
|
||||
}
|
||||
inputs.push(format!("{}: {}", role, text_parts.join("")));
|
||||
}
|
||||
let input_text = inputs.join("\n");
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(format!("{}/responses", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&serde_json::json!({ "model": request.model, "input": input_text }))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let err = resp.text().await.unwrap_or_default();
|
||||
return Err(AppError::ProviderError(format!("OpenAI Responses API error: {}", err)));
|
||||
}
|
||||
|
||||
let resp_json: serde_json::Value = resp.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
||||
|
||||
// Normalize Responses API output into ProviderResponse
|
||||
let mut content_text = String::new();
|
||||
if let Some(output) = resp_json.get("output").and_then(|o| o.as_array()) {
|
||||
if let Some(first) = output.get(0) {
|
||||
if let Some(contents) = first.get("content").and_then(|c| c.as_array()) {
|
||||
for item in contents {
|
||||
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
|
||||
if !content_text.is_empty() { content_text.push_str("\n"); }
|
||||
content_text.push_str(text);
|
||||
} else if let Some(parts) = item.get("parts").and_then(|p| p.as_array()) {
|
||||
for p in parts {
|
||||
if let Some(t) = p.as_str() {
|
||||
if !content_text.is_empty() { content_text.push_str("\n"); }
|
||||
content_text.push_str(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if content_text.is_empty() {
|
||||
if let Some(cands) = resp_json.get("candidates").and_then(|c| c.as_array()) {
|
||||
if let Some(c0) = cands.get(0) {
|
||||
if let Some(content) = c0.get("content") {
|
||||
if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) {
|
||||
for p in parts {
|
||||
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
|
||||
if !content_text.is_empty() { content_text.push_str("\n"); }
|
||||
content_text.push_str(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let prompt_tokens = resp_json.get("usage").and_then(|u| u.get("prompt_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
||||
let completion_tokens = resp_json.get("usage").and_then(|u| u.get("completion_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
||||
let total_tokens = resp_json.get("usage").and_then(|u| u.get("total_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
||||
|
||||
Ok(ProviderResponse {
|
||||
content: content_text,
|
||||
reasoning_content: None,
|
||||
tool_calls: None,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cache_read_tokens: 0,
|
||||
cache_write_tokens: 0,
|
||||
model: request.model,
|
||||
})
|
||||
}
|
||||
|
||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
||||
}
|
||||
|
||||
fn calculate_cost(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
cache_read_tokens: u32,
|
||||
cache_write_tokens: u32,
|
||||
registry: &crate::models::registry::ModelRegistry,
|
||||
) -> f64 {
|
||||
helpers::calculate_cost_with_registry(
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
registry,
|
||||
&self.pricing,
|
||||
0.15,
|
||||
0.60,
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_completion_stream(
|
||||
&self,
|
||||
request: UnifiedRequest,
|
||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
||||
let body = helpers::build_openai_body(&request, messages_json, true);
|
||||
|
||||
// Try to create an EventSource for streaming; if creation fails or
|
||||
// the stream errors, fall back to a single synchronous request and
|
||||
// emit its result as a single chunk.
|
||||
let es_result = reqwest_eventsource::EventSource::new(
|
||||
self.client
|
||||
.post(format!("{}/chat/completions", self.config.base_url))
|
||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||
.json(&body),
|
||||
);
|
||||
|
||||
if es_result.is_err() {
|
||||
// Fallback to non-streaming request which itself may retry to
|
||||
// Responses API if necessary (handled in chat_completion).
|
||||
let resp = self.chat_completion(request.clone()).await?;
|
||||
let single_stream = async_stream::try_stream! {
|
||||
let chunk = ProviderStreamChunk {
|
||||
content: resp.content,
|
||||
reasoning_content: resp.reasoning_content,
|
||||
finish_reason: Some("stop".to_string()),
|
||||
tool_calls: None,
|
||||
model: resp.model.clone(),
|
||||
usage: Some(super::StreamUsage {
|
||||
prompt_tokens: resp.prompt_tokens,
|
||||
completion_tokens: resp.completion_tokens,
|
||||
total_tokens: resp.total_tokens,
|
||||
cache_read_tokens: resp.cache_read_tokens,
|
||||
cache_write_tokens: resp.cache_write_tokens,
|
||||
}),
|
||||
};
|
||||
|
||||
yield chunk;
|
||||
};
|
||||
|
||||
return Ok(Box::pin(single_stream));
|
||||
}
|
||||
|
||||
let es = es_result.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
||||
|
||||
Ok(helpers::create_openai_stream(es, request.model, None))
|
||||
}
|
||||
}
|
||||
@@ -1,369 +0,0 @@
|
||||
//! Rate limiting and circuit breaking for LLM proxy
|
||||
//!
|
||||
//! This module provides:
|
||||
//! 1. Per-client rate limiting using governor crate
|
||||
//! 2. Provider circuit breaking to handle API failures
|
||||
//! 3. Global rate limiting for overall system protection
|
||||
|
||||
use anyhow::Result;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Rate limiter configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RateLimiterConfig {
|
||||
/// Requests per minute per client
|
||||
pub requests_per_minute: u32,
|
||||
/// Burst size (maximum burst capacity)
|
||||
pub burst_size: u32,
|
||||
/// Global requests per minute (across all clients)
|
||||
pub global_requests_per_minute: u32,
|
||||
}
|
||||
|
||||
impl Default for RateLimiterConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
requests_per_minute: 60, // 1 request per second per client
|
||||
burst_size: 10, // Allow bursts of up to 10 requests
|
||||
global_requests_per_minute: 600, // 10 requests per second globally
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Circuit breaker state
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum CircuitState {
|
||||
Closed, // Normal operation
|
||||
Open, // Circuit is open, requests fail fast
|
||||
HalfOpen, // Testing if service has recovered
|
||||
}
|
||||
|
||||
/// Circuit breaker configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CircuitBreakerConfig {
|
||||
/// Failure threshold to open circuit
|
||||
pub failure_threshold: u32,
|
||||
/// Time window for failure counting (seconds)
|
||||
pub failure_window_secs: u64,
|
||||
/// Time to wait before trying half-open state (seconds)
|
||||
pub reset_timeout_secs: u64,
|
||||
/// Success threshold to close circuit
|
||||
pub success_threshold: u32,
|
||||
}
|
||||
|
||||
impl Default for CircuitBreakerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
failure_threshold: 5, // 5 failures
|
||||
failure_window_secs: 60, // within 60 seconds
|
||||
reset_timeout_secs: 30, // wait 30 seconds before half-open
|
||||
success_threshold: 3, // 3 successes to close circuit
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple token bucket rate limiter for a single client
|
||||
#[derive(Debug)]
|
||||
struct TokenBucket {
|
||||
tokens: f64,
|
||||
capacity: f64,
|
||||
refill_rate: f64, // tokens per second
|
||||
last_refill: Instant,
|
||||
}
|
||||
|
||||
impl TokenBucket {
|
||||
fn new(capacity: f64, refill_rate: f64) -> Self {
|
||||
Self {
|
||||
tokens: capacity,
|
||||
capacity,
|
||||
refill_rate,
|
||||
last_refill: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
fn refill(&mut self) {
|
||||
let now = Instant::now();
|
||||
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
|
||||
let new_tokens = elapsed * self.refill_rate;
|
||||
|
||||
self.tokens = (self.tokens + new_tokens).min(self.capacity);
|
||||
self.last_refill = now;
|
||||
}
|
||||
|
||||
fn try_acquire(&mut self, tokens: f64) -> bool {
|
||||
self.refill();
|
||||
|
||||
if self.tokens >= tokens {
|
||||
self.tokens -= tokens;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Circuit breaker for a provider
|
||||
#[derive(Debug)]
|
||||
pub struct ProviderCircuitBreaker {
|
||||
state: CircuitState,
|
||||
failure_count: u32,
|
||||
success_count: u32,
|
||||
last_failure_time: Option<std::time::Instant>,
|
||||
last_state_change: std::time::Instant,
|
||||
config: CircuitBreakerConfig,
|
||||
}
|
||||
|
||||
impl ProviderCircuitBreaker {
|
||||
pub fn new(config: CircuitBreakerConfig) -> Self {
|
||||
Self {
|
||||
state: CircuitState::Closed,
|
||||
failure_count: 0,
|
||||
success_count: 0,
|
||||
last_failure_time: None,
|
||||
last_state_change: std::time::Instant::now(),
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if request is allowed
|
||||
pub fn allow_request(&mut self) -> bool {
|
||||
match self.state {
|
||||
CircuitState::Closed => true,
|
||||
CircuitState::Open => {
|
||||
// Check if reset timeout has passed
|
||||
let elapsed = self.last_state_change.elapsed();
|
||||
if elapsed.as_secs() >= self.config.reset_timeout_secs {
|
||||
self.state = CircuitState::HalfOpen;
|
||||
self.last_state_change = std::time::Instant::now();
|
||||
info!("Circuit breaker transitioning to half-open state");
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
CircuitState::HalfOpen => true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a successful request
|
||||
pub fn record_success(&mut self) {
|
||||
match self.state {
|
||||
CircuitState::Closed => {
|
||||
// Reset failure count on success
|
||||
self.failure_count = 0;
|
||||
self.last_failure_time = None;
|
||||
}
|
||||
CircuitState::HalfOpen => {
|
||||
self.success_count += 1;
|
||||
if self.success_count >= self.config.success_threshold {
|
||||
self.state = CircuitState::Closed;
|
||||
self.success_count = 0;
|
||||
self.failure_count = 0;
|
||||
self.last_state_change = std::time::Instant::now();
|
||||
info!("Circuit breaker closed after successful requests");
|
||||
}
|
||||
}
|
||||
CircuitState::Open => {
|
||||
// Should not happen, but handle gracefully
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a failed request
|
||||
pub fn record_failure(&mut self) {
|
||||
let now = std::time::Instant::now();
|
||||
|
||||
// Check if failure window has expired
|
||||
if let Some(last_failure) = self.last_failure_time
|
||||
&& now.duration_since(last_failure).as_secs() > self.config.failure_window_secs
|
||||
{
|
||||
// Reset failure count if window expired
|
||||
self.failure_count = 0;
|
||||
}
|
||||
|
||||
self.failure_count += 1;
|
||||
self.last_failure_time = Some(now);
|
||||
|
||||
if self.failure_count >= self.config.failure_threshold && self.state == CircuitState::Closed {
|
||||
self.state = CircuitState::Open;
|
||||
self.last_state_change = now;
|
||||
warn!("Circuit breaker opened due to {} failures", self.failure_count);
|
||||
} else if self.state == CircuitState::HalfOpen {
|
||||
// Failure in half-open state, go back to open
|
||||
self.state = CircuitState::Open;
|
||||
self.success_count = 0;
|
||||
self.last_state_change = now;
|
||||
warn!("Circuit breaker re-opened after failure in half-open state");
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current state
|
||||
pub fn state(&self) -> CircuitState {
|
||||
self.state
|
||||
}
|
||||
}
|
||||
|
||||
/// Rate limiting and circuit breaking manager
|
||||
#[derive(Debug)]
|
||||
pub struct RateLimitManager {
|
||||
client_buckets: Arc<RwLock<HashMap<String, TokenBucket>>>,
|
||||
global_bucket: Arc<RwLock<TokenBucket>>,
|
||||
circuit_breakers: Arc<RwLock<HashMap<String, ProviderCircuitBreaker>>>,
|
||||
config: RateLimiterConfig,
|
||||
circuit_config: CircuitBreakerConfig,
|
||||
}
|
||||
|
||||
impl RateLimitManager {
|
||||
pub fn new(config: RateLimiterConfig, circuit_config: CircuitBreakerConfig) -> Self {
|
||||
// Convert requests per minute to tokens per second
|
||||
let global_refill_rate = config.global_requests_per_minute as f64 / 60.0;
|
||||
|
||||
Self {
|
||||
client_buckets: Arc::new(RwLock::new(HashMap::new())),
|
||||
global_bucket: Arc::new(RwLock::new(TokenBucket::new(
|
||||
config.burst_size as f64,
|
||||
global_refill_rate,
|
||||
))),
|
||||
circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
|
||||
config,
|
||||
circuit_config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a client request is allowed
|
||||
pub async fn check_client_request(&self, client_id: &str) -> Result<bool> {
|
||||
// Check global rate limit first (1 token per request)
|
||||
{
|
||||
let mut global_bucket = self.global_bucket.write().await;
|
||||
if !global_bucket.try_acquire(1.0) {
|
||||
warn!("Global rate limit exceeded");
|
||||
return Ok(false);
|
||||
}
|
||||
}
|
||||
|
||||
// Check client-specific rate limit
|
||||
let mut buckets = self.client_buckets.write().await;
|
||||
let bucket = buckets.entry(client_id.to_string()).or_insert_with(|| {
|
||||
TokenBucket::new(
|
||||
self.config.burst_size as f64,
|
||||
self.config.requests_per_minute as f64 / 60.0,
|
||||
)
|
||||
});
|
||||
|
||||
Ok(bucket.try_acquire(1.0))
|
||||
}
|
||||
|
||||
/// Check if provider requests are allowed (circuit breaker)
|
||||
pub async fn check_provider_request(&self, provider_name: &str) -> Result<bool> {
|
||||
let mut breakers = self.circuit_breakers.write().await;
|
||||
let breaker = breakers
|
||||
.entry(provider_name.to_string())
|
||||
.or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone()));
|
||||
|
||||
Ok(breaker.allow_request())
|
||||
}
|
||||
|
||||
/// Record provider success
|
||||
pub async fn record_provider_success(&self, provider_name: &str) {
|
||||
let mut breakers = self.circuit_breakers.write().await;
|
||||
if let Some(breaker) = breakers.get_mut(provider_name) {
|
||||
breaker.record_success();
|
||||
}
|
||||
}
|
||||
|
||||
/// Record provider failure
|
||||
pub async fn record_provider_failure(&self, provider_name: &str) {
|
||||
let mut breakers = self.circuit_breakers.write().await;
|
||||
let breaker = breakers
|
||||
.entry(provider_name.to_string())
|
||||
.or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone()));
|
||||
|
||||
breaker.record_failure();
|
||||
}
|
||||
|
||||
/// Get provider circuit state
|
||||
pub async fn get_provider_state(&self, provider_name: &str) -> CircuitState {
|
||||
let breakers = self.circuit_breakers.read().await;
|
||||
breakers
|
||||
.get(provider_name)
|
||||
.map(|b| b.state())
|
||||
.unwrap_or(CircuitState::Closed)
|
||||
}
|
||||
}
|
||||
|
||||
/// Axum middleware for rate limiting
|
||||
pub mod middleware {
|
||||
use super::*;
|
||||
use crate::errors::AppError;
|
||||
use crate::state::AppState;
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use sqlx;
|
||||
|
||||
/// Rate limiting middleware
|
||||
pub async fn rate_limit_middleware(
|
||||
State(state): State<AppState>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, AppError> {
|
||||
// Extract token synchronously from headers (avoids holding &Request across await)
|
||||
let token = extract_bearer_token(&request);
|
||||
|
||||
// Resolve client_id: DB token lookup, then prefix fallback
|
||||
let client_id = resolve_client_id(token, &state).await;
|
||||
|
||||
// Check rate limits
|
||||
if !state.rate_limit_manager.check_client_request(&client_id).await? {
|
||||
return Err(AppError::RateLimitError("Rate limit exceeded".to_string()));
|
||||
}
|
||||
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
|
||||
/// Synchronously extract bearer token from request headers
|
||||
fn extract_bearer_token(request: &Request) -> Option<String> {
|
||||
request.headers().get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.strip_prefix("Bearer "))
|
||||
.map(|t| t.to_string())
|
||||
}
|
||||
|
||||
/// Resolve client ID: try DB token first, then fall back to token-prefix derivation
|
||||
async fn resolve_client_id(token: Option<String>, state: &AppState) -> String {
|
||||
if let Some(token) = token {
|
||||
// Try DB token lookup first
|
||||
if let Ok(Some(cid)) = sqlx::query_scalar::<_, String>(
|
||||
"SELECT client_id FROM client_tokens WHERE token = ? AND is_active = TRUE",
|
||||
)
|
||||
.bind(&token)
|
||||
.fetch_optional(&state.db_pool)
|
||||
.await
|
||||
{
|
||||
return cid;
|
||||
}
|
||||
|
||||
// Fallback to token-prefix derivation (env tokens / permissive mode)
|
||||
return format!("client_{}", &token[..8.min(token.len())]);
|
||||
}
|
||||
|
||||
// No token — anonymous
|
||||
"anonymous".to_string()
|
||||
}
|
||||
|
||||
/// Circuit breaker middleware for provider requests
|
||||
pub async fn circuit_breaker_middleware(provider_name: &str, state: &AppState) -> Result<(), AppError> {
|
||||
if !state.rate_limit_manager.check_provider_request(provider_name).await? {
|
||||
return Err(AppError::ProviderError(format!(
|
||||
"Provider {} is currently unavailable (circuit breaker open)",
|
||||
provider_name
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,445 +0,0 @@
|
||||
use axum::{
|
||||
Json, Router,
|
||||
extract::State,
|
||||
response::IntoResponse,
|
||||
response::sse::{Event, Sse},
|
||||
routing::{get, post},
|
||||
};
|
||||
|
||||
use futures::StreamExt;
|
||||
use sqlx;
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{
|
||||
auth::AuthenticatedClient,
|
||||
errors::AppError,
|
||||
models::{
|
||||
ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage,
|
||||
ChatStreamChoice, ChatStreamDelta, Usage,
|
||||
},
|
||||
rate_limiting,
|
||||
state::AppState,
|
||||
};
|
||||
|
||||
pub fn router(state: AppState) -> Router {
|
||||
Router::new()
|
||||
.route("/v1/chat/completions", post(chat_completions))
|
||||
.route("/v1/models", get(list_models))
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
rate_limiting::middleware::rate_limit_middleware,
|
||||
))
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
|
||||
/// GET /v1/models — OpenAI-compatible model listing.
|
||||
/// Returns all models from enabled providers so clients like Open WebUI can
|
||||
/// discover which models are available through the proxy.
|
||||
async fn list_models(
|
||||
State(state): State<AppState>,
|
||||
_auth: AuthenticatedClient,
|
||||
) -> Result<Json<serde_json::Value>, AppError> {
|
||||
let registry = &state.model_registry;
|
||||
let providers = state.provider_manager.get_all_providers().await;
|
||||
|
||||
let mut models = Vec::new();
|
||||
|
||||
for provider in &providers {
|
||||
let provider_name = provider.name();
|
||||
|
||||
// Map internal provider names to registry provider IDs
|
||||
let registry_key = match provider_name {
|
||||
"gemini" => "google",
|
||||
"grok" => "xai",
|
||||
_ => provider_name,
|
||||
};
|
||||
|
||||
// Find this provider's models in the registry
|
||||
if let Some(provider_info) = registry.providers.get(registry_key) {
|
||||
for (model_id, meta) in &provider_info.models {
|
||||
// Skip disabled models via the config cache
|
||||
if let Some(cfg) = state.model_config_cache.get(model_id).await {
|
||||
if !cfg.enabled {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
models.push(serde_json::json!({
|
||||
"id": model_id,
|
||||
"object": "model",
|
||||
"created": 0,
|
||||
"owned_by": provider_name,
|
||||
"name": meta.name,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// For Ollama, models are configured in the TOML, not the registry
|
||||
if provider_name == "ollama" {
|
||||
for model_id in &state.config.providers.ollama.models {
|
||||
models.push(serde_json::json!({
|
||||
"id": model_id,
|
||||
"object": "model",
|
||||
"created": 0,
|
||||
"owned_by": "ollama",
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"object": "list",
|
||||
"data": models
|
||||
})))
|
||||
}
|
||||
|
||||
async fn get_model_cost(
|
||||
model: &str,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
cache_read_tokens: u32,
|
||||
cache_write_tokens: u32,
|
||||
provider: &Arc<dyn crate::providers::Provider>,
|
||||
state: &AppState,
|
||||
) -> f64 {
|
||||
// Check in-memory cache for cost overrides (no SQLite hit)
|
||||
if let Some(cached) = state.model_config_cache.get(model).await {
|
||||
if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) {
|
||||
// Manual overrides don't have cache-specific rates, so use simple formula
|
||||
return (prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to provider's registry-based calculation (cache-aware)
|
||||
provider.calculate_cost(model, prompt_tokens, completion_tokens, cache_read_tokens, cache_write_tokens, &state.model_registry)
|
||||
}
|
||||
|
||||
async fn chat_completions(
|
||||
State(state): State<AppState>,
|
||||
auth: AuthenticatedClient,
|
||||
Json(mut request): Json<ChatCompletionRequest>,
|
||||
) -> Result<axum::response::Response, AppError> {
|
||||
// Resolve client_id: try DB token first, then env tokens, then permissive fallback
|
||||
let db_client_id: Option<String> = sqlx::query_scalar::<_, String>(
|
||||
"SELECT client_id FROM client_tokens WHERE token = ? AND is_active = TRUE",
|
||||
)
|
||||
.bind(&auth.token)
|
||||
.fetch_optional(&state.db_pool)
|
||||
.await
|
||||
.unwrap_or(None);
|
||||
|
||||
let client_id = if let Some(cid) = db_client_id {
|
||||
// Update last_used_at in background (fire-and-forget)
|
||||
let pool = state.db_pool.clone();
|
||||
let token = auth.token.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = sqlx::query("UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ?")
|
||||
.bind(&token)
|
||||
.execute(&pool)
|
||||
.await;
|
||||
});
|
||||
cid
|
||||
} else if state.auth_tokens.is_empty() || state.auth_tokens.contains(&auth.token) {
|
||||
// Env token match or permissive mode (no env tokens configured)
|
||||
auth.client_id.clone()
|
||||
} else {
|
||||
return Err(AppError::AuthError("Invalid authentication token".to_string()));
|
||||
};
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
let model = request.model.clone();
|
||||
|
||||
info!("Chat completion request from client {} for model {}", client_id, model);
|
||||
|
||||
// Check if model is enabled via in-memory cache (no SQLite hit)
|
||||
let cached_config = state.model_config_cache.get(&model).await;
|
||||
|
||||
let (model_enabled, model_mapping) = match cached_config {
|
||||
Some(cfg) => (cfg.enabled, cfg.mapping),
|
||||
None => (true, None),
|
||||
};
|
||||
|
||||
if !model_enabled {
|
||||
return Err(AppError::ValidationError(format!(
|
||||
"Model {} is currently disabled",
|
||||
model
|
||||
)));
|
||||
}
|
||||
|
||||
// Apply mapping if present
|
||||
if let Some(target_model) = model_mapping {
|
||||
info!("Mapping model {} to {}", model, target_model);
|
||||
request.model = target_model;
|
||||
}
|
||||
|
||||
// Find appropriate provider for the model
|
||||
let provider = state
|
||||
.provider_manager
|
||||
.get_provider_for_model(&request.model)
|
||||
.await
|
||||
.ok_or_else(|| AppError::ProviderError(format!("No provider found for model: {}", request.model)))?;
|
||||
|
||||
let provider_name = provider.name().to_string();
|
||||
|
||||
// Check circuit breaker for this provider
|
||||
rate_limiting::middleware::circuit_breaker_middleware(&provider_name, &state).await?;
|
||||
|
||||
// Convert to unified request format
|
||||
let mut unified_request =
|
||||
crate::models::UnifiedRequest::try_from(request).map_err(|e| AppError::ValidationError(e.to_string()))?;
|
||||
|
||||
// Set client_id from authentication
|
||||
unified_request.client_id = client_id.clone();
|
||||
|
||||
// Hydrate images if present
|
||||
if unified_request.has_images {
|
||||
unified_request
|
||||
.hydrate_images()
|
||||
.await
|
||||
.map_err(|e| AppError::ValidationError(format!("Failed to process images: {}", e)))?;
|
||||
}
|
||||
|
||||
let has_images = unified_request.has_images;
|
||||
|
||||
// Measure proxy overhead (time spent before sending to upstream provider)
|
||||
let proxy_overhead = start_time.elapsed();
|
||||
|
||||
// Check if streaming is requested
|
||||
if unified_request.stream {
|
||||
// Estimate prompt tokens for logging later
|
||||
let prompt_tokens = crate::utils::tokens::estimate_request_tokens(&model, &unified_request);
|
||||
|
||||
// Handle streaming response
|
||||
let stream_result = provider.chat_completion_stream(unified_request).await;
|
||||
|
||||
match stream_result {
|
||||
Ok(stream) => {
|
||||
// Record provider success
|
||||
state.rate_limit_manager.record_provider_success(&provider_name).await;
|
||||
|
||||
info!(
|
||||
"Streaming started for {} (proxy overhead: {}ms)",
|
||||
model,
|
||||
proxy_overhead.as_millis()
|
||||
);
|
||||
|
||||
// Wrap with AggregatingStream for token counting and database logging
|
||||
let aggregating_stream = crate::utils::streaming::AggregatingStream::new(
|
||||
stream,
|
||||
crate::utils::streaming::StreamConfig {
|
||||
client_id: client_id.clone(),
|
||||
provider: provider.clone(),
|
||||
model: model.clone(),
|
||||
prompt_tokens,
|
||||
has_images,
|
||||
logger: state.request_logger.clone(),
|
||||
client_manager: state.client_manager.clone(),
|
||||
model_registry: state.model_registry.clone(),
|
||||
model_config_cache: state.model_config_cache.clone(),
|
||||
},
|
||||
);
|
||||
|
||||
// Create SSE stream - simpler approach that works
|
||||
let stream_id = format!("chatcmpl-{}", Uuid::new_v4());
|
||||
let stream_created = chrono::Utc::now().timestamp() as u64;
|
||||
let stream_id_sse = stream_id.clone();
|
||||
|
||||
// Build stream that yields events wrapped in Result
|
||||
let stream = async_stream::stream! {
|
||||
let mut aggregator = Box::pin(aggregating_stream);
|
||||
let mut first_chunk = true;
|
||||
|
||||
while let Some(chunk_result) = aggregator.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
let role = if first_chunk {
|
||||
first_chunk = false;
|
||||
Some("assistant".to_string())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let response = ChatCompletionStreamResponse {
|
||||
id: stream_id_sse.clone(),
|
||||
object: "chat.completion.chunk".to_string(),
|
||||
created: stream_created,
|
||||
model: chunk.model.clone(),
|
||||
choices: vec![ChatStreamChoice {
|
||||
index: 0,
|
||||
delta: ChatStreamDelta {
|
||||
role,
|
||||
content: Some(chunk.content),
|
||||
reasoning_content: chunk.reasoning_content,
|
||||
tool_calls: chunk.tool_calls,
|
||||
},
|
||||
finish_reason: chunk.finish_reason,
|
||||
}],
|
||||
};
|
||||
|
||||
// Use axum's Event directly, wrap in Ok
|
||||
match Event::default().json_data(response) {
|
||||
Ok(event) => yield Ok::<_, crate::errors::AppError>(event),
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize SSE: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Stream error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Yield [DONE] at the end
|
||||
yield Ok::<_, crate::errors::AppError>(Event::default().data("[DONE]"));
|
||||
};
|
||||
|
||||
Ok(Sse::new(stream).into_response())
|
||||
}
|
||||
Err(e) => {
|
||||
// Record provider failure
|
||||
state.rate_limit_manager.record_provider_failure(&provider_name).await;
|
||||
|
||||
// Log failed request
|
||||
let duration = start_time.elapsed();
|
||||
warn!("Streaming request failed after {:?}: {}", duration, e);
|
||||
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle non-streaming response
|
||||
// Allow provider-specific routing: for OpenAI, some models prefer the
|
||||
// Responses API (/v1/responses). Use the model registry heuristic to
|
||||
// choose chat_responses vs chat_completion automatically.
|
||||
let use_responses = provider.name() == "openai"
|
||||
&& crate::utils::registry::model_prefers_responses(&state.model_registry, &unified_request.model);
|
||||
|
||||
let result = if use_responses {
|
||||
provider.chat_responses(unified_request).await
|
||||
} else {
|
||||
provider.chat_completion(unified_request).await
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(response) => {
|
||||
// Record provider success
|
||||
state.rate_limit_manager.record_provider_success(&provider_name).await;
|
||||
|
||||
let duration = start_time.elapsed();
|
||||
let cost = get_model_cost(
|
||||
&response.model,
|
||||
response.prompt_tokens,
|
||||
response.completion_tokens,
|
||||
response.cache_read_tokens,
|
||||
response.cache_write_tokens,
|
||||
&provider,
|
||||
&state,
|
||||
)
|
||||
.await;
|
||||
// Log request to database
|
||||
state.request_logger.log_request(crate::logging::RequestLog {
|
||||
timestamp: chrono::Utc::now(),
|
||||
client_id: client_id.clone(),
|
||||
provider: provider_name.clone(),
|
||||
model: response.model.clone(),
|
||||
prompt_tokens: response.prompt_tokens,
|
||||
completion_tokens: response.completion_tokens,
|
||||
total_tokens: response.total_tokens,
|
||||
cache_read_tokens: response.cache_read_tokens,
|
||||
cache_write_tokens: response.cache_write_tokens,
|
||||
cost,
|
||||
has_images,
|
||||
status: "success".to_string(),
|
||||
error_message: None,
|
||||
duration_ms: duration.as_millis() as u64,
|
||||
});
|
||||
|
||||
// Update client usage (fire-and-forget, don't block response)
|
||||
{
|
||||
let cm = state.client_manager.clone();
|
||||
let cid = client_id.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = cm.update_client_usage(&cid, response.total_tokens as i64, cost).await;
|
||||
});
|
||||
}
|
||||
|
||||
// Convert ProviderResponse to ChatCompletionResponse
|
||||
let finish_reason = if response.tool_calls.is_some() {
|
||||
"tool_calls".to_string()
|
||||
} else {
|
||||
"stop".to_string()
|
||||
};
|
||||
|
||||
let chat_response = ChatCompletionResponse {
|
||||
id: format!("chatcmpl-{}", Uuid::new_v4()),
|
||||
object: "chat.completion".to_string(),
|
||||
created: chrono::Utc::now().timestamp() as u64,
|
||||
model: response.model,
|
||||
choices: vec![ChatChoice {
|
||||
index: 0,
|
||||
message: ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: crate::models::MessageContent::Text {
|
||||
content: response.content,
|
||||
},
|
||||
reasoning_content: response.reasoning_content,
|
||||
tool_calls: response.tool_calls,
|
||||
name: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
finish_reason: Some(finish_reason),
|
||||
}],
|
||||
usage: Some(Usage {
|
||||
prompt_tokens: response.prompt_tokens,
|
||||
completion_tokens: response.completion_tokens,
|
||||
total_tokens: response.total_tokens,
|
||||
cache_read_tokens: if response.cache_read_tokens > 0 { Some(response.cache_read_tokens) } else { None },
|
||||
cache_write_tokens: if response.cache_write_tokens > 0 { Some(response.cache_write_tokens) } else { None },
|
||||
}),
|
||||
};
|
||||
|
||||
// Log successful request with proxy overhead breakdown
|
||||
let upstream_ms = duration.as_millis() as u64 - proxy_overhead.as_millis() as u64;
|
||||
info!(
|
||||
"Request completed in {:?} (proxy: {}ms, upstream: {}ms)",
|
||||
duration,
|
||||
proxy_overhead.as_millis(),
|
||||
upstream_ms
|
||||
);
|
||||
|
||||
Ok(Json(chat_response).into_response())
|
||||
}
|
||||
Err(e) => {
|
||||
// Record provider failure
|
||||
state.rate_limit_manager.record_provider_failure(&provider_name).await;
|
||||
|
||||
// Log failed request to database
|
||||
let duration = start_time.elapsed();
|
||||
state.request_logger.log_request(crate::logging::RequestLog {
|
||||
timestamp: chrono::Utc::now(),
|
||||
client_id: client_id.clone(),
|
||||
provider: provider_name.clone(),
|
||||
model: model.clone(),
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
cache_read_tokens: 0,
|
||||
cache_write_tokens: 0,
|
||||
cost: 0.0,
|
||||
has_images: false,
|
||||
status: "error".to_string(),
|
||||
error_message: Some(e.to_string()),
|
||||
duration_ms: duration.as_millis() as u64,
|
||||
});
|
||||
|
||||
warn!("Request failed after {:?}: {}", duration, e);
|
||||
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
129
src/state/mod.rs
129
src/state/mod.rs
@@ -1,129 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::{
|
||||
client::ClientManager, config::AppConfig, database::DbPool, logging::RequestLogger,
|
||||
models::registry::ModelRegistry, providers::ProviderManager, rate_limiting::RateLimitManager,
|
||||
};
|
||||
|
||||
/// Cached model configuration entry
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedModelConfig {
|
||||
pub enabled: bool,
|
||||
pub mapping: Option<String>,
|
||||
pub prompt_cost_per_m: Option<f64>,
|
||||
pub completion_cost_per_m: Option<f64>,
|
||||
}
|
||||
|
||||
/// In-memory cache for model_configs table.
|
||||
/// Refreshes periodically to avoid hitting SQLite on every request.
|
||||
#[derive(Clone)]
|
||||
pub struct ModelConfigCache {
|
||||
cache: Arc<RwLock<HashMap<String, CachedModelConfig>>>,
|
||||
db_pool: DbPool,
|
||||
}
|
||||
|
||||
impl ModelConfigCache {
|
||||
pub fn new(db_pool: DbPool) -> Self {
|
||||
Self {
|
||||
cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
db_pool,
|
||||
}
|
||||
}
|
||||
|
||||
/// Load all model configs from the database into cache
|
||||
pub async fn refresh(&self) {
|
||||
match sqlx::query_as::<_, (String, bool, Option<String>, Option<f64>, Option<f64>)>(
|
||||
"SELECT id, enabled, mapping, prompt_cost_per_m, completion_cost_per_m FROM model_configs",
|
||||
)
|
||||
.fetch_all(&self.db_pool)
|
||||
.await
|
||||
{
|
||||
Ok(rows) => {
|
||||
let mut map = HashMap::with_capacity(rows.len());
|
||||
for (id, enabled, mapping, prompt_cost, completion_cost) in rows {
|
||||
map.insert(
|
||||
id,
|
||||
CachedModelConfig {
|
||||
enabled,
|
||||
mapping,
|
||||
prompt_cost_per_m: prompt_cost,
|
||||
completion_cost_per_m: completion_cost,
|
||||
},
|
||||
);
|
||||
}
|
||||
*self.cache.write().await = map;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to refresh model config cache: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a cached model config. Returns None if not in cache (model is unconfigured).
|
||||
pub async fn get(&self, model: &str) -> Option<CachedModelConfig> {
|
||||
self.cache.read().await.get(model).cloned()
|
||||
}
|
||||
|
||||
/// Invalidate cache — call this after dashboard writes to model_configs
|
||||
pub async fn invalidate(&self) {
|
||||
self.refresh().await;
|
||||
}
|
||||
|
||||
/// Start a background task that refreshes the cache every `interval` seconds
|
||||
pub fn start_refresh_task(self, interval_secs: u64) {
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
self.refresh().await;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Shared application state
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub config: Arc<AppConfig>,
|
||||
pub provider_manager: ProviderManager,
|
||||
pub db_pool: DbPool,
|
||||
pub rate_limit_manager: Arc<RateLimitManager>,
|
||||
pub client_manager: Arc<ClientManager>,
|
||||
pub request_logger: Arc<RequestLogger>,
|
||||
pub model_registry: Arc<ModelRegistry>,
|
||||
pub model_config_cache: ModelConfigCache,
|
||||
pub dashboard_tx: broadcast::Sender<serde_json::Value>,
|
||||
pub auth_tokens: Vec<String>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn new(
|
||||
config: Arc<AppConfig>,
|
||||
provider_manager: ProviderManager,
|
||||
db_pool: DbPool,
|
||||
rate_limit_manager: RateLimitManager,
|
||||
model_registry: ModelRegistry,
|
||||
auth_tokens: Vec<String>,
|
||||
) -> Self {
|
||||
let client_manager = Arc::new(ClientManager::new(db_pool.clone()));
|
||||
let (dashboard_tx, _) = broadcast::channel(100);
|
||||
let request_logger = Arc::new(RequestLogger::new(db_pool.clone(), dashboard_tx.clone()));
|
||||
let model_config_cache = ModelConfigCache::new(db_pool.clone());
|
||||
|
||||
Self {
|
||||
config,
|
||||
provider_manager,
|
||||
db_pool,
|
||||
rate_limit_manager: Arc::new(rate_limit_manager),
|
||||
client_manager,
|
||||
request_logger,
|
||||
model_registry: Arc::new(model_registry),
|
||||
model_config_cache,
|
||||
dashboard_tx,
|
||||
auth_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
pub mod registry;
|
||||
pub mod streaming;
|
||||
pub mod tokens;
|
||||
@@ -1,49 +0,0 @@
|
||||
use crate::models::registry::ModelRegistry;
|
||||
use anyhow::Result;
|
||||
use tracing::info;
|
||||
|
||||
const MODELS_DEV_URL: &str = "https://models.dev/api.json";
|
||||
|
||||
pub async fn fetch_registry() -> Result<ModelRegistry> {
|
||||
info!("Fetching model registry from {}", MODELS_DEV_URL);
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(10))
|
||||
.build()?;
|
||||
|
||||
let response = client.get(MODELS_DEV_URL).send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(anyhow::anyhow!("Failed to fetch registry: HTTP {}", response.status()));
|
||||
}
|
||||
|
||||
let registry: ModelRegistry = response.json().await?;
|
||||
info!("Successfully loaded model registry");
|
||||
|
||||
Ok(registry)
|
||||
}
|
||||
|
||||
/// Heuristic: decide whether a model should be routed to OpenAI Responses API
|
||||
/// instead of the legacy chat/completions endpoint.
|
||||
///
|
||||
/// Currently this uses simple patterns (codex, gpt-5 series) and also checks
|
||||
/// the loaded registry metadata name for the substring "codex" as a hint.
|
||||
pub fn model_prefers_responses(registry: &ModelRegistry, model: &str) -> bool {
|
||||
let model_lc = model.to_lowercase();
|
||||
|
||||
if model_lc.contains("codex") {
|
||||
return true;
|
||||
}
|
||||
|
||||
if model_lc.starts_with("gpt-5") || model_lc.contains("gpt-5.") {
|
||||
return true;
|
||||
}
|
||||
|
||||
if let Some(meta) = registry.find_model(model) {
|
||||
if meta.name.to_lowercase().contains("codex") {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
@@ -1,334 +0,0 @@
|
||||
use crate::client::ClientManager;
|
||||
use crate::errors::AppError;
|
||||
use crate::logging::{RequestLog, RequestLogger};
|
||||
use crate::models::ToolCall;
|
||||
use crate::providers::{Provider, ProviderStreamChunk, StreamUsage};
|
||||
use crate::state::ModelConfigCache;
|
||||
use crate::utils::tokens::estimate_completion_tokens;
|
||||
use futures::stream::Stream;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
/// Configuration for creating an AggregatingStream.
|
||||
pub struct StreamConfig {
|
||||
pub client_id: String,
|
||||
pub provider: Arc<dyn Provider>,
|
||||
pub model: String,
|
||||
pub prompt_tokens: u32,
|
||||
pub has_images: bool,
|
||||
pub logger: Arc<RequestLogger>,
|
||||
pub client_manager: Arc<ClientManager>,
|
||||
pub model_registry: Arc<crate::models::registry::ModelRegistry>,
|
||||
pub model_config_cache: ModelConfigCache,
|
||||
}
|
||||
|
||||
pub struct AggregatingStream<S> {
|
||||
inner: S,
|
||||
client_id: String,
|
||||
provider: Arc<dyn Provider>,
|
||||
model: String,
|
||||
prompt_tokens: u32,
|
||||
has_images: bool,
|
||||
accumulated_content: String,
|
||||
accumulated_reasoning: String,
|
||||
accumulated_tool_calls: Vec<ToolCall>,
|
||||
/// Real usage data from the provider's final stream chunk (when available).
|
||||
real_usage: Option<StreamUsage>,
|
||||
logger: Arc<RequestLogger>,
|
||||
client_manager: Arc<ClientManager>,
|
||||
model_registry: Arc<crate::models::registry::ModelRegistry>,
|
||||
model_config_cache: ModelConfigCache,
|
||||
start_time: std::time::Instant,
|
||||
has_logged: bool,
|
||||
}
|
||||
|
||||
impl<S> AggregatingStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<ProviderStreamChunk, AppError>> + Unpin,
|
||||
{
|
||||
pub fn new(inner: S, config: StreamConfig) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
client_id: config.client_id,
|
||||
provider: config.provider,
|
||||
model: config.model,
|
||||
prompt_tokens: config.prompt_tokens,
|
||||
has_images: config.has_images,
|
||||
accumulated_content: String::new(),
|
||||
accumulated_reasoning: String::new(),
|
||||
accumulated_tool_calls: Vec::new(),
|
||||
real_usage: None,
|
||||
logger: config.logger,
|
||||
client_manager: config.client_manager,
|
||||
model_registry: config.model_registry,
|
||||
model_config_cache: config.model_config_cache,
|
||||
start_time: std::time::Instant::now(),
|
||||
has_logged: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn finalize(&mut self) {
|
||||
if self.has_logged {
|
||||
return;
|
||||
}
|
||||
self.has_logged = true;
|
||||
|
||||
let duration = self.start_time.elapsed();
|
||||
let client_id = self.client_id.clone();
|
||||
let provider_name = self.provider.name().to_string();
|
||||
let model = self.model.clone();
|
||||
let logger = self.logger.clone();
|
||||
let client_manager = self.client_manager.clone();
|
||||
let provider = self.provider.clone();
|
||||
let estimated_prompt_tokens = self.prompt_tokens;
|
||||
let has_images = self.has_images;
|
||||
let registry = self.model_registry.clone();
|
||||
let config_cache = self.model_config_cache.clone();
|
||||
let real_usage = self.real_usage.take();
|
||||
|
||||
// Estimate completion tokens (including reasoning if present)
|
||||
let estimated_content_tokens = estimate_completion_tokens(&self.accumulated_content, &model);
|
||||
let estimated_reasoning_tokens = if !self.accumulated_reasoning.is_empty() {
|
||||
estimate_completion_tokens(&self.accumulated_reasoning, &model)
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let estimated_completion = estimated_content_tokens + estimated_reasoning_tokens;
|
||||
|
||||
// Spawn a background task to log the completion
|
||||
tokio::spawn(async move {
|
||||
// Use real usage from the provider when available, otherwise fall back to estimates
|
||||
let (prompt_tokens, completion_tokens, total_tokens, cache_read_tokens, cache_write_tokens) =
|
||||
if let Some(usage) = &real_usage {
|
||||
(
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens,
|
||||
usage.total_tokens,
|
||||
usage.cache_read_tokens,
|
||||
usage.cache_write_tokens,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
estimated_prompt_tokens,
|
||||
estimated_completion,
|
||||
estimated_prompt_tokens + estimated_completion,
|
||||
0u32,
|
||||
0u32,
|
||||
)
|
||||
};
|
||||
|
||||
// Check in-memory cache for cost overrides (no SQLite hit)
|
||||
let cost = if let Some(cached) = config_cache.get(&model).await {
|
||||
if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) {
|
||||
// Cost override doesn't have cache-aware pricing, use simple formula
|
||||
(prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0)
|
||||
} else {
|
||||
provider.calculate_cost(
|
||||
&model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
®istry,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
provider.calculate_cost(
|
||||
&model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
®istry,
|
||||
)
|
||||
};
|
||||
|
||||
// Log to database
|
||||
logger.log_request(RequestLog {
|
||||
timestamp: chrono::Utc::now(),
|
||||
client_id: client_id.clone(),
|
||||
provider: provider_name,
|
||||
model,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
cost,
|
||||
has_images,
|
||||
status: "success".to_string(),
|
||||
error_message: None,
|
||||
duration_ms: duration.as_millis() as u64,
|
||||
});
|
||||
|
||||
// Update client usage
|
||||
let _ = client_manager
|
||||
.update_client_usage(&client_id, total_tokens as i64, cost)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Stream for AggregatingStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<ProviderStreamChunk, AppError>> + Unpin,
|
||||
{
|
||||
type Item = Result<ProviderStreamChunk, AppError>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let result = Pin::new(&mut self.inner).poll_next(cx);
|
||||
|
||||
match &result {
|
||||
Poll::Ready(Some(Ok(chunk))) => {
|
||||
self.accumulated_content.push_str(&chunk.content);
|
||||
if let Some(reasoning) = &chunk.reasoning_content {
|
||||
self.accumulated_reasoning.push_str(reasoning);
|
||||
}
|
||||
// Capture real usage from the provider when present (typically on the final chunk)
|
||||
if let Some(usage) = &chunk.usage {
|
||||
self.real_usage = Some(usage.clone());
|
||||
}
|
||||
// Accumulate tool call deltas into complete tool calls
|
||||
if let Some(deltas) = &chunk.tool_calls {
|
||||
for delta in deltas {
|
||||
let idx = delta.index as usize;
|
||||
// Grow the accumulated_tool_calls vec if needed
|
||||
while self.accumulated_tool_calls.len() <= idx {
|
||||
self.accumulated_tool_calls.push(ToolCall {
|
||||
id: String::new(),
|
||||
call_type: "function".to_string(),
|
||||
function: crate::models::FunctionCall {
|
||||
name: String::new(),
|
||||
arguments: String::new(),
|
||||
},
|
||||
});
|
||||
}
|
||||
let tc = &mut self.accumulated_tool_calls[idx];
|
||||
if let Some(id) = &delta.id {
|
||||
tc.id.clone_from(id);
|
||||
}
|
||||
if let Some(ct) = &delta.call_type {
|
||||
tc.call_type.clone_from(ct);
|
||||
}
|
||||
if let Some(f) = &delta.function {
|
||||
if let Some(name) = &f.name {
|
||||
tc.function.name.push_str(name);
|
||||
}
|
||||
if let Some(args) = &f.arguments {
|
||||
tc.function.arguments.push_str(args);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Err(_))) => {
|
||||
// If there's an error, we might still want to log what we got so far?
|
||||
// For now, just finalize if we have content
|
||||
if !self.accumulated_content.is_empty() {
|
||||
self.finalize();
|
||||
}
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
self.finalize();
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use anyhow::Result;
|
||||
use futures::stream::{self, StreamExt};
|
||||
|
||||
// Simple mock provider for testing
|
||||
struct MockProvider;
|
||||
#[async_trait::async_trait]
|
||||
impl Provider for MockProvider {
|
||||
fn name(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
fn supports_model(&self, _model: &str) -> bool {
|
||||
true
|
||||
}
|
||||
fn supports_multimodal(&self) -> bool {
|
||||
false
|
||||
}
|
||||
async fn chat_completion(
|
||||
&self,
|
||||
_req: crate::models::UnifiedRequest,
|
||||
) -> Result<crate::providers::ProviderResponse, AppError> {
|
||||
unimplemented!()
|
||||
}
|
||||
async fn chat_completion_stream(
|
||||
&self,
|
||||
_req: crate::models::UnifiedRequest,
|
||||
) -> Result<futures::stream::BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
||||
unimplemented!()
|
||||
}
|
||||
fn estimate_tokens(&self, _req: &crate::models::UnifiedRequest) -> Result<u32> {
|
||||
Ok(10)
|
||||
}
|
||||
fn calculate_cost(&self, _model: &str, _p: u32, _c: u32, _cr: u32, _cw: u32, _r: &crate::models::registry::ModelRegistry) -> f64 {
|
||||
0.05
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_aggregating_stream() {
|
||||
let chunks = vec![
|
||||
Ok(ProviderStreamChunk {
|
||||
content: "Hello".to_string(),
|
||||
reasoning_content: None,
|
||||
finish_reason: None,
|
||||
tool_calls: None,
|
||||
model: "test".to_string(),
|
||||
usage: None,
|
||||
}),
|
||||
Ok(ProviderStreamChunk {
|
||||
content: " World".to_string(),
|
||||
reasoning_content: None,
|
||||
finish_reason: Some("stop".to_string()),
|
||||
tool_calls: None,
|
||||
model: "test".to_string(),
|
||||
usage: None,
|
||||
}),
|
||||
];
|
||||
let inner_stream = stream::iter(chunks);
|
||||
|
||||
let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
|
||||
let (dashboard_tx, _) = tokio::sync::broadcast::channel(16);
|
||||
let logger = Arc::new(RequestLogger::new(pool.clone(), dashboard_tx));
|
||||
let client_manager = Arc::new(ClientManager::new(pool.clone()));
|
||||
let registry = Arc::new(crate::models::registry::ModelRegistry {
|
||||
providers: std::collections::HashMap::new(),
|
||||
});
|
||||
|
||||
let mut agg_stream = AggregatingStream::new(
|
||||
inner_stream,
|
||||
StreamConfig {
|
||||
client_id: "client_1".to_string(),
|
||||
provider: Arc::new(MockProvider),
|
||||
model: "test".to_string(),
|
||||
prompt_tokens: 10,
|
||||
has_images: false,
|
||||
logger,
|
||||
client_manager,
|
||||
model_registry: registry,
|
||||
model_config_cache: ModelConfigCache::new(pool.clone()),
|
||||
},
|
||||
);
|
||||
|
||||
while let Some(item) = agg_stream.next().await {
|
||||
assert!(item.is_ok());
|
||||
}
|
||||
|
||||
assert_eq!(agg_stream.accumulated_content, "Hello World");
|
||||
assert!(agg_stream.has_logged);
|
||||
}
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
use crate::models::UnifiedRequest;
|
||||
use tiktoken_rs::get_bpe_from_model;
|
||||
|
||||
/// Count tokens for a given model and text
|
||||
pub fn count_tokens(model: &str, text: &str) -> u32 {
|
||||
// If we can't get the bpe for the model, fallback to a safe default (cl100k_base for GPT-4/o1)
|
||||
let bpe = get_bpe_from_model(model)
|
||||
.unwrap_or_else(|_| tiktoken_rs::cl100k_base().expect("Failed to get cl100k_base encoding"));
|
||||
|
||||
bpe.encode_with_special_tokens(text).len() as u32
|
||||
}
|
||||
|
||||
/// Estimate tokens for a unified request.
|
||||
/// Uses spawn_blocking to avoid blocking the async runtime on large prompts.
|
||||
pub fn estimate_request_tokens(model: &str, request: &UnifiedRequest) -> u32 {
|
||||
let mut total_text = String::new();
|
||||
let msg_count = request.messages.len();
|
||||
|
||||
// Base tokens per message for OpenAI (approximate)
|
||||
let tokens_per_message: u32 = 3;
|
||||
|
||||
for msg in &request.messages {
|
||||
for part in &msg.content {
|
||||
match part {
|
||||
crate::models::ContentPart::Text { text } => {
|
||||
total_text.push_str(text);
|
||||
total_text.push('\n');
|
||||
}
|
||||
crate::models::ContentPart::Image { .. } => {
|
||||
// Vision models usually have a fixed cost or calculation based on size
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Quick heuristic for small inputs (< 1KB) — avoid spawn_blocking overhead
|
||||
if total_text.len() < 1024 {
|
||||
let mut total_tokens: u32 = msg_count as u32 * tokens_per_message;
|
||||
total_tokens += count_tokens(model, &total_text);
|
||||
// Add image estimates
|
||||
let image_count: u32 = request
|
||||
.messages
|
||||
.iter()
|
||||
.flat_map(|m| m.content.iter())
|
||||
.filter(|p| matches!(p, crate::models::ContentPart::Image { .. }))
|
||||
.count() as u32;
|
||||
total_tokens += image_count * 1000;
|
||||
total_tokens += 3; // assistant reply header
|
||||
return total_tokens;
|
||||
}
|
||||
|
||||
// For large inputs, use the fast heuristic (chars / 4) to avoid blocking
|
||||
// the async runtime. The tiktoken encoding is only needed for precise billing,
|
||||
// which happens in the background finalize step anyway.
|
||||
let estimated_text_tokens = (total_text.len() as u32) / 4;
|
||||
let image_count: u32 = request
|
||||
.messages
|
||||
.iter()
|
||||
.flat_map(|m| m.content.iter())
|
||||
.filter(|p| matches!(p, crate::models::ContentPart::Image { .. }))
|
||||
.count() as u32;
|
||||
|
||||
(msg_count as u32 * tokens_per_message) + estimated_text_tokens + (image_count * 1000) + 3
|
||||
}
|
||||
|
||||
/// Estimate tokens for completion text
|
||||
pub fn estimate_completion_tokens(text: &str, model: &str) -> u32 {
|
||||
count_tokens(model, text)
|
||||
}
|
||||
@@ -61,9 +61,9 @@
|
||||
--text-white: var(--fg0);
|
||||
|
||||
/* Borders */
|
||||
--border-color: var(--bg2);
|
||||
--border-radius: 8px;
|
||||
--border-radius-sm: 4px;
|
||||
--border-color: var(--bg3);
|
||||
--border-radius: 0px;
|
||||
--border-radius-sm: 0px;
|
||||
|
||||
/* Spacing System */
|
||||
--spacing-xs: 0.25rem;
|
||||
@@ -72,15 +72,15 @@
|
||||
--spacing-lg: 1.5rem;
|
||||
--spacing-xl: 2rem;
|
||||
|
||||
/* Shadows */
|
||||
--shadow-sm: 0 1px 2px 0 rgba(0, 0, 0, 0.2);
|
||||
--shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.3);
|
||||
--shadow-md: 0 10px 15px -3px rgba(0, 0, 0, 0.4);
|
||||
--shadow-lg: 0 20px 25px -5px rgba(0, 0, 0, 0.5);
|
||||
/* Shadows - Retro Block Style */
|
||||
--shadow-sm: 2px 2px 0px rgba(0, 0, 0, 0.4);
|
||||
--shadow: 4px 4px 0px rgba(0, 0, 0, 0.5);
|
||||
--shadow-md: 6px 6px 0px rgba(0, 0, 0, 0.6);
|
||||
--shadow-lg: 8px 8px 0px rgba(0, 0, 0, 0.7);
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: 'Inter', -apple-system, sans-serif;
|
||||
font-family: 'JetBrains Mono', 'Fira Code', 'Courier New', monospace;
|
||||
background-color: var(--bg-primary);
|
||||
color: var(--text-primary);
|
||||
line-height: 1.6;
|
||||
@@ -105,12 +105,12 @@ body {
|
||||
|
||||
.login-card {
|
||||
background: var(--bg1);
|
||||
border-radius: 24px;
|
||||
border-radius: var(--border-radius);
|
||||
padding: 4rem 2.5rem 3rem;
|
||||
width: 100%;
|
||||
max-width: 440px;
|
||||
box-shadow: var(--shadow-lg);
|
||||
border: 1px solid var(--bg2);
|
||||
border: 2px solid var(--bg3);
|
||||
text-align: center;
|
||||
animation: slideUp 0.6s cubic-bezier(0.34, 1.56, 0.64, 1);
|
||||
position: relative;
|
||||
@@ -148,22 +148,54 @@ body {
|
||||
width: 80px;
|
||||
height: 80px;
|
||||
margin: 0 auto 1.25rem;
|
||||
border-radius: 16px;
|
||||
background: var(--bg2);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
color: var(--orange);
|
||||
font-size: 2rem;
|
||||
background: rgba(254, 128, 25, 0.15);
|
||||
color: var(--primary);
|
||||
border-radius: 12px;
|
||||
font-size: 2.5rem;
|
||||
}
|
||||
|
||||
/* GopherGate Logo Icon */
|
||||
.logo-icon-container {
|
||||
width: 60px;
|
||||
height: 60px;
|
||||
background: var(--blue-light);
|
||||
border-radius: 12px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
box-shadow: var(--shadow);
|
||||
border: 2px solid var(--fg1);
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
.logo-icon-container.small {
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
border-radius: 6px;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.logo-icon-text {
|
||||
font-family: 'JetBrains Mono', monospace;
|
||||
font-weight: 700;
|
||||
color: var(--bg0);
|
||||
font-size: 1.8rem;
|
||||
}
|
||||
|
||||
.logo-icon-container.small .logo-icon-text {
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
.login-header h1 {
|
||||
font-size: 1.75rem;
|
||||
font-size: 2rem;
|
||||
font-weight: 800;
|
||||
color: var(--fg0);
|
||||
color: var(--primary-light);
|
||||
margin-bottom: 0.5rem;
|
||||
letter-spacing: -0.025em;
|
||||
text-transform: uppercase;
|
||||
}
|
||||
|
||||
.login-subtitle {
|
||||
@@ -191,7 +223,7 @@ body {
|
||||
color: var(--fg3);
|
||||
pointer-events: none;
|
||||
transition: all 0.25s ease;
|
||||
background: var(--bg1);
|
||||
background: transparent;
|
||||
padding: 0 0.375rem;
|
||||
z-index: 2;
|
||||
font-weight: 500;
|
||||
@@ -202,30 +234,32 @@ body {
|
||||
|
||||
.form-group input:focus ~ label,
|
||||
.form-group input:not(:placeholder-shown) ~ label {
|
||||
top: -0.625rem;
|
||||
top: 0;
|
||||
left: 0.875rem;
|
||||
font-size: 0.7rem;
|
||||
font-size: 0.75rem;
|
||||
color: var(--orange);
|
||||
font-weight: 600;
|
||||
transform: translateY(0);
|
||||
transform: translateY(-50%);
|
||||
background: linear-gradient(180deg, var(--bg1) 50%, var(--bg0) 50%);
|
||||
}
|
||||
|
||||
.form-group input {
|
||||
padding: 1rem 1.25rem;
|
||||
background: var(--bg0);
|
||||
border: 2px solid var(--bg3);
|
||||
border-radius: 12px;
|
||||
border-radius: var(--border-radius);
|
||||
font-family: inherit;
|
||||
font-size: 1rem;
|
||||
color: var(--fg1);
|
||||
transition: all 0.3s;
|
||||
transition: all 0.2s;
|
||||
width: 100%;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.form-group input:focus {
|
||||
border-color: var(--orange);
|
||||
box-shadow: 0 0 0 4px rgba(214, 93, 14, 0.2);
|
||||
outline: none;
|
||||
box-shadow: 4px 4px 0px rgba(214, 93, 14, 0.4);
|
||||
}
|
||||
|
||||
.login-btn {
|
||||
@@ -295,6 +329,25 @@ body {
|
||||
font-size: 1.125rem;
|
||||
}
|
||||
|
||||
/* Badges */
|
||||
.badge {
|
||||
display: inline-block;
|
||||
padding: 0.25rem 0.5rem;
|
||||
font-size: 0.75rem;
|
||||
font-weight: 600;
|
||||
line-height: 1;
|
||||
text-align: center;
|
||||
white-space: nowrap;
|
||||
vertical-align: baseline;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.badge-success { background-color: rgba(152, 151, 26, 0.15); color: var(--green-light); border: 1px solid var(--green); }
|
||||
.badge-info { background-color: rgba(69, 133, 136, 0.15); color: var(--blue-light); border: 1px solid var(--blue); }
|
||||
.badge-warning { background-color: rgba(215, 153, 33, 0.15); color: var(--yellow-light); border: 1px solid var(--yellow); }
|
||||
.badge-danger { background-color: rgba(204, 36, 29, 0.15); color: var(--red-light); border: 1px solid var(--red); }
|
||||
.badge-client { background-color: var(--bg2); color: var(--fg1); border: 1px solid var(--bg3); padding: 2px 6px; font-size: 0.7rem; text-transform: uppercase; }
|
||||
|
||||
/* Responsive Login */
|
||||
@media (max-width: 480px) {
|
||||
.login-card {
|
||||
@@ -373,11 +426,15 @@ body {
|
||||
}
|
||||
|
||||
.sidebar.collapsed .logo {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.sidebar.collapsed .logo span {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.sidebar.collapsed .sidebar-toggle {
|
||||
opacity: 1;
|
||||
margin-left: 0;
|
||||
}
|
||||
|
||||
.logo {
|
||||
@@ -392,6 +449,7 @@ body {
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
|
||||
.sidebar-logo {
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
@@ -586,17 +644,48 @@ body {
|
||||
|
||||
/* Main Content Area */
|
||||
.main-content {
|
||||
margin-left: 260px;
|
||||
padding-left: 260px;
|
||||
flex: 1;
|
||||
min-height: 100vh;
|
||||
transition: all 0.3s;
|
||||
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
background-color: var(--bg-primary);
|
||||
}
|
||||
|
||||
.sidebar.collapsed ~ .main-content {
|
||||
margin-left: 80px;
|
||||
.sidebar.collapsed + .main-content {
|
||||
padding-left: 80px;
|
||||
}
|
||||
|
||||
.top-bar {
|
||||
height: 70px;
|
||||
background: var(--bg0);
|
||||
border-bottom: 1px solid var(--bg2);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 0 var(--spacing-xl);
|
||||
position: sticky;
|
||||
top: 0;
|
||||
z-index: 100;
|
||||
}
|
||||
|
||||
.top-bar .page-title h2 {
|
||||
font-size: 1.25rem;
|
||||
font-weight: 700;
|
||||
color: var(--fg0);
|
||||
}
|
||||
|
||||
.top-bar-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: var(--spacing-lg);
|
||||
}
|
||||
|
||||
.content-body {
|
||||
padding: var(--spacing-xl);
|
||||
flex: 1;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.top-nav {
|
||||
@@ -732,11 +821,11 @@ body {
|
||||
.stat-change.positive { color: var(--green-light); }
|
||||
.stat-change.negative { color: var(--red-light); }
|
||||
|
||||
/* Generic Cards */
|
||||
/* Cards */
|
||||
.card {
|
||||
background: var(--bg1);
|
||||
border-radius: var(--border-radius);
|
||||
border: 1px solid var(--bg2);
|
||||
border: 1px solid var(--bg3);
|
||||
box-shadow: var(--shadow-sm);
|
||||
margin-bottom: 1.5rem;
|
||||
display: flex;
|
||||
@@ -749,6 +838,15 @@ body {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
flex-wrap: wrap;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.card-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.card-title {
|
||||
@@ -817,25 +915,26 @@ body {
|
||||
/* Badges */
|
||||
.status-badge {
|
||||
padding: 0.25rem 0.75rem;
|
||||
border-radius: 9999px;
|
||||
border-radius: var(--border-radius);
|
||||
font-size: 0.7rem;
|
||||
font-weight: 700;
|
||||
text-transform: uppercase;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.375rem;
|
||||
border: 1px solid transparent;
|
||||
}
|
||||
|
||||
.status-badge.online, .status-badge.success { background: rgba(184, 187, 38, 0.2); color: var(--green-light); }
|
||||
.status-badge.offline, .status-badge.danger { background: rgba(251, 73, 52, 0.2); color: var(--red-light); }
|
||||
.status-badge.warning { background: rgba(250, 189, 47, 0.2); color: var(--yellow-light); }
|
||||
.status-badge.online, .status-badge.success { background: rgba(184, 187, 38, 0.2); color: var(--green-light); border-color: rgba(184, 187, 38, 0.4); }
|
||||
.status-badge.offline, .status-badge.danger { background: rgba(251, 73, 52, 0.2); color: var(--red-light); border-color: rgba(251, 73, 52, 0.4); }
|
||||
.status-badge.warning { background: rgba(250, 189, 47, 0.2); color: var(--yellow-light); border-color: rgba(250, 189, 47, 0.4); }
|
||||
|
||||
.badge-client {
|
||||
background: var(--bg2);
|
||||
color: var(--blue-light);
|
||||
padding: 2px 8px;
|
||||
border-radius: 6px;
|
||||
font-family: monospace;
|
||||
border-radius: var(--border-radius);
|
||||
font-family: inherit;
|
||||
font-size: 0.85rem;
|
||||
border: 1px solid var(--bg3);
|
||||
}
|
||||
@@ -889,7 +988,7 @@ body {
|
||||
width: 100%;
|
||||
background: var(--bg0);
|
||||
border: 1px solid var(--bg3);
|
||||
border-radius: 8px;
|
||||
border-radius: var(--border-radius);
|
||||
padding: 0.75rem;
|
||||
font-family: inherit;
|
||||
font-size: 0.875rem;
|
||||
@@ -900,7 +999,7 @@ body {
|
||||
.form-control input:focus, .form-control textarea:focus, .form-control select:focus {
|
||||
outline: none;
|
||||
border-color: var(--orange);
|
||||
box-shadow: 0 0 0 2px rgba(214, 93, 14, 0.2);
|
||||
box-shadow: 2px 2px 0px rgba(214, 93, 14, 0.4);
|
||||
}
|
||||
|
||||
.btn {
|
||||
@@ -908,21 +1007,27 @@ body {
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.625rem 1.25rem;
|
||||
border-radius: 8px;
|
||||
border-radius: var(--border-radius);
|
||||
font-weight: 600;
|
||||
font-size: 0.875rem;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
transition: all 0.1s;
|
||||
border: 1px solid transparent;
|
||||
text-transform: uppercase;
|
||||
}
|
||||
|
||||
.btn-primary { background: var(--orange); color: var(--bg0); }
|
||||
.btn:active {
|
||||
transform: translate(2px, 2px);
|
||||
box-shadow: none !important;
|
||||
}
|
||||
|
||||
.btn-primary { background: var(--orange); color: var(--bg0); box-shadow: 2px 2px 0px var(--bg4); }
|
||||
.btn-primary:hover { background: var(--orange-light); }
|
||||
|
||||
.btn-secondary { background: var(--bg2); border-color: var(--bg3); color: var(--fg1); }
|
||||
.btn-secondary { background: var(--bg2); border-color: var(--bg3); color: var(--fg1); box-shadow: 2px 2px 0px var(--bg0); }
|
||||
.btn-secondary:hover { background: var(--bg3); color: var(--fg0); }
|
||||
|
||||
.btn-danger { background: var(--red); color: var(--fg0); }
|
||||
.btn-danger { background: var(--red); color: var(--fg0); box-shadow: 2px 2px 0px var(--bg4); }
|
||||
.btn-danger:hover { background: var(--red-light); }
|
||||
|
||||
/* Small inline action buttons (edit, delete, copy) */
|
||||
@@ -981,13 +1086,13 @@ body {
|
||||
|
||||
.modal-content {
|
||||
background: var(--bg1);
|
||||
border-radius: 16px;
|
||||
border-radius: var(--border-radius);
|
||||
width: 90%;
|
||||
max-width: 500px;
|
||||
box-shadow: var(--shadow-lg);
|
||||
border: 1px solid var(--bg3);
|
||||
border: 2px solid var(--bg3);
|
||||
transform: translateY(20px);
|
||||
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.modal.active .modal-content {
|
||||
@@ -1029,6 +1134,53 @@ body {
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
/* Connection Status Indicator */
|
||||
.status-indicator {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
padding: 0.5rem 0.875rem;
|
||||
background: var(--bg1);
|
||||
border: 1px solid var(--bg3);
|
||||
border-radius: 6px;
|
||||
font-size: 0.8rem;
|
||||
font-weight: 600;
|
||||
color: var(--fg3);
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.status-dot {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
background: var(--fg4);
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.status-dot.connected {
|
||||
background: var(--green-light);
|
||||
box-shadow: 0 0 0 0 rgba(184, 187, 38, 0.4);
|
||||
animation: status-pulse 2s infinite;
|
||||
}
|
||||
|
||||
.status-dot.disconnected {
|
||||
background: var(--red-light);
|
||||
}
|
||||
|
||||
.status-dot.connecting {
|
||||
background: var(--yellow-light);
|
||||
}
|
||||
|
||||
.status-dot.error {
|
||||
background: var(--red);
|
||||
}
|
||||
|
||||
@keyframes status-pulse {
|
||||
0% { box-shadow: 0 0 0 0 rgba(184, 187, 38, 0.4); }
|
||||
70% { box-shadow: 0 0 0 6px rgba(184, 187, 38, 0); }
|
||||
100% { box-shadow: 0 0 0 0 rgba(184, 187, 38, 0); }
|
||||
}
|
||||
|
||||
/* WebSocket Dot Pulse */
|
||||
@keyframes ws-pulse {
|
||||
0% { box-shadow: 0 0 0 0 rgba(184, 187, 38, 0.4); }
|
||||
|
||||
BIN
static/favicon.ico
Normal file
BIN
static/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1002 B |
@@ -3,50 +3,38 @@
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>LLM Proxy Gateway - Admin Dashboard</title>
|
||||
<link rel="stylesheet" href="/css/dashboard.css?v=7">
|
||||
<title>GopherGate - Admin Dashboard</title>
|
||||
<link rel="stylesheet" href="/css/dashboard.css?v=11">
|
||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
|
||||
<link rel="icon" href="img/logo-icon.png" type="image/png" sizes="any">
|
||||
<link rel="apple-touch-icon" href="img/logo-icon.png">
|
||||
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet">
|
||||
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/luxon@3.4.4/build/global/luxon.min.js"></script>
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;700&family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
|
||||
</head>
|
||||
<body>
|
||||
<!-- Login Screen -->
|
||||
<div id="login-screen" class="login-container">
|
||||
<body class="gruvbox-dark">
|
||||
<!-- Auth Page -->
|
||||
<div id="auth-page" class="login-container">
|
||||
<div class="login-card">
|
||||
<div class="login-header">
|
||||
<img src="img/logo-full.png" alt="LLM Proxy Logo" class="login-logo" onerror="this.style.display='none'; this.nextElementSibling.style.display='block';">
|
||||
<i class="fas fa-robot login-logo-fallback" style="display: none;"></i>
|
||||
<h1>LLM Proxy Gateway</h1>
|
||||
<p class="login-subtitle">Admin Dashboard</p>
|
||||
<div class="logo-icon-container">
|
||||
<span class="logo-icon-text">GG</span>
|
||||
</div>
|
||||
<form id="login-form" class="login-form">
|
||||
<div class="form-group">
|
||||
<input type="text" id="username" name="username" placeholder=" " required>
|
||||
<label for="username">
|
||||
<i class="fas fa-user"></i> Username
|
||||
</label>
|
||||
<h1>GopherGate</h1>
|
||||
<p class="login-subtitle">Secure LLM Gateway & Management</p>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<input type="password" id="password" name="password" placeholder=" " required>
|
||||
<label for="password">
|
||||
<i class="fas fa-lock"></i> Password
|
||||
</label>
|
||||
<form id="login-form">
|
||||
<div class="form-control">
|
||||
<label for="username">Username</label>
|
||||
<input type="text" id="username" name="username" required autocomplete="username">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<button type="submit" class="login-btn">
|
||||
<i class="fas fa-sign-in-alt"></i> Sign In
|
||||
</button>
|
||||
</div>
|
||||
<div class="login-footer">
|
||||
<p>Default: <code>admin</code> / <code>admin</code> (change in Settings > Security)</p>
|
||||
<div class="form-control">
|
||||
<label for="password">Password</label>
|
||||
<input type="password" id="password" name="password" required autocomplete="current-password">
|
||||
</div>
|
||||
<button type="submit" id="login-btn" class="btn btn-primary btn-block">Sign In</button>
|
||||
</form>
|
||||
<div id="login-error" class="error-message" style="display: none;">
|
||||
<i class="fas fa-exclamation-circle"></i>
|
||||
<span>Invalid credentials. Please try again.</span>
|
||||
<span></span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -57,9 +45,10 @@
|
||||
<nav class="sidebar">
|
||||
<div class="sidebar-header">
|
||||
<div class="logo">
|
||||
<img src="img/logo-icon.png" alt="LLM Proxy" class="sidebar-logo" onerror="this.style.display='none'; this.nextElementSibling.style.display='inline-block';">
|
||||
<i class="fas fa-shield-alt logo-fallback" style="display: none;"></i>
|
||||
<span>LLM Proxy</span>
|
||||
<div class="logo-icon-container small">
|
||||
<span class="logo-icon-text">GG</span>
|
||||
</div>
|
||||
<span>GopherGate</span>
|
||||
</div>
|
||||
<button class="sidebar-toggle" id="sidebar-toggle">
|
||||
<i class="fas fa-bars"></i>
|
||||
@@ -69,68 +58,74 @@
|
||||
<div class="sidebar-menu">
|
||||
<div class="menu-section">
|
||||
<h3 class="menu-title">MAIN</h3>
|
||||
<a href="#overview" class="menu-item active" data-page="overview" data-tooltip="Dashboard Overview">
|
||||
<ul class="menu-list">
|
||||
<li class="menu-item active" data-page="overview">
|
||||
<i class="fas fa-th-large"></i>
|
||||
<span>Overview</span>
|
||||
</a>
|
||||
<a href="#analytics" class="menu-item" data-page="analytics" data-tooltip="Usage Analytics">
|
||||
<i class="fas fa-chart-line"></i>
|
||||
</li>
|
||||
<li class="menu-item" data-page="analytics">
|
||||
<i class="fas fa-chart-bar"></i>
|
||||
<span>Analytics</span>
|
||||
</a>
|
||||
<a href="#costs" class="menu-item" data-page="costs" data-tooltip="Cost Tracking">
|
||||
</li>
|
||||
<li class="menu-item" data-page="costs">
|
||||
<i class="fas fa-dollar-sign"></i>
|
||||
<span>Cost Management</span>
|
||||
</a>
|
||||
<span>Costs & Billing</span>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div class="menu-section">
|
||||
<h3 class="menu-title">MANAGEMENT</h3>
|
||||
<a href="#clients" class="menu-item" data-page="clients" data-tooltip="API Clients">
|
||||
<ul class="menu-list">
|
||||
<li class="menu-item" data-page="clients">
|
||||
<i class="fas fa-users"></i>
|
||||
<span>Client Management</span>
|
||||
</a>
|
||||
<a href="#providers" class="menu-item" data-page="providers" data-tooltip="Model Providers">
|
||||
<span>Clients</span>
|
||||
</li>
|
||||
<li class="menu-item" data-page="providers">
|
||||
<i class="fas fa-server"></i>
|
||||
<span>Providers</span>
|
||||
</a>
|
||||
<a href="#models" class="menu-item" data-page="models" data-tooltip="Manage Models">
|
||||
<i class="fas fa-cube"></i>
|
||||
</li>
|
||||
<li class="menu-item" data-page="models">
|
||||
<i class="fas fa-brain"></i>
|
||||
<span>Models</span>
|
||||
</a>
|
||||
<a href="#monitoring" class="menu-item" data-page="monitoring" data-tooltip="Live Monitoring">
|
||||
<i class="fas fa-heartbeat"></i>
|
||||
<span>Real-time Monitoring</span>
|
||||
</a>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div class="menu-section">
|
||||
<h3 class="menu-title">SYSTEM</h3>
|
||||
<a href="#users" class="menu-item admin-only" data-page="users" data-tooltip="User Accounts">
|
||||
<ul class="menu-list">
|
||||
<li class="menu-item" data-page="monitoring">
|
||||
<i class="fas fa-activity"></i>
|
||||
<span>Live Monitoring</span>
|
||||
</li>
|
||||
<li class="menu-item" data-page="logs">
|
||||
<i class="fas fa-list-alt"></i>
|
||||
<span>Logs</span>
|
||||
</li>
|
||||
<li class="menu-item" data-page="users">
|
||||
<i class="fas fa-user-shield"></i>
|
||||
<span>User Management</span>
|
||||
</a>
|
||||
<a href="#settings" class="menu-item admin-only" data-page="settings" data-tooltip="System Settings">
|
||||
<span>Admin Users</span>
|
||||
</li>
|
||||
<li class="menu-item" data-page="settings">
|
||||
<i class="fas fa-cog"></i>
|
||||
<span>Settings</span>
|
||||
</a>
|
||||
<a href="#logs" class="menu-item" data-page="logs" data-tooltip="System Logs">
|
||||
<i class="fas fa-list-alt"></i>
|
||||
<span>System Logs</span>
|
||||
</a>
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="sidebar-footer">
|
||||
<div class="user-info">
|
||||
<div class="user-avatar">
|
||||
<i class="fas fa-user-circle"></i>
|
||||
<i class="fas fa-user"></i>
|
||||
</div>
|
||||
<div class="user-details">
|
||||
<span class="user-name">Loading...</span>
|
||||
<span class="user-role">...</span>
|
||||
<div class="user-name" id="display-username">Admin</div>
|
||||
<div class="user-role" id="display-role">Administrator</div>
|
||||
</div>
|
||||
</div>
|
||||
<button class="logout-btn" id="logout-btn" title="Logout">
|
||||
<button id="logout-btn" class="btn-icon" title="Logout">
|
||||
<i class="fas fa-sign-out-alt"></i>
|
||||
</button>
|
||||
</div>
|
||||
@@ -138,43 +133,40 @@
|
||||
|
||||
<!-- Main Content -->
|
||||
<main class="main-content">
|
||||
<!-- Top Navigation -->
|
||||
<header class="top-nav">
|
||||
<div class="nav-left">
|
||||
<h1 class="page-title" id="page-title">Dashboard Overview</h1>
|
||||
<header class="top-bar">
|
||||
<div class="page-title">
|
||||
<h2 id="current-page-title">Overview</h2>
|
||||
</div>
|
||||
<div class="nav-right">
|
||||
<div class="nav-item" id="ws-status-nav" title="WebSocket Connection Status">
|
||||
<div class="ws-dot"></div>
|
||||
<span class="ws-text">Connecting...</span>
|
||||
<div class="top-bar-actions">
|
||||
<div id="connection-status" class="status-indicator">
|
||||
<span class="status-dot"></span>
|
||||
<span class="status-text">Disconnected</span>
|
||||
</div>
|
||||
<div class="nav-item" title="Refresh Current Page">
|
||||
<i class="fas fa-sync-alt" id="refresh-btn"></i>
|
||||
</div>
|
||||
<div class="nav-item">
|
||||
<span id="current-time">Loading...</span>
|
||||
<div class="theme-toggle" id="theme-toggle">
|
||||
<i class="fas fa-moon"></i>
|
||||
</div>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<!-- Page Content -->
|
||||
<div class="page-content" id="page-content">
|
||||
<!-- Dynamic content container -->
|
||||
<div id="page-content" class="content-body">
|
||||
<!-- Content will be loaded dynamically -->
|
||||
<div class="loader-container">
|
||||
<div class="loader"></div>
|
||||
</div>
|
||||
|
||||
<!-- Global Spinner -->
|
||||
<div class="spinner-container">
|
||||
<div class="spinner"></div>
|
||||
</div>
|
||||
</main>
|
||||
</div>
|
||||
|
||||
<!-- Scripts (cache-busted with version query params) -->
|
||||
<!-- Scripts -->
|
||||
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/luxon@3.3.0/build/global/luxon.min.js"></script>
|
||||
<script src="/js/api.js?v=7"></script>
|
||||
<script src="/js/auth.js?v=7"></script>
|
||||
<script src="/js/dashboard.js?v=7"></script>
|
||||
<script src="/js/websocket.js?v=7"></script>
|
||||
<script src="/js/charts.js?v=7"></script>
|
||||
<script src="/js/websocket.js?v=7"></script>
|
||||
<script src="/js/dashboard.js?v=7"></script>
|
||||
|
||||
<!-- Page Modules -->
|
||||
<script src="/js/pages/overview.js?v=7"></script>
|
||||
<script src="/js/pages/analytics.js?v=7"></script>
|
||||
<script src="/js/pages/costs.js?v=7"></script>
|
||||
|
||||
@@ -32,9 +32,28 @@ class ApiClient {
|
||||
}
|
||||
|
||||
if (!response.ok || !result.success) {
|
||||
// Handle authentication errors (session expired, server restarted, etc.)
|
||||
if (response.status === 401 ||
|
||||
result.error === 'Session expired or invalid' ||
|
||||
result.error === 'Not authenticated' ||
|
||||
result.error === 'Admin access required') {
|
||||
|
||||
if (window.authManager) {
|
||||
// Try to logout to clear local state and show login screen
|
||||
window.authManager.logout();
|
||||
}
|
||||
}
|
||||
throw new Error(result.error || `HTTP error! status: ${response.status}`);
|
||||
}
|
||||
|
||||
// Handling X-Refreshed-Token header
|
||||
if (response.headers.get('X-Refreshed-Token') && window.authManager) {
|
||||
window.authManager.token = response.headers.get('X-Refreshed-Token');
|
||||
if (window.authManager.setToken) {
|
||||
window.authManager.setToken(window.authManager.token);
|
||||
}
|
||||
}
|
||||
|
||||
return result.data;
|
||||
}
|
||||
|
||||
@@ -87,6 +106,17 @@ class ApiClient {
|
||||
const date = luxon.DateTime.fromISO(dateStr);
|
||||
return date.toRelative();
|
||||
}
|
||||
|
||||
// Helper for escaping HTML
|
||||
escapeHtml(unsafe) {
|
||||
if (unsafe === undefined || unsafe === null) return '';
|
||||
return unsafe.toString()
|
||||
.replace(/&/g, "&")
|
||||
.replace(/</g, "<")
|
||||
.replace(/>/g, ">")
|
||||
.replace(/"/g, """)
|
||||
.replace(/'/g, "'");
|
||||
}
|
||||
}
|
||||
|
||||
window.api = new ApiClient();
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Authentication Module for LLM Proxy Dashboard
|
||||
// Authentication Module for GopherGate Dashboard
|
||||
|
||||
class AuthManager {
|
||||
constructor() {
|
||||
@@ -50,9 +50,15 @@ class AuthManager {
|
||||
});
|
||||
}
|
||||
|
||||
setToken(newToken) {
|
||||
if (!newToken) return;
|
||||
this.token = newToken;
|
||||
localStorage.setItem('dashboard_token', this.token);
|
||||
}
|
||||
|
||||
async login(username, password) {
|
||||
const errorElement = document.getElementById('login-error');
|
||||
const loginBtn = document.querySelector('.login-btn');
|
||||
const loginBtn = document.getElementById('login-btn');
|
||||
|
||||
try {
|
||||
loginBtn.innerHTML = '<i class="fas fa-spinner fa-spin"></i> Authenticating...';
|
||||
@@ -118,7 +124,7 @@ class AuthManager {
|
||||
}
|
||||
|
||||
showLogin() {
|
||||
const loginScreen = document.getElementById('login-screen');
|
||||
const loginScreen = document.getElementById('auth-page');
|
||||
const dashboard = document.getElementById('dashboard');
|
||||
|
||||
if (loginScreen) loginScreen.style.display = 'flex';
|
||||
@@ -133,7 +139,7 @@ class AuthManager {
|
||||
if (errorElement) errorElement.style.display = 'none';
|
||||
|
||||
// Reset button
|
||||
const loginBtn = document.querySelector('.login-btn');
|
||||
const loginBtn = document.getElementById('login-btn');
|
||||
if (loginBtn) {
|
||||
loginBtn.innerHTML = '<i class="fas fa-sign-in-alt"></i> Sign In';
|
||||
loginBtn.disabled = false;
|
||||
@@ -141,7 +147,7 @@ class AuthManager {
|
||||
}
|
||||
|
||||
showDashboard() {
|
||||
const loginScreen = document.getElementById('login-screen');
|
||||
const loginScreen = document.getElementById('auth-page');
|
||||
const dashboard = document.getElementById('dashboard');
|
||||
|
||||
if (loginScreen) loginScreen.style.display = 'none';
|
||||
@@ -161,7 +167,7 @@ class AuthManager {
|
||||
const userRoleElement = document.querySelector('.user-role');
|
||||
|
||||
if (userNameElement && this.user) {
|
||||
userNameElement.textContent = this.user.name || this.user.username || 'User';
|
||||
userNameElement.textContent = this.user.display_name || this.user.username || 'User';
|
||||
}
|
||||
|
||||
if (userRoleElement && this.user) {
|
||||
|
||||
@@ -285,7 +285,30 @@ class Dashboard {
|
||||
<p class="card-subtitle">Manage model availability and custom pricing</p>
|
||||
</div>
|
||||
<div class="card-actions">
|
||||
<input type="text" id="model-search" placeholder="Search models..." class="form-control" style="margin-bottom: 0; padding: 4px 8px; width: 250px;">
|
||||
<select id="model-provider-filter" class="form-control" style="margin-bottom: 0; padding: 4px 8px; width: auto;">
|
||||
<option value="">All Providers</option>
|
||||
<option value="openai">OpenAI</option>
|
||||
<option value="anthropic">Anthropic / Gemini</option>
|
||||
<option value="google">Google</option>
|
||||
<option value="deepseek">DeepSeek</option>
|
||||
<option value="xai">xAI</option>
|
||||
<option value="meta">Meta</option>
|
||||
<option value="cohere">Cohere</option>
|
||||
<option value="mistral">Mistral</option>
|
||||
<option value="other">Other</option>
|
||||
</select>
|
||||
<select id="model-modality-filter" class="form-control" style="margin-bottom: 0; padding: 4px 8px; width: auto;">
|
||||
<option value="">All Modalities</option>
|
||||
<option value="text">Text</option>
|
||||
<option value="image">Vision/Image</option>
|
||||
<option value="audio">Audio</option>
|
||||
</select>
|
||||
<select id="model-capability-filter" class="form-control" style="margin-bottom: 0; padding: 4px 8px; width: auto;">
|
||||
<option value="">All Capabilities</option>
|
||||
<option value="tool_call">Tool Calling</option>
|
||||
<option value="reasoning">Reasoning</option>
|
||||
</select>
|
||||
<input type="text" id="model-search" placeholder="Search models..." class="form-control" style="margin-bottom: 0; padding: 4px 8px; width: 200px;">
|
||||
</div>
|
||||
</div>
|
||||
<div class="table-container">
|
||||
|
||||
@@ -42,12 +42,15 @@ class ClientsPage {
|
||||
const statusIcon = client.status === 'active' ? 'check-circle' : 'clock';
|
||||
const created = luxon.DateTime.fromISO(client.created_at).toFormat('MMM dd, yyyy');
|
||||
|
||||
const escapedId = window.api.escapeHtml(client.id);
|
||||
const escapedName = window.api.escapeHtml(client.name);
|
||||
|
||||
return `
|
||||
<tr>
|
||||
<td><span class="badge-client">${client.id}</span></td>
|
||||
<td><strong>${client.name}</strong></td>
|
||||
<td><span class="badge-client">${escapedId}</span></td>
|
||||
<td><strong>${escapedName}</strong></td>
|
||||
<td>
|
||||
<code class="token-display">sk-••••${client.id.substring(client.id.length - 4)}</code>
|
||||
<code class="token-display">sk-••••${escapedId.substring(escapedId.length - 4)}</code>
|
||||
</td>
|
||||
<td>${created}</td>
|
||||
<td>${client.last_used ? window.api.formatTimeAgo(client.last_used) : 'Never'}</td>
|
||||
@@ -55,16 +58,16 @@ class ClientsPage {
|
||||
<td>
|
||||
<span class="status-badge ${statusClass}">
|
||||
<i class="fas fa-${statusIcon}"></i>
|
||||
${client.status}
|
||||
${window.api.escapeHtml(client.status)}
|
||||
</span>
|
||||
</td>
|
||||
<td>
|
||||
${window._userRole === 'admin' ? `
|
||||
<div class="action-buttons">
|
||||
<button class="btn-action" title="Edit" onclick="window.clientsPage.editClient('${client.id}')">
|
||||
<button class="btn-action" title="Edit" onclick="window.clientsPage.editClient('${escapedId}')">
|
||||
<i class="fas fa-edit"></i>
|
||||
</button>
|
||||
<button class="btn-action danger" title="Delete" onclick="window.clientsPage.deleteClient('${client.id}')">
|
||||
<button class="btn-action danger" title="Delete" onclick="window.clientsPage.deleteClient('${escapedId}')">
|
||||
<i class="fas fa-trash"></i>
|
||||
</button>
|
||||
</div>
|
||||
@@ -188,10 +191,13 @@ class ClientsPage {
|
||||
showTokenRevealModal(clientName, token) {
|
||||
const modal = document.createElement('div');
|
||||
modal.className = 'modal active';
|
||||
const escapedName = window.api.escapeHtml(clientName);
|
||||
const escapedToken = window.api.escapeHtml(token);
|
||||
|
||||
modal.innerHTML = `
|
||||
<div class="modal-content">
|
||||
<div class="modal-header">
|
||||
<h3 class="modal-title">Client Created: ${clientName}</h3>
|
||||
<h3 class="modal-title">Client Created: ${escapedName}</h3>
|
||||
</div>
|
||||
<div class="modal-body">
|
||||
<p style="margin-bottom: 0.75rem; color: var(--yellow);">
|
||||
@@ -201,7 +207,7 @@ class ClientsPage {
|
||||
<div class="form-control">
|
||||
<label>API Token</label>
|
||||
<div style="display: flex; gap: 0.5rem;">
|
||||
<input type="text" id="revealed-token" value="${token}" readonly
|
||||
<input type="text" id="revealed-token" value="${escapedToken}" readonly
|
||||
style="font-family: monospace; font-size: 0.85rem;">
|
||||
<button class="btn btn-secondary" id="copy-token-btn" title="Copy">
|
||||
<i class="fas fa-copy"></i>
|
||||
@@ -248,10 +254,16 @@ class ClientsPage {
|
||||
showEditClientModal(client) {
|
||||
const modal = document.createElement('div');
|
||||
modal.className = 'modal active';
|
||||
|
||||
const escapedId = window.api.escapeHtml(client.id);
|
||||
const escapedName = window.api.escapeHtml(client.name);
|
||||
const escapedDescription = window.api.escapeHtml(client.description);
|
||||
const escapedRateLimit = window.api.escapeHtml(client.rate_limit_per_minute);
|
||||
|
||||
modal.innerHTML = `
|
||||
<div class="modal-content">
|
||||
<div class="modal-header">
|
||||
<h3 class="modal-title">Edit Client: ${client.id}</h3>
|
||||
<h3 class="modal-title">Edit Client: ${escapedId}</h3>
|
||||
<button class="modal-close" onclick="this.closest('.modal').remove()">
|
||||
<i class="fas fa-times"></i>
|
||||
</button>
|
||||
@@ -259,15 +271,15 @@ class ClientsPage {
|
||||
<div class="modal-body">
|
||||
<div class="form-control">
|
||||
<label for="edit-client-name">Display Name</label>
|
||||
<input type="text" id="edit-client-name" value="${client.name || ''}" placeholder="e.g. My Coding Assistant">
|
||||
<input type="text" id="edit-client-name" value="${escapedName}" placeholder="e.g. My Coding Assistant">
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label for="edit-client-description">Description</label>
|
||||
<textarea id="edit-client-description" rows="3" placeholder="Optional description">${client.description || ''}</textarea>
|
||||
<textarea id="edit-client-description" rows="3" placeholder="Optional description">${escapedDescription}</textarea>
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label for="edit-client-rate-limit">Rate Limit (requests/minute)</label>
|
||||
<input type="number" id="edit-client-rate-limit" min="0" value="${client.rate_limit_per_minute || ''}" placeholder="Leave empty for unlimited">
|
||||
<input type="number" id="edit-client-rate-limit" min="0" value="${escapedRateLimit}" placeholder="Leave empty for unlimited">
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label class="toggle-label">
|
||||
@@ -357,12 +369,16 @@ class ClientsPage {
|
||||
const lastUsed = t.last_used_at
|
||||
? luxon.DateTime.fromISO(t.last_used_at).toRelative()
|
||||
: 'Never';
|
||||
const escapedMaskedToken = window.api.escapeHtml(t.token_masked);
|
||||
const escapedClientId = window.api.escapeHtml(clientId);
|
||||
const tokenId = parseInt(t.id); // Assuming ID is numeric
|
||||
|
||||
return `
|
||||
<div style="display: flex; align-items: center; gap: 0.5rem; padding: 0.4rem 0; border-bottom: 1px solid var(--bg3);">
|
||||
<code style="flex: 1; font-size: 0.8rem; color: var(--fg2);">${t.token_masked}</code>
|
||||
<code style="flex: 1; font-size: 0.8rem; color: var(--fg2);">${escapedMaskedToken}</code>
|
||||
<span style="font-size: 0.75rem; color: var(--fg4);" title="Last used">${lastUsed}</span>
|
||||
<button class="btn-action danger" title="Revoke" style="padding: 0.2rem 0.4rem;"
|
||||
onclick="window.clientsPage.revokeToken('${clientId}', ${t.id}, this)">
|
||||
onclick="window.clientsPage.revokeToken('${escapedClientId}', ${tokenId}, this)">
|
||||
<i class="fas fa-trash" style="font-size: 0.75rem;"></i>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
@@ -38,6 +38,24 @@ class LogsPage {
|
||||
const statusClass = log.status === 'success' ? 'success' : 'danger';
|
||||
const timestamp = luxon.DateTime.fromISO(log.timestamp).toFormat('yyyy-MM-dd HH:mm:ss');
|
||||
|
||||
let tokenDetails = `${log.tokens} total tokens`;
|
||||
if (log.status === 'success') {
|
||||
const parts = [];
|
||||
parts.push(`${log.prompt_tokens} in`);
|
||||
|
||||
let completionStr = `${log.completion_tokens} out`;
|
||||
if (log.reasoning_tokens > 0) {
|
||||
completionStr += ` (${log.reasoning_tokens} reasoning)`;
|
||||
}
|
||||
parts.push(completionStr);
|
||||
|
||||
if (log.cache_read_tokens > 0) {
|
||||
parts.push(`${log.cache_read_tokens} cache-hit`);
|
||||
}
|
||||
|
||||
tokenDetails = parts.join(', ');
|
||||
}
|
||||
|
||||
return `
|
||||
<tr class="log-row">
|
||||
<td class="whitespace-nowrap">${timestamp}</td>
|
||||
@@ -55,7 +73,7 @@ class LogsPage {
|
||||
<td>
|
||||
<div class="log-message-container">
|
||||
<code class="log-model">${log.model}</code>
|
||||
<span class="log-tokens">${log.tokens} tokens</span>
|
||||
<span class="log-tokens" title="${log.tokens} total tokens">${tokenDetails}</span>
|
||||
<span class="log-duration">${log.duration}ms</span>
|
||||
${log.error ? `<div class="log-error-msg">${log.error}</div>` : ''}
|
||||
</div>
|
||||
|
||||
@@ -31,13 +31,58 @@ class ModelsPage {
|
||||
return;
|
||||
}
|
||||
|
||||
const searchInput = document.getElementById('model-search');
|
||||
const providerFilter = document.getElementById('model-provider-filter');
|
||||
const modalityFilter = document.getElementById('model-modality-filter');
|
||||
const capabilityFilter = document.getElementById('model-capability-filter');
|
||||
|
||||
const q = searchInput ? searchInput.value.toLowerCase() : '';
|
||||
const providerVal = providerFilter ? providerFilter.value : '';
|
||||
const modalityVal = modalityFilter ? modalityFilter.value : '';
|
||||
const capabilityVal = capabilityFilter ? capabilityFilter.value : '';
|
||||
|
||||
// Apply filters non-destructively
|
||||
let filteredModels = this.models.filter(m => {
|
||||
// Text search
|
||||
if (q && !(m.id.toLowerCase().includes(q) || m.name.toLowerCase().includes(q) || m.provider.toLowerCase().includes(q))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Provider filter
|
||||
if (providerVal) {
|
||||
if (providerVal === 'other') {
|
||||
const known = ['openai', 'anthropic', 'google', 'deepseek', 'xai', 'meta', 'cohere', 'mistral'];
|
||||
if (known.includes(m.provider.toLowerCase())) return false;
|
||||
} else if (m.provider.toLowerCase() !== providerVal) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Modality filter
|
||||
if (modalityVal) {
|
||||
const mods = m.modalities && m.modalities.input ? m.modalities.input.map(x => x.toLowerCase()) : [];
|
||||
if (!mods.includes(modalityVal)) return false;
|
||||
}
|
||||
|
||||
// Capability filter
|
||||
if (capabilityVal === 'tool_call' && !m.tool_call) return false;
|
||||
if (capabilityVal === 'reasoning' && !m.reasoning) return false;
|
||||
|
||||
return true;
|
||||
});
|
||||
|
||||
if (filteredModels.length === 0) {
|
||||
tableBody.innerHTML = '<tr><td colspan="7" class="text-center">No models match the selected filters</td></tr>';
|
||||
return;
|
||||
}
|
||||
|
||||
// Sort by provider then name
|
||||
this.models.sort((a, b) => {
|
||||
filteredModels.sort((a, b) => {
|
||||
if (a.provider !== b.provider) return a.provider.localeCompare(b.provider);
|
||||
return a.name.localeCompare(b.name);
|
||||
});
|
||||
|
||||
tableBody.innerHTML = this.models.map(model => {
|
||||
tableBody.innerHTML = filteredModels.map(model => {
|
||||
const statusClass = model.enabled ? 'success' : 'secondary';
|
||||
const statusIcon = model.enabled ? 'check-circle' : 'ban';
|
||||
|
||||
@@ -99,6 +144,14 @@ class ModelsPage {
|
||||
<input type="number" id="model-completion-cost" value="${model.completion_cost}" step="0.01">
|
||||
</div>
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label for="model-cache-read-cost">Cache Read Cost (per 1M tokens)</label>
|
||||
<input type="number" id="model-cache-read-cost" value="${model.cache_read_cost || 0}" step="0.01">
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label for="model-cache-write-cost">Cache Write Cost (per 1M tokens)</label>
|
||||
<input type="number" id="model-cache-write-cost" value="${model.cache_write_cost || 0}" step="0.01">
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label for="model-mapping">Internal Mapping (Optional)</label>
|
||||
<input type="text" id="model-mapping" value="${model.mapping || ''}" placeholder="e.g. gpt-4o-2024-05-13">
|
||||
@@ -118,6 +171,8 @@ class ModelsPage {
|
||||
const enabled = modal.querySelector('#model-enabled').checked;
|
||||
const promptCost = parseFloat(modal.querySelector('#model-prompt-cost').value);
|
||||
const completionCost = parseFloat(modal.querySelector('#model-completion-cost').value);
|
||||
const cacheReadCost = parseFloat(modal.querySelector('#model-cache-read-cost').value);
|
||||
const cacheWriteCost = parseFloat(modal.querySelector('#model-cache-write-cost').value);
|
||||
const mapping = modal.querySelector('#model-mapping').value;
|
||||
|
||||
try {
|
||||
@@ -125,6 +180,8 @@ class ModelsPage {
|
||||
enabled,
|
||||
prompt_cost: promptCost,
|
||||
completion_cost: completionCost,
|
||||
cache_read_cost: isNaN(cacheReadCost) ? null : cacheReadCost,
|
||||
cache_write_cost: isNaN(cacheWriteCost) ? null : cacheWriteCost,
|
||||
mapping: mapping || null
|
||||
});
|
||||
|
||||
@@ -138,27 +195,18 @@ class ModelsPage {
|
||||
}
|
||||
|
||||
setupEventListeners() {
|
||||
const searchInput = document.getElementById('model-search');
|
||||
if (searchInput) {
|
||||
searchInput.oninput = (e) => this.filterModels(e.target.value);
|
||||
}
|
||||
const attachFilter = (id) => {
|
||||
const el = document.getElementById(id);
|
||||
if (el) {
|
||||
el.addEventListener('input', () => this.renderModelsTable());
|
||||
el.addEventListener('change', () => this.renderModelsTable());
|
||||
}
|
||||
};
|
||||
|
||||
filterModels(query) {
|
||||
if (!query) {
|
||||
this.renderModelsTable();
|
||||
return;
|
||||
}
|
||||
|
||||
const q = query.toLowerCase();
|
||||
const originalModels = this.models;
|
||||
this.models = this.models.filter(m =>
|
||||
m.id.toLowerCase().includes(q) ||
|
||||
m.name.toLowerCase().includes(q) ||
|
||||
m.provider.toLowerCase().includes(q)
|
||||
);
|
||||
this.renderModelsTable();
|
||||
this.models = originalModels;
|
||||
attachFilter('model-search');
|
||||
attachFilter('model-provider-filter');
|
||||
attachFilter('model-modality-filter');
|
||||
attachFilter('model-capability-filter');
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -492,7 +492,7 @@ class MonitoringPage {
|
||||
simulateRequest() {
|
||||
const clients = ['client-1', 'client-2', 'client-3', 'client-4', 'client-5'];
|
||||
const providers = ['OpenAI', 'Gemini', 'DeepSeek', 'Grok'];
|
||||
const models = ['gpt-4', 'gpt-3.5-turbo', 'gemini-pro', 'deepseek-chat', 'grok-beta'];
|
||||
const models = ['gpt-4o', 'gpt-4o-mini', 'gemini-2.0-flash', 'deepseek-chat', 'grok-4-1-fast-non-reasoning'];
|
||||
const statuses = ['success', 'success', 'success', 'error', 'warning']; // Mostly success
|
||||
|
||||
const request = {
|
||||
|
||||
@@ -47,16 +47,21 @@ class ProvidersPage {
|
||||
const isLowBalance = provider.credit_balance <= provider.low_credit_threshold && provider.id !== 'ollama';
|
||||
const balanceColor = isLowBalance ? 'var(--red-light)' : 'var(--green-light)';
|
||||
|
||||
const escapedId = window.api.escapeHtml(provider.id);
|
||||
const escapedName = window.api.escapeHtml(provider.name);
|
||||
const escapedStatus = window.api.escapeHtml(provider.status);
|
||||
const billingMode = provider.billing_mode ? provider.billing_mode.toUpperCase() : 'PREPAID';
|
||||
|
||||
return `
|
||||
<div class="provider-card ${provider.status}">
|
||||
<div class="provider-card ${escapedStatus}">
|
||||
<div class="provider-card-header">
|
||||
<div class="provider-info">
|
||||
<h4 class="provider-name">${provider.name}</h4>
|
||||
<span class="provider-id">${provider.id}</span>
|
||||
<h4 class="provider-name">${escapedName}</h4>
|
||||
<span class="provider-id">${escapedId}</span>
|
||||
</div>
|
||||
<span class="status-badge ${statusClass}">
|
||||
<i class="fas fa-circle"></i>
|
||||
${provider.status}
|
||||
${escapedStatus}
|
||||
</span>
|
||||
</div>
|
||||
<div class="provider-card-body">
|
||||
@@ -67,12 +72,12 @@ class ProvidersPage {
|
||||
</div>
|
||||
<div class="meta-item" style="color: ${balanceColor}; font-weight: 700;">
|
||||
<i class="fas fa-wallet"></i>
|
||||
<span>Balance: ${provider.id === 'ollama' ? 'FREE' : window.api.formatCurrency(provider.credit_balance)}</span>
|
||||
<span>Balance: ${escapedId === 'ollama' ? 'FREE' : window.api.formatCurrency(provider.credit_balance)}</span>
|
||||
${isLowBalance ? '<i class="fas fa-exclamation-triangle" title="Low Balance"></i>' : ''}
|
||||
</div>
|
||||
<div class="meta-item">
|
||||
<i class="fas fa-exchange-alt"></i>
|
||||
<span>Billing: ${provider.billing_mode ? provider.billing_mode.toUpperCase() : 'PREPAID'}</span>
|
||||
<span>Billing: ${window.api.escapeHtml(billingMode)}</span>
|
||||
</div>
|
||||
<div class="meta-item">
|
||||
<i class="fas fa-clock"></i>
|
||||
@@ -80,16 +85,16 @@ class ProvidersPage {
|
||||
</div>
|
||||
</div>
|
||||
<div class="model-tags">
|
||||
${(provider.models || []).slice(0, 5).map(m => `<span class="model-tag">${m}</span>`).join('')}
|
||||
${(provider.models || []).slice(0, 5).map(m => `<span class="model-tag">${window.api.escapeHtml(m)}</span>`).join('')}
|
||||
${modelCount > 5 ? `<span class="model-tag more">+${modelCount - 5} more</span>` : ''}
|
||||
</div>
|
||||
</div>
|
||||
<div class="provider-card-footer">
|
||||
<button class="btn btn-secondary btn-sm" onclick="window.providersPage.testProvider('${provider.id}')">
|
||||
<button class="btn btn-secondary btn-sm" onclick="window.providersPage.testProvider('${escapedId}')">
|
||||
<i class="fas fa-vial"></i> Test
|
||||
</button>
|
||||
${window._userRole === 'admin' ? `
|
||||
<button class="btn btn-primary btn-sm" onclick="window.providersPage.configureProvider('${provider.id}')">
|
||||
<button class="btn btn-primary btn-sm" onclick="window.providersPage.configureProvider('${escapedId}')">
|
||||
<i class="fas fa-cog"></i> Config
|
||||
</button>
|
||||
` : ''}
|
||||
@@ -144,10 +149,17 @@ class ProvidersPage {
|
||||
|
||||
const modal = document.createElement('div');
|
||||
modal.className = 'modal active';
|
||||
|
||||
const escapedId = window.api.escapeHtml(provider.id);
|
||||
const escapedName = window.api.escapeHtml(provider.name);
|
||||
const escapedBaseUrl = window.api.escapeHtml(provider.base_url);
|
||||
const escapedBalance = window.api.escapeHtml(provider.credit_balance);
|
||||
const escapedThreshold = window.api.escapeHtml(provider.low_credit_threshold);
|
||||
|
||||
modal.innerHTML = `
|
||||
<div class="modal-content">
|
||||
<div class="modal-header">
|
||||
<h3 class="modal-title">Configure ${provider.name}</h3>
|
||||
<h3 class="modal-title">Configure ${escapedName}</h3>
|
||||
<button class="modal-close" onclick="this.closest('.modal').remove()">
|
||||
<i class="fas fa-times"></i>
|
||||
</button>
|
||||
@@ -161,7 +173,7 @@ class ProvidersPage {
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label for="provider-base-url">Base URL</label>
|
||||
<input type="text" id="provider-base-url" value="${provider.base_url || ''}" placeholder="Default API URL">
|
||||
<input type="text" id="provider-base-url" value="${escapedBaseUrl}" placeholder="Default API URL">
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label for="provider-api-key">API Key (Optional / Overwrite)</label>
|
||||
@@ -170,11 +182,11 @@ class ProvidersPage {
|
||||
<div class="grid-2">
|
||||
<div class="form-control">
|
||||
<label for="provider-balance">Current Credit Balance ($)</label>
|
||||
<input type="number" id="provider-balance" value="${provider.credit_balance}" step="0.01">
|
||||
<input type="number" id="provider-balance" value="${escapedBalance}" step="0.01">
|
||||
</div>
|
||||
<div class="form-control">
|
||||
<label for="provider-threshold">Low Balance Alert ($)</label>
|
||||
<input type="number" id="provider-threshold" value="${provider.low_credit_threshold}" step="0.50">
|
||||
<input type="number" id="provider-threshold" value="${escapedThreshold}" step="0.50">
|
||||
</div>
|
||||
</div>
|
||||
<div class="form-control">
|
||||
|
||||
@@ -280,8 +280,6 @@
|
||||
// ── Helpers ────────────────────────────────────────────────────
|
||||
|
||||
function escapeHtml(str) {
|
||||
const div = document.createElement('div');
|
||||
div.textContent = str;
|
||||
return div.innerHTML;
|
||||
return window.api.escapeHtml(str);
|
||||
}
|
||||
})();
|
||||
|
||||
@@ -248,21 +248,19 @@ class WebSocketManager {
|
||||
}
|
||||
|
||||
updateStatus(status) {
|
||||
const statusElement = document.getElementById('ws-status-nav');
|
||||
const statusElement = document.getElementById('connection-status');
|
||||
if (!statusElement) return;
|
||||
|
||||
const dot = statusElement.querySelector('.ws-dot');
|
||||
const text = statusElement.querySelector('.ws-text');
|
||||
const dot = statusElement.querySelector('.status-dot');
|
||||
const text = statusElement.querySelector('.status-text');
|
||||
|
||||
if (!dot || !text) return;
|
||||
|
||||
// Remove all status classes
|
||||
dot.classList.remove('connected', 'disconnected');
|
||||
statusElement.classList.remove('connected', 'disconnected');
|
||||
dot.classList.remove('connected', 'disconnected', 'error', 'connecting');
|
||||
|
||||
// Add new status class
|
||||
dot.classList.add(status);
|
||||
statusElement.classList.add(status);
|
||||
|
||||
// Update text
|
||||
const statusText = {
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Test script for LLM Proxy Dashboard
|
||||
|
||||
echo "Building LLM Proxy Gateway..."
|
||||
cargo build --release
|
||||
|
||||
echo ""
|
||||
echo "Starting server in background..."
|
||||
./target/release/llm-proxy &
|
||||
SERVER_PID=$!
|
||||
|
||||
# Wait for server to start
|
||||
sleep 3
|
||||
|
||||
echo ""
|
||||
echo "Testing dashboard endpoints..."
|
||||
|
||||
# Test health endpoint
|
||||
echo "1. Testing health endpoint:"
|
||||
curl -s http://localhost:8080/health
|
||||
|
||||
echo ""
|
||||
echo "2. Testing dashboard static files:"
|
||||
curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/
|
||||
|
||||
echo ""
|
||||
echo "3. Testing API endpoints:"
|
||||
curl -s http://localhost:8080/api/auth/status | jq . 2>/dev/null || echo "JSON response received"
|
||||
|
||||
echo ""
|
||||
echo "Dashboard should be available at: http://localhost:8080"
|
||||
echo "Default login: admin / admin"
|
||||
echo ""
|
||||
echo "Press Ctrl+C to stop the server"
|
||||
|
||||
# Keep script running
|
||||
wait $SERVER_PID
|
||||
@@ -1,75 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Test script for LLM Proxy Gateway
|
||||
|
||||
echo "Building LLM Proxy Gateway..."
|
||||
cargo build --release
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Build failed!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Build successful!"
|
||||
|
||||
echo ""
|
||||
echo "Project Structure Summary:"
|
||||
echo "=========================="
|
||||
echo "Core Components:"
|
||||
echo " - main.rs: Application entry point with server setup"
|
||||
echo " - config/: Configuration management"
|
||||
echo " - server/: API route handlers"
|
||||
echo " - auth/: Bearer token authentication"
|
||||
echo " - database/: SQLite database setup"
|
||||
echo " - models/: Data structures (OpenAI-compatible)"
|
||||
echo " - providers/: LLM provider implementations (OpenAI, Gemini, DeepSeek, Grok)"
|
||||
echo " - errors/: Custom error types"
|
||||
echo " - dashboard/: Admin dashboard with WebSocket support"
|
||||
echo " - logging/: Request logging middleware"
|
||||
echo " - state/: Shared application state"
|
||||
echo " - multimodal/: Image processing support (basic structure)"
|
||||
|
||||
echo ""
|
||||
echo "Key Features Implemented:"
|
||||
echo "=========================="
|
||||
echo "✓ OpenAI-compatible API endpoint (/v1/chat/completions)"
|
||||
echo "✓ Bearer token authentication"
|
||||
echo "✓ SQLite database for request tracking"
|
||||
echo "✓ Request logging with token/cost calculation"
|
||||
echo "✓ Provider abstraction layer"
|
||||
echo "✓ Admin dashboard with real-time monitoring"
|
||||
echo "✓ WebSocket support for live updates"
|
||||
echo "✓ Configuration management (config.toml, .env, env vars)"
|
||||
echo "✓ Multimodal support structure (images)"
|
||||
echo "✓ Error handling with proper HTTP status codes"
|
||||
|
||||
echo ""
|
||||
echo "Next Steps Needed:"
|
||||
echo "=================="
|
||||
echo "1. Add API keys to .env file:"
|
||||
echo " OPENAI_API_KEY=your_key_here"
|
||||
echo " GEMINI_API_KEY=your_key_here"
|
||||
echo " DEEPSEEK_API_KEY=your_key_here"
|
||||
echo " GROK_API_KEY=your_key_here (optional)"
|
||||
echo ""
|
||||
echo "2. Create config.toml for custom configuration (optional)"
|
||||
echo ""
|
||||
echo "3. Run the server:"
|
||||
echo " cargo run"
|
||||
echo ""
|
||||
echo "4. Access dashboard at: http://localhost:8080"
|
||||
echo ""
|
||||
echo "5. Test API with curl:"
|
||||
echo " curl -X POST http://localhost:8080/v1/chat/completions \\"
|
||||
echo " -H 'Authorization: Bearer your_token' \\"
|
||||
echo " -H 'Content-Type: application/json' \\"
|
||||
echo " -d '{\"model\": \"gpt-4\", \"messages\": [{\"role\": \"user\", \"content\": \"Hello\"}]}'"
|
||||
|
||||
echo ""
|
||||
echo "Deployment Notes:"
|
||||
echo "================="
|
||||
echo "Memory: Designed for 512MB RAM (LXC container)"
|
||||
echo "Database: SQLite (./data/llm_proxy.db)"
|
||||
echo "Port: 8080 (configurable)"
|
||||
echo "Authentication: Single Bearer token (configurable)"
|
||||
echo "Providers: OpenAI, Gemini, DeepSeek, Grok (disabled by default)"
|
||||
Reference in New Issue
Block a user