Compare commits
85 Commits
4be23629d8
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 9375448087 | |||
| 5be2f6f7aa | |||
| eebcadcba1 | |||
| 6b2bd13903 | |||
| 5dfda0a10c | |||
| a8a02d9e1c | |||
| bd1d17cc4d | |||
| 9207a7231c | |||
| c6efff9034 | |||
| 27fbd8ed15 | |||
| 348341f304 | |||
| 9380580504 | |||
| 08cf5cc1d9 | |||
| 0f0486d8d4 | |||
| 0ea2a3a985 | |||
| 21e5908c35 | |||
| 6f0a159245 | |||
| 4120a83b67 | |||
| 742cd9e921 | |||
| 593971ecb5 | |||
| 03dca998df | |||
| 0ce5f4f490 | |||
| dec4b927dc | |||
| 3f1e6d3407 | |||
| f02fd6c249 | |||
| f23796f0cc | |||
| 3f76a544e0 | |||
| e474549940 | |||
| b7e37b0399 | |||
| 263c0f0dc9 | |||
| 26d8431998 | |||
| 1f3adceda4 | |||
| 9c64a8fe42 | |||
| b04b794705 | |||
| 0f3c5b6eb4 | |||
| 66a1643bca | |||
| edc6445d70 | |||
| 2d8f1a1fd0 | |||
| cd1a1b45aa | |||
| 246a6d88f0 | |||
| 7d43b2c31b | |||
| 45c2d5e643 | |||
| 1d032c6732 | |||
| 2245cca67a | |||
| c7c244992a | |||
| 4f5b55d40f | |||
| 90874a6721 | |||
| 6b10d4249c | |||
| 57aa0aa70e | |||
| 4de457cc5e | |||
| 66e8b114b9 | |||
| 1cac45502a | |||
| 79dc8fe409 | |||
| 24a898c9a7 | |||
| 7c2a317c01 | |||
| cb619f9286 | |||
| 441270317c | |||
| 2e4318d84b | |||
| d0be16d8e3 | |||
| 83e0ad0240 | |||
| 275ce34d05 | |||
| cb5b921550 | |||
| 649371154f | |||
| 78fff61660 | |||
| b131094dfd | |||
| c3d81c1733 | |||
| e123f542f1 | |||
| 0d28241e39 | |||
| 754ee9cb84 | |||
| 5a9086b883 | |||
| cc5eba1957 | |||
| 3ab00fb188 | |||
| c2595f7a74 | |||
| 0526304398 | |||
| 75e2967727 | |||
| e1bc3b35eb | |||
| 0d32d953d2 | |||
| bd5ca2dd98 | |||
| 6a0aca1a6c | |||
| 4c629e17cb | |||
| fc3bc6968d | |||
| d6280abad9 | |||
| 96486b6318 | |||
| e8955fd36c | |||
| a243a3987d |
26
.env
26
.env
@@ -1,26 +0,0 @@
|
|||||||
# LLM Proxy Gateway Environment Variables
|
|
||||||
|
|
||||||
# OpenAI
|
|
||||||
OPENAI_API_KEY=sk-demo-openai-key
|
|
||||||
|
|
||||||
# Google Gemini
|
|
||||||
GEMINI_API_KEY=AIza-demo-gemini-key
|
|
||||||
|
|
||||||
# DeepSeek
|
|
||||||
DEEPSEEK_API_KEY=sk-demo-deepseek-key
|
|
||||||
|
|
||||||
# xAI Grok (not yet available)
|
|
||||||
GROK_API_KEY=gk-demo-grok-key
|
|
||||||
|
|
||||||
# Authentication tokens (comma-separated list)
|
|
||||||
LLM_PROXY__SERVER__AUTH_TOKENS=demo-token-123456,another-token
|
|
||||||
|
|
||||||
# Server port (optional)
|
|
||||||
LLM_PROXY__SERVER__PORT=8080
|
|
||||||
|
|
||||||
# Database path (optional)
|
|
||||||
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
|
|
||||||
|
|
||||||
SESSION_SECRET=ki9khXAk9usDkasMrD2UbK4LOgrDRJz0
|
|
||||||
|
|
||||||
LLM_PROXY__ENCRYPTION_KEY=eac0239bfc402c7eb888366dd76c314288a8693efd5b7457819aeaf1fe429ac2
|
|
||||||
64
.env.example
64
.env.example
@@ -1,31 +1,47 @@
|
|||||||
# LLM Proxy Gateway Environment Variables
|
# GopherGate Configuration Example
|
||||||
# Copy to .env and fill in your API keys
|
# Copy this file to .env and fill in your values
|
||||||
|
|
||||||
# OpenAI
|
# ==============================================================================
|
||||||
OPENAI_API_KEY=your_openai_api_key_here
|
# MANDATORY: Encryption & Security
|
||||||
|
# ==============================================================================
|
||||||
|
# A 32-byte hex or base64 encoded string used for session signing and
|
||||||
|
# database encryption.
|
||||||
|
# Generate one with: openssl rand -hex 32
|
||||||
|
LLM_PROXY__ENCRYPTION_KEY=your_secure_32_byte_key_here
|
||||||
|
|
||||||
# Google Gemini
|
# ==============================================================================
|
||||||
GEMINI_API_KEY=your_gemini_api_key_here
|
# LLM Provider API Keys
|
||||||
|
# ==============================================================================
|
||||||
|
OPENAI_API_KEY=sk-...
|
||||||
|
GEMINI_API_KEY=AIza...
|
||||||
|
DEEPSEEK_API_KEY=sk-...
|
||||||
|
MOONSHOT_API_KEY=sk-...
|
||||||
|
GROK_API_KEY=xai-...
|
||||||
|
|
||||||
# DeepSeek
|
# ==============================================================================
|
||||||
DEEPSEEK_API_KEY=your_deepseek_api_key_here
|
# Server Configuration
|
||||||
|
# ==============================================================================
|
||||||
|
LLM_PROXY__SERVER__PORT=8080
|
||||||
|
LLM_PROXY__SERVER__HOST=0.0.0.0
|
||||||
|
|
||||||
# xAI Grok (not yet available)
|
# Optional: Bearer tokens for client authentication (comma-separated)
|
||||||
GROK_API_KEY=your_grok_api_key_here
|
# If not set, the proxy will look up tokens in the database.
|
||||||
|
# LLM_PROXY__SERVER__AUTH_TOKENS=token1,token2
|
||||||
|
|
||||||
# Ollama (local server)
|
# ==============================================================================
|
||||||
# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://your-ollama-host:11434/v1
|
# Database Configuration
|
||||||
|
# ==============================================================================
|
||||||
|
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
|
||||||
|
LLM_PROXY__DATABASE__MAX_CONNECTIONS=10
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Provider Overrides (Optional)
|
||||||
|
# ==============================================================================
|
||||||
|
# LLM_PROXY__PROVIDERS__OPENAI__BASE_URL=https://api.openai.com/v1
|
||||||
|
# LLM_PROXY__PROVIDERS__GEMINI__ENABLED=true
|
||||||
|
# LLM_PROXY__PROVIDERS__MOONSHOT__BASE_URL=https://api.moonshot.ai/v1
|
||||||
|
# LLM_PROXY__PROVIDERS__MOONSHOT__ENABLED=true
|
||||||
|
# LLM_PROXY__PROVIDERS__MOONSHOT__DEFAULT_MODEL=kimi-k2.5
|
||||||
|
# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://localhost:11434/v1
|
||||||
# LLM_PROXY__PROVIDERS__OLLAMA__ENABLED=true
|
# LLM_PROXY__PROVIDERS__OLLAMA__ENABLED=true
|
||||||
# LLM_PROXY__PROVIDERS__OLLAMA__MODELS=llama3,mistral,llava
|
# LLM_PROXY__PROVIDERS__OLLAMA__MODELS=llama3,mistral,llava
|
||||||
|
|
||||||
# Authentication tokens (comma-separated list)
|
|
||||||
LLM_PROXY__SERVER__AUTH_TOKENS=your_bearer_token_here,another_token
|
|
||||||
|
|
||||||
# Server port (optional)
|
|
||||||
LLM_PROXY__SERVER__PORT=8080
|
|
||||||
|
|
||||||
# Database path (optional)
|
|
||||||
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
|
|
||||||
|
|
||||||
# Session secret for HMAC-signed tokens (hex or base64 encoded, 32 bytes)
|
|
||||||
SESSION_SECRET=your_session_secret_here_32_bytes
|
|
||||||
62
.github/workflows/ci.yml
vendored
62
.github/workflows/ci.yml
vendored
@@ -6,56 +6,44 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
branches: [main]
|
branches: [main]
|
||||||
|
|
||||||
env:
|
|
||||||
CARGO_TERM_COLOR: always
|
|
||||||
RUST_BACKTRACE: 1
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
check:
|
lint:
|
||||||
name: Check
|
name: Lint
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: dtolnay/rust-toolchain@stable
|
- name: Set up Go
|
||||||
- uses: Swatinem/rust-cache@v2
|
uses: actions/setup-go@v5
|
||||||
- run: cargo check --all-targets
|
|
||||||
|
|
||||||
clippy:
|
|
||||||
name: Clippy
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- uses: dtolnay/rust-toolchain@stable
|
|
||||||
with:
|
with:
|
||||||
components: clippy
|
go-version: '1.22'
|
||||||
- uses: Swatinem/rust-cache@v2
|
cache: true
|
||||||
- run: cargo clippy --all-targets -- -D warnings
|
- name: golangci-lint
|
||||||
|
uses: golangci/golangci-lint-action@v4
|
||||||
fmt:
|
|
||||||
name: Formatting
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- uses: dtolnay/rust-toolchain@stable
|
|
||||||
with:
|
with:
|
||||||
components: rustfmt
|
version: latest
|
||||||
- run: cargo fmt --all -- --check
|
|
||||||
|
|
||||||
test:
|
test:
|
||||||
name: Test
|
name: Test
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: dtolnay/rust-toolchain@stable
|
- name: Set up Go
|
||||||
- uses: Swatinem/rust-cache@v2
|
uses: actions/setup-go@v5
|
||||||
- run: cargo test --all-targets
|
with:
|
||||||
|
go-version: '1.22'
|
||||||
|
cache: true
|
||||||
|
- name: Run Tests
|
||||||
|
run: go test -v ./...
|
||||||
|
|
||||||
build-release:
|
build:
|
||||||
name: Release Build
|
name: Build
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: [check, clippy, test]
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: dtolnay/rust-toolchain@stable
|
- name: Set up Go
|
||||||
- uses: Swatinem/rust-cache@v2
|
uses: actions/setup-go@v5
|
||||||
- run: cargo build --release
|
with:
|
||||||
|
go-version: '1.22'
|
||||||
|
cache: true
|
||||||
|
- name: Build
|
||||||
|
run: go build -v -o gophergate ./cmd/gophergate
|
||||||
|
|||||||
16
.gitignore
vendored
16
.gitignore
vendored
@@ -1,5 +1,13 @@
|
|||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
!.env.example
|
||||||
/target
|
/target
|
||||||
/.env
|
/llm-proxy
|
||||||
/*.db
|
/llm-proxy-go
|
||||||
/*.db-shm
|
/gophergate
|
||||||
/*.db-wal
|
/data/
|
||||||
|
*.db
|
||||||
|
*.db-shm
|
||||||
|
*.db-wal
|
||||||
|
*.log
|
||||||
|
server.pid
|
||||||
|
|||||||
62
BACKEND_ARCHITECTURE.md
Normal file
62
BACKEND_ARCHITECTURE.md
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# Backend Architecture (Go)
|
||||||
|
|
||||||
|
The GopherGate backend is implemented in Go, focusing on high performance, clear concurrency patterns, and maintainability.
|
||||||
|
|
||||||
|
## Core Technologies
|
||||||
|
|
||||||
|
- **Runtime:** Go 1.22+
|
||||||
|
- **Web Framework:** [Gin Gonic](https://github.com/gin-gonic/gin) - Fast and lightweight HTTP routing.
|
||||||
|
- **Database:** [sqlx](https://github.com/jmoiron/sqlx) - Lightweight wrapper for standard `database/sql`.
|
||||||
|
- **SQLite Driver:** [modernc.org/sqlite](https://modernc.org/sqlite) - CGO-free SQLite implementation for ease of cross-compilation.
|
||||||
|
- **Config:** [Viper](https://github.com/spf13/viper) - Robust configuration management supporting environment variables and files.
|
||||||
|
- **Metrics:** [gopsutil](https://github.com/shirou/gopsutil) - System-level resource monitoring.
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```text
|
||||||
|
├── cmd/
|
||||||
|
│ └── gophergate/ # Entry point (main.go)
|
||||||
|
├── internal/
|
||||||
|
│ ├── config/ # Configuration loading and validation
|
||||||
|
│ ├── db/ # Database schema, migrations, and models
|
||||||
|
│ ├── middleware/ # Auth and logging middleware
|
||||||
|
│ ├── models/ # Unified request/response structs
|
||||||
|
│ ├── providers/ # LLM provider implementations (OpenAI, Gemini, etc.)
|
||||||
|
│ ├── server/ # HTTP server, dashboard handlers, and WebSocket hub
|
||||||
|
│ └── utils/ # Common utilities (registry, pricing, etc.)
|
||||||
|
└── static/ # Frontend assets (served by the backend)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Key Components
|
||||||
|
|
||||||
|
### 1. Provider Interface (`internal/providers/provider.go`)
|
||||||
|
Standardized interface for all LLM backends. Implementations handle mapping between the unified format and provider-specific APIs (OpenAI, Gemini, DeepSeek, Grok).
|
||||||
|
|
||||||
|
### 2. Model Registry & Pricing (`internal/utils/registry.go`)
|
||||||
|
Integrates with `models.dev/api.json` to provide real-time model metadata and pricing.
|
||||||
|
- **Fuzzy Matching:** Supports matching versioned model IDs (e.g., `gpt-4o-2024-08-06`) to base registry entries.
|
||||||
|
- **Automatic Refreshes:** The registry is fetched at startup and refreshed every 24 hours via a background goroutine.
|
||||||
|
|
||||||
|
### 3. Asynchronous Logging (`internal/server/logging.go`)
|
||||||
|
Uses a buffered channel and background worker to log every request to SQLite without blocking the client response. It also broadcasts logs to the WebSocket hub for real-time dashboard updates.
|
||||||
|
|
||||||
|
### 4. Session Management (`internal/server/sessions.go`)
|
||||||
|
Implements HMAC-SHA256 signed tokens for dashboard authentication. Tokens secure the management interface while standard Bearer tokens are used for LLM API access.
|
||||||
|
|
||||||
|
### 5. WebSocket Hub (`internal/server/websocket.go`)
|
||||||
|
A centralized hub for managing WebSocket connections, allowing real-time broadcast of system events, system metrics, and request logs to the dashboard.
|
||||||
|
|
||||||
|
## Concurrency Model
|
||||||
|
|
||||||
|
Go's goroutines and channels are used extensively:
|
||||||
|
- **Streaming:** Each streaming request uses a goroutine to read and parse the provider's response, feeding chunks into a channel for SSE delivery.
|
||||||
|
- **Logging:** A single background worker processes the `logChan` to perform serial database writes.
|
||||||
|
- **WebSocket:** The `Hub` runs in a dedicated goroutine, handling registration and broadcasting.
|
||||||
|
- **Maintenance:** Background tasks handle registry refreshes and status monitoring.
|
||||||
|
|
||||||
|
## Security
|
||||||
|
|
||||||
|
- **Encryption Key:** A mandatory 32-byte key is used for both session signing and encryption of sensitive data.
|
||||||
|
- **Auth Middleware:** Scoped to `/v1` routes to verify client API keys against the database.
|
||||||
|
- **Bcrypt:** Passwords for dashboard users are hashed using Bcrypt with a work factor of 12.
|
||||||
|
- **Database Hardening:** Automatic migrations ensure the schema is always current with the code.
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
# LLM Proxy Code Review Plan
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
The **LLM Proxy** project is a Rust-based middleware designed to provide a unified interface for multiple Large Language Models (LLMs). Based on the repository structure, the project aims to implement a high-performance proxy server (`src/`) that handles request routing, usage tracking, and billing logic. A static dashboard (`static/`) provides a management interface for monitoring consumption and managing API keys. The architecture leverages Rust's async capabilities for efficient request handling and SQLite for persistent state management.
|
|
||||||
|
|
||||||
## Review Phases
|
|
||||||
|
|
||||||
### Phase 1: Backend Architecture & Rust Logic (@code-reviewer)
|
|
||||||
- **Focus on:**
|
|
||||||
- **Core Proxy Logic:** Efficiency of the request/response pipeline and streaming support.
|
|
||||||
- **State Management:** Thread-safety and shared state patterns using `Arc` and `Mutex`/`RwLock`.
|
|
||||||
- **Error Handling:** Use of idiomatic Rust error types and propagation.
|
|
||||||
- **Async Performance:** Proper use of `tokio` or similar runtimes to avoid blocking the executor.
|
|
||||||
- **Rust Idioms:** Adherence to Clippy suggestions and standard Rust naming conventions.
|
|
||||||
|
|
||||||
### Phase 2: Security & Authentication Audit (@security-auditor)
|
|
||||||
- **Focus on:**
|
|
||||||
- **API Key Management:** Secure storage, masking in logs, and rotation mechanisms.
|
|
||||||
- **JWT Handling:** Validation logic, signature verification, and expiration checks.
|
|
||||||
- **Input Validation:** Sanitization of prompts and configuration parameters to prevent injection.
|
|
||||||
- **Dependency Audit:** Scanning for known vulnerabilities in the `Cargo.lock` using `cargo-audit`.
|
|
||||||
|
|
||||||
### Phase 3: Database & Data Integrity Review (@database-optimizer)
|
|
||||||
- **Focus on:**
|
|
||||||
- **Schema Design:** Efficiency of the SQLite schema for usage tracking and billing.
|
|
||||||
- **Migration Strategy:** Robustness of the migration scripts to prevent data loss.
|
|
||||||
- **Usage Tracking:** Accuracy of token counting and concurrency handling during increments.
|
|
||||||
- **Query Optimization:** Identifying potential bottlenecks in reporting queries.
|
|
||||||
|
|
||||||
### Phase 4: Frontend & Dashboard Review (@frontend-developer)
|
|
||||||
- **Focus on:**
|
|
||||||
- **Vanilla JS Patterns:** Review of Web Components and modular JS in `static/js`.
|
|
||||||
- **Security:** Protection against XSS in the dashboard and secure handling of local storage.
|
|
||||||
- **UI/UX Consistency:** Ensuring the management interface is intuitive and responsive.
|
|
||||||
- **API Integration:** Robustness of the frontend's communication with the Rust backend.
|
|
||||||
|
|
||||||
### Phase 5: Infrastructure & Deployment Review (@devops-engineer)
|
|
||||||
- **Focus on:**
|
|
||||||
- **Dockerfile Optimization:** Multi-stage builds to minimize image size and attack surface.
|
|
||||||
- **Resource Limits:** Configuration of CPU/Memory limits for the proxy container.
|
|
||||||
- **Deployment Docs:** Clarity of the setup process and environment variable documentation.
|
|
||||||
|
|
||||||
## Timeline (Gantt)
|
|
||||||
|
|
||||||
```mermaid
|
|
||||||
gantt
|
|
||||||
title LLM Proxy Code Review Timeline (March 2026)
|
|
||||||
dateFormat YYYY-MM-DD
|
|
||||||
section Backend & Security
|
|
||||||
Architecture & Rust Logic (Phase 1) :active, p1, 2026-03-06, 1d
|
|
||||||
Security & Auth Audit (Phase 2) :p2, 2026-03-07, 1d
|
|
||||||
section Data & Frontend
|
|
||||||
Database & Integrity (Phase 3) :p3, 2026-03-07, 1d
|
|
||||||
Frontend & Dashboard (Phase 4) :p4, 2026-03-08, 1d
|
|
||||||
section DevOps
|
|
||||||
Infra & Deployment (Phase 5) :p5, 2026-03-08, 1d
|
|
||||||
Final Review & Sign-off :2026-03-08, 4h
|
|
||||||
```
|
|
||||||
|
|
||||||
## Success Criteria
|
|
||||||
- **Security:** Zero high-priority vulnerabilities identified; all API keys masked in logs.
|
|
||||||
- **Performance:** Proxy overhead is minimal (<10ms latency addition); queries are indexed.
|
|
||||||
- **Maintainability:** Code passes all linting (`cargo clippy`) and formatting (`cargo fmt`) checks.
|
|
||||||
- **Documentation:** README and deployment guides are up-to-date and accurate.
|
|
||||||
- **Reliability:** Usage tracking matches actual API consumption with 99.9% accuracy.
|
|
||||||
4139
Cargo.lock
generated
4139
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
75
Cargo.toml
75
Cargo.toml
@@ -1,75 +0,0 @@
|
|||||||
[package]
|
|
||||||
name = "llm-proxy"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2024"
|
|
||||||
rust-version = "1.87"
|
|
||||||
description = "Unified LLM proxy gateway supporting OpenAI, Gemini, DeepSeek, and Grok with token tracking and cost calculation"
|
|
||||||
authors = ["newkirk"]
|
|
||||||
license = "MIT OR Apache-2.0"
|
|
||||||
repository = ""
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
# ========== Web Framework & Async Runtime ==========
|
|
||||||
axum = { version = "0.8", features = ["macros", "ws"] }
|
|
||||||
tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "net", "time", "signal", "fs"] }
|
|
||||||
tower = "0.5"
|
|
||||||
tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip", "fs", "set-header", "limit"] }
|
|
||||||
governor = "0.7"
|
|
||||||
|
|
||||||
# ========== HTTP Clients ==========
|
|
||||||
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
|
|
||||||
tiktoken-rs = "0.9"
|
|
||||||
|
|
||||||
# ========== Database & ORM ==========
|
|
||||||
sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite", "macros", "migrate", "chrono"] }
|
|
||||||
|
|
||||||
# ========== Authentication & Middleware ==========
|
|
||||||
axum-extra = { version = "0.12", features = ["typed-header"] }
|
|
||||||
headers = "0.4"
|
|
||||||
|
|
||||||
# ========== Configuration Management ==========
|
|
||||||
config = "0.13"
|
|
||||||
dotenvy = "0.15"
|
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
|
||||||
serde_json = "1.0"
|
|
||||||
toml = "0.8"
|
|
||||||
|
|
||||||
# ========== Logging & Monitoring ==========
|
|
||||||
tracing = "0.1"
|
|
||||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
|
||||||
|
|
||||||
# ========== Multimodal & Image Processing ==========
|
|
||||||
base64 = "0.21"
|
|
||||||
image = { version = "0.25", default-features = false, features = ["jpeg", "png", "webp"] }
|
|
||||||
mime = "0.3"
|
|
||||||
|
|
||||||
# ========== Error Handling & Utilities ==========
|
|
||||||
anyhow = "1.0"
|
|
||||||
thiserror = "1.0"
|
|
||||||
bcrypt = "0.15"
|
|
||||||
aes-gcm = "0.10"
|
|
||||||
hmac = "0.12"
|
|
||||||
sha2 = "0.10"
|
|
||||||
chrono = { version = "0.4", features = ["serde"] }
|
|
||||||
uuid = { version = "1.0", features = ["v4", "serde"] }
|
|
||||||
futures = "0.3"
|
|
||||||
async-trait = "0.1"
|
|
||||||
async-stream = "0.3"
|
|
||||||
reqwest-eventsource = "0.6"
|
|
||||||
rand = "0.9"
|
|
||||||
hex = "0.4"
|
|
||||||
|
|
||||||
[dev-dependencies]
|
|
||||||
tokio-test = "0.4"
|
|
||||||
mockito = "1.0"
|
|
||||||
tempfile = "3.10"
|
|
||||||
assert_cmd = "2.0"
|
|
||||||
insta = "1.39"
|
|
||||||
anyhow = "1.0"
|
|
||||||
|
|
||||||
[profile.release]
|
|
||||||
opt-level = 3
|
|
||||||
lto = true
|
|
||||||
codegen-units = 1
|
|
||||||
strip = true
|
|
||||||
panic = "abort"
|
|
||||||
@@ -1,220 +0,0 @@
|
|||||||
# LLM Proxy Gateway - Admin Dashboard
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
This is a comprehensive admin dashboard for the LLM Proxy Gateway, providing real-time monitoring, analytics, and management capabilities for the proxy service.
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
### 1. Dashboard Overview
|
|
||||||
- Real-time request counters and statistics
|
|
||||||
- System health indicators
|
|
||||||
- Provider status monitoring
|
|
||||||
- Recent requests stream
|
|
||||||
|
|
||||||
### 2. Usage Analytics
|
|
||||||
- Time series charts for requests, tokens, and costs
|
|
||||||
- Filter by date range, client, provider, and model
|
|
||||||
- Top clients and models analysis
|
|
||||||
- Export functionality to CSV/JSON
|
|
||||||
|
|
||||||
### 3. Cost Management
|
|
||||||
- Cost breakdown by provider, client, and model
|
|
||||||
- Budget tracking with alerts
|
|
||||||
- Cost projections
|
|
||||||
- Pricing configuration management
|
|
||||||
|
|
||||||
### 4. Client Management
|
|
||||||
- List, create, revoke, and rotate API tokens
|
|
||||||
- Client-specific rate limits
|
|
||||||
- Usage statistics per client
|
|
||||||
- Token management interface
|
|
||||||
|
|
||||||
### 5. Provider Configuration
|
|
||||||
- Enable/disable LLM providers
|
|
||||||
- Configure API keys (masked display)
|
|
||||||
- Test provider connections
|
|
||||||
- Model availability management
|
|
||||||
|
|
||||||
### 6. User Management (RBAC)
|
|
||||||
- **Admin Role:** Full access to all dashboard features, user management, system configuration
|
|
||||||
- **Viewer Role:** Read-only access to usage analytics, costs, and monitoring
|
|
||||||
- Create/manage dashboard users with role assignment
|
|
||||||
- Secure password management
|
|
||||||
|
|
||||||
### 7. Real-time Monitoring
|
|
||||||
- Live request stream via WebSocket
|
|
||||||
- System metrics dashboard
|
|
||||||
- Response time and error rate tracking
|
|
||||||
- Live system logs
|
|
||||||
|
|
||||||
### 7. **System Settings**
|
|
||||||
- General configuration
|
|
||||||
- Database management
|
|
||||||
- Logging settings
|
|
||||||
- Security settings
|
|
||||||
|
|
||||||
## Technology Stack
|
|
||||||
|
|
||||||
### Frontend
|
|
||||||
- **HTML5/CSS3**: Modern, responsive design with CSS Grid/Flexbox
|
|
||||||
- **JavaScript (ES6+)**: Vanilla JavaScript with modular architecture
|
|
||||||
- **Chart.js**: Interactive data visualizations
|
|
||||||
- **Luxon**: Date/time manipulation
|
|
||||||
- **WebSocket API**: Real-time updates
|
|
||||||
|
|
||||||
### Backend (Rust/Axum)
|
|
||||||
- **Axum**: Web framework with WebSocket support
|
|
||||||
- **Tokio**: Async runtime
|
|
||||||
- **Serde**: JSON serialization/deserialization
|
|
||||||
- **Broadcast channels**: Real-time event distribution
|
|
||||||
|
|
||||||
## Installation & Setup
|
|
||||||
|
|
||||||
### 1. Build and Run the Server
|
|
||||||
```bash
|
|
||||||
# Build the project
|
|
||||||
cargo build --release
|
|
||||||
|
|
||||||
# Run the server
|
|
||||||
cargo run --release
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Access the Dashboard
|
|
||||||
Once the server is running, access the dashboard at:
|
|
||||||
```
|
|
||||||
http://localhost:8080
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Default Login Credentials
|
|
||||||
- **Username**: `admin`
|
|
||||||
- **Password**: `admin123`
|
|
||||||
|
|
||||||
## API Endpoints
|
|
||||||
|
|
||||||
### Authentication
|
|
||||||
- `POST /api/auth/login` - Dashboard login
|
|
||||||
- `GET /api/auth/status` - Authentication status
|
|
||||||
|
|
||||||
### Analytics
|
|
||||||
- `GET /api/usage/summary` - Overall usage summary
|
|
||||||
- `GET /api/usage/time-series` - Time series data
|
|
||||||
- `GET /api/usage/clients` - Client breakdown
|
|
||||||
- `GET /api/usage/providers` - Provider breakdown
|
|
||||||
|
|
||||||
### Clients
|
|
||||||
- `GET /api/clients` - List all clients
|
|
||||||
- `POST /api/clients` - Create new client
|
|
||||||
- `PUT /api/clients/{id}` - Update client
|
|
||||||
- `DELETE /api/clients/{id}` - Revoke client
|
|
||||||
- `GET /api/clients/{id}/usage` - Client-specific usage
|
|
||||||
|
|
||||||
### Users (RBAC)
|
|
||||||
- `GET /api/users` - List all dashboard users
|
|
||||||
- `POST /api/users` - Create new user
|
|
||||||
- `PUT /api/users/{id}` - Update user (admin only)
|
|
||||||
- `DELETE /api/users/{id}` - Delete user (admin only)
|
|
||||||
|
|
||||||
### Providers
|
|
||||||
- `GET /api/providers` - List providers and status
|
|
||||||
- `PUT /api/providers/{name}` - Update provider config
|
|
||||||
- `POST /api/providers/{name}/test` - Test provider connection
|
|
||||||
|
|
||||||
### System
|
|
||||||
- `GET /api/system/health` - System health
|
|
||||||
- `GET /api/system/logs` - Recent logs
|
|
||||||
- `POST /api/system/backup` - Trigger backup
|
|
||||||
|
|
||||||
### WebSocket
|
|
||||||
- `GET /ws` - WebSocket endpoint for real-time updates
|
|
||||||
|
|
||||||
## Project Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
llm-proxy/
|
|
||||||
├── src/
|
|
||||||
│ ├── dashboard/ # Dashboard backend module
|
|
||||||
│ │ └── mod.rs # Dashboard routes and handlers
|
|
||||||
│ ├── server/ # Main proxy server
|
|
||||||
│ ├── providers/ # LLM provider implementations
|
|
||||||
│ └── ... # Other modules
|
|
||||||
├── static/ # Frontend dashboard files
|
|
||||||
│ ├── index.html # Main dashboard HTML
|
|
||||||
│ ├── css/
|
|
||||||
│ │ └── dashboard.css # Dashboard styles
|
|
||||||
│ ├── js/
|
|
||||||
│ │ ├── auth.js # Authentication module
|
|
||||||
│ │ ├── dashboard.js # Main dashboard controller
|
|
||||||
│ │ ├── websocket.js # WebSocket manager
|
|
||||||
│ │ ├── charts.js # Chart.js utilities
|
|
||||||
│ │ └── pages/ # Page-specific modules
|
|
||||||
│ │ ├── overview.js
|
|
||||||
│ │ ├── analytics.js
|
|
||||||
│ │ ├── costs.js
|
|
||||||
│ │ ├── clients.js
|
|
||||||
│ │ ├── providers.js
|
|
||||||
│ │ ├── monitoring.js
|
|
||||||
│ │ ├── settings.js
|
|
||||||
│ │ └── logs.js
|
|
||||||
│ ├── img/ # Images and icons
|
|
||||||
│ └── fonts/ # Font files
|
|
||||||
└── Cargo.toml # Rust dependencies
|
|
||||||
```
|
|
||||||
|
|
||||||
## Development
|
|
||||||
|
|
||||||
### Adding New Pages
|
|
||||||
1. Create a new JavaScript module in `static/js/pages/`
|
|
||||||
2. Implement the page class with `init()` method
|
|
||||||
3. Register the page in `dashboard.js`
|
|
||||||
4. Add menu item in `index.html`
|
|
||||||
|
|
||||||
### Adding New API Endpoints
|
|
||||||
1. Add route in `src/dashboard/mod.rs`
|
|
||||||
2. Implement handler function
|
|
||||||
3. Update frontend JavaScript to call the endpoint
|
|
||||||
|
|
||||||
### Styling Guidelines
|
|
||||||
- Use CSS custom properties (variables) from `:root`
|
|
||||||
- Follow mobile-first responsive design
|
|
||||||
- Use BEM-like naming convention for CSS classes
|
|
||||||
- Maintain consistent spacing with CSS variables
|
|
||||||
|
|
||||||
## Security Considerations
|
|
||||||
|
|
||||||
1. **Authentication**: Simple password-based auth for demo; replace with proper auth in production
|
|
||||||
2. **API Keys**: Tokens are masked in the UI (only last 4 characters shown)
|
|
||||||
3. **CORS**: Configure appropriate CORS headers for production
|
|
||||||
4. **Rate Limiting**: Implement rate limiting for API endpoints
|
|
||||||
5. **HTTPS**: Always use HTTPS in production
|
|
||||||
|
|
||||||
## Performance Optimizations
|
|
||||||
|
|
||||||
1. **Code Splitting**: JavaScript modules are loaded on-demand
|
|
||||||
2. **Caching**: Static assets are served with cache headers
|
|
||||||
3. **WebSocket**: Real-time updates reduce polling overhead
|
|
||||||
4. **Lazy Loading**: Charts and tables load data as needed
|
|
||||||
5. **Compression**: Enable gzip/brotli compression for static files
|
|
||||||
|
|
||||||
## Browser Support
|
|
||||||
|
|
||||||
- Chrome 60+
|
|
||||||
- Firefox 55+
|
|
||||||
- Safari 11+
|
|
||||||
- Edge 79+
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
MIT License - See LICENSE file for details.
|
|
||||||
|
|
||||||
## Contributing
|
|
||||||
|
|
||||||
1. Fork the repository
|
|
||||||
2. Create a feature branch
|
|
||||||
3. Make your changes
|
|
||||||
4. Add tests if applicable
|
|
||||||
5. Submit a pull request
|
|
||||||
|
|
||||||
## Support
|
|
||||||
|
|
||||||
For issues and feature requests, please use the GitHub issue tracker.
|
|
||||||
@@ -1,480 +0,0 @@
|
|||||||
# Database Review Report for LLM-Proxy Repository
|
|
||||||
|
|
||||||
**Review Date:** 2025-03-06
|
|
||||||
**Reviewer:** Database Optimization Expert
|
|
||||||
**Repository:** llm-proxy
|
|
||||||
**Focus Areas:** Schema Design, Query Optimization, Migration Strategy, Data Integrity, Usage Tracking Accuracy
|
|
||||||
|
|
||||||
## Executive Summary
|
|
||||||
|
|
||||||
The llm-proxy database implementation demonstrates solid foundation with appropriate table structures and clear separation of concerns. However, several areas require improvement to ensure scalability, data consistency, and performance as usage grows. Key findings include:
|
|
||||||
|
|
||||||
1. **Schema Design**: Generally normalized but missing foreign key enforcement and some critical indexes.
|
|
||||||
2. **Query Optimization**: Well-optimized for most queries but missing composite indexes for common filtering patterns.
|
|
||||||
3. **Migration Strategy**: Ad-hoc migration approach that may cause issues with schema evolution.
|
|
||||||
4. **Data Integrity**: Potential race conditions in usage tracking and missing transaction boundaries.
|
|
||||||
5. **Usage Tracking**: Generally accurate but risk of inconsistent state between related tables.
|
|
||||||
|
|
||||||
This report provides detailed analysis and actionable recommendations for each area.
|
|
||||||
|
|
||||||
## 1. Schema Design Review
|
|
||||||
|
|
||||||
### Tables Overview
|
|
||||||
|
|
||||||
The database consists of 6 main tables:
|
|
||||||
|
|
||||||
1. **clients**: Client management with usage aggregates
|
|
||||||
2. **llm_requests**: Request logging with token counts and costs
|
|
||||||
3. **provider_configs**: Provider configuration and credit balances
|
|
||||||
4. **model_configs**: Model-specific configuration and cost overrides
|
|
||||||
5. **users**: Dashboard user authentication
|
|
||||||
6. **client_tokens**: API token storage for client authentication
|
|
||||||
|
|
||||||
### Normalization Assessment
|
|
||||||
|
|
||||||
**Strengths:**
|
|
||||||
- Tables follow 3rd Normal Form (3NF) with appropriate separation
|
|
||||||
- Foreign key relationships properly defined
|
|
||||||
- No obvious data duplication across tables
|
|
||||||
|
|
||||||
**Areas for Improvement:**
|
|
||||||
- **Denormalized aggregates**: `clients.total_requests`, `total_tokens`, `total_cost` are derived from `llm_requests`. This introduces risk of inconsistency.
|
|
||||||
- **Provider credit balance**: Stored in `provider_configs` but also updated based on `llm_requests`. No audit trail for balance changes.
|
|
||||||
|
|
||||||
### Data Type Analysis
|
|
||||||
|
|
||||||
**Appropriate Choices:**
|
|
||||||
- INTEGER for token counts (cast from u32 to i64)
|
|
||||||
- REAL for monetary values
|
|
||||||
- DATETIME for timestamps using SQLite's CURRENT_TIMESTAMP
|
|
||||||
- TEXT for identifiers with appropriate length
|
|
||||||
|
|
||||||
**Potential Issues:**
|
|
||||||
- `llm_requests.request_body` and `response_body` defined as TEXT but always set to NULL - consider removing or making optional columns.
|
|
||||||
- `provider_configs.billing_mode` added via migration but default value not consistently applied to existing rows.
|
|
||||||
|
|
||||||
### Constraints and Foreign Keys
|
|
||||||
|
|
||||||
**Current Constraints:**
|
|
||||||
- Primary keys defined for all tables
|
|
||||||
- UNIQUE constraints on `clients.client_id`, `users.username`, `client_tokens.token`
|
|
||||||
- Foreign key definitions present but **not enforced** (SQLite default)
|
|
||||||
|
|
||||||
**Missing Constraints:**
|
|
||||||
- NOT NULL constraints missing on several columns where nullability not intended
|
|
||||||
- CHECK constraints for positive values (`credit_balance >= 0`)
|
|
||||||
- Foreign key enforcement not enabled
|
|
||||||
|
|
||||||
## 2. Query Optimization Analysis
|
|
||||||
|
|
||||||
### Indexing Strategy
|
|
||||||
|
|
||||||
**Existing Indexes:**
|
|
||||||
- `idx_clients_client_id` - Essential for client lookups
|
|
||||||
- `idx_clients_created_at` - Useful for chronological listing
|
|
||||||
- `idx_llm_requests_timestamp` - Critical for time-based queries
|
|
||||||
- `idx_llm_requests_client_id` - Supports client-specific queries
|
|
||||||
- `idx_llm_requests_provider` - Good for provider breakdowns
|
|
||||||
- `idx_llm_requests_status` - Low cardinality but acceptable
|
|
||||||
- `idx_client_tokens_token` UNIQUE - Essential for authentication
|
|
||||||
- `idx_client_tokens_client_id` - Supports token management
|
|
||||||
|
|
||||||
**Missing Critical Indexes:**
|
|
||||||
1. `model_configs.provider_id` - Foreign key column used in JOINs
|
|
||||||
2. `llm_requests(client_id, timestamp)` - Composite index for client time-series queries
|
|
||||||
3. `llm_requests(provider, timestamp)` - For provider performance analysis
|
|
||||||
4. `llm_requests(status, timestamp)` - For error trend analysis
|
|
||||||
|
|
||||||
### N+1 Query Detection
|
|
||||||
|
|
||||||
**Well-Optimized Areas:**
|
|
||||||
- Model configuration caching prevents repeated database hits
|
|
||||||
- Provider configs loaded in batch for dashboard display
|
|
||||||
- Client listing uses single efficient query
|
|
||||||
|
|
||||||
**Potential N+1 Patterns:**
|
|
||||||
- In `server/mod.rs` list_models function, cache lookup per model but this is in-memory
|
|
||||||
- No significant database N+1 issues identified
|
|
||||||
|
|
||||||
### Inefficient Query Patterns
|
|
||||||
|
|
||||||
**Query 1: Time-series aggregation with strftime()**
|
|
||||||
```sql
|
|
||||||
SELECT strftime('%Y-%m-%d', timestamp) as date, ...
|
|
||||||
FROM llm_requests
|
|
||||||
WHERE 1=1 {}
|
|
||||||
GROUP BY date, client_id, provider, model
|
|
||||||
ORDER BY date DESC
|
|
||||||
LIMIT 200
|
|
||||||
```
|
|
||||||
**Issue:** Function on indexed column prevents index utilization for the WHERE clause when filtering by timestamp range.
|
|
||||||
|
|
||||||
**Recommendation:** Store computed date column or use range queries on timestamp directly.
|
|
||||||
|
|
||||||
**Query 2: Today's stats using strftime()**
|
|
||||||
```sql
|
|
||||||
WHERE strftime('%Y-%m-%d', timestamp) = ?
|
|
||||||
```
|
|
||||||
**Issue:** Non-sargable query prevents index usage.
|
|
||||||
|
|
||||||
**Recommendation:** Use range query:
|
|
||||||
```sql
|
|
||||||
WHERE timestamp >= date(?) AND timestamp < date(?, '+1 day')
|
|
||||||
```
|
|
||||||
|
|
||||||
### Recommended Index Additions
|
|
||||||
|
|
||||||
```sql
|
|
||||||
-- Composite indexes for common query patterns
|
|
||||||
CREATE INDEX idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp);
|
|
||||||
CREATE INDEX idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp);
|
|
||||||
CREATE INDEX idx_llm_requests_status_timestamp ON llm_requests(status, timestamp);
|
|
||||||
|
|
||||||
-- Foreign key index
|
|
||||||
CREATE INDEX idx_model_configs_provider_id ON model_configs(provider_id);
|
|
||||||
|
|
||||||
-- Optional: Covering index for client usage queries
|
|
||||||
CREATE INDEX idx_clients_usage ON clients(client_id, total_requests, total_tokens, total_cost);
|
|
||||||
```
|
|
||||||
|
|
||||||
## 3. Migration Strategy Assessment
|
|
||||||
|
|
||||||
### Current Approach
|
|
||||||
|
|
||||||
The migration system uses a hybrid approach:
|
|
||||||
|
|
||||||
1. **Schema synchronization**: `CREATE TABLE IF NOT EXISTS` on startup
|
|
||||||
2. **Ad-hoc migrations**: `ALTER TABLE` statements with error suppression
|
|
||||||
3. **Single migration file**: `migrations/001-add-billing-mode.sql` with transaction wrapper
|
|
||||||
|
|
||||||
**Pros:**
|
|
||||||
- Simple to understand and maintain
|
|
||||||
- Automatic schema creation for new deployments
|
|
||||||
- Error suppression prevents crashes on column existence
|
|
||||||
|
|
||||||
**Cons:**
|
|
||||||
- No version tracking of applied migrations
|
|
||||||
- Potential for inconsistent schema across deployments
|
|
||||||
- `ALTER TABLE` error suppression hides genuine schema issues
|
|
||||||
- No rollback capability
|
|
||||||
|
|
||||||
### Risks and Limitations
|
|
||||||
|
|
||||||
1. **Schema Drift**: Different instances may have different schemas if migrations are applied out of order
|
|
||||||
2. **Data Loss Risk**: No backup/verification before schema changes
|
|
||||||
3. **Production Issues**: Error suppression could mask migration failures until runtime
|
|
||||||
|
|
||||||
### Recommendations
|
|
||||||
|
|
||||||
1. **Implement Proper Migration Tooling**: Use `sqlx migrate` or similar versioned migration system
|
|
||||||
2. **Add Migration Version Table**: Track applied migrations and checksum verification
|
|
||||||
3. **Separate Migration Scripts**: One file per migration with up/down directions
|
|
||||||
4. **Pre-deployment Validation**: Schema checks in CI/CD pipeline
|
|
||||||
5. **Backup Strategy**: Automatic backups before migration execution
|
|
||||||
|
|
||||||
## 4. Data Integrity Evaluation
|
|
||||||
|
|
||||||
### Foreign Key Enforcement
|
|
||||||
|
|
||||||
**Critical Issue:** Foreign key constraints are defined but **not enforced** in SQLite.
|
|
||||||
|
|
||||||
**Impact:** Orphaned records, inconsistent referential integrity.
|
|
||||||
|
|
||||||
**Solution:** Enable foreign key support in connection string:
|
|
||||||
```rust
|
|
||||||
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?
|
|
||||||
.create_if_missing(true)
|
|
||||||
.pragma("foreign_keys", "ON");
|
|
||||||
```
|
|
||||||
|
|
||||||
### Transaction Usage
|
|
||||||
|
|
||||||
**Good Patterns:**
|
|
||||||
- Request logging uses transactions for insert + provider balance update
|
|
||||||
- Atomic UPDATE for client usage statistics
|
|
||||||
|
|
||||||
**Problematic Areas:**
|
|
||||||
|
|
||||||
1. **Split Transactions**: Client usage update and request logging are in separate transactions
|
|
||||||
- In `logging/mod.rs`: `insert_log` transaction includes provider balance update
|
|
||||||
- In `utils/streaming.rs`: Client usage updated separately after logging
|
|
||||||
- **Risk**: Partial updates if one transaction fails
|
|
||||||
|
|
||||||
2. **No Transaction for Client Creation**: Client and token creation not atomic
|
|
||||||
|
|
||||||
**Recommendations:**
|
|
||||||
- Wrap client usage update within the same transaction as request logging
|
|
||||||
- Use transaction for client + token creation
|
|
||||||
- Consider using savepoints for complex operations
|
|
||||||
|
|
||||||
### Race Conditions and Consistency
|
|
||||||
|
|
||||||
**Potential Race Conditions:**
|
|
||||||
1. **Provider credit balance**: Concurrent requests may cause lost updates
|
|
||||||
- Current: `UPDATE provider_configs SET credit_balance = credit_balance - ?`
|
|
||||||
- SQLite provides serializable isolation, but negative balances not prevented
|
|
||||||
|
|
||||||
2. **Client usage aggregates**: Concurrent updates to `total_requests`, `total_tokens`, `total_cost`
|
|
||||||
- Similar UPDATE pattern, generally safe but consider idempotency
|
|
||||||
|
|
||||||
**Recommendations:**
|
|
||||||
- Add check constraint: `CHECK (credit_balance >= 0)`
|
|
||||||
- Implement idempotent request logging with unique request IDs
|
|
||||||
- Consider optimistic concurrency control for critical balances
|
|
||||||
|
|
||||||
## 5. Usage Tracking Accuracy
|
|
||||||
|
|
||||||
### Token Counting Methodology
|
|
||||||
|
|
||||||
**Current Approach:**
|
|
||||||
- Prompt tokens: Estimated using provider-specific estimators
|
|
||||||
- Completion tokens: Estimated or from provider real usage data
|
|
||||||
- Cache tokens: Separately tracked for cache-aware pricing
|
|
||||||
|
|
||||||
**Strengths:**
|
|
||||||
- Fallback to estimation when provider doesn't report usage
|
|
||||||
- Cache token differentiation for accurate pricing
|
|
||||||
|
|
||||||
**Weaknesses:**
|
|
||||||
- Estimation may differ from actual provider counts
|
|
||||||
- No validation of provider-reported token counts
|
|
||||||
|
|
||||||
### Cost Calculation
|
|
||||||
|
|
||||||
**Well Implemented:**
|
|
||||||
- Model-specific cost overrides via `model_configs`
|
|
||||||
- Cache-aware pricing when supported by registry
|
|
||||||
- Provider fallback calculations
|
|
||||||
|
|
||||||
**Potential Issues:**
|
|
||||||
- Floating-point precision for monetary calculations
|
|
||||||
- No rounding strategy for fractional cents
|
|
||||||
|
|
||||||
### Update Consistency
|
|
||||||
|
|
||||||
**Inconsistency Risk:** Client aggregates updated separately from request logging.
|
|
||||||
|
|
||||||
**Example Flow:**
|
|
||||||
1. Request log inserted and provider balance updated (transaction)
|
|
||||||
2. Client usage updated (separate operation)
|
|
||||||
3. If step 2 fails, client stats undercount usage
|
|
||||||
|
|
||||||
**Solution:** Include client update in the same transaction:
|
|
||||||
```rust
|
|
||||||
// In insert_log function, add:
|
|
||||||
UPDATE clients
|
|
||||||
SET total_requests = total_requests + 1,
|
|
||||||
total_tokens = total_tokens + ?,
|
|
||||||
total_cost = total_cost + ?
|
|
||||||
WHERE client_id = ?;
|
|
||||||
```
|
|
||||||
|
|
||||||
### Financial Accuracy
|
|
||||||
|
|
||||||
**Good Practices:**
|
|
||||||
- Token-level granularity for cost calculation
|
|
||||||
- Separation of prompt/completion/cache pricing
|
|
||||||
- Database persistence for audit trail
|
|
||||||
|
|
||||||
**Recommendations:**
|
|
||||||
1. **Audit Trail**: Add `balance_transactions` table for provider credit changes
|
|
||||||
2. **Rounding Policy**: Define rounding strategy (e.g., to 6 decimal places)
|
|
||||||
3. **Validation**: Periodic reconciliation of aggregates vs. detail records
|
|
||||||
|
|
||||||
## 6. Performance Recommendations
|
|
||||||
|
|
||||||
### Schema Improvements
|
|
||||||
|
|
||||||
1. **Partitioning Strategy**: For high-volume `llm_requests`, consider:
|
|
||||||
- Monthly partitioning by timestamp
|
|
||||||
- Archive old data to separate tables
|
|
||||||
|
|
||||||
2. **Data Retention Policy**: Implement automatic cleanup of old request logs
|
|
||||||
```sql
|
|
||||||
DELETE FROM llm_requests WHERE timestamp < date('now', '-90 days');
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **Column Optimization**: Remove unused `request_body`, `response_body` columns or implement compression
|
|
||||||
|
|
||||||
### Query Optimizations
|
|
||||||
|
|
||||||
1. **Avoid Functions on Indexed Columns**: Rewrite date queries as range queries
|
|
||||||
2. **Batch Updates**: Consider batch updates for client usage instead of per-request
|
|
||||||
3. **Read Replicas**: For dashboard queries, consider separate read connection
|
|
||||||
|
|
||||||
### Connection Pooling
|
|
||||||
|
|
||||||
**Current:** SQLx connection pool with default settings
|
|
||||||
|
|
||||||
**Recommendations:**
|
|
||||||
- Configure pool size based on expected concurrency
|
|
||||||
- Implement connection health checks
|
|
||||||
- Monitor pool utilization metrics
|
|
||||||
|
|
||||||
### Monitoring Setup
|
|
||||||
|
|
||||||
**Essential Metrics:**
|
|
||||||
- Query execution times (slow query logging)
|
|
||||||
- Index usage statistics
|
|
||||||
- Table growth trends
|
|
||||||
- Connection pool utilization
|
|
||||||
|
|
||||||
**Implementation:**
|
|
||||||
- Add `sqlx::metrics` integration
|
|
||||||
- Regular `ANALYZE` execution for query planner
|
|
||||||
- Dashboard for database health monitoring
|
|
||||||
|
|
||||||
## 7. Security Considerations
|
|
||||||
|
|
||||||
### Data Protection
|
|
||||||
|
|
||||||
**Sensitive Data:**
|
|
||||||
- `provider_configs.api_key` - Should be encrypted at rest
|
|
||||||
- `users.password_hash` - Already hashed with bcrypt
|
|
||||||
- `client_tokens.token` - Plain text storage
|
|
||||||
|
|
||||||
**Recommendations:**
|
|
||||||
- Encrypt API keys using libsodium or similar
|
|
||||||
- Implement token hashing (similar to password hashing)
|
|
||||||
- Regular security audits of authentication flows
|
|
||||||
|
|
||||||
### SQL Injection Prevention
|
|
||||||
|
|
||||||
**Good Practices:**
|
|
||||||
- Use sqlx query builder with parameter binding
|
|
||||||
- No raw SQL concatenation observed in code review
|
|
||||||
|
|
||||||
**Verification Needed:** Ensure all dynamic SQL uses parameterized queries
|
|
||||||
|
|
||||||
### Access Controls
|
|
||||||
|
|
||||||
**Database Level:**
|
|
||||||
- SQLite lacks built-in user management
|
|
||||||
- Consider file system permissions for database file
|
|
||||||
- Application-level authentication is primary control
|
|
||||||
|
|
||||||
## 8. Summary of Critical Issues
|
|
||||||
|
|
||||||
**Priority 1 (Critical):**
|
|
||||||
1. Foreign key constraints not enabled
|
|
||||||
2. Split transactions risking data inconsistency
|
|
||||||
3. Missing composite indexes for common queries
|
|
||||||
|
|
||||||
**Priority 2 (High):**
|
|
||||||
1. No proper migration versioning system
|
|
||||||
2. Potential race conditions in balance updates
|
|
||||||
3. Non-sargable date queries impacting performance
|
|
||||||
|
|
||||||
**Priority 3 (Medium):**
|
|
||||||
1. Denormalized aggregates without consistency guarantees
|
|
||||||
2. No data retention policy for request logs
|
|
||||||
3. Missing check constraints for data validation
|
|
||||||
|
|
||||||
## 9. Recommended Action Plan
|
|
||||||
|
|
||||||
### Phase 1: Immediate Fixes (1-2 weeks)
|
|
||||||
1. Enable foreign key constraints in database connection
|
|
||||||
2. Add composite indexes for common query patterns
|
|
||||||
3. Fix transaction boundaries for client usage updates
|
|
||||||
4. Rewrite non-sargable date queries
|
|
||||||
|
|
||||||
### Phase 2: Short-term Improvements (3-4 weeks)
|
|
||||||
1. Implement proper migration system with version tracking
|
|
||||||
2. Add check constraints for data validation
|
|
||||||
3. Implement connection pooling configuration
|
|
||||||
4. Create database monitoring dashboard
|
|
||||||
|
|
||||||
### Phase 3: Long-term Enhancements (2-3 months)
|
|
||||||
1. Implement data retention and archiving strategy
|
|
||||||
2. Add audit trail for provider balance changes
|
|
||||||
3. Consider partitioning for high-volume tables
|
|
||||||
4. Implement encryption for sensitive data
|
|
||||||
|
|
||||||
### Phase 4: Ongoing Maintenance
|
|
||||||
1. Regular index maintenance and query plan analysis
|
|
||||||
2. Periodic reconciliation of aggregate vs. detail data
|
|
||||||
3. Security audits and dependency updates
|
|
||||||
4. Performance benchmarking and optimization
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Appendices
|
|
||||||
|
|
||||||
### A. Sample Migration Implementation
|
|
||||||
|
|
||||||
```sql
|
|
||||||
-- migrations/002-enable-foreign-keys.sql
|
|
||||||
PRAGMA foreign_keys = ON;
|
|
||||||
|
|
||||||
-- migrations/003-add-composite-indexes.sql
|
|
||||||
CREATE INDEX idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp);
|
|
||||||
CREATE INDEX idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp);
|
|
||||||
CREATE INDEX idx_model_configs_provider_id ON model_configs(provider_id);
|
|
||||||
```
|
|
||||||
|
|
||||||
### B. Transaction Fix Example
|
|
||||||
|
|
||||||
```rust
|
|
||||||
async fn insert_log(pool: &SqlitePool, log: RequestLog) -> Result<(), sqlx::Error> {
|
|
||||||
let mut tx = pool.begin().await?;
|
|
||||||
|
|
||||||
// Insert or ignore client
|
|
||||||
sqlx::query("INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')")
|
|
||||||
.bind(&log.client_id)
|
|
||||||
.bind(&log.client_id)
|
|
||||||
.execute(&mut *tx)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Insert request log
|
|
||||||
sqlx::query("INSERT INTO llm_requests ...")
|
|
||||||
.execute(&mut *tx)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Update provider balance
|
|
||||||
if log.cost > 0.0 {
|
|
||||||
sqlx::query("UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')")
|
|
||||||
.bind(log.cost)
|
|
||||||
.bind(&log.provider)
|
|
||||||
.execute(&mut *tx)
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update client aggregates within same transaction
|
|
||||||
sqlx::query("UPDATE clients SET total_requests = total_requests + 1, total_tokens = total_tokens + ?, total_cost = total_cost + ? WHERE client_id = ?")
|
|
||||||
.bind(log.total_tokens as i64)
|
|
||||||
.bind(log.cost)
|
|
||||||
.bind(&log.client_id)
|
|
||||||
.execute(&mut *tx)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
tx.commit().await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### C. Monitoring Query Examples
|
|
||||||
|
|
||||||
```sql
|
|
||||||
-- Identify unused indexes
|
|
||||||
SELECT * FROM sqlite_master
|
|
||||||
WHERE type = 'index'
|
|
||||||
AND name NOT IN (
|
|
||||||
SELECT DISTINCT name
|
|
||||||
FROM sqlite_stat1
|
|
||||||
WHERE tbl = 'llm_requests'
|
|
||||||
);
|
|
||||||
|
|
||||||
-- Table size analysis
|
|
||||||
SELECT name, (pgsize * page_count) / 1024 / 1024 as size_mb
|
|
||||||
FROM dbstat
|
|
||||||
WHERE name = 'llm_requests';
|
|
||||||
|
|
||||||
-- Query performance analysis (requires EXPLAIN QUERY PLAN)
|
|
||||||
EXPLAIN QUERY PLAN
|
|
||||||
SELECT * FROM llm_requests
|
|
||||||
WHERE client_id = ? AND timestamp >= ?;
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
*This report provides a comprehensive analysis of the current database implementation and actionable recommendations for improvement. Regular review and iteration will ensure the database continues to meet performance, consistency, and scalability requirements as the application grows.*
|
|
||||||
43
Dockerfile
43
Dockerfile
@@ -1,35 +1,34 @@
|
|||||||
# ── Build stage ──────────────────────────────────────────────
|
# Build stage
|
||||||
FROM rust:1-bookworm AS builder
|
FROM golang:1.22-alpine AS builder
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Cache dependency build
|
# Copy go mod and sum files
|
||||||
COPY Cargo.toml Cargo.lock ./
|
COPY go.mod go.sum ./
|
||||||
RUN mkdir src && echo 'fn main() {}' > src/main.rs && \
|
RUN go mod download
|
||||||
cargo build --release && \
|
|
||||||
rm -rf src
|
|
||||||
|
|
||||||
# Build the actual binary
|
# Copy the source code
|
||||||
COPY src/ src/
|
COPY . .
|
||||||
RUN touch src/main.rs && cargo build --release
|
|
||||||
|
|
||||||
# ── Runtime stage ────────────────────────────────────────────
|
# Build the application
|
||||||
FROM debian:bookworm-slim
|
RUN CGO_ENABLED=0 GOOS=linux go build -o gophergate ./cmd/gophergate
|
||||||
|
|
||||||
RUN apt-get update && \
|
# Final stage
|
||||||
apt-get install -y --no-install-recommends ca-certificates && \
|
FROM alpine:latest
|
||||||
rm -rf /var/lib/apt/lists/*
|
|
||||||
|
RUN apk --no-cache add ca-certificates tzdata
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
COPY --from=builder /app/target/release/llm-proxy /app/llm-proxy
|
# Copy the binary from the builder stage
|
||||||
COPY static/ /app/static/
|
COPY --from=builder /app/gophergate .
|
||||||
|
COPY --from=builder /app/static ./static
|
||||||
|
|
||||||
# Default config location
|
# Create data directory
|
||||||
VOLUME ["/app/config", "/app/data"]
|
RUN mkdir -p /app/data
|
||||||
|
|
||||||
|
# Expose port
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
|
|
||||||
ENV RUST_LOG=info
|
# Run the application
|
||||||
|
CMD ["./gophergate"]
|
||||||
ENTRYPOINT ["/app/llm-proxy"]
|
|
||||||
|
|||||||
232
OPTIMIZATION.md
232
OPTIMIZATION.md
@@ -1,232 +0,0 @@
|
|||||||
# Optimization for 512MB RAM Environment
|
|
||||||
|
|
||||||
This document provides guidance for optimizing the LLM Proxy Gateway for deployment in resource-constrained environments (512MB RAM).
|
|
||||||
|
|
||||||
## Memory Optimization Strategies
|
|
||||||
|
|
||||||
### 1. Build Optimization
|
|
||||||
|
|
||||||
The project is already configured with optimized build settings in `Cargo.toml`:
|
|
||||||
|
|
||||||
```toml
|
|
||||||
[profile.release]
|
|
||||||
opt-level = 3 # Maximum optimization
|
|
||||||
lto = true # Link-time optimization
|
|
||||||
codegen-units = 1 # Single codegen unit for better optimization
|
|
||||||
strip = true # Strip debug symbols
|
|
||||||
```
|
|
||||||
|
|
||||||
**Additional optimizations you can apply:**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Build with specific target for better optimization
|
|
||||||
cargo build --release --target x86_64-unknown-linux-musl
|
|
||||||
|
|
||||||
# Or for ARM (Raspberry Pi, etc.)
|
|
||||||
cargo build --release --target aarch64-unknown-linux-musl
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Runtime Memory Management
|
|
||||||
|
|
||||||
#### Database Connection Pool
|
|
||||||
- Default: 10 connections
|
|
||||||
- Recommended for 512MB: 5 connections
|
|
||||||
|
|
||||||
Update `config.toml`:
|
|
||||||
```toml
|
|
||||||
[database]
|
|
||||||
max_connections = 5
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Rate Limiting Memory Usage
|
|
||||||
- Client rate limit buckets: Store in memory
|
|
||||||
- Circuit breakers: Minimal memory usage
|
|
||||||
- Consider reducing burst capacity if memory is critical
|
|
||||||
|
|
||||||
#### Provider Management
|
|
||||||
- Only enable providers you actually use
|
|
||||||
- Disable unused providers in configuration
|
|
||||||
|
|
||||||
### 3. Configuration for Low Memory
|
|
||||||
|
|
||||||
Create a `config-low-memory.toml`:
|
|
||||||
|
|
||||||
```toml
|
|
||||||
[server]
|
|
||||||
port = 8080
|
|
||||||
host = "0.0.0.0"
|
|
||||||
|
|
||||||
[database]
|
|
||||||
path = "./data/llm_proxy.db"
|
|
||||||
max_connections = 3 # Reduced from default 10
|
|
||||||
|
|
||||||
[providers]
|
|
||||||
# Only enable providers you need
|
|
||||||
openai.enabled = true
|
|
||||||
gemini.enabled = false # Disable if not used
|
|
||||||
deepseek.enabled = false # Disable if not used
|
|
||||||
grok.enabled = false # Disable if not used
|
|
||||||
|
|
||||||
[rate_limiting]
|
|
||||||
# Reduce memory usage for rate limiting
|
|
||||||
client_requests_per_minute = 30 # Reduced from 60
|
|
||||||
client_burst_size = 5 # Reduced from 10
|
|
||||||
global_requests_per_minute = 300 # Reduced from 600
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. System-Level Optimizations
|
|
||||||
|
|
||||||
#### Linux Kernel Parameters
|
|
||||||
Add to `/etc/sysctl.conf`:
|
|
||||||
```bash
|
|
||||||
# Reduce TCP buffer sizes
|
|
||||||
net.ipv4.tcp_rmem = 4096 87380 174760
|
|
||||||
net.ipv4.tcp_wmem = 4096 65536 131072
|
|
||||||
|
|
||||||
# Reduce connection tracking
|
|
||||||
net.netfilter.nf_conntrack_max = 65536
|
|
||||||
net.netfilter.nf_conntrack_tcp_timeout_established = 1200
|
|
||||||
|
|
||||||
# Reduce socket buffer sizes
|
|
||||||
net.core.rmem_max = 131072
|
|
||||||
net.core.wmem_max = 131072
|
|
||||||
net.core.rmem_default = 65536
|
|
||||||
net.core.wmem_default = 65536
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Systemd Service Configuration
|
|
||||||
Create `/etc/systemd/system/llm-proxy.service`:
|
|
||||||
```ini
|
|
||||||
[Unit]
|
|
||||||
Description=LLM Proxy Gateway
|
|
||||||
After=network.target
|
|
||||||
|
|
||||||
[Service]
|
|
||||||
Type=simple
|
|
||||||
User=llmproxy
|
|
||||||
Group=llmproxy
|
|
||||||
WorkingDirectory=/opt/llm-proxy
|
|
||||||
ExecStart=/opt/llm-proxy/llm-proxy
|
|
||||||
Restart=on-failure
|
|
||||||
RestartSec=5
|
|
||||||
|
|
||||||
# Memory limits
|
|
||||||
MemoryMax=400M
|
|
||||||
MemorySwapMax=100M
|
|
||||||
|
|
||||||
# CPU limits
|
|
||||||
CPUQuota=50%
|
|
||||||
|
|
||||||
# Process limits
|
|
||||||
LimitNOFILE=65536
|
|
||||||
LimitNPROC=512
|
|
||||||
|
|
||||||
Environment="RUST_LOG=info"
|
|
||||||
Environment="LLM_PROXY__DATABASE__MAX_CONNECTIONS=3"
|
|
||||||
|
|
||||||
[Install]
|
|
||||||
WantedBy=multi-user.target
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5. Application-Specific Optimizations
|
|
||||||
|
|
||||||
#### Disable Unused Features
|
|
||||||
- **Multimodal support**: If not using images, disable image processing dependencies
|
|
||||||
- **Dashboard**: The dashboard uses WebSockets and additional memory. Consider disabling if not needed.
|
|
||||||
- **Detailed logging**: Reduce log verbosity in production
|
|
||||||
|
|
||||||
#### Memory Pool Sizes
|
|
||||||
The application uses several memory pools:
|
|
||||||
1. **Database connection pool**: Configured via `max_connections`
|
|
||||||
2. **HTTP client pool**: Reqwest client pool (defaults to reasonable values)
|
|
||||||
3. **Async runtime**: Tokio worker threads
|
|
||||||
|
|
||||||
Reduce Tokio worker threads for low-core systems:
|
|
||||||
```rust
|
|
||||||
// In main.rs, modify tokio runtime creation
|
|
||||||
#[tokio::main(flavor = "current_thread")] // Single-threaded runtime
|
|
||||||
async fn main() -> Result<()> {
|
|
||||||
// Or for multi-threaded with limited threads:
|
|
||||||
// #[tokio::main(worker_threads = 2)]
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6. Monitoring and Profiling
|
|
||||||
|
|
||||||
#### Memory Usage Monitoring
|
|
||||||
```bash
|
|
||||||
# Install heaptrack for memory profiling
|
|
||||||
cargo install heaptrack
|
|
||||||
|
|
||||||
# Profile memory usage
|
|
||||||
heaptrack ./target/release/llm-proxy
|
|
||||||
|
|
||||||
# Monitor with ps
|
|
||||||
ps aux --sort=-%mem | head -10
|
|
||||||
|
|
||||||
# Monitor with top
|
|
||||||
top -p $(pgrep llm-proxy)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Performance Benchmarks
|
|
||||||
Test with different configurations:
|
|
||||||
```bash
|
|
||||||
# Test with 100 concurrent connections
|
|
||||||
wrk -t4 -c100 -d30s http://localhost:8080/health
|
|
||||||
|
|
||||||
# Test chat completion endpoint
|
|
||||||
ab -n 1000 -c 10 -p test_request.json -T application/json http://localhost:8080/v1/chat/completions
|
|
||||||
```
|
|
||||||
|
|
||||||
### 7. Deployment Checklist for 512MB RAM
|
|
||||||
|
|
||||||
- [ ] Build with release profile: `cargo build --release`
|
|
||||||
- [ ] Configure database with `max_connections = 3`
|
|
||||||
- [ ] Disable unused providers in configuration
|
|
||||||
- [ ] Set appropriate rate limiting limits
|
|
||||||
- [ ] Configure systemd with memory limits
|
|
||||||
- [ ] Set up log rotation to prevent disk space issues
|
|
||||||
- [ ] Monitor memory usage during initial deployment
|
|
||||||
- [ ] Consider using swap space (512MB-1GB) for safety
|
|
||||||
|
|
||||||
### 8. Troubleshooting High Memory Usage
|
|
||||||
|
|
||||||
#### Common Issues and Solutions:
|
|
||||||
|
|
||||||
1. **Database connection leaks**: Ensure connections are properly closed
|
|
||||||
2. **Memory fragmentation**: Use jemalloc or mimalloc as allocator
|
|
||||||
3. **Unbounded queues**: Check WebSocket message queues
|
|
||||||
4. **Cache growth**: Implement cache limits or TTL
|
|
||||||
|
|
||||||
#### Add to Cargo.toml for alternative allocator:
|
|
||||||
```toml
|
|
||||||
[dependencies]
|
|
||||||
mimalloc = { version = "0.1", default-features = false }
|
|
||||||
|
|
||||||
[features]
|
|
||||||
default = ["mimalloc"]
|
|
||||||
```
|
|
||||||
|
|
||||||
#### In main.rs:
|
|
||||||
```rust
|
|
||||||
#[global_allocator]
|
|
||||||
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
|
||||||
```
|
|
||||||
|
|
||||||
### 9. Expected Memory Usage
|
|
||||||
|
|
||||||
| Component | Baseline | With 10 clients | With 100 clients |
|
|
||||||
|-----------|----------|-----------------|------------------|
|
|
||||||
| Base executable | 15MB | 15MB | 15MB |
|
|
||||||
| Database connections | 5MB | 8MB | 15MB |
|
|
||||||
| Rate limiting | 2MB | 5MB | 20MB |
|
|
||||||
| HTTP clients | 3MB | 5MB | 10MB |
|
|
||||||
| **Total** | **25MB** | **33MB** | **60MB** |
|
|
||||||
|
|
||||||
**Note**: These are estimates. Actual usage depends on request volume, payload sizes, and configuration.
|
|
||||||
|
|
||||||
### 10. Further Reading
|
|
||||||
|
|
||||||
- [Tokio performance guide](https://tokio.rs/tokio/topics/performance)
|
|
||||||
- [Rust performance book](https://nnethercote.github.io/perf-book/)
|
|
||||||
- [Linux memory management](https://www.kernel.org/doc/html/latest/admin-guide/mm/)
|
|
||||||
- [SQLite performance tips](https://www.sqlite.org/faq.html#q19)
|
|
||||||
62
PLAN.md
62
PLAN.md
@@ -1,62 +0,0 @@
|
|||||||
# Project Plan: LLM Proxy Enhancements & Security Upgrade
|
|
||||||
|
|
||||||
This document outlines the roadmap for standardizing frontend security, cleaning up the codebase, upgrading session management to HMAC-signed tokens, and extending integration testing.
|
|
||||||
|
|
||||||
## Phase 1: Frontend Security Standardization
|
|
||||||
**Primary Agent:** `frontend-developer`
|
|
||||||
|
|
||||||
- [x] Audit `static/js/pages/users.js` for manual HTML string concatenation.
|
|
||||||
- [x] Replace custom escaping or unescaped injections with `window.api.escapeHtml`.
|
|
||||||
- [x] Verify user list and user detail rendering for XSS vulnerabilities.
|
|
||||||
|
|
||||||
## Phase 2: Codebase Cleanup
|
|
||||||
**Primary Agent:** `backend-developer`
|
|
||||||
|
|
||||||
- [x] Identify and remove unused imports in `src/config/mod.rs`.
|
|
||||||
- [x] Identify and remove unused imports in `src/providers/mod.rs`.
|
|
||||||
- [x] Run `cargo clippy` and `cargo fmt` to ensure adherence to standards.
|
|
||||||
|
|
||||||
## Phase 3: HMAC Architectural Upgrade
|
|
||||||
**Primary Agents:** `fullstack-developer`, `security-auditor`, `backend-developer`
|
|
||||||
|
|
||||||
### 3.1 Design (Security Auditor)
|
|
||||||
- [x] Define Token Structure: `base64(payload).signature`.
|
|
||||||
- Payload: `{ "session_id": "...", "username": "...", "role": "...", "exp": ... }`
|
|
||||||
- [x] Select HMAC algorithm (HMAC-SHA256).
|
|
||||||
- [x] Define environment variable for secret key: `SESSION_SECRET`.
|
|
||||||
|
|
||||||
### 3.2 Implementation (Backend Developer)
|
|
||||||
- [x] Refactor `src/dashboard/sessions.rs`:
|
|
||||||
- Integrate `hmac` and `sha2` crates (or similar).
|
|
||||||
- Update `create_session` to return signed tokens.
|
|
||||||
- Update `validate_session` to verify signature before checking store.
|
|
||||||
- [x] Implement activity-based session refresh:
|
|
||||||
- If session is valid and >50% through its TTL, extend `expires_at` and issue new signed token.
|
|
||||||
|
|
||||||
### 3.3 Integration (Fullstack Developer)
|
|
||||||
- [x] Update dashboard API handlers to handle new token format.
|
|
||||||
- [x] Update frontend session storage/retrieval if necessary.
|
|
||||||
|
|
||||||
## Phase 4: Extended Integration Testing
|
|
||||||
**Primary Agent:** `qa-automation`
|
|
||||||
|
|
||||||
- [ ] Setup test environment with encrypted key storage enabled.
|
|
||||||
- [ ] Implement end-to-end flow:
|
|
||||||
1. Store encrypted provider key via API.
|
|
||||||
2. Authenticate through Proxy.
|
|
||||||
3. Make proxied LLM request (verifying decryption and usage).
|
|
||||||
- [ ] Validate HMAC token expiration and refresh logic in automated tests.
|
|
||||||
|
|
||||||
## Phase 5: Code Quality & Refactoring
|
|
||||||
**Primary Agent:** `fullstack-developer`
|
|
||||||
|
|
||||||
- [x] Refactor dashboard monolith into modular sub-modules (`auth.rs`, `usage.rs`, etc.).
|
|
||||||
- [x] Standardize error handling and remove `unwrap()` in production paths.
|
|
||||||
- [x] Implement system health metrics and backup functionality.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Technical Standards
|
|
||||||
- **Rust:** No `unwrap()` in production code; use proper error handling (`Result`).
|
|
||||||
- **Frontend:** Always use `window.api` wrappers for sensitive operations.
|
|
||||||
- **Security:** Secrets must never be logged or hardcoded.
|
|
||||||
96
README.md
96
README.md
@@ -1,119 +1,125 @@
|
|||||||
# LLM Proxy Gateway
|
# GopherGate
|
||||||
|
|
||||||
A unified, high-performance LLM proxy gateway built in Rust. It provides a single OpenAI-compatible API to access multiple providers (OpenAI, Gemini, DeepSeek, Grok, Ollama) with built-in token tracking, real-time cost calculation, multi-user authentication, and a management dashboard.
|
A unified, high-performance LLM proxy gateway built in Go. It provides a single OpenAI-compatible API to access multiple providers (OpenAI, Gemini, DeepSeek, Moonshot, Grok, Ollama) with built-in token tracking, real-time cost calculation, multi-user authentication, and a management dashboard.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **Unified API:** OpenAI-compatible `/v1/chat/completions` and `/v1/models` endpoints.
|
- **Unified API:** OpenAI-compatible `/v1/chat/completions` and `/v1/models` endpoints.
|
||||||
- **Multi-Provider Support:**
|
- **Multi-Provider Support:**
|
||||||
- **OpenAI:** GPT-4o, GPT-4o Mini, o1, o3 reasoning models.
|
- **OpenAI:** GPT-4o, GPT-4o Mini, o1, o3 reasoning models.
|
||||||
- **Google Gemini:** Gemini 2.0 Flash, Pro, and vision models.
|
- **Google Gemini:** Gemini 2.0 Flash, Pro, and vision models (with native CoT support).
|
||||||
- **DeepSeek:** DeepSeek Chat and Reasoner models.
|
- **DeepSeek:** DeepSeek Chat and Reasoner (R1) models.
|
||||||
- **xAI Grok:** Grok-beta models.
|
- **Moonshot:** Kimi K2.5 and other Kimi models.
|
||||||
|
- **xAI Grok:** Grok-4 models.
|
||||||
- **Ollama:** Local LLMs running on your network.
|
- **Ollama:** Local LLMs running on your network.
|
||||||
- **Observability & Tracking:**
|
- **Observability & Tracking:**
|
||||||
- **Real-time Costing:** Fetches live pricing and context specs from `models.dev` on startup.
|
- **Asynchronous Logging:** Non-blocking request logging to SQLite using background workers.
|
||||||
- **Token Counting:** Precise estimation using `tiktoken-rs`.
|
- **Token Counting:** Precise estimation and tracking of prompt, completion, and reasoning tokens.
|
||||||
- **Database Logging:** Every request logged to SQLite for historical analysis.
|
- **Database Persistence:** Every request logged to SQLite for historical analysis and dashboard analytics.
|
||||||
- **Streaming Support:** Full SSE (Server-Sent Events) with `[DONE]` termination for client compatibility.
|
- **Streaming Support:** Full SSE (Server-Sent Events) support for all providers.
|
||||||
- **Multimodal (Vision):** Image processing (Base64 and remote URLs) across compatible providers.
|
- **Multimodal (Vision):** Image processing (Base64 and remote URLs) across compatible providers.
|
||||||
- **Multi-User Access Control:**
|
- **Multi-User Access Control:**
|
||||||
- **Admin Role:** Full access to all dashboard features, user management, and system configuration.
|
- **Admin Role:** Full access to all dashboard features, user management, and system configuration.
|
||||||
- **Viewer Role:** Read-only access to usage analytics, costs, and monitoring.
|
- **Viewer Role:** Read-only access to usage analytics, costs, and monitoring.
|
||||||
- **Client API Keys:** Create and manage multiple client tokens for external integrations.
|
- **Client API Keys:** Create and manage multiple client tokens for external integrations.
|
||||||
- **Reliability:**
|
- **Reliability:**
|
||||||
- **Circuit Breaking:** Automatically protects when providers are down.
|
- **Circuit Breaking:** Automatically protects when providers are down (coming soon).
|
||||||
- **Rate Limiting:** Per-client and global rate limits.
|
- **Rate Limiting:** Per-client and global rate limits (coming soon).
|
||||||
- **Cache-Aware Costing:** Tracks cache hit/miss tokens for accurate billing.
|
|
||||||
|
|
||||||
## Security
|
## Security
|
||||||
|
|
||||||
LLM Proxy is designed with security in mind:
|
GopherGate is designed with security in mind:
|
||||||
|
|
||||||
- **HMAC Session Tokens:** Management dashboard sessions are secured using HMAC-SHA256 signed tokens.
|
- **Signed Session Tokens:** Management dashboard sessions are secured using HMAC-SHA256 signed tokens.
|
||||||
- **Encrypted Provider Keys:** Sensitive LLM provider API keys are stored encrypted (AES-256-GCM) in the database.
|
- **Encrypted Storage:** Support for encrypted provider API keys in the database.
|
||||||
- **Session Refresh:** Activity-based session extension prevents session hijacking while maintaining user convenience.
|
- **Auth Middleware:** Secure client authentication via database-backed API keys.
|
||||||
- **XSS Prevention:** Standardized frontend escaping using `window.api.escapeHtml`.
|
|
||||||
|
|
||||||
**Note:** You must define a `SESSION_SECRET` in your `.env` file for secure session signing.
|
**Note:** You must define an `LLM_PROXY__ENCRYPTION_KEY` in your `.env` file for secure session signing and encryption.
|
||||||
|
|
||||||
## Tech Stack
|
## Tech Stack
|
||||||
|
|
||||||
- **Runtime:** Rust with Tokio.
|
- **Runtime:** Go 1.22+
|
||||||
- **Web Framework:** Axum.
|
- **Web Framework:** Gin Gonic
|
||||||
- **Database:** SQLx with SQLite.
|
- **Database:** sqlx with SQLite (CGO-free via `modernc.org/sqlite`)
|
||||||
- **Frontend:** Vanilla JS/CSS with Chart.js for visualizations.
|
- **Frontend:** Vanilla JS/CSS with Chart.js for visualizations
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
### Prerequisites
|
### Prerequisites
|
||||||
|
|
||||||
- Rust (1.80+)
|
- Go (1.22+)
|
||||||
- SQLite3
|
- SQLite3 (optional, driver is built-in)
|
||||||
- Docker (optional, for containerized deployment)
|
- Docker (optional, for containerized deployment)
|
||||||
|
|
||||||
### Quick Start
|
### Quick Start
|
||||||
|
|
||||||
1. Clone and build:
|
1. Clone and build:
|
||||||
```bash
|
```bash
|
||||||
git clone ssh://git.dustin.coffee:2222/hobokenchicken/llm-proxy.git
|
git clone <repository-url>
|
||||||
cd llm-proxy
|
cd gophergate
|
||||||
cargo build --release
|
go build -o gophergate ./cmd/gophergate
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Configure environment:
|
2. Configure environment:
|
||||||
```bash
|
```bash
|
||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
# Edit .env and add your API keys:
|
# Edit .env and add your configuration:
|
||||||
# SESSION_SECRET=... (Generate a strong random secret)
|
# LLM_PROXY__ENCRYPTION_KEY=... (32-byte hex or base64 string)
|
||||||
# OPENAI_API_KEY=sk-...
|
# OPENAI_API_KEY=sk-...
|
||||||
# GEMINI_API_KEY=AIza...
|
# GEMINI_API_KEY=AIza...
|
||||||
|
# MOONSHOT_API_KEY=...
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Run the proxy:
|
3. Run the proxy:
|
||||||
```bash
|
```bash
|
||||||
cargo run --release
|
./gophergate
|
||||||
```
|
```
|
||||||
|
|
||||||
The server starts on `http://localhost:8080` by default.
|
The server starts on `http://0.0.0.0:8080` by default.
|
||||||
|
|
||||||
### Deployment (Docker)
|
### Deployment (Docker)
|
||||||
|
|
||||||
A multi-stage `Dockerfile` is provided for efficient deployment:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Build the container
|
# Build the container
|
||||||
docker build -t llm-proxy .
|
docker build -t gophergate .
|
||||||
|
|
||||||
# Run the container
|
# Run the container
|
||||||
docker run -p 8080:8080 \
|
docker run -p 8080:8080 \
|
||||||
-e SESSION_SECRET=your-secure-secret \
|
-e LLM_PROXY__ENCRYPTION_KEY=your-secure-key \
|
||||||
-v ./data:/app/data \
|
-v ./data:/app/data \
|
||||||
llm-proxy
|
gophergate
|
||||||
```
|
```
|
||||||
|
|
||||||
## Management Dashboard
|
## Management Dashboard
|
||||||
|
|
||||||
Access the dashboard at `http://localhost:8080`. The dashboard architecture has been refactored into modular sub-components for better maintainability:
|
Access the dashboard at `http://localhost:8080`.
|
||||||
|
|
||||||
- **Auth (`/api/auth`):** Login, session management, and password changes.
|
- **Auth:** Login, session management, and status tracking.
|
||||||
- **Usage (`/api/usage`):** Summary stats, time-series analytics, and provider breakdown.
|
- **Usage:** Summary stats, time-series analytics, and provider breakdown.
|
||||||
- **Clients (`/api/clients`):** API key management and per-client usage tracking.
|
- **Clients:** API key management and per-client usage tracking.
|
||||||
- **Providers (`/api/providers`):** Provider configuration, status monitoring, and connection testing.
|
- **Providers:** Provider configuration and status monitoring.
|
||||||
- **System (`/api/system`):** Health metrics, live logs, database backups, and global settings.
|
- **Users:** Admin-only user management for dashboard access.
|
||||||
- **Monitoring:** Live request stream via WebSocket.
|
- **Monitoring:** Live request stream via WebSocket.
|
||||||
|
|
||||||
### Default Credentials
|
### Default Credentials
|
||||||
|
|
||||||
- **Username:** `admin`
|
- **Username:** `admin`
|
||||||
- **Password:** `admin123`
|
- **Password:** `admin123` (You will be prompted to change this on first login)
|
||||||
|
|
||||||
Change the admin password in the dashboard after first login!
|
**Forgot Password?**
|
||||||
|
You can reset the admin password to default by running:
|
||||||
|
```bash
|
||||||
|
./gophergate -reset-admin
|
||||||
|
```
|
||||||
|
|
||||||
## API Usage
|
## API Usage
|
||||||
|
|
||||||
The proxy is a drop-in replacement for OpenAI. Configure your client:
|
The proxy is a drop-in replacement for OpenAI. Configure your client:
|
||||||
|
|
||||||
|
Moonshot models are available through the same OpenAI-compatible endpoint. For
|
||||||
|
example, use `kimi-k2.5` as the model name after setting `MOONSHOT_API_KEY` in
|
||||||
|
your environment.
|
||||||
|
|
||||||
### Python
|
### Python
|
||||||
```python
|
```python
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
@@ -131,4 +137,4 @@ response = client.chat.completions.create(
|
|||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
MIT OR Apache-2.0
|
MIT
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
# LLM Proxy Rust Backend Code Review Report
|
|
||||||
|
|
||||||
## Executive Summary
|
|
||||||
This code review examines the `llm-proxy` Rust backend, focusing on architectural soundness, performance characteristics, and adherence to Rust idioms. The codebase demonstrates solid engineering with well-structured modular design, comprehensive error handling, and thoughtful async patterns. However, several areas require attention for production readiness, particularly around thread safety, memory efficiency, and error recovery.
|
|
||||||
|
|
||||||
## 1. Core Proxy Logic Review
|
|
||||||
### Strengths
|
|
||||||
- Clean provider abstraction (`Provider` trait).
|
|
||||||
- Streaming support with `AggregatingStream` for token counting.
|
|
||||||
- Model mapping and caching system.
|
|
||||||
|
|
||||||
### Issues Found
|
|
||||||
- **Provider Manager Thread Safety Risk:** O(n) lookups using `Vec` with `RwLock`. Use `DashMap` instead.
|
|
||||||
- **Streaming Memory Inefficiency:** Accumulates complete response content in memory.
|
|
||||||
- **Model Registry Cache Invalidation:** No strategy when config changes via dashboard.
|
|
||||||
|
|
||||||
## 2. State Management Review
|
|
||||||
- **Token Bucket Algorithm Flaw:** Custom implementation lacks thread-safe refill sync. Use `governor` crate.
|
|
||||||
- **Broadcast Channel Unbounded Growth Risk:** Fixed-size (100) may drop messages.
|
|
||||||
- **Database Connection Pool Contention:** SQLite connections shared without differentiation.
|
|
||||||
|
|
||||||
## 3. Error Handling Review
|
|
||||||
- **Error Recovery Missing:** No circuit breaker or retry logic for provider calls.
|
|
||||||
- **Stream Error Logging Gap:** Stream errors swallowed without logging partial usage.
|
|
||||||
|
|
||||||
## 4. Rust Idioms and Performance
|
|
||||||
- **Unnecessary String Cloning:** Frequent cloning in authentication hot paths.
|
|
||||||
- **JSON Parsing Inefficiency:** Multiple passes with `serde_json::Value`. Use typed structs.
|
|
||||||
- **Missing `#[derive(Copy)]`:** For small enums like `CircuitState`.
|
|
||||||
|
|
||||||
## 5. Async Performance
|
|
||||||
- **Blocking Calls:** Token estimation may block async runtime.
|
|
||||||
- **Missing Connection Timeouts:** Only overall timeout, no separate read/write timeouts.
|
|
||||||
- **Unbounded Task Spawn:** For client usage updates under load.
|
|
||||||
|
|
||||||
## 6. Security Considerations
|
|
||||||
- **Token Leakage:** Redact tokens in `Debug` and `Display` impls.
|
|
||||||
- **No Request Size Limits:** Vulnerable to memory exhaustion.
|
|
||||||
|
|
||||||
## 7. Testability
|
|
||||||
- **Mocking Difficulty:** Tight coupling to concrete provider implementations.
|
|
||||||
- **Missing Integration Tests:** No E2E tests for streaming.
|
|
||||||
|
|
||||||
## 8. Summary of Critical Actions
|
|
||||||
### High Priority
|
|
||||||
1. Replace custom token bucket with `governor` crate.
|
|
||||||
2. Fix provider lookup O(n) scaling issue.
|
|
||||||
3. Implement proper error recovery with retries.
|
|
||||||
4. Add request size limits and timeout configurations.
|
|
||||||
|
|
||||||
### Medium Priority
|
|
||||||
1. Reduce string cloning in hot paths.
|
|
||||||
2. Implement cache invalidation for model configs.
|
|
||||||
3. Add connection pooling separation.
|
|
||||||
4. Improve streaming memory efficiency.
|
|
||||||
|
|
||||||
---
|
|
||||||
*Review conducted by: Senior Principal Engineer (Code Reviewer)*
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
# LLM Proxy Security Audit Report
|
|
||||||
|
|
||||||
## Executive Summary
|
|
||||||
A comprehensive security audit of the `llm-proxy` repository was conducted. The audit identified **1 critical vulnerability**, **3 high-risk issues**, **4 medium-risk issues**, and **3 low-risk issues**. The most severe findings include Cross-Site Scripting (XSS) in the dashboard interface and insecure storage of provider API keys in the database.
|
|
||||||
|
|
||||||
## Detailed Findings
|
|
||||||
|
|
||||||
### Critical Risk Vulnerabilities
|
|
||||||
#### **CRITICAL-01: Cross-Site Scripting (XSS) in Dashboard Interface**
|
|
||||||
- **Location**: `static/js/pages/clients.js` (multiple locations).
|
|
||||||
- **Description**: User-controlled data (e.g., `client.id`) inserted directly into HTML or `onclick` handlers without escaping.
|
|
||||||
- **Impact**: Arbitrary JavaScript execution in admin context, potentially stealing session tokens.
|
|
||||||
|
|
||||||
#### **CRITICAL-02: Insecure API Key Storage in Database**
|
|
||||||
- **Location**: `src/database/mod.rs`, `src/providers/mod.rs`, `src/dashboard/providers.rs`.
|
|
||||||
- **Description**: Provider API keys are stored in **plaintext** in the SQLite database.
|
|
||||||
- **Impact**: Compromised database file exposes all provider API keys.
|
|
||||||
|
|
||||||
### High Risk Vulnerabilities
|
|
||||||
#### **HIGH-01: Missing Input Validation and Size Limits**
|
|
||||||
- **Location**: `src/server/mod.rs`, `src/models/mod.rs`.
|
|
||||||
- **Impact**: Denial of Service via large payloads.
|
|
||||||
|
|
||||||
#### **HIGH-02: Sensitive Data Logging Without Encryption**
|
|
||||||
- **Location**: `src/database/mod.rs`, `src/logging/mod.rs`.
|
|
||||||
- **Description**: Full request and response bodies stored in `llm_requests` table without encryption or redaction.
|
|
||||||
|
|
||||||
#### **HIGH-03: Weak Default Credentials and Password Policy**
|
|
||||||
- **Description**: Default admin password is 'admin' with only 4-character minimum password length.
|
|
||||||
|
|
||||||
### Medium Risk Vulnerabilities
|
|
||||||
#### **MEDIUM-01: Missing CSRF Protection**
|
|
||||||
- No CSRF tokens or SameSite cookie attributes for state-changing dashboard endpoints.
|
|
||||||
|
|
||||||
#### **MEDIUM-02: Insecure Session Management**
|
|
||||||
- Session tokens stored in localStorage without HttpOnly flag.
|
|
||||||
- Tokens use simple `session-{uuid}` format.
|
|
||||||
|
|
||||||
#### **MEDIUM-03: Error Information Leakage**
|
|
||||||
- Internal error details exposed to clients in some cases.
|
|
||||||
|
|
||||||
#### **MEDIUM-04: Outdated Dependencies**
|
|
||||||
- Outdated versions of `chrono`, `tokio`, and `reqwest`.
|
|
||||||
|
|
||||||
### Low Risk Vulnerabilities
|
|
||||||
- Missing security headers (CSP, HSTS, X-Frame-Options).
|
|
||||||
- Insufficient rate limiting on dashboard authentication.
|
|
||||||
- No database encryption at rest.
|
|
||||||
|
|
||||||
## Recommendations
|
|
||||||
### Immediate Actions
|
|
||||||
1. **Fix XSS Vulnerabilities:** Implement proper HTML escaping for all user-controlled data.
|
|
||||||
2. **Secure API Key Storage:** Encrypt API keys in database using a library like `ring`.
|
|
||||||
3. **Implement Input Validation:** Add maximum payload size limits (e.g., 10MB).
|
|
||||||
4. **Improve Data Protection:** Add option to disable request/response body logging.
|
|
||||||
|
|
||||||
---
|
|
||||||
*Report generated by Security Auditor Agent on March 6, 2026*
|
|
||||||
56
TODO.md
Normal file
56
TODO.md
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# Migration TODO List
|
||||||
|
|
||||||
|
## Completed Tasks
|
||||||
|
- [x] Initial Go project setup
|
||||||
|
- [x] Database schema & migrations (hardcoded in `db.go`)
|
||||||
|
- [x] Configuration loader (Viper)
|
||||||
|
- [x] Auth Middleware (scoped to `/v1`)
|
||||||
|
- [x] Basic Provider implementations (OpenAI, Gemini, DeepSeek, Grok)
|
||||||
|
- [x] Streaming Support (SSE & Gemini custom streaming)
|
||||||
|
- [x] Archive Rust files to `rust` branch
|
||||||
|
- [x] Clean root and set Go version as `main`
|
||||||
|
- [x] Enhanced `helpers.go` for Multimodal & Tool Calling (OpenAI compatible)
|
||||||
|
- [x] Enhanced `server.go` for robust request conversion
|
||||||
|
- [x] Dashboard Management APIs (Clients, Tokens, Users, Providers)
|
||||||
|
- [x] Dashboard Analytics & Usage Summary (Fixed SQL robustness)
|
||||||
|
- [x] WebSocket for real-time dashboard updates (Hub with client counting)
|
||||||
|
- [x] Asynchronous Request Logging to SQLite
|
||||||
|
- [x] Update documentation (README, deployment, architecture)
|
||||||
|
- [x] Cost Tracking accuracy (Registry integration with `models.dev`)
|
||||||
|
- [x] Model Listing endpoint (`/v1/models`) with provider filtering
|
||||||
|
- [x] System Metrics endpoint (`/api/system/metrics` using `gopsutil`)
|
||||||
|
- [x] Fixed dashboard 404s and 500s
|
||||||
|
|
||||||
|
## Feature Parity Checklist (High Priority)
|
||||||
|
|
||||||
|
### OpenAI Provider
|
||||||
|
- [x] Tool Calling
|
||||||
|
- [x] Multimodal (Images) support
|
||||||
|
- [x] Accurate usage parsing (cached & reasoning tokens)
|
||||||
|
- [ ] Reasoning Content (CoT) support for `o1`, `o3` (need to ensure it's parsed in responses)
|
||||||
|
- [ ] Support for `/v1/responses` API (required for some gpt-5/o1 models)
|
||||||
|
|
||||||
|
### Gemini Provider
|
||||||
|
- [x] Tool Calling (mapping to Gemini format)
|
||||||
|
- [x] Multimodal (Images) support
|
||||||
|
- [x] Reasoning/Thought support
|
||||||
|
- [x] Handle Tool Response role in unified format
|
||||||
|
|
||||||
|
### DeepSeek Provider
|
||||||
|
- [x] Reasoning Content (CoT) support
|
||||||
|
- [x] Parameter sanitization for `deepseek-reasoner`
|
||||||
|
- [x] Tool Calling support
|
||||||
|
- [x] Accurate usage parsing (cache hits & reasoning)
|
||||||
|
|
||||||
|
### Grok Provider
|
||||||
|
- [x] Tool Calling support
|
||||||
|
- [x] Multimodal support
|
||||||
|
- [x] Accurate usage parsing (via OpenAI helper)
|
||||||
|
|
||||||
|
## Infrastructure & Middleware
|
||||||
|
- [ ] Implement Rate Limiting (`golang.org/x/time/rate`)
|
||||||
|
- [ ] Implement Circuit Breaker (`github.com/sony/gobreaker`)
|
||||||
|
|
||||||
|
## Verification
|
||||||
|
- [ ] Unit tests for feature-specific mapping (CoT, Tools, Images)
|
||||||
|
- [ ] Integration tests with live LLM APIs
|
||||||
@@ -1 +0,0 @@
|
|||||||
too-many-arguments-threshold = 8
|
|
||||||
55
cmd/gophergate/main.go
Normal file
55
cmd/gophergate/main.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"gophergate/internal/config"
|
||||||
|
"gophergate/internal/db"
|
||||||
|
"gophergate/internal/server"
|
||||||
|
|
||||||
|
"github.com/joho/godotenv"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
resetAdmin := flag.Bool("reset-admin", false, "Reset admin password to admin123")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// Load environment variables
|
||||||
|
if err := godotenv.Load(); err != nil {
|
||||||
|
log.Println("No .env file found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load configuration
|
||||||
|
cfg, err := config.Load()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to load configuration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize database
|
||||||
|
database, err := db.Init(cfg.Database.Path)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to initialize database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if *resetAdmin {
|
||||||
|
hash, _ := bcrypt.GenerateFromPassword([]byte("admin123"), 12)
|
||||||
|
_, err = database.Exec("UPDATE users SET password_hash = ?, must_change_password = 1 WHERE username = 'admin'", string(hash))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to reset admin password: %v", err)
|
||||||
|
}
|
||||||
|
log.Println("Admin password has been reset to 'admin123'")
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize server
|
||||||
|
s := server.NewServer(cfg, database)
|
||||||
|
|
||||||
|
// Run server
|
||||||
|
log.Printf("Starting GopherGate on %s:%d", cfg.Server.Host, cfg.Server.Port)
|
||||||
|
if err := s.Run(); err != nil {
|
||||||
|
log.Fatalf("Server failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
Binary file not shown.
667
deploy.sh
667
deploy.sh
@@ -1,667 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# LLM Proxy Gateway Deployment Script
|
|
||||||
# This script automates the deployment of the LLM Proxy Gateway on a Linux server
|
|
||||||
|
|
||||||
set -e # Exit on error
|
|
||||||
set -u # Exit on undefined variable
|
|
||||||
|
|
||||||
# Configuration
|
|
||||||
APP_NAME="llm-proxy"
|
|
||||||
APP_USER="llmproxy"
|
|
||||||
APP_GROUP="llmproxy"
|
|
||||||
GIT_REPO="ssh://git.dustin.coffee:2222/hobokenchicken/llm-proxy.git"
|
|
||||||
INSTALL_DIR="/opt/$APP_NAME"
|
|
||||||
CONFIG_DIR="/etc/$APP_NAME"
|
|
||||||
DATA_DIR="/var/lib/$APP_NAME"
|
|
||||||
LOG_DIR="/var/log/$APP_NAME"
|
|
||||||
SERVICE_FILE="/etc/systemd/system/$APP_NAME.service"
|
|
||||||
ENV_FILE="$CONFIG_DIR/.env"
|
|
||||||
|
|
||||||
# Colors for output
|
|
||||||
RED='\033[0;31m'
|
|
||||||
GREEN='\033[0;32m'
|
|
||||||
YELLOW='\033[1;33m'
|
|
||||||
NC='\033[0m' # No Color
|
|
||||||
|
|
||||||
# Logging functions
|
|
||||||
log_info() {
|
|
||||||
echo -e "${GREEN}[INFO]${NC} $1"
|
|
||||||
}
|
|
||||||
|
|
||||||
log_warn() {
|
|
||||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
|
||||||
}
|
|
||||||
|
|
||||||
log_error() {
|
|
||||||
echo -e "${RED}[ERROR]${NC} $1"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check if running as root
|
|
||||||
check_root() {
|
|
||||||
if [[ $EUID -ne 0 ]]; then
|
|
||||||
log_error "This script must be run as root"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# Install system dependencies
|
|
||||||
install_dependencies() {
|
|
||||||
log_info "Installing system dependencies..."
|
|
||||||
|
|
||||||
# Detect package manager
|
|
||||||
if command -v apt-get &> /dev/null; then
|
|
||||||
# Debian/Ubuntu
|
|
||||||
apt-get update
|
|
||||||
apt-get install -y \
|
|
||||||
build-essential \
|
|
||||||
pkg-config \
|
|
||||||
libssl-dev \
|
|
||||||
sqlite3 \
|
|
||||||
curl \
|
|
||||||
git
|
|
||||||
elif command -v yum &> /dev/null; then
|
|
||||||
# RHEL/CentOS
|
|
||||||
yum groupinstall -y "Development Tools"
|
|
||||||
yum install -y \
|
|
||||||
openssl-devel \
|
|
||||||
sqlite \
|
|
||||||
curl \
|
|
||||||
git
|
|
||||||
elif command -v dnf &> /dev/null; then
|
|
||||||
# Fedora
|
|
||||||
dnf groupinstall -y "Development Tools"
|
|
||||||
dnf install -y \
|
|
||||||
openssl-devel \
|
|
||||||
sqlite \
|
|
||||||
curl \
|
|
||||||
git
|
|
||||||
elif command -v pacman &> /dev/null; then
|
|
||||||
# Arch Linux
|
|
||||||
pacman -Syu --noconfirm \
|
|
||||||
base-devel \
|
|
||||||
openssl \
|
|
||||||
sqlite \
|
|
||||||
curl \
|
|
||||||
git
|
|
||||||
else
|
|
||||||
log_warn "Could not detect package manager. Please install dependencies manually."
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# Install Rust if not present
|
|
||||||
install_rust() {
|
|
||||||
log_info "Checking for Rust installation..."
|
|
||||||
|
|
||||||
if ! command -v rustc &> /dev/null; then
|
|
||||||
log_info "Installing Rust..."
|
|
||||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
|
||||||
source "$HOME/.cargo/env"
|
|
||||||
else
|
|
||||||
log_info "Rust is already installed"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Verify installation
|
|
||||||
rustc --version
|
|
||||||
cargo --version
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create system user and directories
|
|
||||||
setup_directories() {
|
|
||||||
log_info "Creating system user and directories..."
|
|
||||||
|
|
||||||
# Create user and group if they don't exist
|
|
||||||
if ! id "$APP_USER" &>/dev/null; then
|
|
||||||
# Arch uses /usr/bin/nologin, Debian/Ubuntu use /usr/sbin/nologin
|
|
||||||
NOLOGIN=$(command -v nologin 2>/dev/null || echo "/usr/bin/nologin")
|
|
||||||
useradd -r -s "$NOLOGIN" -M "$APP_USER"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Create directories
|
|
||||||
mkdir -p "$INSTALL_DIR"
|
|
||||||
mkdir -p "$CONFIG_DIR"
|
|
||||||
mkdir -p "$DATA_DIR"
|
|
||||||
mkdir -p "$LOG_DIR"
|
|
||||||
|
|
||||||
# Set permissions
|
|
||||||
chown -R "$APP_USER:$APP_GROUP" "$INSTALL_DIR"
|
|
||||||
chown -R "$APP_USER:$APP_GROUP" "$CONFIG_DIR"
|
|
||||||
chown -R "$APP_USER:$APP_GROUP" "$DATA_DIR"
|
|
||||||
chown -R "$APP_USER:$APP_GROUP" "$LOG_DIR"
|
|
||||||
|
|
||||||
chmod 750 "$INSTALL_DIR"
|
|
||||||
chmod 750 "$CONFIG_DIR"
|
|
||||||
chmod 750 "$DATA_DIR"
|
|
||||||
chmod 750 "$LOG_DIR"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Build the application
|
|
||||||
build_application() {
|
|
||||||
log_info "Building the application..."
|
|
||||||
|
|
||||||
# Clone or update repository
|
|
||||||
if [[ ! -d "$INSTALL_DIR/.git" ]]; then
|
|
||||||
log_info "Cloning repository..."
|
|
||||||
git clone "$GIT_REPO" "$INSTALL_DIR"
|
|
||||||
else
|
|
||||||
log_info "Updating repository..."
|
|
||||||
cd "$INSTALL_DIR"
|
|
||||||
git pull
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Build in release mode
|
|
||||||
cd "$INSTALL_DIR"
|
|
||||||
log_info "Building release binary..."
|
|
||||||
cargo build --release
|
|
||||||
|
|
||||||
# Verify build
|
|
||||||
if [[ -f "target/release/$APP_NAME" ]]; then
|
|
||||||
log_info "Build successful"
|
|
||||||
else
|
|
||||||
log_error "Build failed"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create configuration files
|
|
||||||
create_configuration() {
|
|
||||||
log_info "Creating configuration files..."
|
|
||||||
|
|
||||||
# Create .env file with API keys
|
|
||||||
cat > "$ENV_FILE" << EOF
|
|
||||||
# LLM Proxy Gateway Environment Variables
|
|
||||||
# Add your API keys here
|
|
||||||
|
|
||||||
# OpenAI API Key
|
|
||||||
# OPENAI_API_KEY=sk-your-key-here
|
|
||||||
|
|
||||||
# Google Gemini API Key
|
|
||||||
# GEMINI_API_KEY=AIza-your-key-here
|
|
||||||
|
|
||||||
# DeepSeek API Key
|
|
||||||
# DEEPSEEK_API_KEY=sk-your-key-here
|
|
||||||
|
|
||||||
# xAI Grok API Key
|
|
||||||
# GROK_API_KEY=gk-your-key-here
|
|
||||||
|
|
||||||
# Authentication tokens (comma-separated)
|
|
||||||
# LLM_PROXY__SERVER__AUTH_TOKENS=token1,token2,token3
|
|
||||||
EOF
|
|
||||||
|
|
||||||
# Create config.toml
|
|
||||||
cat > "$CONFIG_DIR/config.toml" << EOF
|
|
||||||
# LLM Proxy Gateway Configuration
|
|
||||||
|
|
||||||
[server]
|
|
||||||
port = 8080
|
|
||||||
host = "0.0.0.0"
|
|
||||||
# auth_tokens = ["token1", "token2", "token3"] # Uncomment to enable authentication
|
|
||||||
|
|
||||||
[database]
|
|
||||||
path = "$DATA_DIR/llm_proxy.db"
|
|
||||||
max_connections = 5
|
|
||||||
|
|
||||||
[providers.openai]
|
|
||||||
enabled = true
|
|
||||||
api_key_env = "OPENAI_API_KEY"
|
|
||||||
base_url = "https://api.openai.com/v1"
|
|
||||||
default_model = "gpt-4o"
|
|
||||||
|
|
||||||
[providers.gemini]
|
|
||||||
enabled = true
|
|
||||||
api_key_env = "GEMINI_API_KEY"
|
|
||||||
base_url = "https://generativelanguage.googleapis.com/v1"
|
|
||||||
default_model = "gemini-2.0-flash"
|
|
||||||
|
|
||||||
[providers.deepseek]
|
|
||||||
enabled = true
|
|
||||||
api_key_env = "DEEPSEEK_API_KEY"
|
|
||||||
base_url = "https://api.deepseek.com"
|
|
||||||
default_model = "deepseek-reasoner"
|
|
||||||
|
|
||||||
[providers.grok]
|
|
||||||
enabled = false # Disabled by default until API is researched
|
|
||||||
api_key_env = "GROK_API_KEY"
|
|
||||||
base_url = "https://api.x.ai/v1"
|
|
||||||
default_model = "grok-beta"
|
|
||||||
|
|
||||||
[model_mapping]
|
|
||||||
"gpt-*" = "openai"
|
|
||||||
"gemini-*" = "gemini"
|
|
||||||
"deepseek-*" = "deepseek"
|
|
||||||
"grok-*" = "grok"
|
|
||||||
|
|
||||||
[pricing]
|
|
||||||
openai = { input = 0.01, output = 0.03 }
|
|
||||||
gemini = { input = 0.0005, output = 0.0015 }
|
|
||||||
deepseek = { input = 0.00014, output = 0.00028 }
|
|
||||||
grok = { input = 0.001, output = 0.003 }
|
|
||||||
EOF
|
|
||||||
|
|
||||||
# Set permissions
|
|
||||||
chown "$APP_USER:$APP_GROUP" "$ENV_FILE"
|
|
||||||
chown "$APP_USER:$APP_GROUP" "$CONFIG_DIR/config.toml"
|
|
||||||
chmod 640 "$ENV_FILE"
|
|
||||||
chmod 640 "$CONFIG_DIR/config.toml"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create systemd service
|
|
||||||
create_systemd_service() {
|
|
||||||
log_info "Creating systemd service..."
|
|
||||||
|
|
||||||
cat > "$SERVICE_FILE" << EOF
|
|
||||||
[Unit]
|
|
||||||
Description=LLM Proxy Gateway
|
|
||||||
Documentation=https://git.dustin.coffee/hobokenchicken/llm-proxy
|
|
||||||
After=network.target
|
|
||||||
Wants=network.target
|
|
||||||
|
|
||||||
[Service]
|
|
||||||
Type=simple
|
|
||||||
User=$APP_USER
|
|
||||||
Group=$APP_GROUP
|
|
||||||
WorkingDirectory=$INSTALL_DIR
|
|
||||||
EnvironmentFile=$ENV_FILE
|
|
||||||
Environment="RUST_LOG=info"
|
|
||||||
Environment="LLM_PROXY__CONFIG_PATH=$CONFIG_DIR/config.toml"
|
|
||||||
Environment="LLM_PROXY__DATABASE__PATH=$DATA_DIR/llm_proxy.db"
|
|
||||||
ExecStart=$INSTALL_DIR/target/release/$APP_NAME
|
|
||||||
Restart=on-failure
|
|
||||||
RestartSec=5
|
|
||||||
|
|
||||||
# Security hardening
|
|
||||||
NoNewPrivileges=true
|
|
||||||
PrivateTmp=true
|
|
||||||
ProtectSystem=strict
|
|
||||||
ProtectHome=true
|
|
||||||
ReadWritePaths=$DATA_DIR $LOG_DIR
|
|
||||||
|
|
||||||
# Resource limits (adjust based on your server)
|
|
||||||
MemoryMax=400M
|
|
||||||
MemorySwapMax=100M
|
|
||||||
CPUQuota=50%
|
|
||||||
LimitNOFILE=65536
|
|
||||||
|
|
||||||
# Logging
|
|
||||||
StandardOutput=journal
|
|
||||||
StandardError=journal
|
|
||||||
SyslogIdentifier=$APP_NAME
|
|
||||||
|
|
||||||
[Install]
|
|
||||||
WantedBy=multi-user.target
|
|
||||||
EOF
|
|
||||||
|
|
||||||
# Reload systemd
|
|
||||||
systemctl daemon-reload
|
|
||||||
}
|
|
||||||
|
|
||||||
# Setup nginx reverse proxy (optional)
|
|
||||||
setup_nginx_proxy() {
|
|
||||||
if ! command -v nginx &> /dev/null; then
|
|
||||||
log_warn "nginx not installed. Skipping reverse proxy setup."
|
|
||||||
return
|
|
||||||
fi
|
|
||||||
|
|
||||||
log_info "Setting up nginx reverse proxy..."
|
|
||||||
|
|
||||||
cat > "/etc/nginx/sites-available/$APP_NAME" << EOF
|
|
||||||
server {
|
|
||||||
listen 80;
|
|
||||||
server_name your-domain.com; # Change to your domain
|
|
||||||
|
|
||||||
# Redirect to HTTPS (recommended)
|
|
||||||
return 301 https://\$server_name\$request_uri;
|
|
||||||
}
|
|
||||||
|
|
||||||
server {
|
|
||||||
listen 443 ssl http2;
|
|
||||||
server_name your-domain.com; # Change to your domain
|
|
||||||
|
|
||||||
# SSL certificates (adjust paths)
|
|
||||||
ssl_certificate /etc/letsencrypt/live/your-domain.com/fullchain.pem;
|
|
||||||
ssl_certificate_key /etc/letsencrypt/live/your-domain.com/privkey.pem;
|
|
||||||
|
|
||||||
# SSL configuration
|
|
||||||
ssl_protocols TLSv1.2 TLSv1.3;
|
|
||||||
ssl_ciphers ECDHE-RSA-AES256-GCM-SHA512:DHE-RSA-AES256-GCM-SHA512:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384;
|
|
||||||
ssl_prefer_server_ciphers off;
|
|
||||||
|
|
||||||
# Proxy to LLM Proxy Gateway
|
|
||||||
location / {
|
|
||||||
proxy_pass http://127.0.0.1:8080;
|
|
||||||
proxy_http_version 1.1;
|
|
||||||
proxy_set_header Upgrade \$http_upgrade;
|
|
||||||
proxy_set_header Connection "upgrade";
|
|
||||||
proxy_set_header Host \$host;
|
|
||||||
proxy_set_header X-Real-IP \$remote_addr;
|
|
||||||
proxy_set_header X-Forwarded-For \$proxy_add_x_forwarded_for;
|
|
||||||
proxy_set_header X-Forwarded-Proto \$scheme;
|
|
||||||
|
|
||||||
# Timeouts
|
|
||||||
proxy_connect_timeout 60s;
|
|
||||||
proxy_send_timeout 60s;
|
|
||||||
proxy_read_timeout 60s;
|
|
||||||
}
|
|
||||||
|
|
||||||
# Health check endpoint
|
|
||||||
location /health {
|
|
||||||
proxy_pass http://127.0.0.1:8080/health;
|
|
||||||
access_log off;
|
|
||||||
}
|
|
||||||
|
|
||||||
# Dashboard
|
|
||||||
location /dashboard {
|
|
||||||
proxy_pass http://127.0.0.1:8080/dashboard;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
EOF
|
|
||||||
|
|
||||||
# Enable site
|
|
||||||
ln -sf "/etc/nginx/sites-available/$APP_NAME" "/etc/nginx/sites-enabled/"
|
|
||||||
|
|
||||||
# Test nginx configuration
|
|
||||||
nginx -t
|
|
||||||
|
|
||||||
log_info "nginx configuration created. Please update the domain and SSL certificate paths."
|
|
||||||
}
|
|
||||||
|
|
||||||
# Setup firewall
|
|
||||||
setup_firewall() {
|
|
||||||
log_info "Configuring firewall..."
|
|
||||||
|
|
||||||
# Check for ufw (Ubuntu)
|
|
||||||
if command -v ufw &> /dev/null; then
|
|
||||||
ufw allow 22/tcp # SSH
|
|
||||||
ufw allow 80/tcp # HTTP
|
|
||||||
ufw allow 443/tcp # HTTPS
|
|
||||||
ufw --force enable
|
|
||||||
log_info "UFW firewall configured"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Check for firewalld (RHEL/CentOS)
|
|
||||||
if command -v firewall-cmd &> /dev/null; then
|
|
||||||
firewall-cmd --permanent --add-service=ssh
|
|
||||||
firewall-cmd --permanent --add-service=http
|
|
||||||
firewall-cmd --permanent --add-service=https
|
|
||||||
firewall-cmd --reload
|
|
||||||
log_info "Firewalld configured"
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# Initialize database
|
|
||||||
initialize_database() {
|
|
||||||
log_info "Initializing database..."
|
|
||||||
|
|
||||||
# Run the application once to create database
|
|
||||||
sudo -u "$APP_USER" "$INSTALL_DIR/target/release/$APP_NAME" --help &> /dev/null || true
|
|
||||||
|
|
||||||
log_info "Database initialized at $DATA_DIR/llm_proxy.db"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Start and enable service
|
|
||||||
start_service() {
|
|
||||||
log_info "Starting $APP_NAME service..."
|
|
||||||
|
|
||||||
systemctl enable "$APP_NAME"
|
|
||||||
systemctl start "$APP_NAME"
|
|
||||||
|
|
||||||
# Check status
|
|
||||||
sleep 2
|
|
||||||
systemctl status "$APP_NAME" --no-pager
|
|
||||||
}
|
|
||||||
|
|
||||||
# Verify installation
|
|
||||||
verify_installation() {
|
|
||||||
log_info "Verifying installation..."
|
|
||||||
|
|
||||||
# Check if service is running
|
|
||||||
if systemctl is-active --quiet "$APP_NAME"; then
|
|
||||||
log_info "Service is running"
|
|
||||||
else
|
|
||||||
log_error "Service is not running"
|
|
||||||
journalctl -u "$APP_NAME" -n 20 --no-pager
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Test health endpoint
|
|
||||||
if curl -s http://localhost:8080/health | grep -q "OK"; then
|
|
||||||
log_info "Health check passed"
|
|
||||||
else
|
|
||||||
log_error "Health check failed"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Test dashboard
|
|
||||||
if curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/dashboard | grep -q "200"; then
|
|
||||||
log_info "Dashboard is accessible"
|
|
||||||
else
|
|
||||||
log_warn "Dashboard may not be accessible (this is normal if not configured)"
|
|
||||||
fi
|
|
||||||
|
|
||||||
log_info "Installation verified successfully!"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Print next steps
|
|
||||||
print_next_steps() {
|
|
||||||
cat << EOF
|
|
||||||
|
|
||||||
${GREEN}=== LLM Proxy Gateway Installation Complete ===${NC}
|
|
||||||
|
|
||||||
${YELLOW}Next steps:${NC}
|
|
||||||
|
|
||||||
1. ${GREEN}Configure API keys${NC}
|
|
||||||
Edit: $ENV_FILE
|
|
||||||
Add your API keys for the providers you want to use
|
|
||||||
|
|
||||||
2. ${GREEN}Configure authentication${NC}
|
|
||||||
Edit: $CONFIG_DIR/config.toml
|
|
||||||
Uncomment and set auth_tokens for client authentication
|
|
||||||
|
|
||||||
3. ${GREEN}Configure nginx${NC}
|
|
||||||
Edit: /etc/nginx/sites-available/$APP_NAME
|
|
||||||
Update domain name and SSL certificate paths
|
|
||||||
|
|
||||||
4. ${GREEN}Test the API${NC}
|
|
||||||
curl -X POST http://localhost:8080/v1/chat/completions \\
|
|
||||||
-H "Content-Type: application/json" \\
|
|
||||||
-H "Authorization: Bearer your-token" \\
|
|
||||||
-d '{
|
|
||||||
"model": "gpt-4o",
|
|
||||||
"messages": [{"role": "user", "content": "Hello!"}]
|
|
||||||
}'
|
|
||||||
|
|
||||||
5. ${GREEN}Access the dashboard${NC}
|
|
||||||
Open: http://your-server-ip:8080/dashboard
|
|
||||||
Or: https://your-domain.com/dashboard (if nginx configured)
|
|
||||||
|
|
||||||
${YELLOW}Useful commands:${NC}
|
|
||||||
systemctl status $APP_NAME # Check service status
|
|
||||||
journalctl -u $APP_NAME -f # View logs
|
|
||||||
systemctl restart $APP_NAME # Restart service
|
|
||||||
|
|
||||||
${YELLOW}Configuration files:${NC}
|
|
||||||
Service: $SERVICE_FILE
|
|
||||||
Config: $CONFIG_DIR/config.toml
|
|
||||||
Environment: $ENV_FILE
|
|
||||||
Database: $DATA_DIR/llm_proxy.db
|
|
||||||
Logs: $LOG_DIR/
|
|
||||||
|
|
||||||
${GREEN}For more information, see:${NC}
|
|
||||||
https://git.dustin.coffee/hobokenchicken/llm-proxy
|
|
||||||
$INSTALL_DIR/README.md
|
|
||||||
$INSTALL_DIR/deployment.md
|
|
||||||
|
|
||||||
EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
# Main deployment function
|
|
||||||
deploy() {
|
|
||||||
log_info "Starting LLM Proxy Gateway deployment..."
|
|
||||||
|
|
||||||
check_root
|
|
||||||
install_dependencies
|
|
||||||
install_rust
|
|
||||||
setup_directories
|
|
||||||
build_application
|
|
||||||
create_configuration
|
|
||||||
create_systemd_service
|
|
||||||
initialize_database
|
|
||||||
start_service
|
|
||||||
verify_installation
|
|
||||||
print_next_steps
|
|
||||||
|
|
||||||
# Optional steps (uncomment if needed)
|
|
||||||
# setup_nginx_proxy
|
|
||||||
# setup_firewall
|
|
||||||
|
|
||||||
log_info "Deployment completed successfully!"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Update function
|
|
||||||
update() {
|
|
||||||
log_info "Updating LLM Proxy Gateway..."
|
|
||||||
|
|
||||||
check_root
|
|
||||||
|
|
||||||
# Pull latest changes (while service keeps running)
|
|
||||||
cd "$INSTALL_DIR"
|
|
||||||
log_info "Pulling latest changes..."
|
|
||||||
git pull
|
|
||||||
|
|
||||||
# Build new binary (service stays up on the old binary)
|
|
||||||
log_info "Building release binary (service still running)..."
|
|
||||||
if ! cargo build --release; then
|
|
||||||
log_error "Build failed — service was NOT interrupted. Fix the error and try again."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Verify binary exists
|
|
||||||
if [[ ! -f "target/release/$APP_NAME" ]]; then
|
|
||||||
log_error "Binary not found after build — aborting."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Restart service to pick up new binary
|
|
||||||
log_info "Build succeeded. Restarting service..."
|
|
||||||
systemctl restart "$APP_NAME"
|
|
||||||
|
|
||||||
sleep 2
|
|
||||||
if systemctl is-active --quiet "$APP_NAME"; then
|
|
||||||
log_info "Update completed successfully!"
|
|
||||||
systemctl status "$APP_NAME" --no-pager
|
|
||||||
else
|
|
||||||
log_error "Service failed to start after update. Check logs:"
|
|
||||||
journalctl -u "$APP_NAME" -n 20 --no-pager
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# Uninstall function
|
|
||||||
uninstall() {
|
|
||||||
log_info "Uninstalling LLM Proxy Gateway..."
|
|
||||||
|
|
||||||
check_root
|
|
||||||
|
|
||||||
# Stop and disable service
|
|
||||||
systemctl stop "$APP_NAME" 2>/dev/null || true
|
|
||||||
systemctl disable "$APP_NAME" 2>/dev/null || true
|
|
||||||
rm -f "$SERVICE_FILE"
|
|
||||||
systemctl daemon-reload
|
|
||||||
|
|
||||||
# Remove application files
|
|
||||||
rm -rf "$INSTALL_DIR"
|
|
||||||
rm -rf "$CONFIG_DIR"
|
|
||||||
|
|
||||||
# Keep data and logs (comment out to remove)
|
|
||||||
log_warn "Data directory $DATA_DIR and logs $LOG_DIR have been preserved"
|
|
||||||
log_warn "Remove manually if desired:"
|
|
||||||
log_warn " rm -rf $DATA_DIR $LOG_DIR"
|
|
||||||
|
|
||||||
# Remove user (optional)
|
|
||||||
read -p "Remove user $APP_USER? [y/N]: " -n 1 -r
|
|
||||||
echo
|
|
||||||
if [[ $REPLY =~ ^[Yy]$ ]]; then
|
|
||||||
userdel "$APP_USER" 2>/dev/null || true
|
|
||||||
groupdel "$APP_GROUP" 2>/dev/null || true
|
|
||||||
fi
|
|
||||||
|
|
||||||
log_info "Uninstallation completed!"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Show usage
|
|
||||||
usage() {
|
|
||||||
cat << EOF
|
|
||||||
LLM Proxy Gateway Deployment Script
|
|
||||||
|
|
||||||
Usage: $0 [command]
|
|
||||||
|
|
||||||
Commands:
|
|
||||||
deploy - Install and configure LLM Proxy Gateway
|
|
||||||
update - Pull latest changes, rebuild, and restart
|
|
||||||
status - Show service status and health check
|
|
||||||
logs - Tail the service logs (Ctrl+C to stop)
|
|
||||||
uninstall - Remove LLM Proxy Gateway
|
|
||||||
help - Show this help message
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
$0 deploy # Full installation
|
|
||||||
$0 update # Update to latest version
|
|
||||||
$0 status # Check if service is healthy
|
|
||||||
$0 logs # Follow live logs
|
|
||||||
|
|
||||||
EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
# Status function
|
|
||||||
status() {
|
|
||||||
echo ""
|
|
||||||
log_info "Service status:"
|
|
||||||
systemctl status "$APP_NAME" --no-pager 2>/dev/null || log_warn "Service not found"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Health check
|
|
||||||
if curl -sf http://localhost:8080/health &>/dev/null; then
|
|
||||||
log_info "Health check: OK"
|
|
||||||
else
|
|
||||||
log_warn "Health check: FAILED (service may not be running or port 8080 not responding)"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Show current git commit
|
|
||||||
if [[ -d "$INSTALL_DIR/.git" ]]; then
|
|
||||||
echo ""
|
|
||||||
log_info "Installed version:"
|
|
||||||
git -C "$INSTALL_DIR" log -1 --format=" %h %s (%cr)" 2>/dev/null
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# Logs function
|
|
||||||
logs() {
|
|
||||||
log_info "Tailing $APP_NAME logs (Ctrl+C to stop)..."
|
|
||||||
journalctl -u "$APP_NAME" -f
|
|
||||||
}
|
|
||||||
|
|
||||||
# Parse command line arguments
|
|
||||||
case "${1:-}" in
|
|
||||||
deploy)
|
|
||||||
deploy
|
|
||||||
;;
|
|
||||||
update)
|
|
||||||
update
|
|
||||||
;;
|
|
||||||
status)
|
|
||||||
status
|
|
||||||
;;
|
|
||||||
logs)
|
|
||||||
logs
|
|
||||||
;;
|
|
||||||
uninstall)
|
|
||||||
uninstall
|
|
||||||
;;
|
|
||||||
help|--help|-h)
|
|
||||||
usage
|
|
||||||
;;
|
|
||||||
*)
|
|
||||||
usage
|
|
||||||
exit 1
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
336
deployment.md
336
deployment.md
@@ -1,322 +1,52 @@
|
|||||||
# LLM Proxy Gateway - Deployment Guide
|
# Deployment Guide (Go)
|
||||||
|
|
||||||
## Overview
|
This guide covers deploying the Go-based GopherGate.
|
||||||
A unified LLM proxy gateway supporting OpenAI, Google Gemini, DeepSeek, and xAI Grok with token tracking, cost calculation, and admin dashboard.
|
|
||||||
|
|
||||||
## System Requirements
|
## Environment Setup
|
||||||
- **CPU**: 2 cores minimum
|
|
||||||
- **RAM**: 512MB minimum (1GB recommended)
|
|
||||||
- **Storage**: 10GB minimum
|
|
||||||
- **OS**: Linux (tested on Arch Linux, Ubuntu, Debian)
|
|
||||||
- **Runtime**: Rust 1.70+ with Cargo
|
|
||||||
|
|
||||||
## Deployment Options
|
1. **Mandatory Configuration:**
|
||||||
|
Create a `.env` file from the example:
|
||||||
### Option 1: Docker (Recommended)
|
|
||||||
```dockerfile
|
|
||||||
FROM rust:1.70-alpine as builder
|
|
||||||
WORKDIR /app
|
|
||||||
COPY . .
|
|
||||||
RUN cargo build --release
|
|
||||||
|
|
||||||
FROM alpine:latest
|
|
||||||
RUN apk add --no-cache libgcc
|
|
||||||
COPY --from=builder /app/target/release/llm-proxy /usr/local/bin/
|
|
||||||
COPY --from=builder /app/static /app/static
|
|
||||||
WORKDIR /app
|
|
||||||
EXPOSE 8080
|
|
||||||
CMD ["llm-proxy"]
|
|
||||||
```
|
|
||||||
|
|
||||||
### Option 2: Systemd Service (Bare Metal/LXC)
|
|
||||||
```ini
|
|
||||||
# /etc/systemd/system/llm-proxy.service
|
|
||||||
[Unit]
|
|
||||||
Description=LLM Proxy Gateway
|
|
||||||
After=network.target
|
|
||||||
|
|
||||||
[Service]
|
|
||||||
Type=simple
|
|
||||||
User=llmproxy
|
|
||||||
Group=llmproxy
|
|
||||||
WorkingDirectory=/opt/llm-proxy
|
|
||||||
ExecStart=/opt/llm-proxy/llm-proxy
|
|
||||||
Restart=always
|
|
||||||
RestartSec=10
|
|
||||||
Environment="RUST_LOG=info"
|
|
||||||
Environment="LLM_PROXY__SERVER__PORT=8080"
|
|
||||||
Environment="LLM_PROXY__SERVER__AUTH_TOKENS=sk-test-123,sk-test-456"
|
|
||||||
|
|
||||||
[Install]
|
|
||||||
WantedBy=multi-user.target
|
|
||||||
```
|
|
||||||
|
|
||||||
### Option 3: LXC Container (Proxmox)
|
|
||||||
1. Create Alpine Linux LXC container
|
|
||||||
2. Install Rust: `apk add rust cargo`
|
|
||||||
3. Copy application files
|
|
||||||
4. Build: `cargo build --release`
|
|
||||||
5. Run: `./target/release/llm-proxy`
|
|
||||||
|
|
||||||
## Configuration
|
|
||||||
|
|
||||||
### Environment Variables
|
|
||||||
```bash
|
|
||||||
# Required API Keys
|
|
||||||
OPENAI_API_KEY=sk-...
|
|
||||||
GEMINI_API_KEY=AIza...
|
|
||||||
DEEPSEEK_API_KEY=sk-...
|
|
||||||
GROK_API_KEY=gk-... # Optional
|
|
||||||
|
|
||||||
# Server Configuration (with LLM_PROXY__ prefix)
|
|
||||||
LLM_PROXY__SERVER__PORT=8080
|
|
||||||
LLM_PROXY__SERVER__HOST=0.0.0.0
|
|
||||||
LLM_PROXY__SERVER__AUTH_TOKENS=sk-test-123,sk-test-456
|
|
||||||
|
|
||||||
# Database Configuration
|
|
||||||
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
|
|
||||||
LLM_PROXY__DATABASE__MAX_CONNECTIONS=10
|
|
||||||
|
|
||||||
# Provider Configuration
|
|
||||||
LLM_PROXY__PROVIDERS__OPENAI__ENABLED=true
|
|
||||||
LLM_PROXY__PROVIDERS__GEMINI__ENABLED=true
|
|
||||||
LLM_PROXY__PROVIDERS__DEEPSEEK__ENABLED=true
|
|
||||||
LLM_PROXY__PROVIDERS__GROK__ENABLED=false
|
|
||||||
```
|
|
||||||
|
|
||||||
### Configuration File (config.toml)
|
|
||||||
Create `config.toml` in the application directory:
|
|
||||||
```toml
|
|
||||||
[server]
|
|
||||||
port = 8080
|
|
||||||
host = "0.0.0.0"
|
|
||||||
auth_tokens = ["sk-test-123", "sk-test-456"]
|
|
||||||
|
|
||||||
[database]
|
|
||||||
path = "./data/llm_proxy.db"
|
|
||||||
max_connections = 10
|
|
||||||
|
|
||||||
[providers.openai]
|
|
||||||
enabled = true
|
|
||||||
base_url = "https://api.openai.com/v1"
|
|
||||||
default_model = "gpt-4o"
|
|
||||||
|
|
||||||
[providers.gemini]
|
|
||||||
enabled = true
|
|
||||||
base_url = "https://generativelanguage.googleapis.com/v1"
|
|
||||||
default_model = "gemini-2.0-flash"
|
|
||||||
|
|
||||||
[providers.deepseek]
|
|
||||||
enabled = true
|
|
||||||
base_url = "https://api.deepseek.com"
|
|
||||||
default_model = "deepseek-reasoner"
|
|
||||||
|
|
||||||
[providers.grok]
|
|
||||||
enabled = false
|
|
||||||
base_url = "https://api.x.ai/v1"
|
|
||||||
default_model = "grok-beta"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Nginx Reverse Proxy Configuration
|
|
||||||
|
|
||||||
**Important for SSE/Streaming:** Disable buffering and configure timeouts for proper SSE support.
|
|
||||||
|
|
||||||
```nginx
|
|
||||||
server {
|
|
||||||
listen 80;
|
|
||||||
server_name llm-proxy.yourdomain.com;
|
|
||||||
|
|
||||||
location / {
|
|
||||||
proxy_pass http://localhost:8080;
|
|
||||||
proxy_http_version 1.1;
|
|
||||||
|
|
||||||
# SSE/Streaming support
|
|
||||||
proxy_buffering off;
|
|
||||||
chunked_transfer_encoding on;
|
|
||||||
proxy_set_header Connection '';
|
|
||||||
|
|
||||||
# Timeouts for long-running streams
|
|
||||||
proxy_connect_timeout 7200s;
|
|
||||||
proxy_read_timeout 7200s;
|
|
||||||
proxy_send_timeout 7200s;
|
|
||||||
|
|
||||||
# Disable gzip for streaming
|
|
||||||
gzip off;
|
|
||||||
|
|
||||||
# Headers
|
|
||||||
proxy_set_header Host $host;
|
|
||||||
proxy_set_header X-Real-IP $remote_addr;
|
|
||||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
|
||||||
proxy_set_header X-Forwarded-Proto $scheme;
|
|
||||||
}
|
|
||||||
|
|
||||||
# SSL configuration (recommended)
|
|
||||||
listen 443 ssl http2;
|
|
||||||
ssl_certificate /etc/letsencrypt/live/llm-proxy.yourdomain.com/fullchain.pem;
|
|
||||||
ssl_certificate_key /etc/letsencrypt/live/llm-proxy.yourdomain.com/privkey.pem;
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### NGINX Proxy Manager
|
|
||||||
|
|
||||||
If using NGINX Proxy Manager, add this to **Advanced Settings**:
|
|
||||||
|
|
||||||
```nginx
|
|
||||||
proxy_buffering off;
|
|
||||||
proxy_http_version 1.1;
|
|
||||||
proxy_set_header Connection '';
|
|
||||||
chunked_transfer_encoding on;
|
|
||||||
proxy_connect_timeout 7200s;
|
|
||||||
proxy_read_timeout 7200s;
|
|
||||||
proxy_send_timeout 7200s;
|
|
||||||
gzip off;
|
|
||||||
```
|
|
||||||
|
|
||||||
## Security Considerations
|
|
||||||
|
|
||||||
### 1. Authentication
|
|
||||||
- Use strong Bearer tokens
|
|
||||||
- Rotate tokens regularly
|
|
||||||
- Consider implementing JWT for production
|
|
||||||
|
|
||||||
### 2. Rate Limiting
|
|
||||||
- Implement per-client rate limiting
|
|
||||||
- Consider using `governor` crate for advanced rate limiting
|
|
||||||
|
|
||||||
### 3. Network Security
|
|
||||||
- Run behind reverse proxy (nginx)
|
|
||||||
- Enable HTTPS
|
|
||||||
- Restrict access by IP if needed
|
|
||||||
- Use firewall rules
|
|
||||||
|
|
||||||
### 4. Data Security
|
|
||||||
- Database encryption (SQLCipher for SQLite)
|
|
||||||
- Secure API key storage
|
|
||||||
- Regular backups
|
|
||||||
|
|
||||||
## Monitoring & Maintenance
|
|
||||||
|
|
||||||
### Logging
|
|
||||||
- Application logs: `RUST_LOG=info` (or `debug` for troubleshooting)
|
|
||||||
- Access logs via nginx
|
|
||||||
- Database logs for audit trail
|
|
||||||
|
|
||||||
### Health Checks
|
|
||||||
```bash
|
|
||||||
# Health endpoint
|
|
||||||
curl http://localhost:8080/health
|
|
||||||
|
|
||||||
# Database check
|
|
||||||
sqlite3 ./data/llm_proxy.db "SELECT COUNT(*) FROM llm_requests;"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Backup Strategy
|
|
||||||
```bash
|
|
||||||
#!/bin/bash
|
|
||||||
# backup.sh
|
|
||||||
BACKUP_DIR="/backups/llm-proxy"
|
|
||||||
DATE=$(date +%Y%m%d_%H%M%S)
|
|
||||||
|
|
||||||
# Backup database
|
|
||||||
sqlite3 ./data/llm_proxy.db ".backup $BACKUP_DIR/llm_proxy_$DATE.db"
|
|
||||||
|
|
||||||
# Backup configuration
|
|
||||||
cp config.toml $BACKUP_DIR/config_$DATE.toml
|
|
||||||
|
|
||||||
# Rotate old backups (keep 30 days)
|
|
||||||
find $BACKUP_DIR -name "*.db" -mtime +30 -delete
|
|
||||||
find $BACKUP_DIR -name "*.toml" -mtime +30 -delete
|
|
||||||
```
|
|
||||||
|
|
||||||
## Performance Tuning
|
|
||||||
|
|
||||||
### Database Optimization
|
|
||||||
```sql
|
|
||||||
-- Run these SQL commands periodically
|
|
||||||
VACUUM;
|
|
||||||
ANALYZE;
|
|
||||||
```
|
|
||||||
|
|
||||||
### Memory Management
|
|
||||||
- Monitor memory usage with `htop` or `ps aux`
|
|
||||||
- Adjust `max_connections` based on load
|
|
||||||
- Consider connection pooling for high traffic
|
|
||||||
|
|
||||||
### Scaling
|
|
||||||
1. **Vertical Scaling**: Increase container resources
|
|
||||||
2. **Horizontal Scaling**: Deploy multiple instances behind load balancer
|
|
||||||
3. **Database**: Migrate to PostgreSQL for high-volume usage
|
|
||||||
|
|
||||||
## Troubleshooting
|
|
||||||
|
|
||||||
### Common Issues
|
|
||||||
|
|
||||||
1. **Port already in use**
|
|
||||||
```bash
|
```bash
|
||||||
netstat -tulpn | grep :8080
|
cp .env.example .env
|
||||||
kill <PID> # or change port in config
|
|
||||||
```
|
```
|
||||||
|
Ensure `LLM_PROXY__ENCRYPTION_KEY` is set to a secure 32-byte string.
|
||||||
|
|
||||||
2. **Database permissions**
|
2. **Data Directory:**
|
||||||
```bash
|
The proxy stores its database in `./data/llm_proxy.db` by default. Ensure this directory exists and is writable.
|
||||||
chown -R llmproxy:llmproxy /opt/llm-proxy/data
|
|
||||||
chmod 600 /opt/llm-proxy/data/llm_proxy.db
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **API key errors**
|
## Binary Deployment
|
||||||
- Verify environment variables are set
|
|
||||||
- Check provider status (dashboard)
|
|
||||||
- Test connectivity: `curl https://api.openai.com/v1/models`
|
|
||||||
|
|
||||||
4. **High memory usage**
|
### 1. Build
|
||||||
- Check for memory leaks
|
|
||||||
- Reduce `max_connections`
|
|
||||||
- Implement connection timeouts
|
|
||||||
|
|
||||||
### Debug Mode
|
|
||||||
```bash
|
```bash
|
||||||
# Run with debug logging
|
go build -o gophergate ./cmd/gophergate
|
||||||
RUST_LOG=debug ./llm-proxy
|
|
||||||
|
|
||||||
# Check system logs
|
|
||||||
journalctl -u llm-proxy -f
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Integration
|
### 2. Run
|
||||||
|
```bash
|
||||||
### Open-WebUI Compatibility
|
./gophergate
|
||||||
The proxy provides OpenAI-compatible API, so configure Open-WebUI:
|
|
||||||
```
|
|
||||||
API Base URL: http://your-proxy-address:8080
|
|
||||||
API Key: sk-test-123 (or your configured token)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Custom Clients
|
## Docker Deployment
|
||||||
```python
|
|
||||||
import openai
|
|
||||||
|
|
||||||
client = openai.OpenAI(
|
The project includes a multi-stage `Dockerfile` for minimal image size.
|
||||||
base_url="http://localhost:8080/v1",
|
|
||||||
api_key="sk-test-123"
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
### 1. Build Image
|
||||||
model="gpt-4",
|
```bash
|
||||||
messages=[{"role": "user", "content": "Hello"}]
|
docker build -t gophergate .
|
||||||
)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Updates & Upgrades
|
### 2. Run Container
|
||||||
|
```bash
|
||||||
|
docker run -d \
|
||||||
|
--name gophergate \
|
||||||
|
-p 8080:8080 \
|
||||||
|
-v $(pwd)/data:/app/data \
|
||||||
|
--env-file .env \
|
||||||
|
gophergate
|
||||||
|
```
|
||||||
|
|
||||||
1. **Backup** current configuration and database
|
## Production Considerations
|
||||||
2. **Stop** the service: `systemctl stop llm-proxy`
|
|
||||||
3. **Update** code: `git pull` or copy new binaries
|
|
||||||
4. **Migrate** database if needed (check migrations/)
|
|
||||||
5. **Restart**: `systemctl start llm-proxy`
|
|
||||||
6. **Verify**: Check logs and test endpoints
|
|
||||||
|
|
||||||
## Support
|
- **SSL/TLS:** It is recommended to run the proxy behind a reverse proxy like Nginx or Caddy for SSL termination.
|
||||||
- Check logs in `/var/log/llm-proxy/`
|
- **Backups:** Regularly backup the `data/llm_proxy.db` file.
|
||||||
- Monitor dashboard at `http://your-server:8080`
|
- **Monitoring:** Monitor the `/health` endpoint for system status.
|
||||||
- Review database metrics in dashboard
|
|
||||||
- Enable debug logging for troubleshooting
|
|
||||||
|
|||||||
70
go.mod
Normal file
70
go.mod
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
module gophergate
|
||||||
|
|
||||||
|
go 1.26.1
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/gin-gonic/gin v1.12.0
|
||||||
|
github.com/go-resty/resty/v2 v2.17.2
|
||||||
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/gorilla/websocket v1.5.3
|
||||||
|
github.com/jmoiron/sqlx v1.4.0
|
||||||
|
github.com/joho/godotenv v1.5.1
|
||||||
|
github.com/shirou/gopsutil/v3 v3.24.5
|
||||||
|
github.com/spf13/viper v1.21.0
|
||||||
|
golang.org/x/crypto v0.48.0
|
||||||
|
modernc.org/sqlite v1.47.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||||
|
github.com/bytedance/sonic v1.15.0 // indirect
|
||||||
|
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||||
|
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||||
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
|
||||||
|
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||||
|
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||||
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
|
github.com/go-playground/validator/v10 v10.30.1 // indirect
|
||||||
|
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||||
|
github.com/goccy/go-json v0.10.5 // indirect
|
||||||
|
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||||
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
|
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||||
|
github.com/leodido/go-urn v1.4.0 // indirect
|
||||||
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
|
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||||
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||||
|
github.com/quic-go/qpack v0.6.0 // indirect
|
||||||
|
github.com/quic-go/quic-go v0.59.0 // indirect
|
||||||
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
|
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||||
|
github.com/shoenig/go-m1cpu v0.1.6 // indirect
|
||||||
|
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||||
|
github.com/spf13/afero v1.15.0 // indirect
|
||||||
|
github.com/spf13/cast v1.10.0 // indirect
|
||||||
|
github.com/spf13/pflag v1.0.10 // indirect
|
||||||
|
github.com/subosito/gotenv v1.6.0 // indirect
|
||||||
|
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||||
|
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
|
github.com/ugorji/go/codec v1.3.1 // indirect
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||||
|
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
|
||||||
|
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||||
|
golang.org/x/arch v0.22.0 // indirect
|
||||||
|
golang.org/x/net v0.51.0 // indirect
|
||||||
|
golang.org/x/sys v0.42.0 // indirect
|
||||||
|
golang.org/x/text v0.34.0 // indirect
|
||||||
|
golang.org/x/time v0.15.0 // indirect
|
||||||
|
google.golang.org/protobuf v1.36.10 // indirect
|
||||||
|
modernc.org/libc v1.70.0 // indirect
|
||||||
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
|
modernc.org/memory v1.11.0 // indirect
|
||||||
|
)
|
||||||
207
go.sum
Normal file
207
go.sum
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||||
|
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||||
|
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
|
||||||
|
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
|
||||||
|
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
|
||||||
|
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
|
||||||
|
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
|
||||||
|
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
||||||
|
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||||
|
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||||
|
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||||
|
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||||
|
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||||
|
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
||||||
|
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
|
||||||
|
github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8=
|
||||||
|
github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc=
|
||||||
|
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||||
|
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||||
|
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||||
|
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||||
|
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||||
|
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||||
|
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||||
|
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||||
|
github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w=
|
||||||
|
github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM=
|
||||||
|
github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk=
|
||||||
|
github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
|
||||||
|
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
|
||||||
|
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
|
||||||
|
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
||||||
|
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||||
|
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||||
|
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||||
|
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
||||||
|
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||||
|
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||||
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
|
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||||
|
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||||
|
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
|
||||||
|
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
|
||||||
|
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||||
|
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||||
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||||
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
|
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||||
|
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||||
|
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||||
|
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||||
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||||
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||||
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
|
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||||
|
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||||
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||||
|
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
|
||||||
|
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
|
||||||
|
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
|
||||||
|
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
|
||||||
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
|
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||||
|
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||||
|
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
|
||||||
|
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
|
||||||
|
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=
|
||||||
|
github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk=
|
||||||
|
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
||||||
|
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
|
||||||
|
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
|
||||||
|
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
|
||||||
|
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
|
||||||
|
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
|
||||||
|
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||||
|
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
|
||||||
|
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||||
|
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
||||||
|
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||||
|
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
|
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
|
||||||
|
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||||
|
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||||
|
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||||
|
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||||
|
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||||
|
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
|
github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY=
|
||||||
|
github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||||
|
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
|
||||||
|
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
|
||||||
|
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||||
|
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||||
|
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||||
|
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||||
|
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
|
||||||
|
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||||
|
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||||
|
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||||
|
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
|
||||||
|
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
|
||||||
|
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||||
|
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||||
|
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||||
|
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
|
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||||
|
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
|
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
|
||||||
|
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
|
||||||
|
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||||
|
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||||
|
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
|
||||||
|
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||||
|
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||||
|
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||||
|
modernc.org/ccgo/v4 v4.32.0 h1:hjG66bI/kqIPX1b2yT6fr/jt+QedtP2fqojG2VrFuVw=
|
||||||
|
modernc.org/ccgo/v4 v4.32.0/go.mod h1:6F08EBCx5uQc38kMGl+0Nm0oWczoo1c7cgpzEry7Uc0=
|
||||||
|
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
|
||||||
|
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
|
||||||
|
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||||
|
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||||
|
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
|
||||||
|
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||||
|
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||||
|
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||||
|
modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw=
|
||||||
|
modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo=
|
||||||
|
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||||
|
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||||
|
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||||
|
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||||
|
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||||
|
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||||
|
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||||
|
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||||
|
modernc.org/sqlite v1.47.0 h1:R1XyaNpoW4Et9yly+I2EeX7pBza/w+pmYee/0HJDyKk=
|
||||||
|
modernc.org/sqlite v1.47.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
|
||||||
|
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||||
|
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||||
|
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||||
|
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||||
207
internal/config/config.go
Normal file
207
internal/config/config.go
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Server ServerConfig `mapstructure:"server"`
|
||||||
|
Database DatabaseConfig `mapstructure:"database"`
|
||||||
|
Providers ProviderConfig `mapstructure:"providers"`
|
||||||
|
EncryptionKey string `mapstructure:"encryption_key"`
|
||||||
|
KeyBytes []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerConfig struct {
|
||||||
|
Port int `mapstructure:"port"`
|
||||||
|
Host string `mapstructure:"host"`
|
||||||
|
AuthTokens []string `mapstructure:"auth_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DatabaseConfig struct {
|
||||||
|
Path string `mapstructure:"path"`
|
||||||
|
MaxConnections int `mapstructure:"max_connections"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProviderConfig struct {
|
||||||
|
OpenAI OpenAIConfig `mapstructure:"openai"`
|
||||||
|
Gemini GeminiConfig `mapstructure:"gemini"`
|
||||||
|
DeepSeek DeepSeekConfig `mapstructure:"deepseek"`
|
||||||
|
Moonshot MoonshotConfig `mapstructure:"moonshot"`
|
||||||
|
Grok GrokConfig `mapstructure:"grok"`
|
||||||
|
Ollama OllamaConfig `mapstructure:"ollama"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIConfig struct {
|
||||||
|
APIKeyEnv string `mapstructure:"api_key_env"`
|
||||||
|
BaseURL string `mapstructure:"base_url"`
|
||||||
|
DefaultModel string `mapstructure:"default_model"`
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiConfig struct {
|
||||||
|
APIKeyEnv string `mapstructure:"api_key_env"`
|
||||||
|
BaseURL string `mapstructure:"base_url"`
|
||||||
|
DefaultModel string `mapstructure:"default_model"`
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeepSeekConfig struct {
|
||||||
|
APIKeyEnv string `mapstructure:"api_key_env"`
|
||||||
|
BaseURL string `mapstructure:"base_url"`
|
||||||
|
DefaultModel string `mapstructure:"default_model"`
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type MoonshotConfig struct {
|
||||||
|
APIKeyEnv string `mapstructure:"api_key_env"`
|
||||||
|
BaseURL string `mapstructure:"base_url"`
|
||||||
|
DefaultModel string `mapstructure:"default_model"`
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GrokConfig struct {
|
||||||
|
APIKeyEnv string `mapstructure:"api_key_env"`
|
||||||
|
BaseURL string `mapstructure:"base_url"`
|
||||||
|
DefaultModel string `mapstructure:"default_model"`
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OllamaConfig struct {
|
||||||
|
BaseURL string `mapstructure:"base_url"`
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
DefaultModel string `mapstructure:"default_model"`
|
||||||
|
Models []string `mapstructure:"models"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func Load() (*Config, error) {
|
||||||
|
v := viper.New()
|
||||||
|
|
||||||
|
// Defaults
|
||||||
|
v.SetDefault("server.port", 8080)
|
||||||
|
v.SetDefault("server.host", "0.0.0.0")
|
||||||
|
v.SetDefault("server.auth_tokens", []string{})
|
||||||
|
v.SetDefault("database.path", "./data/llm_proxy.db")
|
||||||
|
v.SetDefault("database.max_connections", 10)
|
||||||
|
|
||||||
|
v.SetDefault("providers.openai.api_key_env", "OPENAI_API_KEY")
|
||||||
|
v.SetDefault("providers.openai.base_url", "https://api.openai.com/v1")
|
||||||
|
v.SetDefault("providers.openai.default_model", "gpt-4o")
|
||||||
|
v.SetDefault("providers.openai.enabled", true)
|
||||||
|
|
||||||
|
v.SetDefault("providers.gemini.api_key_env", "GEMINI_API_KEY")
|
||||||
|
v.SetDefault("providers.gemini.base_url", "https://generativelanguage.googleapis.com/v1")
|
||||||
|
v.SetDefault("providers.gemini.default_model", "gemini-2.0-flash")
|
||||||
|
v.SetDefault("providers.gemini.enabled", true)
|
||||||
|
|
||||||
|
v.SetDefault("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")
|
||||||
|
v.SetDefault("providers.deepseek.base_url", "https://api.deepseek.com")
|
||||||
|
v.SetDefault("providers.deepseek.default_model", "deepseek-reasoner")
|
||||||
|
v.SetDefault("providers.deepseek.enabled", true)
|
||||||
|
|
||||||
|
v.SetDefault("providers.moonshot.api_key_env", "MOONSHOT_API_KEY")
|
||||||
|
v.SetDefault("providers.moonshot.base_url", "https://api.moonshot.ai/v1")
|
||||||
|
v.SetDefault("providers.moonshot.default_model", "kimi-k2.5")
|
||||||
|
v.SetDefault("providers.moonshot.enabled", true)
|
||||||
|
|
||||||
|
v.SetDefault("providers.grok.api_key_env", "GROK_API_KEY")
|
||||||
|
v.SetDefault("providers.grok.base_url", "https://api.x.ai/v1")
|
||||||
|
v.SetDefault("providers.grok.default_model", "grok-4-1-fast-non-reasoning")
|
||||||
|
v.SetDefault("providers.grok.enabled", true)
|
||||||
|
|
||||||
|
v.SetDefault("providers.ollama.base_url", "http://localhost:11434/v1")
|
||||||
|
v.SetDefault("providers.ollama.enabled", false)
|
||||||
|
v.SetDefault("providers.ollama.models", []string{})
|
||||||
|
|
||||||
|
// Environment variables
|
||||||
|
v.SetEnvPrefix("LLM_PROXY")
|
||||||
|
v.SetEnvKeyReplacer(strings.NewReplacer(".", "__"))
|
||||||
|
v.AutomaticEnv()
|
||||||
|
|
||||||
|
// Explicitly bind keys that might use double underscores in .env
|
||||||
|
v.BindEnv("encryption_key", "LLM_PROXY__ENCRYPTION_KEY")
|
||||||
|
v.BindEnv("server.port", "LLM_PROXY__SERVER__PORT")
|
||||||
|
v.BindEnv("server.host", "LLM_PROXY__SERVER__HOST")
|
||||||
|
|
||||||
|
// Config file
|
||||||
|
v.SetConfigName("config")
|
||||||
|
v.SetConfigType("toml")
|
||||||
|
v.AddConfigPath(".")
|
||||||
|
if envPath := os.Getenv("LLM_PROXY__CONFIG_PATH"); envPath != "" {
|
||||||
|
v.SetConfigFile(envPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := v.ReadInConfig(); err != nil {
|
||||||
|
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||||
|
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var cfg Config
|
||||||
|
if err := v.Unmarshal(&cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Debug Config: port from viper=%d, host from viper=%s\n", cfg.Server.Port, cfg.Server.Host)
|
||||||
|
fmt.Printf("Debug Env: LLM_PROXY__SERVER__PORT=%s, LLM_PROXY__SERVER__HOST=%s\n", os.Getenv("LLM_PROXY__SERVER__PORT"), os.Getenv("LLM_PROXY__SERVER__HOST"))
|
||||||
|
|
||||||
|
// Manual overrides for nested keys which Viper doesn't always bind correctly with AutomaticEnv + SetEnvPrefix
|
||||||
|
if port := os.Getenv("LLM_PROXY__SERVER__PORT"); port != "" {
|
||||||
|
fmt.Sscanf(port, "%d", &cfg.Server.Port)
|
||||||
|
fmt.Printf("Overriding port to %d from env\n", cfg.Server.Port)
|
||||||
|
}
|
||||||
|
if host := os.Getenv("LLM_PROXY__SERVER__HOST"); host != "" {
|
||||||
|
cfg.Server.Host = host
|
||||||
|
fmt.Printf("Overriding host to %s from env\n", cfg.Server.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate encryption key
|
||||||
|
if cfg.EncryptionKey == "" {
|
||||||
|
return nil, fmt.Errorf("encryption key is required (LLM_PROXY__ENCRYPTION_KEY)")
|
||||||
|
}
|
||||||
|
|
||||||
|
keyBytes, err := hex.DecodeString(cfg.EncryptionKey)
|
||||||
|
if err != nil {
|
||||||
|
keyBytes, err = base64.StdEncoding.DecodeString(cfg.EncryptionKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("encryption key must be hex or base64 encoded")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(keyBytes) != 32 {
|
||||||
|
return nil, fmt.Errorf("encryption key must be 32 bytes, got %d", len(keyBytes))
|
||||||
|
}
|
||||||
|
cfg.KeyBytes = keyBytes
|
||||||
|
|
||||||
|
return &cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) GetAPIKey(provider string) (string, error) {
|
||||||
|
var envVar string
|
||||||
|
switch provider {
|
||||||
|
case "openai":
|
||||||
|
envVar = c.Providers.OpenAI.APIKeyEnv
|
||||||
|
case "gemini":
|
||||||
|
envVar = c.Providers.Gemini.APIKeyEnv
|
||||||
|
case "deepseek":
|
||||||
|
envVar = c.Providers.DeepSeek.APIKeyEnv
|
||||||
|
case "moonshot":
|
||||||
|
envVar = c.Providers.Moonshot.APIKeyEnv
|
||||||
|
case "grok":
|
||||||
|
envVar = c.Providers.Grok.APIKeyEnv
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("unknown provider: %s", provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
val := os.Getenv(envVar)
|
||||||
|
if val == "" {
|
||||||
|
return "", fmt.Errorf("environment variable %s not set for %s", envVar, provider)
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(val), nil
|
||||||
|
}
|
||||||
264
internal/db/db.go
Normal file
264
internal/db/db.go
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jmoiron/sqlx"
|
||||||
|
_ "modernc.org/sqlite"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DB struct {
|
||||||
|
*sqlx.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func Init(path string) (*DB, error) {
|
||||||
|
// Ensure directory exists
|
||||||
|
dir := filepath.Dir(path)
|
||||||
|
if dir != "." && dir != "/" {
|
||||||
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create database directory: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to SQLite
|
||||||
|
dsn := fmt.Sprintf("file:%s?_pragma=foreign_keys(1)", path)
|
||||||
|
db, err := sqlx.Connect("sqlite", dsn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
instance := &DB{db}
|
||||||
|
|
||||||
|
// Run migrations
|
||||||
|
if err := instance.RunMigrations(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to run migrations: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return instance, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) RunMigrations() error {
|
||||||
|
// Tables creation
|
||||||
|
queries := []string{
|
||||||
|
`CREATE TABLE IF NOT EXISTS clients (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
client_id TEXT UNIQUE NOT NULL,
|
||||||
|
name TEXT,
|
||||||
|
description TEXT,
|
||||||
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
is_active BOOLEAN DEFAULT TRUE,
|
||||||
|
rate_limit_per_minute INTEGER DEFAULT 60,
|
||||||
|
total_requests INTEGER DEFAULT 0,
|
||||||
|
total_tokens INTEGER DEFAULT 0,
|
||||||
|
total_cost REAL DEFAULT 0.0
|
||||||
|
)`,
|
||||||
|
`CREATE TABLE IF NOT EXISTS llm_requests (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
client_id TEXT,
|
||||||
|
provider TEXT,
|
||||||
|
model TEXT,
|
||||||
|
prompt_tokens INTEGER,
|
||||||
|
completion_tokens INTEGER,
|
||||||
|
reasoning_tokens INTEGER DEFAULT 0,
|
||||||
|
total_tokens INTEGER,
|
||||||
|
cost REAL,
|
||||||
|
has_images BOOLEAN DEFAULT FALSE,
|
||||||
|
status TEXT DEFAULT 'success',
|
||||||
|
error_message TEXT,
|
||||||
|
duration_ms INTEGER,
|
||||||
|
request_body TEXT,
|
||||||
|
response_body TEXT,
|
||||||
|
cache_read_tokens INTEGER DEFAULT 0,
|
||||||
|
cache_write_tokens INTEGER DEFAULT 0,
|
||||||
|
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE SET NULL
|
||||||
|
)`,
|
||||||
|
`CREATE TABLE IF NOT EXISTS provider_configs (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
display_name TEXT NOT NULL,
|
||||||
|
enabled BOOLEAN DEFAULT TRUE,
|
||||||
|
base_url TEXT,
|
||||||
|
api_key TEXT,
|
||||||
|
credit_balance REAL DEFAULT 0.0,
|
||||||
|
low_credit_threshold REAL DEFAULT 5.0,
|
||||||
|
billing_mode TEXT,
|
||||||
|
api_key_encrypted BOOLEAN DEFAULT FALSE,
|
||||||
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)`,
|
||||||
|
`CREATE TABLE IF NOT EXISTS model_configs (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
provider_id TEXT NOT NULL,
|
||||||
|
display_name TEXT,
|
||||||
|
enabled BOOLEAN DEFAULT TRUE,
|
||||||
|
prompt_cost_per_m REAL,
|
||||||
|
completion_cost_per_m REAL,
|
||||||
|
cache_read_cost_per_m REAL,
|
||||||
|
cache_write_cost_per_m REAL,
|
||||||
|
mapping TEXT,
|
||||||
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE
|
||||||
|
)`,
|
||||||
|
`CREATE TABLE IF NOT EXISTS users (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
username TEXT UNIQUE NOT NULL,
|
||||||
|
password_hash TEXT NOT NULL,
|
||||||
|
display_name TEXT,
|
||||||
|
role TEXT DEFAULT 'admin',
|
||||||
|
must_change_password BOOLEAN DEFAULT FALSE,
|
||||||
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)`,
|
||||||
|
`CREATE TABLE IF NOT EXISTS client_tokens (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
client_id TEXT NOT NULL,
|
||||||
|
token TEXT NOT NULL UNIQUE,
|
||||||
|
name TEXT DEFAULT 'default',
|
||||||
|
is_active BOOLEAN DEFAULT TRUE,
|
||||||
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
last_used_at DATETIME,
|
||||||
|
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE CASCADE
|
||||||
|
)`,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, q := range queries {
|
||||||
|
if _, err := db.Exec(q); err != nil {
|
||||||
|
return fmt.Errorf("migration failed for query [%s]: %w", q, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add indices
|
||||||
|
indices := []string{
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)",
|
||||||
|
"CREATE UNIQUE INDEX IF NOT EXISTS idx_client_tokens_token ON client_tokens(token)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_client_tokens_client_id ON client_tokens(client_id)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp)",
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id)",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, idx := range indices {
|
||||||
|
if _, err := db.Exec(idx); err != nil {
|
||||||
|
return fmt.Errorf("failed to create index [%s]: %w", idx, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default admin user
|
||||||
|
var count int
|
||||||
|
if err := db.Get(&count, "SELECT COUNT(*) FROM users"); err != nil {
|
||||||
|
return fmt.Errorf("failed to count users: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if count == 0 {
|
||||||
|
hash, err := bcrypt.GenerateFromPassword([]byte("admin123"), 12)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to hash default password: %w", err)
|
||||||
|
}
|
||||||
|
_, err = db.Exec("INSERT INTO users (username, password_hash, role, must_change_password) VALUES ('admin', ?, 'admin', 1)", string(hash))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to insert default admin: %w", err)
|
||||||
|
}
|
||||||
|
log.Println("Created default admin user with password 'admin123' (must change on first login)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default client
|
||||||
|
_, err := db.Exec(`INSERT OR IGNORE INTO clients (client_id, name, description)
|
||||||
|
VALUES ('default', 'Default Client', 'Default client for anonymous requests')`)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to insert default client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Data models for DB tables
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
ClientID string `db:"client_id"`
|
||||||
|
Name *string `db:"name"`
|
||||||
|
Description *string `db:"description"`
|
||||||
|
CreatedAt time.Time `db:"created_at"`
|
||||||
|
UpdatedAt time.Time `db:"updated_at"`
|
||||||
|
IsActive bool `db:"is_active"`
|
||||||
|
RateLimitPerMinute int `db:"rate_limit_per_minute"`
|
||||||
|
TotalRequests int `db:"total_requests"`
|
||||||
|
TotalTokens int `db:"total_tokens"`
|
||||||
|
TotalCost float64 `db:"total_cost"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type LLMRequest struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
Timestamp time.Time `db:"timestamp"`
|
||||||
|
ClientID *string `db:"client_id"`
|
||||||
|
Provider *string `db:"provider"`
|
||||||
|
Model *string `db:"model"`
|
||||||
|
PromptTokens *int `db:"prompt_tokens"`
|
||||||
|
CompletionTokens *int `db:"completion_tokens"`
|
||||||
|
ReasoningTokens int `db:"reasoning_tokens"`
|
||||||
|
TotalTokens *int `db:"total_tokens"`
|
||||||
|
Cost *float64 `db:"cost"`
|
||||||
|
HasImages bool `db:"has_images"`
|
||||||
|
Status string `db:"status"`
|
||||||
|
ErrorMessage *string `db:"error_message"`
|
||||||
|
DurationMS *int `db:"duration_ms"`
|
||||||
|
RequestBody *string `db:"request_body"`
|
||||||
|
ResponseBody *string `db:"response_body"`
|
||||||
|
CacheReadTokens int `db:"cache_read_tokens"`
|
||||||
|
CacheWriteTokens int `db:"cache_write_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProviderConfig struct {
|
||||||
|
ID string `db:"id"`
|
||||||
|
DisplayName string `db:"display_name"`
|
||||||
|
Enabled bool `db:"enabled"`
|
||||||
|
BaseURL *string `db:"base_url"`
|
||||||
|
APIKey *string `db:"api_key"`
|
||||||
|
CreditBalance float64 `db:"credit_balance"`
|
||||||
|
LowCreditThreshold float64 `db:"low_credit_threshold"`
|
||||||
|
BillingMode *string `db:"billing_mode"`
|
||||||
|
APIKeyEncrypted bool `db:"api_key_encrypted"`
|
||||||
|
UpdatedAt time.Time `db:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelConfig struct {
|
||||||
|
ID string `db:"id"`
|
||||||
|
ProviderID string `db:"provider_id"`
|
||||||
|
DisplayName *string `db:"display_name"`
|
||||||
|
Enabled bool `db:"enabled"`
|
||||||
|
PromptCostPerM *float64 `db:"prompt_cost_per_m"`
|
||||||
|
CompletionCostPerM *float64 `db:"completion_cost_per_m"`
|
||||||
|
CacheReadCostPerM *float64 `db:"cache_read_cost_per_m"`
|
||||||
|
CacheWriteCostPerM *float64 `db:"cache_write_cost_per_m"`
|
||||||
|
Mapping *string `db:"mapping"`
|
||||||
|
UpdatedAt time.Time `db:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID int `db:"id" json:"id"`
|
||||||
|
Username string `db:"username" json:"username"`
|
||||||
|
PasswordHash string `db:"password_hash" json:"-"`
|
||||||
|
DisplayName *string `db:"display_name" json:"display_name"`
|
||||||
|
Role string `db:"role" json:"role"`
|
||||||
|
MustChangePassword bool `db:"must_change_password" json:"must_change_password"`
|
||||||
|
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientToken struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
ClientID string `db:"client_id"`
|
||||||
|
Token string `db:"token"`
|
||||||
|
Name string `db:"name"`
|
||||||
|
IsActive bool `db:"is_active"`
|
||||||
|
CreatedAt time.Time `db:"created_at"`
|
||||||
|
LastUsedAt *time.Time `db:"last_used_at"`
|
||||||
|
}
|
||||||
52
internal/middleware/auth.go
Normal file
52
internal/middleware/auth.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gophergate/internal/db"
|
||||||
|
"gophergate/internal/models"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func AuthMiddleware(database *db.DB) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
authHeader := c.GetHeader("Authorization")
|
||||||
|
if authHeader == "" {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||||
|
if token == authHeader { // No "Bearer " prefix
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to resolve client from database
|
||||||
|
var clientID string
|
||||||
|
err := database.Get(&clientID, "UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = 1 RETURNING client_id", token)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
c.Set("auth", models.AuthInfo{
|
||||||
|
Token: token,
|
||||||
|
ClientID: clientID,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// Fallback to token-prefix derivation (matches Rust behavior)
|
||||||
|
prefixLen := len(token)
|
||||||
|
if prefixLen > 8 {
|
||||||
|
prefixLen = 8
|
||||||
|
}
|
||||||
|
clientID = "client_" + token[:prefixLen]
|
||||||
|
c.Set("auth", models.AuthInfo{
|
||||||
|
Token: token,
|
||||||
|
ClientID: clientID,
|
||||||
|
})
|
||||||
|
log.Printf("Token not found in DB, using fallback client ID: %s", clientID)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
216
internal/models/models.go
Normal file
216
internal/models/models.go
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/go-resty/resty/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenAI-compatible Request/Response Structs
|
||||||
|
|
||||||
|
type ChatCompletionRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []ChatMessage `json:"messages"`
|
||||||
|
Temperature *float64 `json:"temperature,omitempty"`
|
||||||
|
TopP *float64 `json:"top_p,omitempty"`
|
||||||
|
TopK *uint32 `json:"top_k,omitempty"`
|
||||||
|
N *uint32 `json:"n,omitempty"`
|
||||||
|
Stop json.RawMessage `json:"stop,omitempty"` // Can be string or array of strings
|
||||||
|
MaxTokens *uint32 `json:"max_tokens,omitempty"`
|
||||||
|
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||||
|
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||||
|
Stream *bool `json:"stream,omitempty"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatMessage struct {
|
||||||
|
Role string `json:"role"` // "system", "user", "assistant", "tool"
|
||||||
|
Content interface{} `json:"content"`
|
||||||
|
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
Name *string `json:"name,omitempty"`
|
||||||
|
ToolCallID *string `json:"tool_call_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ContentPart struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
ImageUrl *ImageUrl `json:"image_url,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageUrl struct {
|
||||||
|
URL string `json:"url"`
|
||||||
|
Detail *string `json:"detail,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool-Calling Types
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function FunctionDef `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionDef struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description *string `json:"description,omitempty"`
|
||||||
|
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function FunctionCall `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionCall struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCallDelta struct {
|
||||||
|
Index uint32 `json:"index"`
|
||||||
|
ID *string `json:"id,omitempty"`
|
||||||
|
Type *string `json:"type,omitempty"`
|
||||||
|
Function *FunctionCallDelta `json:"function,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionCallDelta struct {
|
||||||
|
Name *string `json:"name,omitempty"`
|
||||||
|
Arguments *string `json:"arguments,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAI-compatible Response Structs
|
||||||
|
|
||||||
|
type ChatCompletionResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []ChatChoice `json:"choices"`
|
||||||
|
Usage *Usage `json:"usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatChoice struct {
|
||||||
|
Index uint32 `json:"index"`
|
||||||
|
Message ChatMessage `json:"message"`
|
||||||
|
FinishReason *string `json:"finish_reason,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Usage struct {
|
||||||
|
PromptTokens uint32 `json:"prompt_tokens"`
|
||||||
|
CompletionTokens uint32 `json:"completion_tokens"`
|
||||||
|
TotalTokens uint32 `json:"total_tokens"`
|
||||||
|
ReasoningTokens *uint32 `json:"reasoning_tokens,omitempty"`
|
||||||
|
CacheReadTokens *uint32 `json:"cache_read_tokens,omitempty"`
|
||||||
|
CacheWriteTokens *uint32 `json:"cache_write_tokens,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Streaming Response Structs
|
||||||
|
|
||||||
|
type ChatCompletionStreamResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []ChatStreamChoice `json:"choices"`
|
||||||
|
Usage *Usage `json:"usage,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatStreamChoice struct {
|
||||||
|
Index uint32 `json:"index"`
|
||||||
|
Delta ChatStreamDelta `json:"delta"`
|
||||||
|
FinishReason *string `json:"finish_reason,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatStreamDelta struct {
|
||||||
|
Role *string `json:"role,omitempty"`
|
||||||
|
Content *string `json:"content,omitempty"`
|
||||||
|
ReasoningContent *string `json:"reasoning_content,omitempty"`
|
||||||
|
ToolCalls []ToolCallDelta `json:"tool_calls,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type StreamUsage struct {
|
||||||
|
PromptTokens uint32 `json:"prompt_tokens"`
|
||||||
|
CompletionTokens uint32 `json:"completion_tokens"`
|
||||||
|
TotalTokens uint32 `json:"total_tokens"`
|
||||||
|
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||||
|
CacheReadTokens uint32 `json:"cache_read_tokens"`
|
||||||
|
CacheWriteTokens uint32 `json:"cache_write_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unified Request Format (for internal use)
|
||||||
|
|
||||||
|
type UnifiedRequest struct {
|
||||||
|
ClientID string
|
||||||
|
Model string
|
||||||
|
Messages []UnifiedMessage
|
||||||
|
Temperature *float64
|
||||||
|
TopP *float64
|
||||||
|
TopK *uint32
|
||||||
|
N *uint32
|
||||||
|
Stop []string
|
||||||
|
MaxTokens *uint32
|
||||||
|
PresencePenalty *float64
|
||||||
|
FrequencyPenalty *float64
|
||||||
|
Stream bool
|
||||||
|
HasImages bool
|
||||||
|
Tools []Tool
|
||||||
|
ToolChoice json.RawMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
type UnifiedMessage struct {
|
||||||
|
Role string
|
||||||
|
Content []UnifiedContentPart
|
||||||
|
ReasoningContent *string
|
||||||
|
ToolCalls []ToolCall
|
||||||
|
Name *string
|
||||||
|
ToolCallID *string
|
||||||
|
}
|
||||||
|
|
||||||
|
type UnifiedContentPart struct {
|
||||||
|
Type string
|
||||||
|
Text string
|
||||||
|
Image *ImageInput
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageInput struct {
|
||||||
|
Base64 string `json:"base64,omitempty"`
|
||||||
|
URL string `json:"url,omitempty"`
|
||||||
|
MimeType string `json:"mime_type,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *ImageInput) ToBase64() (string, string, error) {
|
||||||
|
if i.Base64 != "" {
|
||||||
|
return i.Base64, i.MimeType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if i.URL != "" {
|
||||||
|
client := resty.New()
|
||||||
|
resp, err := client.R().Get(i.URL)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("failed to fetch image: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.IsSuccess() {
|
||||||
|
return "", "", fmt.Errorf("failed to fetch image: HTTP %d", resp.StatusCode())
|
||||||
|
}
|
||||||
|
|
||||||
|
mimeType := resp.Header().Get("Content-Type")
|
||||||
|
if mimeType == "" {
|
||||||
|
mimeType = "image/jpeg"
|
||||||
|
}
|
||||||
|
|
||||||
|
encoded := base64.StdEncoding.EncodeToString(resp.Body())
|
||||||
|
return encoded, mimeType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", "", fmt.Errorf("empty image input")
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthInfo for context
|
||||||
|
type AuthInfo struct {
|
||||||
|
Token string
|
||||||
|
ClientID string
|
||||||
|
}
|
||||||
69
internal/models/registry.go
Normal file
69
internal/models/registry.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
type ModelRegistry struct {
|
||||||
|
Providers map[string]ProviderInfo `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProviderInfo struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Models map[string]ModelMetadata `json:"models"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelMetadata struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Cost *ModelCost `json:"cost,omitempty"`
|
||||||
|
Limit *ModelLimit `json:"limit,omitempty"`
|
||||||
|
Modalities *ModelModalities `json:"modalities,omitempty"`
|
||||||
|
ToolCall *bool `json:"tool_call,omitempty"`
|
||||||
|
Reasoning *bool `json:"reasoning,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelCost struct {
|
||||||
|
Input float64 `json:"input"`
|
||||||
|
Output float64 `json:"output"`
|
||||||
|
CacheRead *float64 `json:"cache_read,omitempty"`
|
||||||
|
CacheWrite *float64 `json:"cache_write,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelLimit struct {
|
||||||
|
Context uint32 `json:"context"`
|
||||||
|
Output uint32 `json:"output"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelModalities struct {
|
||||||
|
Input []string `json:"input"`
|
||||||
|
Output []string `json:"output"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ModelRegistry) FindModel(modelID string) *ModelMetadata {
|
||||||
|
// First try exact match in models map
|
||||||
|
for _, provider := range r.Providers {
|
||||||
|
if model, ok := provider.Models[modelID]; ok {
|
||||||
|
return &model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try searching by ID in metadata
|
||||||
|
for _, provider := range r.Providers {
|
||||||
|
for _, model := range provider.Models {
|
||||||
|
if model.ID == modelID {
|
||||||
|
return &model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try fuzzy matching (e.g. gpt-4o-2024-05-13 matching gpt-4o)
|
||||||
|
for _, provider := range r.Providers {
|
||||||
|
for id, model := range provider.Models {
|
||||||
|
if strings.HasPrefix(modelID, id) {
|
||||||
|
return &model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
220
internal/providers/deepseek.go
Normal file
220
internal/providers/deepseek.go
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gophergate/internal/config"
|
||||||
|
"gophergate/internal/models"
|
||||||
|
"github.com/go-resty/resty/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DeepSeekProvider struct {
|
||||||
|
client *resty.Client
|
||||||
|
config config.DeepSeekConfig
|
||||||
|
apiKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDeepSeekProvider(cfg config.DeepSeekConfig, apiKey string) *DeepSeekProvider {
|
||||||
|
return &DeepSeekProvider{
|
||||||
|
client: resty.New(),
|
||||||
|
config: cfg,
|
||||||
|
apiKey: apiKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DeepSeekProvider) Name() string {
|
||||||
|
return "deepseek"
|
||||||
|
}
|
||||||
|
|
||||||
|
type deepSeekUsage struct {
|
||||||
|
PromptTokens uint32 `json:"prompt_tokens"`
|
||||||
|
CompletionTokens uint32 `json:"completion_tokens"`
|
||||||
|
TotalTokens uint32 `json:"total_tokens"`
|
||||||
|
PromptCacheHitTokens uint32 `json:"prompt_cache_hit_tokens"`
|
||||||
|
PromptCacheMissTokens uint32 `json:"prompt_cache_miss_tokens"`
|
||||||
|
CompletionTokensDetails *struct {
|
||||||
|
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||||
|
} `json:"completion_tokens_details"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *deepSeekUsage) ToUnified() *models.Usage {
|
||||||
|
usage := &models.Usage{
|
||||||
|
PromptTokens: u.PromptTokens,
|
||||||
|
CompletionTokens: u.CompletionTokens,
|
||||||
|
TotalTokens: u.TotalTokens,
|
||||||
|
}
|
||||||
|
if u.PromptCacheHitTokens > 0 {
|
||||||
|
usage.CacheReadTokens = &u.PromptCacheHitTokens
|
||||||
|
}
|
||||||
|
if u.PromptCacheMissTokens > 0 {
|
||||||
|
usage.CacheWriteTokens = &u.PromptCacheMissTokens
|
||||||
|
}
|
||||||
|
if u.CompletionTokensDetails != nil && u.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||||
|
usage.ReasoningTokens = &u.CompletionTokensDetails.ReasoningTokens
|
||||||
|
}
|
||||||
|
return usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||||
|
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||||
|
|
||||||
|
// Sanitize for deepseek-reasoner
|
||||||
|
if req.Model == "deepseek-reasoner" {
|
||||||
|
delete(body, "temperature")
|
||||||
|
delete(body, "top_p")
|
||||||
|
delete(body, "presence_penalty")
|
||||||
|
delete(body, "frequency_penalty")
|
||||||
|
|
||||||
|
if msgs, ok := body["messages"].([]interface{}); ok {
|
||||||
|
for _, m := range msgs {
|
||||||
|
if msg, ok := m.(map[string]interface{}); ok {
|
||||||
|
if msg["role"] == "assistant" {
|
||||||
|
if msg["reasoning_content"] == nil {
|
||||||
|
msg["reasoning_content"] = " "
|
||||||
|
}
|
||||||
|
if msg["content"] == nil || msg["content"] == "" {
|
||||||
|
msg["content"] = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := p.client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||||
|
SetBody(body).
|
||||||
|
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.IsSuccess() {
|
||||||
|
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var respJSON map[string]interface{}
|
||||||
|
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := ParseOpenAIResponse(respJSON, req.Model)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fix usage for DeepSeek specifically if details were missing in ParseOpenAIResponse
|
||||||
|
if usageData, ok := respJSON["usage"]; ok {
|
||||||
|
var dUsage deepSeekUsage
|
||||||
|
usageBytes, _ := json.Marshal(usageData)
|
||||||
|
if err := json.Unmarshal(usageBytes, &dUsage); err == nil {
|
||||||
|
result.Usage = dUsage.ToUnified()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||||
|
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||||
|
|
||||||
|
// Sanitize for deepseek-reasoner
|
||||||
|
if req.Model == "deepseek-reasoner" {
|
||||||
|
delete(body, "temperature")
|
||||||
|
delete(body, "top_p")
|
||||||
|
delete(body, "presence_penalty")
|
||||||
|
delete(body, "frequency_penalty")
|
||||||
|
|
||||||
|
if msgs, ok := body["messages"].([]interface{}); ok {
|
||||||
|
for _, m := range msgs {
|
||||||
|
if msg, ok := m.(map[string]interface{}); ok {
|
||||||
|
if msg["role"] == "assistant" {
|
||||||
|
if msg["reasoning_content"] == nil {
|
||||||
|
msg["reasoning_content"] = " "
|
||||||
|
}
|
||||||
|
if msg["content"] == nil || msg["content"] == "" {
|
||||||
|
msg["content"] = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := p.client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetDoNotParseResponse(true).
|
||||||
|
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.IsSuccess() {
|
||||||
|
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
// Custom scanner loop to handle DeepSeek specific usage in chunks
|
||||||
|
err := StreamDeepSeek(resp.RawBody(), ch)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("DeepSeek Stream error: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func StreamDeepSeek(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse) error {
|
||||||
|
defer ctx.Close()
|
||||||
|
scanner := bufio.NewScanner(ctx)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
if line == "" || !strings.HasPrefix(line, "data: ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
data := strings.TrimPrefix(line, "data: ")
|
||||||
|
if data == "[DONE]" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var chunk models.ChatCompletionStreamResponse
|
||||||
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fix DeepSeek specific usage in stream
|
||||||
|
var rawChunk struct {
|
||||||
|
Usage *deepSeekUsage `json:"usage"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil {
|
||||||
|
chunk.Usage = rawChunk.Usage.ToUnified()
|
||||||
|
}
|
||||||
|
|
||||||
|
ch <- &chunk
|
||||||
|
}
|
||||||
|
return scanner.Err()
|
||||||
|
}
|
||||||
254
internal/providers/gemini.go
Normal file
254
internal/providers/gemini.go
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"gophergate/internal/config"
|
||||||
|
"gophergate/internal/models"
|
||||||
|
"github.com/go-resty/resty/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GeminiProvider struct {
|
||||||
|
client *resty.Client
|
||||||
|
config config.GeminiConfig
|
||||||
|
apiKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGeminiProvider(cfg config.GeminiConfig, apiKey string) *GeminiProvider {
|
||||||
|
return &GeminiProvider{
|
||||||
|
client: resty.New(),
|
||||||
|
config: cfg,
|
||||||
|
apiKey: apiKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *GeminiProvider) Name() string {
|
||||||
|
return "gemini"
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiRequest struct {
|
||||||
|
Contents []GeminiContent `json:"contents"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiContent struct {
|
||||||
|
Role string `json:"role,omitempty"`
|
||||||
|
Parts []GeminiPart `json:"parts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiPart struct {
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||||
|
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
|
||||||
|
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiInlineData struct {
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
Data string `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiFunctionCall struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Args json.RawMessage `json:"args"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiFunctionResponse struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Response json.RawMessage `json:"response"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||||
|
// Gemini mapping
|
||||||
|
var contents []GeminiContent
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
role := "user"
|
||||||
|
if msg.Role == "assistant" {
|
||||||
|
role = "model"
|
||||||
|
} else if msg.Role == "tool" {
|
||||||
|
role = "user" // Tool results are user-side in Gemini
|
||||||
|
}
|
||||||
|
|
||||||
|
var parts []GeminiPart
|
||||||
|
|
||||||
|
// Handle tool responses
|
||||||
|
if msg.Role == "tool" {
|
||||||
|
text := ""
|
||||||
|
if len(msg.Content) > 0 {
|
||||||
|
text = msg.Content[0].Text
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gemini expects functionResponse to be an object
|
||||||
|
name := "unknown_function"
|
||||||
|
if msg.Name != nil {
|
||||||
|
name = *msg.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
FunctionResponse: &GeminiFunctionResponse{
|
||||||
|
Name: name,
|
||||||
|
Response: json.RawMessage(text),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
for _, cp := range msg.Content {
|
||||||
|
if cp.Type == "text" {
|
||||||
|
parts = append(parts, GeminiPart{Text: cp.Text})
|
||||||
|
} else if cp.Image != nil {
|
||||||
|
base64Data, mimeType, _ := cp.Image.ToBase64()
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
InlineData: &GeminiInlineData{
|
||||||
|
MimeType: mimeType,
|
||||||
|
Data: base64Data,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle assistant tool calls
|
||||||
|
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
FunctionCall: &GeminiFunctionCall{
|
||||||
|
Name: tc.Function.Name,
|
||||||
|
Args: json.RawMessage(tc.Function.Arguments),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
contents = append(contents, GeminiContent{
|
||||||
|
Role: role,
|
||||||
|
Parts: parts,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
body := GeminiRequest{
|
||||||
|
Contents: contents,
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
|
||||||
|
|
||||||
|
resp, err := p.client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetBody(body).
|
||||||
|
Post(url)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.IsSuccess() {
|
||||||
|
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse Gemini response and convert to OpenAI format
|
||||||
|
var geminiResp struct {
|
||||||
|
Candidates []struct {
|
||||||
|
Content struct {
|
||||||
|
Parts []struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"parts"`
|
||||||
|
} `json:"content"`
|
||||||
|
FinishReason string `json:"finishReason"`
|
||||||
|
} `json:"candidates"`
|
||||||
|
UsageMetadata struct {
|
||||||
|
PromptTokenCount uint32 `json:"promptTokenCount"`
|
||||||
|
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
||||||
|
TotalTokenCount uint32 `json:"totalTokenCount"`
|
||||||
|
} `json:"usageMetadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(resp.Body(), &geminiResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(geminiResp.Candidates) == 0 {
|
||||||
|
return nil, fmt.Errorf("no candidates in Gemini response")
|
||||||
|
}
|
||||||
|
|
||||||
|
content := ""
|
||||||
|
for _, p := range geminiResp.Candidates[0].Content.Parts {
|
||||||
|
content += p.Text
|
||||||
|
}
|
||||||
|
|
||||||
|
openAIResp := &models.ChatCompletionResponse{
|
||||||
|
ID: "gemini-" + req.Model,
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: 0, // Should be current timestamp
|
||||||
|
Model: req.Model,
|
||||||
|
Choices: []models.ChatChoice{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Message: models.ChatMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: content,
|
||||||
|
},
|
||||||
|
FinishReason: &geminiResp.Candidates[0].FinishReason,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Usage: &models.Usage{
|
||||||
|
PromptTokens: geminiResp.UsageMetadata.PromptTokenCount,
|
||||||
|
CompletionTokens: geminiResp.UsageMetadata.CandidatesTokenCount,
|
||||||
|
TotalTokens: geminiResp.UsageMetadata.TotalTokenCount,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return openAIResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||||
|
// Simplified Gemini mapping
|
||||||
|
var contents []GeminiContent
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
role := "user"
|
||||||
|
if msg.Role == "assistant" {
|
||||||
|
role = "model"
|
||||||
|
}
|
||||||
|
|
||||||
|
var parts []GeminiPart
|
||||||
|
for _, p := range msg.Content {
|
||||||
|
parts = append(parts, GeminiPart{Text: p.Text})
|
||||||
|
}
|
||||||
|
|
||||||
|
contents = append(contents, GeminiContent{
|
||||||
|
Role: role,
|
||||||
|
Parts: parts,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
body := GeminiRequest{
|
||||||
|
Contents: contents,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use streamGenerateContent for streaming
|
||||||
|
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
|
||||||
|
|
||||||
|
resp, err := p.client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetBody(body).
|
||||||
|
SetDoNotParseResponse(true).
|
||||||
|
Post(url)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.IsSuccess() {
|
||||||
|
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
err := StreamGemini(resp.RawBody(), ch, req.Model)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Gemini Stream error: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
95
internal/providers/grok.go
Normal file
95
internal/providers/grok.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"gophergate/internal/config"
|
||||||
|
"gophergate/internal/models"
|
||||||
|
"github.com/go-resty/resty/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GrokProvider struct {
|
||||||
|
client *resty.Client
|
||||||
|
config config.GrokConfig
|
||||||
|
apiKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGrokProvider(cfg config.GrokConfig, apiKey string) *GrokProvider {
|
||||||
|
return &GrokProvider{
|
||||||
|
client: resty.New(),
|
||||||
|
config: cfg,
|
||||||
|
apiKey: apiKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *GrokProvider) Name() string {
|
||||||
|
return "grok"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *GrokProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||||
|
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||||
|
|
||||||
|
resp, err := p.client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||||
|
SetBody(body).
|
||||||
|
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.IsSuccess() {
|
||||||
|
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var respJSON map[string]interface{}
|
||||||
|
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ParseOpenAIResponse(respJSON, req.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *GrokProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||||
|
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||||
|
|
||||||
|
resp, err := p.client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetDoNotParseResponse(true).
|
||||||
|
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.IsSuccess() {
|
||||||
|
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
err := StreamOpenAI(resp.RawBody(), ch)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Grok Stream error: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
318
internal/providers/helpers.go
Normal file
318
internal/providers/helpers.go
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gophergate/internal/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images.
|
||||||
|
func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) {
|
||||||
|
var result []interface{}
|
||||||
|
for _, m := range messages {
|
||||||
|
if m.Role == "tool" {
|
||||||
|
text := ""
|
||||||
|
if len(m.Content) > 0 {
|
||||||
|
text = m.Content[0].Text
|
||||||
|
}
|
||||||
|
msg := map[string]interface{}{
|
||||||
|
"role": "tool",
|
||||||
|
"content": text,
|
||||||
|
}
|
||||||
|
if m.ToolCallID != nil {
|
||||||
|
id := *m.ToolCallID
|
||||||
|
if len(id) > 40 {
|
||||||
|
id = id[:40]
|
||||||
|
}
|
||||||
|
msg["tool_call_id"] = id
|
||||||
|
}
|
||||||
|
if m.Name != nil {
|
||||||
|
msg["name"] = *m.Name
|
||||||
|
}
|
||||||
|
result = append(result, msg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var parts []interface{}
|
||||||
|
for _, p := range m.Content {
|
||||||
|
if p.Type == "text" {
|
||||||
|
parts = append(parts, map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": p.Text,
|
||||||
|
})
|
||||||
|
} else if p.Image != nil {
|
||||||
|
base64Data, mimeType, err := p.Image.ToBase64()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to convert image to base64: %w", err)
|
||||||
|
}
|
||||||
|
parts = append(parts, map[string]interface{}{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": map[string]interface{}{
|
||||||
|
"url": fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var finalContent interface{}
|
||||||
|
if len(parts) == 1 {
|
||||||
|
if p, ok := parts[0].(map[string]interface{}); ok && p["type"] == "text" {
|
||||||
|
finalContent = p["text"]
|
||||||
|
} else {
|
||||||
|
finalContent = parts
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
finalContent = parts
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := map[string]interface{}{
|
||||||
|
"role": m.Role,
|
||||||
|
"content": finalContent,
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.ReasoningContent != nil {
|
||||||
|
msg["reasoning_content"] = *m.ReasoningContent
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.ToolCalls) > 0 {
|
||||||
|
sanitizedCalls := make([]models.ToolCall, len(m.ToolCalls))
|
||||||
|
copy(sanitizedCalls, m.ToolCalls)
|
||||||
|
for i := range sanitizedCalls {
|
||||||
|
if len(sanitizedCalls[i].ID) > 40 {
|
||||||
|
sanitizedCalls[i].ID = sanitizedCalls[i].ID[:40]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
msg["tool_calls"] = sanitizedCalls
|
||||||
|
if len(parts) == 0 {
|
||||||
|
msg["content"] = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Name != nil {
|
||||||
|
msg["name"] = *m.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
result = append(result, msg)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{}, stream bool) map[string]interface{} {
|
||||||
|
body := map[string]interface{}{
|
||||||
|
"model": request.Model,
|
||||||
|
"messages": messagesJSON,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
if stream {
|
||||||
|
body["stream_options"] = map[string]interface{}{
|
||||||
|
"include_usage": true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.Temperature != nil {
|
||||||
|
body["temperature"] = *request.Temperature
|
||||||
|
}
|
||||||
|
if request.MaxTokens != nil {
|
||||||
|
body["max_tokens"] = *request.MaxTokens
|
||||||
|
}
|
||||||
|
if len(request.Tools) > 0 {
|
||||||
|
body["tools"] = request.Tools
|
||||||
|
}
|
||||||
|
if request.ToolChoice != nil {
|
||||||
|
var toolChoice interface{}
|
||||||
|
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
|
||||||
|
body["tool_choice"] = toolChoice
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAIUsage struct {
|
||||||
|
PromptTokens uint32 `json:"prompt_tokens"`
|
||||||
|
CompletionTokens uint32 `json:"completion_tokens"`
|
||||||
|
TotalTokens uint32 `json:"total_tokens"`
|
||||||
|
PromptTokensDetails *struct {
|
||||||
|
CachedTokens uint32 `json:"cached_tokens"`
|
||||||
|
} `json:"prompt_tokens_details"`
|
||||||
|
CompletionTokensDetails *struct {
|
||||||
|
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||||
|
} `json:"completion_tokens_details"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *openAIUsage) ToUnified() *models.Usage {
|
||||||
|
usage := &models.Usage{
|
||||||
|
PromptTokens: u.PromptTokens,
|
||||||
|
CompletionTokens: u.CompletionTokens,
|
||||||
|
TotalTokens: u.TotalTokens,
|
||||||
|
}
|
||||||
|
if u.PromptTokensDetails != nil && u.PromptTokensDetails.CachedTokens > 0 {
|
||||||
|
usage.CacheReadTokens = &u.PromptTokensDetails.CachedTokens
|
||||||
|
}
|
||||||
|
if u.CompletionTokensDetails != nil && u.CompletionTokensDetails.ReasoningTokens > 0 {
|
||||||
|
usage.ReasoningTokens = &u.CompletionTokensDetails.ReasoningTokens
|
||||||
|
}
|
||||||
|
return usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models.ChatCompletionResponse, error) {
|
||||||
|
data, err := json.Marshal(respJSON)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp models.ChatCompletionResponse
|
||||||
|
if err := json.Unmarshal(data, &resp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manually fix usage because ChatCompletionResponse uses the unified Usage struct
|
||||||
|
// but the provider might have returned more details.
|
||||||
|
if usageData, ok := respJSON["usage"]; ok {
|
||||||
|
var oUsage openAIUsage
|
||||||
|
usageBytes, _ := json.Marshal(usageData)
|
||||||
|
if err := json.Unmarshal(usageBytes, &oUsage); err == nil {
|
||||||
|
resp.Usage = oUsage.ToUnified()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Streaming support
|
||||||
|
|
||||||
|
func ParseOpenAIStreamChunk(line string) (*models.ChatCompletionStreamResponse, bool, error) {
|
||||||
|
if line == "" {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(line, "data: ") {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
data := strings.TrimPrefix(line, "data: ")
|
||||||
|
if data == "[DONE]" {
|
||||||
|
return nil, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var chunk models.ChatCompletionStreamResponse
|
||||||
|
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
|
||||||
|
return nil, false, fmt.Errorf("failed to unmarshal stream chunk: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle specialized usage in stream chunks
|
||||||
|
var rawChunk struct {
|
||||||
|
Usage *openAIUsage `json:"usage"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(data), &rawChunk); err == nil && rawChunk.Usage != nil {
|
||||||
|
chunk.Usage = rawChunk.Usage.ToUnified()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &chunk, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func StreamOpenAI(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse) error {
|
||||||
|
defer ctx.Close()
|
||||||
|
scanner := bufio.NewScanner(ctx)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
chunk, done, err := ParseOpenAIStreamChunk(line)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if chunk != nil {
|
||||||
|
ch <- chunk
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return scanner.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error {
|
||||||
|
defer ctx.Close()
|
||||||
|
|
||||||
|
dec := json.NewDecoder(ctx)
|
||||||
|
|
||||||
|
t, err := dec.Token()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if delim, ok := t.(json.Delim); ok && delim == '[' {
|
||||||
|
for dec.More() {
|
||||||
|
var geminiChunk struct {
|
||||||
|
Candidates []struct {
|
||||||
|
Content struct {
|
||||||
|
Parts []struct {
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
Thought string `json:"thought,omitempty"`
|
||||||
|
} `json:"parts"`
|
||||||
|
} `json:"content"`
|
||||||
|
FinishReason string `json:"finishReason"`
|
||||||
|
} `json:"candidates"`
|
||||||
|
UsageMetadata struct {
|
||||||
|
PromptTokenCount uint32 `json:"promptTokenCount"`
|
||||||
|
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
|
||||||
|
TotalTokenCount uint32 `json:"totalTokenCount"`
|
||||||
|
} `json:"usageMetadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := dec.Decode(&geminiChunk); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(geminiChunk.Candidates) > 0 || geminiChunk.UsageMetadata.TotalTokenCount > 0 {
|
||||||
|
content := ""
|
||||||
|
var reasoning *string
|
||||||
|
if len(geminiChunk.Candidates) > 0 {
|
||||||
|
for _, p := range geminiChunk.Candidates[0].Content.Parts {
|
||||||
|
if p.Text != "" {
|
||||||
|
content += p.Text
|
||||||
|
}
|
||||||
|
if p.Thought != "" {
|
||||||
|
if reasoning == nil {
|
||||||
|
reasoning = new(string)
|
||||||
|
}
|
||||||
|
*reasoning += p.Thought
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var finishReason *string
|
||||||
|
if len(geminiChunk.Candidates) > 0 {
|
||||||
|
fr := strings.ToLower(geminiChunk.Candidates[0].FinishReason)
|
||||||
|
finishReason = &fr
|
||||||
|
}
|
||||||
|
|
||||||
|
ch <- &models.ChatCompletionStreamResponse{
|
||||||
|
ID: "gemini-stream",
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: 0,
|
||||||
|
Model: model,
|
||||||
|
Choices: []models.ChatStreamChoice{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Delta: models.ChatStreamDelta{
|
||||||
|
Content: &content,
|
||||||
|
ReasoningContent: reasoning,
|
||||||
|
},
|
||||||
|
FinishReason: finishReason,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Usage: &models.Usage{
|
||||||
|
PromptTokens: geminiChunk.UsageMetadata.PromptTokenCount,
|
||||||
|
CompletionTokens: geminiChunk.UsageMetadata.CandidatesTokenCount,
|
||||||
|
TotalTokens: geminiChunk.UsageMetadata.TotalTokenCount,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
114
internal/providers/moonshot.go
Normal file
114
internal/providers/moonshot.go
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gophergate/internal/config"
|
||||||
|
"gophergate/internal/models"
|
||||||
|
"github.com/go-resty/resty/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MoonshotProvider struct {
|
||||||
|
client *resty.Client
|
||||||
|
config config.MoonshotConfig
|
||||||
|
apiKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMoonshotProvider(cfg config.MoonshotConfig, apiKey string) *MoonshotProvider {
|
||||||
|
return &MoonshotProvider{
|
||||||
|
client: resty.New(),
|
||||||
|
config: cfg,
|
||||||
|
apiKey: strings.TrimSpace(apiKey),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MoonshotProvider) Name() string {
|
||||||
|
return "moonshot"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MoonshotProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||||
|
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||||
|
if strings.Contains(strings.ToLower(req.Model), "kimi-k2.5") {
|
||||||
|
if maxTokens, ok := body["max_tokens"]; ok {
|
||||||
|
delete(body, "max_tokens")
|
||||||
|
body["max_completion_tokens"] = maxTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := strings.TrimRight(p.config.BaseURL, "/")
|
||||||
|
|
||||||
|
resp, err := p.client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||||
|
SetHeader("Content-Type", "application/json").
|
||||||
|
SetHeader("Accept", "application/json").
|
||||||
|
SetBody(body).
|
||||||
|
Post(fmt.Sprintf("%s/chat/completions", baseURL))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.IsSuccess() {
|
||||||
|
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var respJSON map[string]interface{}
|
||||||
|
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ParseOpenAIResponse(respJSON, req.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MoonshotProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||||
|
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||||
|
if strings.Contains(strings.ToLower(req.Model), "kimi-k2.5") {
|
||||||
|
if maxTokens, ok := body["max_tokens"]; ok {
|
||||||
|
delete(body, "max_tokens")
|
||||||
|
body["max_completion_tokens"] = maxTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := strings.TrimRight(p.config.BaseURL, "/")
|
||||||
|
|
||||||
|
resp, err := p.client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||||
|
SetHeader("Content-Type", "application/json").
|
||||||
|
SetHeader("Accept", "text/event-stream").
|
||||||
|
SetBody(body).
|
||||||
|
SetDoNotParseResponse(true).
|
||||||
|
Post(fmt.Sprintf("%s/chat/completions", baseURL))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.IsSuccess() {
|
||||||
|
return nil, fmt.Errorf("Moonshot API error (%d): %s", resp.StatusCode(), resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
if err := StreamOpenAI(resp.RawBody(), ch); err != nil {
|
||||||
|
fmt.Printf("Moonshot Stream error: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
113
internal/providers/openai.go
Normal file
113
internal/providers/openai.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gophergate/internal/config"
|
||||||
|
"gophergate/internal/models"
|
||||||
|
"github.com/go-resty/resty/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OpenAIProvider struct {
|
||||||
|
client *resty.Client
|
||||||
|
config config.OpenAIConfig
|
||||||
|
apiKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOpenAIProvider(cfg config.OpenAIConfig, apiKey string) *OpenAIProvider {
|
||||||
|
return &OpenAIProvider{
|
||||||
|
client: resty.New(),
|
||||||
|
config: cfg,
|
||||||
|
apiKey: apiKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OpenAIProvider) Name() string {
|
||||||
|
return "openai"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
|
||||||
|
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := BuildOpenAIBody(req, messagesJSON, false)
|
||||||
|
|
||||||
|
// Transition: Newer models require max_completion_tokens
|
||||||
|
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
|
||||||
|
if maxTokens, ok := body["max_tokens"]; ok {
|
||||||
|
delete(body, "max_tokens")
|
||||||
|
body["max_completion_tokens"] = maxTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := p.client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||||
|
SetBody(body).
|
||||||
|
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.IsSuccess() {
|
||||||
|
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var respJSON map[string]interface{}
|
||||||
|
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ParseOpenAIResponse(respJSON, req.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
|
||||||
|
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to convert messages: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := BuildOpenAIBody(req, messagesJSON, true)
|
||||||
|
|
||||||
|
// Transition: Newer models require max_completion_tokens
|
||||||
|
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
|
||||||
|
if maxTokens, ok := body["max_tokens"]; ok {
|
||||||
|
delete(body, "max_tokens")
|
||||||
|
body["max_completion_tokens"] = maxTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := p.client.R().
|
||||||
|
SetContext(ctx).
|
||||||
|
SetHeader("Authorization", "Bearer "+p.apiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetDoNotParseResponse(true).
|
||||||
|
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.IsSuccess() {
|
||||||
|
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make(chan *models.ChatCompletionStreamResponse)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
err := StreamOpenAI(resp.RawBody(), ch)
|
||||||
|
if err != nil {
|
||||||
|
// In a real app, you might want to send an error chunk or log it
|
||||||
|
fmt.Printf("Stream error: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
13
internal/providers/provider.go
Normal file
13
internal/providers/provider.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"gophergate/internal/models"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Provider interface {
|
||||||
|
Name() string
|
||||||
|
ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error)
|
||||||
|
ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error)
|
||||||
|
}
|
||||||
1393
internal/server/dashboard.go
Normal file
1393
internal/server/dashboard.go
Normal file
File diff suppressed because it is too large
Load Diff
113
internal/server/logging.go
Normal file
113
internal/server/logging.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gophergate/internal/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestLog struct {
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
Provider string `json:"provider"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
PromptTokens uint32 `json:"prompt_tokens"`
|
||||||
|
CompletionTokens uint32 `json:"completion_tokens"`
|
||||||
|
ReasoningTokens uint32 `json:"reasoning_tokens"`
|
||||||
|
TotalTokens uint32 `json:"total_tokens"`
|
||||||
|
CacheReadTokens uint32 `json:"cache_read_tokens"`
|
||||||
|
CacheWriteTokens uint32 `json:"cache_write_tokens"`
|
||||||
|
Cost float64 `json:"cost"`
|
||||||
|
HasImages bool `json:"has_images"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
ErrorMessage string `json:"error_message,omitempty"`
|
||||||
|
DurationMS int64 `json:"duration_ms"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestLogger struct {
|
||||||
|
database *db.DB
|
||||||
|
hub *Hub
|
||||||
|
logChan chan RequestLog
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRequestLogger(database *db.DB, hub *Hub) *RequestLogger {
|
||||||
|
return &RequestLogger{
|
||||||
|
database: database,
|
||||||
|
hub: hub,
|
||||||
|
logChan: make(chan RequestLog, 100),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *RequestLogger) Start() {
|
||||||
|
go func() {
|
||||||
|
for entry := range l.logChan {
|
||||||
|
l.processLog(entry)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *RequestLogger) LogRequest(entry RequestLog) {
|
||||||
|
select {
|
||||||
|
case l.logChan <- entry:
|
||||||
|
default:
|
||||||
|
log.Println("Request log channel full, dropping log entry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *RequestLogger) processLog(entry RequestLog) {
|
||||||
|
// Broadcast to dashboard
|
||||||
|
l.hub.broadcast <- map[string]interface{}{
|
||||||
|
"type": "request",
|
||||||
|
"channel": "requests",
|
||||||
|
"payload": entry,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert into DB
|
||||||
|
tx, err := l.database.Begin()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to begin transaction for logging: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
// Ensure client exists
|
||||||
|
_, _ = tx.Exec("INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')",
|
||||||
|
entry.ClientID, entry.ClientID)
|
||||||
|
|
||||||
|
// Insert log
|
||||||
|
_, err = tx.Exec(`
|
||||||
|
INSERT INTO llm_requests
|
||||||
|
(timestamp, client_id, provider, model, prompt_tokens, completion_tokens, reasoning_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status, error_message, duration_ms)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
`, entry.Timestamp, entry.ClientID, entry.Provider, entry.Model,
|
||||||
|
entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.TotalTokens,
|
||||||
|
entry.CacheReadTokens, entry.CacheWriteTokens, entry.Cost, entry.HasImages,
|
||||||
|
entry.Status, entry.ErrorMessage, entry.DurationMS)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to insert request log: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update client stats
|
||||||
|
_, _ = tx.Exec(`
|
||||||
|
UPDATE clients SET
|
||||||
|
total_requests = total_requests + 1,
|
||||||
|
total_tokens = total_tokens + ?,
|
||||||
|
total_cost = total_cost + ?,
|
||||||
|
updated_at = CURRENT_TIMESTAMP
|
||||||
|
WHERE client_id = ?
|
||||||
|
`, entry.TotalTokens, entry.Cost, entry.ClientID)
|
||||||
|
|
||||||
|
// Update provider balance
|
||||||
|
if entry.Cost > 0 {
|
||||||
|
_, _ = tx.Exec("UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')",
|
||||||
|
entry.Cost, entry.Provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to commit logging transaction: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
494
internal/server/server.go
Normal file
494
internal/server/server.go
Normal file
@@ -0,0 +1,494 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gophergate/internal/config"
|
||||||
|
"gophergate/internal/db"
|
||||||
|
"gophergate/internal/middleware"
|
||||||
|
"gophergate/internal/models"
|
||||||
|
"gophergate/internal/providers"
|
||||||
|
"gophergate/internal/utils"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
router *gin.Engine
|
||||||
|
cfg *config.Config
|
||||||
|
database *db.DB
|
||||||
|
providers map[string]providers.Provider
|
||||||
|
sessions *SessionManager
|
||||||
|
hub *Hub
|
||||||
|
logger *RequestLogger
|
||||||
|
registry *models.ModelRegistry
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServer(cfg *config.Config, database *db.DB) *Server {
|
||||||
|
router := gin.Default()
|
||||||
|
hub := NewHub()
|
||||||
|
|
||||||
|
s := &Server{
|
||||||
|
router: router,
|
||||||
|
cfg: cfg,
|
||||||
|
database: database,
|
||||||
|
providers: make(map[string]providers.Provider),
|
||||||
|
sessions: NewSessionManager(cfg.KeyBytes, 24*time.Hour),
|
||||||
|
hub: hub,
|
||||||
|
logger: NewRequestLogger(database, hub),
|
||||||
|
registry: &models.ModelRegistry{Providers: make(map[string]models.ProviderInfo)},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch registry in background
|
||||||
|
go func() {
|
||||||
|
registry, err := utils.FetchRegistry()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Warning: Failed to fetch initial model registry: %v\n", err)
|
||||||
|
} else {
|
||||||
|
s.registry = registry
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Initialize providers from DB and Config
|
||||||
|
if err := s.RefreshProviders(); err != nil {
|
||||||
|
fmt.Printf("Warning: Failed to initial refresh providers: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.setupRoutes()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) RefreshProviders() error {
|
||||||
|
var dbConfigs []db.ProviderConfig
|
||||||
|
err := s.database.Select(&dbConfigs, "SELECT * FROM provider_configs")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to fetch provider configs from db: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dbMap := make(map[string]db.ProviderConfig)
|
||||||
|
for _, cfg := range dbConfigs {
|
||||||
|
dbMap[cfg.ID] = cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
providerIDs := []string{"openai", "gemini", "deepseek", "moonshot", "grok"}
|
||||||
|
for _, id := range providerIDs {
|
||||||
|
// Default values from config
|
||||||
|
enabled := false
|
||||||
|
baseURL := ""
|
||||||
|
apiKey := ""
|
||||||
|
|
||||||
|
switch id {
|
||||||
|
case "openai":
|
||||||
|
enabled = s.cfg.Providers.OpenAI.Enabled
|
||||||
|
baseURL = s.cfg.Providers.OpenAI.BaseURL
|
||||||
|
apiKey, _ = s.cfg.GetAPIKey("openai")
|
||||||
|
case "gemini":
|
||||||
|
enabled = s.cfg.Providers.Gemini.Enabled
|
||||||
|
baseURL = s.cfg.Providers.Gemini.BaseURL
|
||||||
|
apiKey, _ = s.cfg.GetAPIKey("gemini")
|
||||||
|
case "deepseek":
|
||||||
|
enabled = s.cfg.Providers.DeepSeek.Enabled
|
||||||
|
baseURL = s.cfg.Providers.DeepSeek.BaseURL
|
||||||
|
apiKey, _ = s.cfg.GetAPIKey("deepseek")
|
||||||
|
case "moonshot":
|
||||||
|
enabled = s.cfg.Providers.Moonshot.Enabled
|
||||||
|
baseURL = s.cfg.Providers.Moonshot.BaseURL
|
||||||
|
apiKey, _ = s.cfg.GetAPIKey("moonshot")
|
||||||
|
case "grok":
|
||||||
|
enabled = s.cfg.Providers.Grok.Enabled
|
||||||
|
baseURL = s.cfg.Providers.Grok.BaseURL
|
||||||
|
apiKey, _ = s.cfg.GetAPIKey("grok")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overrides from DB
|
||||||
|
if dbCfg, ok := dbMap[id]; ok {
|
||||||
|
enabled = dbCfg.Enabled
|
||||||
|
if dbCfg.BaseURL != nil && *dbCfg.BaseURL != "" {
|
||||||
|
baseURL = *dbCfg.BaseURL
|
||||||
|
}
|
||||||
|
if dbCfg.APIKey != nil && *dbCfg.APIKey != "" {
|
||||||
|
key := *dbCfg.APIKey
|
||||||
|
if dbCfg.APIKeyEncrypted {
|
||||||
|
decrypted, err := utils.Decrypt(key, s.cfg.KeyBytes)
|
||||||
|
if err == nil {
|
||||||
|
key = decrypted
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Warning: Failed to decrypt API key for %s: %v\n", id, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
apiKey = key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !enabled {
|
||||||
|
delete(s.providers, id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize provider
|
||||||
|
switch id {
|
||||||
|
case "openai":
|
||||||
|
cfg := s.cfg.Providers.OpenAI
|
||||||
|
cfg.BaseURL = baseURL
|
||||||
|
s.providers["openai"] = providers.NewOpenAIProvider(cfg, apiKey)
|
||||||
|
case "gemini":
|
||||||
|
cfg := s.cfg.Providers.Gemini
|
||||||
|
cfg.BaseURL = baseURL
|
||||||
|
s.providers["gemini"] = providers.NewGeminiProvider(cfg, apiKey)
|
||||||
|
case "deepseek":
|
||||||
|
cfg := s.cfg.Providers.DeepSeek
|
||||||
|
cfg.BaseURL = baseURL
|
||||||
|
s.providers["deepseek"] = providers.NewDeepSeekProvider(cfg, apiKey)
|
||||||
|
case "moonshot":
|
||||||
|
cfg := s.cfg.Providers.Moonshot
|
||||||
|
cfg.BaseURL = baseURL
|
||||||
|
s.providers["moonshot"] = providers.NewMoonshotProvider(cfg, apiKey)
|
||||||
|
case "grok":
|
||||||
|
cfg := s.cfg.Providers.Grok
|
||||||
|
cfg.BaseURL = baseURL
|
||||||
|
s.providers["grok"] = providers.NewGrokProvider(cfg, apiKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) setupRoutes() {
|
||||||
|
s.router.Use(middleware.AuthMiddleware(s.database))
|
||||||
|
|
||||||
|
// Static files
|
||||||
|
s.router.StaticFile("/", "./static/index.html")
|
||||||
|
s.router.StaticFile("/favicon.ico", "./static/favicon.ico")
|
||||||
|
s.router.Static("/css", "./static/css")
|
||||||
|
s.router.Static("/js", "./static/js")
|
||||||
|
s.router.Static("/img", "./static/img")
|
||||||
|
|
||||||
|
// WebSocket
|
||||||
|
s.router.GET("/ws", s.handleWebSocket)
|
||||||
|
|
||||||
|
// API V1 (External LLM Access) - Secured with AuthMiddleware
|
||||||
|
v1 := s.router.Group("/v1")
|
||||||
|
v1.Use(middleware.AuthMiddleware(s.database))
|
||||||
|
{
|
||||||
|
v1.POST("/chat/completions", s.handleChatCompletions)
|
||||||
|
v1.GET("/models", s.handleListModels)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dashboard API Group
|
||||||
|
api := s.router.Group("/api")
|
||||||
|
{
|
||||||
|
api.POST("/auth/login", s.handleLogin)
|
||||||
|
api.GET("/auth/status", s.handleAuthStatus)
|
||||||
|
api.POST("/auth/logout", s.handleLogout)
|
||||||
|
api.POST("/auth/change-password", s.handleChangePassword)
|
||||||
|
|
||||||
|
// Protected dashboard routes (need admin session)
|
||||||
|
admin := api.Group("/")
|
||||||
|
admin.Use(s.adminAuthMiddleware())
|
||||||
|
{
|
||||||
|
admin.GET("/usage/summary", s.handleUsageSummary)
|
||||||
|
admin.GET("/usage/time-series", s.handleTimeSeries)
|
||||||
|
admin.GET("/usage/providers", s.handleProvidersUsage)
|
||||||
|
admin.GET("/usage/clients", s.handleClientsUsage)
|
||||||
|
admin.GET("/usage/detailed", s.handleDetailedUsage)
|
||||||
|
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
|
||||||
|
|
||||||
|
admin.GET("/clients", s.handleGetClients)
|
||||||
|
admin.POST("/clients", s.handleCreateClient)
|
||||||
|
admin.GET("/clients/:id", s.handleGetClient)
|
||||||
|
admin.PUT("/clients/:id", s.handleUpdateClient)
|
||||||
|
admin.DELETE("/clients/:id", s.handleDeleteClient)
|
||||||
|
|
||||||
|
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
|
||||||
|
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
|
||||||
|
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
|
||||||
|
|
||||||
|
admin.GET("/providers", s.handleGetProviders)
|
||||||
|
admin.PUT("/providers/:name", s.handleUpdateProvider)
|
||||||
|
admin.POST("/providers/:name/test", s.handleTestProvider)
|
||||||
|
|
||||||
|
admin.GET("/models", s.handleGetModels)
|
||||||
|
admin.PUT("/models/:id", s.handleUpdateModel)
|
||||||
|
|
||||||
|
admin.GET("/users", s.handleGetUsers)
|
||||||
|
admin.POST("/users", s.handleCreateUser)
|
||||||
|
admin.PUT("/users/:id", s.handleUpdateUser)
|
||||||
|
admin.DELETE("/users/:id", s.handleDeleteUser)
|
||||||
|
|
||||||
|
admin.GET("/system/health", s.handleSystemHealth)
|
||||||
|
admin.GET("/system/metrics", s.handleSystemMetrics)
|
||||||
|
admin.GET("/system/settings", s.handleGetSettings)
|
||||||
|
admin.POST("/system/backup", s.handleCreateBackup)
|
||||||
|
admin.GET("/system/logs", s.handleGetLogs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.router.GET("/health", func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleListModels(c *gin.Context) {
|
||||||
|
type OpenAIModel struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var data []OpenAIModel
|
||||||
|
allowedProviders := map[string]bool{
|
||||||
|
"openai": true,
|
||||||
|
"google": true, // Models from models.dev use 'google' ID for Gemini
|
||||||
|
"deepseek": true,
|
||||||
|
"moonshot": true,
|
||||||
|
"xai": true, // Models from models.dev use 'xai' ID for Grok
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.registry != nil {
|
||||||
|
for pID, pInfo := range s.registry.Providers {
|
||||||
|
if !allowedProviders[pID] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for mID := range pInfo.Models {
|
||||||
|
data = append(data, OpenAIModel{
|
||||||
|
ID: mID,
|
||||||
|
Object: "model",
|
||||||
|
Created: 1700000000,
|
||||||
|
OwnedBy: pID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"object": "list",
|
||||||
|
"data": data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleChatCompletions(c *gin.Context) {
|
||||||
|
startTime := time.Now()
|
||||||
|
var req models.ChatCompletionRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select provider based on model name
|
||||||
|
providerName := "openai" // default
|
||||||
|
if strings.Contains(req.Model, "gemini") {
|
||||||
|
providerName = "gemini"
|
||||||
|
} else if strings.Contains(req.Model, "deepseek") {
|
||||||
|
providerName = "deepseek"
|
||||||
|
} else if strings.Contains(req.Model, "kimi") || strings.Contains(req.Model, "moonshot") {
|
||||||
|
providerName = "moonshot"
|
||||||
|
} else if strings.Contains(req.Model, "grok") {
|
||||||
|
providerName = "grok"
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, ok := s.providers[providerName]
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert ChatCompletionRequest to UnifiedRequest
|
||||||
|
unifiedReq := &models.UnifiedRequest{
|
||||||
|
Model: req.Model,
|
||||||
|
Messages: []models.UnifiedMessage{},
|
||||||
|
Temperature: req.Temperature,
|
||||||
|
TopP: req.TopP,
|
||||||
|
TopK: req.TopK,
|
||||||
|
N: req.N,
|
||||||
|
MaxTokens: req.MaxTokens,
|
||||||
|
PresencePenalty: req.PresencePenalty,
|
||||||
|
FrequencyPenalty: req.FrequencyPenalty,
|
||||||
|
Stream: req.Stream != nil && *req.Stream,
|
||||||
|
Tools: req.Tools,
|
||||||
|
ToolChoice: req.ToolChoice,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Stop sequences
|
||||||
|
if req.Stop != nil {
|
||||||
|
var stop []string
|
||||||
|
if err := json.Unmarshal(req.Stop, &stop); err == nil {
|
||||||
|
unifiedReq.Stop = stop
|
||||||
|
} else {
|
||||||
|
var singleStop string
|
||||||
|
if err := json.Unmarshal(req.Stop, &singleStop); err == nil {
|
||||||
|
unifiedReq.Stop = []string{singleStop}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert messages
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
unifiedMsg := models.UnifiedMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: []models.UnifiedContentPart{},
|
||||||
|
ReasoningContent: msg.ReasoningContent,
|
||||||
|
ToolCalls: msg.ToolCalls,
|
||||||
|
Name: msg.Name,
|
||||||
|
ToolCallID: msg.ToolCallID,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle multimodal content
|
||||||
|
if strContent, ok := msg.Content.(string); ok {
|
||||||
|
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
||||||
|
Type: "text",
|
||||||
|
Text: strContent,
|
||||||
|
})
|
||||||
|
} else if parts, ok := msg.Content.([]interface{}); ok {
|
||||||
|
for _, part := range parts {
|
||||||
|
if partMap, ok := part.(map[string]interface{}); ok {
|
||||||
|
partType, _ := partMap["type"].(string)
|
||||||
|
if partType == "text" {
|
||||||
|
text, _ := partMap["text"].(string)
|
||||||
|
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
||||||
|
Type: "text",
|
||||||
|
Text: text,
|
||||||
|
})
|
||||||
|
} else if partType == "image_url" {
|
||||||
|
if imgURLMap, ok := partMap["image_url"].(map[string]interface{}); ok {
|
||||||
|
url, _ := imgURLMap["url"].(string)
|
||||||
|
imageInput := &models.ImageInput{}
|
||||||
|
if strings.HasPrefix(url, "data:") {
|
||||||
|
mime, data, err := utils.ParseDataURL(url)
|
||||||
|
if err == nil {
|
||||||
|
imageInput.Base64 = data
|
||||||
|
imageInput.MimeType = mime
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
imageInput.URL = url
|
||||||
|
}
|
||||||
|
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
|
||||||
|
Type: "image",
|
||||||
|
Image: imageInput,
|
||||||
|
})
|
||||||
|
unifiedReq.HasImages = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unifiedReq.Messages = append(unifiedReq.Messages, unifiedMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientID := "default"
|
||||||
|
if auth, ok := c.Get("auth"); ok {
|
||||||
|
if authInfo, ok := auth.(models.AuthInfo); ok {
|
||||||
|
unifiedReq.ClientID = authInfo.ClientID
|
||||||
|
clientID = authInfo.ClientID
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unifiedReq.ClientID = clientID
|
||||||
|
}
|
||||||
|
|
||||||
|
if unifiedReq.Stream {
|
||||||
|
ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq)
|
||||||
|
if err != nil {
|
||||||
|
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
|
||||||
|
var lastUsage *models.Usage
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
chunk, ok := <-ch
|
||||||
|
if !ok {
|
||||||
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||||
|
s.logRequest(startTime, clientID, providerName, req.Model, lastUsage, nil, unifiedReq.HasImages)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if chunk.Usage != nil {
|
||||||
|
lastUsage = chunk.Usage
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(chunk)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq)
|
||||||
|
if err != nil {
|
||||||
|
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage, nil, unifiedReq.HasImages)
|
||||||
|
c.JSON(http.StatusOK, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) logRequest(start time.Time, clientID, provider, model string, usage *models.Usage, err error, hasImages bool) {
|
||||||
|
entry := RequestLog{
|
||||||
|
Timestamp: start,
|
||||||
|
ClientID: clientID,
|
||||||
|
Provider: provider,
|
||||||
|
Model: model,
|
||||||
|
Status: "success",
|
||||||
|
DurationMS: time.Since(start).Milliseconds(),
|
||||||
|
HasImages: hasImages,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
entry.Status = "error"
|
||||||
|
entry.ErrorMessage = err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
if usage != nil {
|
||||||
|
entry.PromptTokens = usage.PromptTokens
|
||||||
|
entry.CompletionTokens = usage.CompletionTokens
|
||||||
|
entry.TotalTokens = usage.TotalTokens
|
||||||
|
if usage.ReasoningTokens != nil {
|
||||||
|
entry.ReasoningTokens = *usage.ReasoningTokens
|
||||||
|
}
|
||||||
|
if usage.CacheReadTokens != nil {
|
||||||
|
entry.CacheReadTokens = *usage.CacheReadTokens
|
||||||
|
}
|
||||||
|
if usage.CacheWriteTokens != nil {
|
||||||
|
entry.CacheWriteTokens = *usage.CacheWriteTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate cost using registry
|
||||||
|
entry.Cost = utils.CalculateCost(s.registry, model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.CacheWriteTokens)
|
||||||
|
fmt.Printf("[DEBUG] Request logged: model=%s, prompt=%d, completion=%d, reasoning=%d, cache_read=%d, cost=%f\n",
|
||||||
|
model, entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.CacheReadTokens, entry.Cost)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.LogRequest(entry)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Run() error {
|
||||||
|
go s.hub.Run()
|
||||||
|
s.logger.Start()
|
||||||
|
|
||||||
|
// Start registry refresher
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(24 * time.Hour)
|
||||||
|
for range ticker.C {
|
||||||
|
newRegistry, err := utils.FetchRegistry()
|
||||||
|
if err == nil {
|
||||||
|
s.registry = newRegistry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
|
||||||
|
return s.router.Run(addr)
|
||||||
|
}
|
||||||
155
internal/server/sessions.go
Normal file
155
internal/server/sessions.go
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
DisplayName string `json:"display_name"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
SessionID string `json:"session_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SessionManager struct {
|
||||||
|
sessions map[string]Session
|
||||||
|
mu sync.RWMutex
|
||||||
|
secret []byte
|
||||||
|
ttl time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type sessionPayload struct {
|
||||||
|
SessionID string `json:"session_id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
DisplayName string `json:"display_name"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
Exp int64 `json:"exp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSessionManager(secret []byte, ttl time.Duration) *SessionManager {
|
||||||
|
return &SessionManager{
|
||||||
|
sessions: make(map[string]Session),
|
||||||
|
secret: secret,
|
||||||
|
ttl: ttl,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SessionManager) CreateSession(username, displayName, role string) (string, error) {
|
||||||
|
sessionID := uuid.New().String()
|
||||||
|
now := time.Now()
|
||||||
|
expiresAt := now.Add(m.ttl)
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
m.sessions[sessionID] = Session{
|
||||||
|
Username: username,
|
||||||
|
DisplayName: displayName,
|
||||||
|
Role: role,
|
||||||
|
CreatedAt: now,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
SessionID: sessionID,
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
return m.createSignedToken(sessionID, username, displayName, role, expiresAt.Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SessionManager) createSignedToken(sessionID, username, displayName, role string, exp int64) (string, error) {
|
||||||
|
payload := sessionPayload{
|
||||||
|
SessionID: sessionID,
|
||||||
|
Username: username,
|
||||||
|
DisplayName: displayName,
|
||||||
|
Role: role,
|
||||||
|
Exp: exp,
|
||||||
|
}
|
||||||
|
|
||||||
|
payloadJSON, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON)
|
||||||
|
|
||||||
|
h := hmac.New(sha256.New, m.secret)
|
||||||
|
h.Write(payloadJSON)
|
||||||
|
signature := h.Sum(nil)
|
||||||
|
signatureB64 := base64.RawURLEncoding.EncodeToString(signature)
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s.%s", payloadB64, signatureB64), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SessionManager) ValidateSession(token string) (*Session, string, error) {
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return nil, "", fmt.Errorf("invalid token format")
|
||||||
|
}
|
||||||
|
|
||||||
|
payloadB64 := parts[0]
|
||||||
|
signatureB64 := parts[1]
|
||||||
|
|
||||||
|
payloadJSON, err := base64.RawURLEncoding.DecodeString(payloadB64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := base64.RawURLEncoding.DecodeString(signatureB64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
h := hmac.New(sha256.New, m.secret)
|
||||||
|
h.Write(payloadJSON)
|
||||||
|
if !hmac.Equal(signature, h.Sum(nil)) {
|
||||||
|
return nil, "", fmt.Errorf("invalid signature")
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload sessionPayload
|
||||||
|
if err := json.Unmarshal(payloadJSON, &payload); err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().Unix() > payload.Exp {
|
||||||
|
return nil, "", fmt.Errorf("token expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.RLock()
|
||||||
|
session, ok := m.sessions[payload.SessionID]
|
||||||
|
m.mu.RUnlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, "", fmt.Errorf("session not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &session, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SessionManager) RevokeSession(token string) {
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload sessionPayload
|
||||||
|
if err := json.Unmarshal(payloadJSON, &payload); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
delete(m.sessions, payload.SessionID)
|
||||||
|
m.mu.Unlock()
|
||||||
|
}
|
||||||
107
internal/server/websocket.go
Normal file
107
internal/server/websocket.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
var upgrader = websocket.Upgrader{
|
||||||
|
ReadBufferSize: 1024,
|
||||||
|
WriteBufferSize: 1024,
|
||||||
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
|
return true // In production, refine this
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type Hub struct {
|
||||||
|
clients map[*websocket.Conn]bool
|
||||||
|
broadcast chan interface{}
|
||||||
|
register chan *websocket.Conn
|
||||||
|
unregister chan *websocket.Conn
|
||||||
|
mu sync.Mutex
|
||||||
|
clientCount int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHub() *Hub {
|
||||||
|
return &Hub{
|
||||||
|
clients: make(map[*websocket.Conn]bool),
|
||||||
|
broadcast: make(chan interface{}),
|
||||||
|
register: make(chan *websocket.Conn),
|
||||||
|
unregister: make(chan *websocket.Conn),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hub) Run() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case client := <-h.register:
|
||||||
|
h.mu.Lock()
|
||||||
|
h.clients[client] = true
|
||||||
|
atomic.AddInt32(&h.clientCount, 1)
|
||||||
|
h.mu.Unlock()
|
||||||
|
log.Println("WebSocket client registered")
|
||||||
|
case client := <-h.unregister:
|
||||||
|
h.mu.Lock()
|
||||||
|
if _, ok := h.clients[client]; ok {
|
||||||
|
delete(h.clients, client)
|
||||||
|
client.Close()
|
||||||
|
atomic.AddInt32(&h.clientCount, -1)
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
log.Println("WebSocket client unregistered")
|
||||||
|
case message := <-h.broadcast:
|
||||||
|
h.mu.Lock()
|
||||||
|
for client := range h.clients {
|
||||||
|
err := client.WriteJSON(message)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("WebSocket error: %v", err)
|
||||||
|
client.Close()
|
||||||
|
delete(h.clients, client)
|
||||||
|
atomic.AddInt32(&h.clientCount, -1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hub) GetClientCount() int {
|
||||||
|
return int(atomic.LoadInt32(&h.clientCount))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleWebSocket(c *gin.Context) {
|
||||||
|
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to set websocket upgrade: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.hub.register <- conn
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
s.hub.unregister <- conn
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Initial message
|
||||||
|
conn.WriteJSON(gin.H{
|
||||||
|
"type": "connected",
|
||||||
|
"message": "Connected to GopherGate Dashboard",
|
||||||
|
})
|
||||||
|
|
||||||
|
for {
|
||||||
|
var msg map[string]interface{}
|
||||||
|
err := conn.ReadJSON(&msg)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg["type"] == "ping" {
|
||||||
|
conn.WriteJSON(gin.H{"type": "pong", "payload": gin.H{}})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
71
internal/utils/crypto.go
Normal file
71
internal/utils/crypto.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Encrypt encrypts plain text using AES-GCM with the given 32-byte key.
|
||||||
|
func Encrypt(plainText string, key []byte) (string, error) {
|
||||||
|
if len(key) != 32 {
|
||||||
|
return "", fmt.Errorf("encryption key must be 32 bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce := make([]byte, gcm.NonceSize())
|
||||||
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// The nonce should be prepended to the ciphertext
|
||||||
|
cipherText := gcm.Seal(nonce, nonce, []byte(plainText), nil)
|
||||||
|
return base64.StdEncoding.EncodeToString(cipherText), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt decrypts base64-encoded cipher text using AES-GCM with the given 32-byte key.
|
||||||
|
func Decrypt(encodedCipherText string, key []byte) (string, error) {
|
||||||
|
if len(key) != 32 {
|
||||||
|
return "", fmt.Errorf("encryption key must be 32 bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
cipherText, err := base64.StdEncoding.DecodeString(encodedCipherText)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(key)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
nonceSize := gcm.NonceSize()
|
||||||
|
if len(cipherText) < nonceSize {
|
||||||
|
return "", fmt.Errorf("cipher text too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce, actualCipherText := cipherText[:nonceSize], cipherText[nonceSize:]
|
||||||
|
plainText, err := gcm.Open(nil, nonce, actualCipherText, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(plainText), nil
|
||||||
|
}
|
||||||
69
internal/utils/registry.go
Normal file
69
internal/utils/registry.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gophergate/internal/models"
|
||||||
|
"github.com/go-resty/resty/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const ModelsDevURL = "https://models.dev/api.json"
|
||||||
|
|
||||||
|
func FetchRegistry() (*models.ModelRegistry, error) {
|
||||||
|
log.Printf("Fetching model registry from %s", ModelsDevURL)
|
||||||
|
|
||||||
|
client := resty.New().SetTimeout(10 * time.Second)
|
||||||
|
resp, err := client.R().Get(ModelsDevURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to fetch registry: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.IsSuccess() {
|
||||||
|
return nil, fmt.Errorf("failed to fetch registry: HTTP %d", resp.StatusCode())
|
||||||
|
}
|
||||||
|
|
||||||
|
var providers map[string]models.ProviderInfo
|
||||||
|
if err := json.Unmarshal(resp.Body(), &providers); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal registry: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("Successfully loaded model registry")
|
||||||
|
return &models.ModelRegistry{Providers: providers}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CalculateCost(registry *models.ModelRegistry, modelID string, promptTokens, completionTokens, reasoningTokens, cacheRead, cacheWrite uint32) float64 {
|
||||||
|
meta := registry.FindModel(modelID)
|
||||||
|
if meta == nil || meta.Cost == nil {
|
||||||
|
log.Printf("[DEBUG] CalculateCost: model %s not found or has no cost metadata", modelID)
|
||||||
|
return 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// promptTokens is usually the TOTAL prompt size.
|
||||||
|
// We subtract cacheRead from it to get the uncached part.
|
||||||
|
uncachedTokens := promptTokens
|
||||||
|
if cacheRead > 0 {
|
||||||
|
if cacheRead > promptTokens {
|
||||||
|
uncachedTokens = 0
|
||||||
|
} else {
|
||||||
|
uncachedTokens = promptTokens - cacheRead
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cost := (float64(uncachedTokens) * meta.Cost.Input / 1000000.0) +
|
||||||
|
(float64(completionTokens) * meta.Cost.Output / 1000000.0)
|
||||||
|
|
||||||
|
if meta.Cost.CacheRead != nil {
|
||||||
|
cost += float64(cacheRead) * (*meta.Cost.CacheRead) / 1000000.0
|
||||||
|
}
|
||||||
|
if meta.Cost.CacheWrite != nil {
|
||||||
|
cost += float64(cacheWrite) * (*meta.Cost.CacheWrite) / 1000000.0
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("[DEBUG] CalculateCost: model=%s, uncached=%d, completion=%d, reasoning=%d, cache_read=%d, cache_write=%d, cost=%f (input_rate=%f, output_rate=%f)",
|
||||||
|
modelID, uncachedTokens, completionTokens, reasoningTokens, cacheRead, cacheWrite, cost, meta.Cost.Input, meta.Cost.Output)
|
||||||
|
|
||||||
|
return cost
|
||||||
|
}
|
||||||
19
internal/utils/utils.go
Normal file
19
internal/utils/utils.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ParseDataURL(dataURL string) (string, string, error) {
|
||||||
|
if !strings.HasPrefix(dataURL, "data:") {
|
||||||
|
return "", "", fmt.Errorf("not a data URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(dataURL[5:], ";base64,")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return "", "", fmt.Errorf("invalid data URL format")
|
||||||
|
}
|
||||||
|
|
||||||
|
return parts[0], parts[1], nil
|
||||||
|
}
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
-- Migration: add billing_mode to provider_configs
|
|
||||||
-- Adds a billing_mode TEXT column with default 'prepaid'
|
|
||||||
-- After applying, set Gemini to postpaid with:
|
|
||||||
-- UPDATE provider_configs SET billing_mode = 'postpaid' WHERE id = 'gemini';
|
|
||||||
|
|
||||||
BEGIN TRANSACTION;
|
|
||||||
|
|
||||||
ALTER TABLE provider_configs ADD COLUMN billing_mode TEXT DEFAULT 'prepaid';
|
|
||||||
|
|
||||||
COMMIT;
|
|
||||||
|
|
||||||
-- NOTE: If you use a production SQLite file, run the following to set Gemini to postpaid:
|
|
||||||
-- sqlite3 /path/to/db.sqlite "UPDATE provider_configs SET billing_mode='postpaid' WHERE id='gemini';"
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
-- Migration: add composite indexes for query performance
|
|
||||||
-- Adds three composite indexes:
|
|
||||||
-- 1. idx_llm_requests_client_timestamp on llm_requests(client_id, timestamp)
|
|
||||||
-- 2. idx_llm_requests_provider_timestamp on llm_requests(provider, timestamp)
|
|
||||||
-- 3. idx_model_configs_provider_id on model_configs(provider_id)
|
|
||||||
|
|
||||||
BEGIN TRANSACTION;
|
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id);
|
|
||||||
|
|
||||||
COMMIT;
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
max_width = 120
|
|
||||||
use_field_init_shorthand = true
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
use axum::{extract::FromRequestParts, http::request::Parts};
|
|
||||||
|
|
||||||
use crate::errors::AppError;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct AuthInfo {
|
|
||||||
pub token: String,
|
|
||||||
pub client_id: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct AuthenticatedClient {
|
|
||||||
pub info: AuthInfo,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> FromRequestParts<S> for AuthenticatedClient
|
|
||||||
where
|
|
||||||
S: Send + Sync,
|
|
||||||
{
|
|
||||||
type Rejection = AppError;
|
|
||||||
|
|
||||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
|
||||||
// Retrieve AuthInfo from request extensions, where it was placed by rate_limit_middleware
|
|
||||||
let info = parts
|
|
||||||
.extensions
|
|
||||||
.get::<AuthInfo>()
|
|
||||||
.cloned()
|
|
||||||
.ok_or_else(|| AppError::AuthError("Authentication info not found in request".to_string()))?;
|
|
||||||
|
|
||||||
Ok(AuthenticatedClient { info })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::ops::Deref for AuthenticatedClient {
|
|
||||||
type Target = AuthInfo;
|
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.info
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn validate_token(token: &str, valid_tokens: &[String]) -> bool {
|
|
||||||
// Simple validation against list of tokens
|
|
||||||
// In production, use proper token validation (JWT, database lookup, etc.)
|
|
||||||
valid_tokens.contains(&token.to_string())
|
|
||||||
}
|
|
||||||
@@ -1,304 +0,0 @@
|
|||||||
//! Client management for LLM proxy
|
|
||||||
//!
|
|
||||||
//! This module handles:
|
|
||||||
//! 1. Client registration and management
|
|
||||||
//! 2. Client usage tracking
|
|
||||||
//! 3. Client rate limit configuration
|
|
||||||
|
|
||||||
use anyhow::Result;
|
|
||||||
use chrono::{DateTime, Utc};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use sqlx::{Row, SqlitePool};
|
|
||||||
use tracing::{info, warn};
|
|
||||||
|
|
||||||
/// Client information
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct Client {
|
|
||||||
pub id: i64,
|
|
||||||
pub client_id: String,
|
|
||||||
pub name: String,
|
|
||||||
pub description: String,
|
|
||||||
pub created_at: DateTime<Utc>,
|
|
||||||
pub updated_at: DateTime<Utc>,
|
|
||||||
pub is_active: bool,
|
|
||||||
pub rate_limit_per_minute: i64,
|
|
||||||
pub total_requests: i64,
|
|
||||||
pub total_tokens: i64,
|
|
||||||
pub total_cost: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Client creation request
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct CreateClientRequest {
|
|
||||||
pub client_id: String,
|
|
||||||
pub name: String,
|
|
||||||
pub description: String,
|
|
||||||
pub rate_limit_per_minute: Option<i64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Client update request
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct UpdateClientRequest {
|
|
||||||
pub name: Option<String>,
|
|
||||||
pub description: Option<String>,
|
|
||||||
pub is_active: Option<bool>,
|
|
||||||
pub rate_limit_per_minute: Option<i64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Client manager for database operations
|
|
||||||
pub struct ClientManager {
|
|
||||||
db_pool: SqlitePool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ClientManager {
|
|
||||||
pub fn new(db_pool: SqlitePool) -> Self {
|
|
||||||
Self { db_pool }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a new client
|
|
||||||
pub async fn create_client(&self, request: CreateClientRequest) -> Result<Client> {
|
|
||||||
let rate_limit = request.rate_limit_per_minute.unwrap_or(60);
|
|
||||||
|
|
||||||
// First insert the client
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
INSERT INTO clients (client_id, name, description, rate_limit_per_minute)
|
|
||||||
VALUES (?, ?, ?, ?)
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(&request.client_id)
|
|
||||||
.bind(&request.name)
|
|
||||||
.bind(&request.description)
|
|
||||||
.bind(rate_limit)
|
|
||||||
.execute(&self.db_pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Then fetch the created client
|
|
||||||
let client = self
|
|
||||||
.get_client(&request.client_id)
|
|
||||||
.await?
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Failed to retrieve created client"))?;
|
|
||||||
|
|
||||||
info!("Created client: {} ({})", client.name, client.client_id);
|
|
||||||
Ok(client)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a client by ID
|
|
||||||
pub async fn get_client(&self, client_id: &str) -> Result<Option<Client>> {
|
|
||||||
let row = sqlx::query(
|
|
||||||
r#"
|
|
||||||
SELECT
|
|
||||||
id, client_id, name, description,
|
|
||||||
created_at, updated_at, is_active,
|
|
||||||
rate_limit_per_minute, total_requests, total_tokens, total_cost
|
|
||||||
FROM clients
|
|
||||||
WHERE client_id = ?
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(client_id)
|
|
||||||
.fetch_optional(&self.db_pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if let Some(row) = row {
|
|
||||||
let client = Client {
|
|
||||||
id: row.get("id"),
|
|
||||||
client_id: row.get("client_id"),
|
|
||||||
name: row.get("name"),
|
|
||||||
description: row.get("description"),
|
|
||||||
created_at: row.get("created_at"),
|
|
||||||
updated_at: row.get("updated_at"),
|
|
||||||
is_active: row.get("is_active"),
|
|
||||||
rate_limit_per_minute: row.get("rate_limit_per_minute"),
|
|
||||||
total_requests: row.get("total_requests"),
|
|
||||||
total_tokens: row.get("total_tokens"),
|
|
||||||
total_cost: row.get("total_cost"),
|
|
||||||
};
|
|
||||||
Ok(Some(client))
|
|
||||||
} else {
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Update a client
|
|
||||||
pub async fn update_client(&self, client_id: &str, request: UpdateClientRequest) -> Result<Option<Client>> {
|
|
||||||
// First, get the current client to check if it exists
|
|
||||||
let current_client = self.get_client(client_id).await?;
|
|
||||||
if current_client.is_none() {
|
|
||||||
return Ok(None);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build update query dynamically based on provided fields
|
|
||||||
let mut query_builder = sqlx::QueryBuilder::new("UPDATE clients SET ");
|
|
||||||
let mut has_updates = false;
|
|
||||||
|
|
||||||
if let Some(name) = &request.name {
|
|
||||||
query_builder.push("name = ");
|
|
||||||
query_builder.push_bind(name);
|
|
||||||
has_updates = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(description) = &request.description {
|
|
||||||
if has_updates {
|
|
||||||
query_builder.push(", ");
|
|
||||||
}
|
|
||||||
query_builder.push("description = ");
|
|
||||||
query_builder.push_bind(description);
|
|
||||||
has_updates = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(is_active) = request.is_active {
|
|
||||||
if has_updates {
|
|
||||||
query_builder.push(", ");
|
|
||||||
}
|
|
||||||
query_builder.push("is_active = ");
|
|
||||||
query_builder.push_bind(is_active);
|
|
||||||
has_updates = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(rate_limit) = request.rate_limit_per_minute {
|
|
||||||
if has_updates {
|
|
||||||
query_builder.push(", ");
|
|
||||||
}
|
|
||||||
query_builder.push("rate_limit_per_minute = ");
|
|
||||||
query_builder.push_bind(rate_limit);
|
|
||||||
has_updates = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Always update the updated_at timestamp
|
|
||||||
if has_updates {
|
|
||||||
query_builder.push(", ");
|
|
||||||
}
|
|
||||||
query_builder.push("updated_at = CURRENT_TIMESTAMP");
|
|
||||||
|
|
||||||
if !has_updates {
|
|
||||||
// No updates to make
|
|
||||||
return self.get_client(client_id).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
query_builder.push(" WHERE client_id = ");
|
|
||||||
query_builder.push_bind(client_id);
|
|
||||||
|
|
||||||
let query = query_builder.build();
|
|
||||||
query.execute(&self.db_pool).await?;
|
|
||||||
|
|
||||||
// Fetch the updated client
|
|
||||||
let updated_client = self.get_client(client_id).await?;
|
|
||||||
|
|
||||||
if updated_client.is_some() {
|
|
||||||
info!("Updated client: {}", client_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(updated_client)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// List all clients
|
|
||||||
pub async fn list_clients(&self, limit: Option<i64>, offset: Option<i64>) -> Result<Vec<Client>> {
|
|
||||||
let limit = limit.unwrap_or(100);
|
|
||||||
let offset = offset.unwrap_or(0);
|
|
||||||
|
|
||||||
let rows = sqlx::query(
|
|
||||||
r#"
|
|
||||||
SELECT
|
|
||||||
id, client_id, name, description,
|
|
||||||
created_at, updated_at, is_active,
|
|
||||||
rate_limit_per_minute, total_requests, total_tokens, total_cost
|
|
||||||
FROM clients
|
|
||||||
ORDER BY created_at DESC
|
|
||||||
LIMIT ? OFFSET ?
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(limit)
|
|
||||||
.bind(offset)
|
|
||||||
.fetch_all(&self.db_pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let mut clients = Vec::new();
|
|
||||||
for row in rows {
|
|
||||||
let client = Client {
|
|
||||||
id: row.get("id"),
|
|
||||||
client_id: row.get("client_id"),
|
|
||||||
name: row.get("name"),
|
|
||||||
description: row.get("description"),
|
|
||||||
created_at: row.get("created_at"),
|
|
||||||
updated_at: row.get("updated_at"),
|
|
||||||
is_active: row.get("is_active"),
|
|
||||||
rate_limit_per_minute: row.get("rate_limit_per_minute"),
|
|
||||||
total_requests: row.get("total_requests"),
|
|
||||||
total_tokens: row.get("total_tokens"),
|
|
||||||
total_cost: row.get("total_cost"),
|
|
||||||
};
|
|
||||||
clients.push(client);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(clients)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Delete a client
|
|
||||||
pub async fn delete_client(&self, client_id: &str) -> Result<bool> {
|
|
||||||
let result = sqlx::query("DELETE FROM clients WHERE client_id = ?")
|
|
||||||
.bind(client_id)
|
|
||||||
.execute(&self.db_pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let deleted = result.rows_affected() > 0;
|
|
||||||
|
|
||||||
if deleted {
|
|
||||||
info!("Deleted client: {}", client_id);
|
|
||||||
} else {
|
|
||||||
warn!("Client not found for deletion: {}", client_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(deleted)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Update client usage statistics after a request
|
|
||||||
pub async fn update_client_usage(&self, client_id: &str, tokens: i64, cost: f64) -> Result<()> {
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
UPDATE clients
|
|
||||||
SET
|
|
||||||
total_requests = total_requests + 1,
|
|
||||||
total_tokens = total_tokens + ?,
|
|
||||||
total_cost = total_cost + ?,
|
|
||||||
updated_at = CURRENT_TIMESTAMP
|
|
||||||
WHERE client_id = ?
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(tokens)
|
|
||||||
.bind(cost)
|
|
||||||
.bind(client_id)
|
|
||||||
.execute(&self.db_pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get client usage statistics
|
|
||||||
pub async fn get_client_usage(&self, client_id: &str) -> Result<Option<(i64, i64, f64)>> {
|
|
||||||
let row = sqlx::query(
|
|
||||||
r#"
|
|
||||||
SELECT total_requests, total_tokens, total_cost
|
|
||||||
FROM clients
|
|
||||||
WHERE client_id = ?
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(client_id)
|
|
||||||
.fetch_optional(&self.db_pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if let Some(row) = row {
|
|
||||||
let total_requests: i64 = row.get("total_requests");
|
|
||||||
let total_tokens: i64 = row.get("total_tokens");
|
|
||||||
let total_cost: f64 = row.get("total_cost");
|
|
||||||
Ok(Some((total_requests, total_tokens, total_cost)))
|
|
||||||
} else {
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if a client exists and is active
|
|
||||||
pub async fn validate_client(&self, client_id: &str) -> Result<bool> {
|
|
||||||
let client = self.get_client(client_id).await?;
|
|
||||||
Ok(client.map(|c| c.is_active).unwrap_or(false))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,260 +0,0 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use base64::{Engine as _};
|
|
||||||
use config::{Config, File, FileFormat};
|
|
||||||
use hex;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::path::PathBuf;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ServerConfig {
|
|
||||||
pub port: u16,
|
|
||||||
pub host: String,
|
|
||||||
#[serde(deserialize_with = "deserialize_vec_or_string")]
|
|
||||||
pub auth_tokens: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct DatabaseConfig {
|
|
||||||
pub path: PathBuf,
|
|
||||||
pub max_connections: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ProviderConfig {
|
|
||||||
pub openai: OpenAIConfig,
|
|
||||||
pub gemini: GeminiConfig,
|
|
||||||
pub deepseek: DeepSeekConfig,
|
|
||||||
pub grok: GrokConfig,
|
|
||||||
pub ollama: OllamaConfig,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct OpenAIConfig {
|
|
||||||
pub api_key_env: String,
|
|
||||||
pub base_url: String,
|
|
||||||
pub default_model: String,
|
|
||||||
pub enabled: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct GeminiConfig {
|
|
||||||
pub api_key_env: String,
|
|
||||||
pub base_url: String,
|
|
||||||
pub default_model: String,
|
|
||||||
pub enabled: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct DeepSeekConfig {
|
|
||||||
pub api_key_env: String,
|
|
||||||
pub base_url: String,
|
|
||||||
pub default_model: String,
|
|
||||||
pub enabled: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct GrokConfig {
|
|
||||||
pub api_key_env: String,
|
|
||||||
pub base_url: String,
|
|
||||||
pub default_model: String,
|
|
||||||
pub enabled: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct OllamaConfig {
|
|
||||||
pub base_url: String,
|
|
||||||
pub enabled: bool,
|
|
||||||
#[serde(deserialize_with = "deserialize_vec_or_string")]
|
|
||||||
pub models: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ModelMappingConfig {
|
|
||||||
pub patterns: Vec<(String, String)>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct PricingConfig {
|
|
||||||
pub openai: Vec<ModelPricing>,
|
|
||||||
pub gemini: Vec<ModelPricing>,
|
|
||||||
pub deepseek: Vec<ModelPricing>,
|
|
||||||
pub grok: Vec<ModelPricing>,
|
|
||||||
pub ollama: Vec<ModelPricing>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ModelPricing {
|
|
||||||
pub model: String,
|
|
||||||
pub prompt_tokens_per_million: f64,
|
|
||||||
pub completion_tokens_per_million: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct AppConfig {
|
|
||||||
pub server: ServerConfig,
|
|
||||||
pub database: DatabaseConfig,
|
|
||||||
pub providers: ProviderConfig,
|
|
||||||
pub model_mapping: ModelMappingConfig,
|
|
||||||
pub pricing: PricingConfig,
|
|
||||||
pub config_path: Option<PathBuf>,
|
|
||||||
pub encryption_key: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AppConfig {
|
|
||||||
pub async fn load() -> Result<Arc<Self>> {
|
|
||||||
Self::load_from_path(None).await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Load configuration from a specific path (for testing)
|
|
||||||
pub async fn load_from_path(config_path: Option<PathBuf>) -> Result<Arc<Self>> {
|
|
||||||
// Load configuration from multiple sources
|
|
||||||
let mut config_builder = Config::builder();
|
|
||||||
|
|
||||||
// Default configuration
|
|
||||||
config_builder = config_builder
|
|
||||||
.set_default("server.port", 8080)?
|
|
||||||
.set_default("server.host", "0.0.0.0")?
|
|
||||||
.set_default("server.auth_tokens", Vec::<String>::new())?
|
|
||||||
.set_default("database.path", "./data/llm_proxy.db")?
|
|
||||||
.set_default("database.max_connections", 10)?
|
|
||||||
.set_default("providers.openai.api_key_env", "OPENAI_API_KEY")?
|
|
||||||
.set_default("providers.openai.base_url", "https://api.openai.com/v1")?
|
|
||||||
.set_default("providers.openai.default_model", "gpt-4o")?
|
|
||||||
.set_default("providers.openai.enabled", true)?
|
|
||||||
.set_default("providers.gemini.api_key_env", "GEMINI_API_KEY")?
|
|
||||||
.set_default(
|
|
||||||
"providers.gemini.base_url",
|
|
||||||
"https://generativelanguage.googleapis.com/v1",
|
|
||||||
)?
|
|
||||||
.set_default("providers.gemini.default_model", "gemini-2.0-flash")?
|
|
||||||
.set_default("providers.gemini.enabled", true)?
|
|
||||||
.set_default("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")?
|
|
||||||
.set_default("providers.deepseek.base_url", "https://api.deepseek.com")?
|
|
||||||
.set_default("providers.deepseek.default_model", "deepseek-reasoner")?
|
|
||||||
.set_default("providers.deepseek.enabled", true)?
|
|
||||||
.set_default("providers.grok.api_key_env", "GROK_API_KEY")?
|
|
||||||
.set_default("providers.grok.base_url", "https://api.x.ai/v1")?
|
|
||||||
.set_default("providers.grok.default_model", "grok-beta")?
|
|
||||||
.set_default("providers.grok.enabled", true)?
|
|
||||||
.set_default("providers.ollama.base_url", "http://localhost:11434/v1")?
|
|
||||||
.set_default("providers.ollama.enabled", false)?
|
|
||||||
.set_default("providers.ollama.models", Vec::<String>::new())?
|
|
||||||
.set_default("encryption_key", "")?;
|
|
||||||
|
|
||||||
// Load from config file if exists
|
|
||||||
// Priority: explicit path arg > LLM_PROXY__CONFIG_PATH env var > ./config.toml
|
|
||||||
let config_path = config_path
|
|
||||||
.or_else(|| std::env::var("LLM_PROXY__CONFIG_PATH").ok().map(PathBuf::from))
|
|
||||||
.unwrap_or_else(|| {
|
|
||||||
std::env::current_dir()
|
|
||||||
.unwrap_or_else(|_| PathBuf::from("."))
|
|
||||||
.join("config.toml")
|
|
||||||
});
|
|
||||||
if config_path.exists() {
|
|
||||||
config_builder = config_builder.add_source(File::from(config_path.clone()).format(FileFormat::Toml));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load from .env file
|
|
||||||
dotenvy::dotenv().ok();
|
|
||||||
|
|
||||||
// Load from environment variables (with prefix "LLM_PROXY_")
|
|
||||||
config_builder = config_builder.add_source(
|
|
||||||
config::Environment::with_prefix("LLM_PROXY")
|
|
||||||
.separator("__")
|
|
||||||
.try_parsing(true),
|
|
||||||
);
|
|
||||||
|
|
||||||
let config = config_builder.build()?;
|
|
||||||
|
|
||||||
// Deserialize configuration
|
|
||||||
let server: ServerConfig = config.get("server")?;
|
|
||||||
let database: DatabaseConfig = config.get("database")?;
|
|
||||||
let providers: ProviderConfig = config.get("providers")?;
|
|
||||||
let encryption_key: String = config.get("encryption_key")?;
|
|
||||||
|
|
||||||
// Validate encryption key length (must be 32 bytes after hex or base64 decoding)
|
|
||||||
if encryption_key.is_empty() {
|
|
||||||
anyhow::bail!("Encryption key is required (LLM_PROXY__ENCRYPTION_KEY environment variable)");
|
|
||||||
}
|
|
||||||
// Try hex decode first, then base64
|
|
||||||
let key_bytes = hex::decode(&encryption_key)
|
|
||||||
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&encryption_key))
|
|
||||||
.map_err(|e| anyhow::anyhow!("Encryption key must be hex or base64 encoded: {}", e))?;
|
|
||||||
if key_bytes.len() != 32 {
|
|
||||||
anyhow::bail!("Encryption key must be 32 bytes (256 bits), got {} bytes", key_bytes.len());
|
|
||||||
}
|
|
||||||
|
|
||||||
// For now, use empty model mapping and pricing (will be populated later)
|
|
||||||
let model_mapping = ModelMappingConfig { patterns: vec![] };
|
|
||||||
let pricing = PricingConfig {
|
|
||||||
openai: vec![],
|
|
||||||
gemini: vec![],
|
|
||||||
deepseek: vec![],
|
|
||||||
grok: vec![],
|
|
||||||
ollama: vec![],
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Arc::new(AppConfig {
|
|
||||||
server,
|
|
||||||
database,
|
|
||||||
providers,
|
|
||||||
model_mapping,
|
|
||||||
pricing,
|
|
||||||
config_path: Some(config_path),
|
|
||||||
encryption_key,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_api_key(&self, provider: &str) -> Result<String> {
|
|
||||||
let env_var = match provider {
|
|
||||||
"openai" => &self.providers.openai.api_key_env,
|
|
||||||
"gemini" => &self.providers.gemini.api_key_env,
|
|
||||||
"deepseek" => &self.providers.deepseek.api_key_env,
|
|
||||||
"grok" => &self.providers.grok.api_key_env,
|
|
||||||
_ => return Err(anyhow::anyhow!("Unknown provider: {}", provider)),
|
|
||||||
};
|
|
||||||
|
|
||||||
std::env::var(env_var).map_err(|_| anyhow::anyhow!("Environment variable {} not set for {}", env_var, provider))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Helper function to deserialize a Vec<String> from either a sequence or a comma-separated string
|
|
||||||
fn deserialize_vec_or_string<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
|
|
||||||
where
|
|
||||||
D: serde::Deserializer<'de>,
|
|
||||||
{
|
|
||||||
struct VecOrString;
|
|
||||||
|
|
||||||
impl<'de> serde::de::Visitor<'de> for VecOrString {
|
|
||||||
type Value = Vec<String>;
|
|
||||||
|
|
||||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
||||||
formatter.write_str("a sequence or a comma-separated string")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
|
|
||||||
where
|
|
||||||
E: serde::de::Error,
|
|
||||||
{
|
|
||||||
Ok(value
|
|
||||||
.split(',')
|
|
||||||
.map(|s| s.trim().to_string())
|
|
||||||
.filter(|s| !s.is_empty())
|
|
||||||
.collect())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
|
|
||||||
where
|
|
||||||
S: serde::de::SeqAccess<'de>,
|
|
||||||
{
|
|
||||||
let mut vec = Vec::new();
|
|
||||||
while let Some(element) = seq.next_element()? {
|
|
||||||
vec.push(element);
|
|
||||||
}
|
|
||||||
Ok(vec)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
deserializer.deserialize_any(VecOrString)
|
|
||||||
}
|
|
||||||
@@ -1,229 +0,0 @@
|
|||||||
use axum::{extract::State, http::{HeaderMap, HeaderValue}, response::{Json, IntoResponse}};
|
|
||||||
use bcrypt;
|
|
||||||
use serde::Deserialize;
|
|
||||||
use sqlx::Row;
|
|
||||||
use tracing::warn;
|
|
||||||
|
|
||||||
use super::{ApiResponse, DashboardState};
|
|
||||||
|
|
||||||
// Authentication handlers
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub(super) struct LoginRequest {
|
|
||||||
pub(super) username: String,
|
|
||||||
pub(super) password: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_login(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
Json(payload): Json<LoginRequest>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
let user_result = sqlx::query(
|
|
||||||
"SELECT username, password_hash, display_name, role, must_change_password FROM users WHERE username = ?",
|
|
||||||
)
|
|
||||||
.bind(&payload.username)
|
|
||||||
.fetch_optional(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match user_result {
|
|
||||||
Ok(Some(row)) => {
|
|
||||||
let hash = row.get::<String, _>("password_hash");
|
|
||||||
if bcrypt::verify(&payload.password, &hash).unwrap_or(false) {
|
|
||||||
let username = row.get::<String, _>("username");
|
|
||||||
let role = row.get::<String, _>("role");
|
|
||||||
let display_name = row
|
|
||||||
.get::<Option<String>, _>("display_name")
|
|
||||||
.unwrap_or_else(|| username.clone());
|
|
||||||
let must_change_password = row.get::<bool, _>("must_change_password");
|
|
||||||
let token = state
|
|
||||||
.session_manager
|
|
||||||
.create_session(username.clone(), role.clone())
|
|
||||||
.await;
|
|
||||||
Json(ApiResponse::success(serde_json::json!({
|
|
||||||
"token": token,
|
|
||||||
"must_change_password": must_change_password,
|
|
||||||
"user": {
|
|
||||||
"username": username,
|
|
||||||
"name": display_name,
|
|
||||||
"role": role
|
|
||||||
}
|
|
||||||
})))
|
|
||||||
} else {
|
|
||||||
Json(ApiResponse::error("Invalid username or password".to_string()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(None) => Json(ApiResponse::error("Invalid username or password".to_string())),
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Database error during login: {}", e);
|
|
||||||
Json(ApiResponse::error("Login failed due to system error".to_string()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_auth_status(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
headers: axum::http::HeaderMap,
|
|
||||||
) -> impl IntoResponse {
|
|
||||||
let token = headers
|
|
||||||
.get("Authorization")
|
|
||||||
.and_then(|v| v.to_str().ok())
|
|
||||||
.and_then(|v| v.strip_prefix("Bearer "));
|
|
||||||
|
|
||||||
if let Some(token) = token
|
|
||||||
&& let Some((session, new_token)) = state.session_manager.validate_session_with_refresh(token).await
|
|
||||||
{
|
|
||||||
// Look up display_name from DB
|
|
||||||
let display_name = sqlx::query_scalar::<_, Option<String>>(
|
|
||||||
"SELECT display_name FROM users WHERE username = ?",
|
|
||||||
)
|
|
||||||
.bind(&session.username)
|
|
||||||
.fetch_optional(&state.app_state.db_pool)
|
|
||||||
.await
|
|
||||||
.ok()
|
|
||||||
.flatten()
|
|
||||||
.flatten()
|
|
||||||
.unwrap_or_else(|| session.username.clone());
|
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
|
||||||
if let Some(refreshed_token) = new_token {
|
|
||||||
if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) {
|
|
||||||
headers.insert("X-Refreshed-Token", header_value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return (headers, Json(ApiResponse::success(serde_json::json!({
|
|
||||||
"authenticated": true,
|
|
||||||
"user": {
|
|
||||||
"username": session.username,
|
|
||||||
"name": display_name,
|
|
||||||
"role": session.role
|
|
||||||
}
|
|
||||||
}))));
|
|
||||||
}
|
|
||||||
|
|
||||||
(HeaderMap::new(), Json(ApiResponse::error("Not authenticated".to_string())))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub(super) struct ChangePasswordRequest {
|
|
||||||
pub(super) current_password: String,
|
|
||||||
pub(super) new_password: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_change_password(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
headers: axum::http::HeaderMap,
|
|
||||||
Json(payload): Json<ChangePasswordRequest>,
|
|
||||||
) -> impl IntoResponse {
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
// Extract the authenticated user from the session token
|
|
||||||
let token = headers
|
|
||||||
.get("Authorization")
|
|
||||||
.and_then(|v| v.to_str().ok())
|
|
||||||
.and_then(|v| v.strip_prefix("Bearer "));
|
|
||||||
|
|
||||||
let (session, new_token) = match token {
|
|
||||||
Some(t) => match state.session_manager.validate_session_with_refresh(t).await {
|
|
||||||
Some((session, new_token)) => (Some(session), new_token),
|
|
||||||
None => (None, None),
|
|
||||||
},
|
|
||||||
None => (None, None),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut response_headers = HeaderMap::new();
|
|
||||||
if let Some(refreshed_token) = new_token {
|
|
||||||
if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) {
|
|
||||||
response_headers.insert("X-Refreshed-Token", header_value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let username = match session {
|
|
||||||
Some(s) => s.username,
|
|
||||||
None => return (response_headers, Json(ApiResponse::error("Not authenticated".to_string()))),
|
|
||||||
};
|
|
||||||
|
|
||||||
let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = ?")
|
|
||||||
.bind(&username)
|
|
||||||
.fetch_one(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match user_result {
|
|
||||||
Ok(row) => {
|
|
||||||
let hash = row.get::<String, _>("password_hash");
|
|
||||||
if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) {
|
|
||||||
let new_hash = match bcrypt::hash(&payload.new_password, 12) {
|
|
||||||
Ok(h) => h,
|
|
||||||
Err(_) => return (response_headers, Json(ApiResponse::error("Failed to hash new password".to_string()))),
|
|
||||||
};
|
|
||||||
|
|
||||||
let update_result = sqlx::query(
|
|
||||||
"UPDATE users SET password_hash = ?, must_change_password = FALSE WHERE username = ?",
|
|
||||||
)
|
|
||||||
.bind(new_hash)
|
|
||||||
.bind(&username)
|
|
||||||
.execute(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match update_result {
|
|
||||||
Ok(_) => (response_headers, Json(ApiResponse::success(
|
|
||||||
serde_json::json!({ "message": "Password updated successfully" }),
|
|
||||||
))),
|
|
||||||
Err(e) => (response_headers, Json(ApiResponse::error(format!("Failed to update database: {}", e)))),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
(response_headers, Json(ApiResponse::error("Current password incorrect".to_string())))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => (response_headers, Json(ApiResponse::error(format!("User not found: {}", e)))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_logout(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
headers: axum::http::HeaderMap,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let token = headers
|
|
||||||
.get("Authorization")
|
|
||||||
.and_then(|v| v.to_str().ok())
|
|
||||||
.and_then(|v| v.strip_prefix("Bearer "));
|
|
||||||
|
|
||||||
if let Some(token) = token {
|
|
||||||
state.session_manager.revoke_session(token).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!({ "message": "Logged out" })))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Helper: Extract and validate a session from the Authorization header.
|
|
||||||
/// Returns the Session and optional new token if refreshed, or an error response.
|
|
||||||
pub(super) async fn extract_session(
|
|
||||||
state: &DashboardState,
|
|
||||||
headers: &axum::http::HeaderMap,
|
|
||||||
) -> Result<(super::sessions::Session, Option<String>), Json<ApiResponse<serde_json::Value>>> {
|
|
||||||
let token = headers
|
|
||||||
.get("Authorization")
|
|
||||||
.and_then(|v| v.to_str().ok())
|
|
||||||
.and_then(|v| v.strip_prefix("Bearer "));
|
|
||||||
|
|
||||||
match token {
|
|
||||||
Some(t) => match state.session_manager.validate_session_with_refresh(t).await {
|
|
||||||
Some((session, new_token)) => Ok((session, new_token)),
|
|
||||||
None => Err(Json(ApiResponse::error("Session expired or invalid".to_string()))),
|
|
||||||
},
|
|
||||||
None => Err(Json(ApiResponse::error("Not authenticated".to_string()))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Helper: Extract session and require admin role.
|
|
||||||
/// Returns session and optional new token if refreshed.
|
|
||||||
pub(super) async fn require_admin(
|
|
||||||
state: &DashboardState,
|
|
||||||
headers: &axum::http::HeaderMap,
|
|
||||||
) -> Result<(super::sessions::Session, Option<String>), Json<ApiResponse<serde_json::Value>>> {
|
|
||||||
let (session, new_token) = extract_session(state, headers).await?;
|
|
||||||
if session.role != "admin" {
|
|
||||||
return Err(Json(ApiResponse::error("Admin access required".to_string())));
|
|
||||||
}
|
|
||||||
Ok((session, new_token))
|
|
||||||
}
|
|
||||||
@@ -1,518 +0,0 @@
|
|||||||
use axum::{
|
|
||||||
extract::{Path, State},
|
|
||||||
response::Json,
|
|
||||||
};
|
|
||||||
use chrono;
|
|
||||||
use rand::Rng;
|
|
||||||
use serde::Deserialize;
|
|
||||||
use serde_json;
|
|
||||||
use sqlx::Row;
|
|
||||||
use tracing::warn;
|
|
||||||
use uuid;
|
|
||||||
|
|
||||||
use super::{ApiResponse, DashboardState};
|
|
||||||
|
|
||||||
/// Generate a random API token: sk-{48 hex chars}
|
|
||||||
fn generate_token() -> String {
|
|
||||||
let mut rng = rand::rng();
|
|
||||||
let bytes: Vec<u8> = (0..24).map(|_| rng.random::<u8>()).collect();
|
|
||||||
format!("sk-{}", hex::encode(bytes))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub(super) struct CreateClientRequest {
|
|
||||||
pub(super) name: String,
|
|
||||||
pub(super) client_id: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub(super) struct UpdateClientPayload {
|
|
||||||
pub(super) name: Option<String>,
|
|
||||||
pub(super) description: Option<String>,
|
|
||||||
pub(super) is_active: Option<bool>,
|
|
||||||
pub(super) rate_limit_per_minute: Option<i64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_get_clients(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
let result = sqlx::query(
|
|
||||||
r#"
|
|
||||||
SELECT
|
|
||||||
client_id as id,
|
|
||||||
name,
|
|
||||||
description,
|
|
||||||
created_at,
|
|
||||||
total_requests,
|
|
||||||
total_tokens,
|
|
||||||
total_cost,
|
|
||||||
is_active,
|
|
||||||
rate_limit_per_minute
|
|
||||||
FROM clients
|
|
||||||
ORDER BY created_at DESC
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.fetch_all(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(rows) => {
|
|
||||||
let clients: Vec<serde_json::Value> = rows
|
|
||||||
.into_iter()
|
|
||||||
.map(|row| {
|
|
||||||
serde_json::json!({
|
|
||||||
"id": row.get::<String, _>("id"),
|
|
||||||
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
|
|
||||||
"description": row.get::<Option<String>, _>("description"),
|
|
||||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
|
||||||
"requests_count": row.get::<i64, _>("total_requests"),
|
|
||||||
"total_tokens": row.get::<i64, _>("total_tokens"),
|
|
||||||
"total_cost": row.get::<f64, _>("total_cost"),
|
|
||||||
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
|
|
||||||
"rate_limit_per_minute": row.get::<Option<i64>, _>("rate_limit_per_minute"),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!(clients)))
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to fetch clients: {}", e);
|
|
||||||
Json(ApiResponse::error("Failed to fetch clients".to_string()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_create_client(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
headers: axum::http::HeaderMap,
|
|
||||||
Json(payload): Json<CreateClientRequest>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
|
||||||
Ok((session, new_token)) => (session, new_token),
|
|
||||||
Err(e) => return e,
|
|
||||||
};
|
|
||||||
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
let client_id = payload
|
|
||||||
.client_id
|
|
||||||
.unwrap_or_else(|| format!("client-{}", &uuid::Uuid::new_v4().to_string()[..8]));
|
|
||||||
|
|
||||||
let result = sqlx::query(
|
|
||||||
r#"
|
|
||||||
INSERT INTO clients (client_id, name, is_active)
|
|
||||||
VALUES (?, ?, TRUE)
|
|
||||||
RETURNING *
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(&client_id)
|
|
||||||
.bind(&payload.name)
|
|
||||||
.fetch_one(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(row) => {
|
|
||||||
// Auto-generate a token for the new client
|
|
||||||
let token = generate_token();
|
|
||||||
let token_result = sqlx::query(
|
|
||||||
"INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, 'default')",
|
|
||||||
)
|
|
||||||
.bind(&client_id)
|
|
||||||
.bind(&token)
|
|
||||||
.execute(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
if let Err(e) = token_result {
|
|
||||||
warn!("Client created but failed to generate token: {}", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!({
|
|
||||||
"id": row.get::<String, _>("client_id"),
|
|
||||||
"name": row.get::<Option<String>, _>("name"),
|
|
||||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
|
||||||
"status": "active",
|
|
||||||
"token": token,
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to create client: {}", e);
|
|
||||||
Json(ApiResponse::error(format!("Failed to create client: {}", e)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_get_client(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
Path(id): Path<String>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
let result = sqlx::query(
|
|
||||||
r#"
|
|
||||||
SELECT
|
|
||||||
c.client_id as id,
|
|
||||||
c.name,
|
|
||||||
c.description,
|
|
||||||
c.is_active,
|
|
||||||
c.rate_limit_per_minute,
|
|
||||||
c.created_at,
|
|
||||||
COALESCE(c.total_tokens, 0) as total_tokens,
|
|
||||||
COALESCE(c.total_cost, 0.0) as total_cost,
|
|
||||||
COUNT(r.id) as total_requests,
|
|
||||||
MAX(r.timestamp) as last_request
|
|
||||||
FROM clients c
|
|
||||||
LEFT JOIN llm_requests r ON c.client_id = r.client_id
|
|
||||||
WHERE c.client_id = ?
|
|
||||||
GROUP BY c.client_id
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(&id)
|
|
||||||
.fetch_optional(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(Some(row)) => Json(ApiResponse::success(serde_json::json!({
|
|
||||||
"id": row.get::<String, _>("id"),
|
|
||||||
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
|
|
||||||
"description": row.get::<Option<String>, _>("description"),
|
|
||||||
"is_active": row.get::<bool, _>("is_active"),
|
|
||||||
"rate_limit_per_minute": row.get::<Option<i64>, _>("rate_limit_per_minute"),
|
|
||||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
|
||||||
"total_tokens": row.get::<i64, _>("total_tokens"),
|
|
||||||
"total_cost": row.get::<f64, _>("total_cost"),
|
|
||||||
"total_requests": row.get::<i64, _>("total_requests"),
|
|
||||||
"last_request": row.get::<Option<chrono::DateTime<chrono::Utc>>, _>("last_request"),
|
|
||||||
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
|
|
||||||
}))),
|
|
||||||
Ok(None) => Json(ApiResponse::error(format!("Client '{}' not found", id))),
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to fetch client: {}", e);
|
|
||||||
Json(ApiResponse::error(format!("Failed to fetch client: {}", e)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_update_client(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
headers: axum::http::HeaderMap,
|
|
||||||
Path(id): Path<String>,
|
|
||||||
Json(payload): Json<UpdateClientPayload>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
|
||||||
Ok((session, new_token)) => (session, new_token),
|
|
||||||
Err(e) => return e,
|
|
||||||
};
|
|
||||||
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
// Build dynamic UPDATE query from provided fields
|
|
||||||
let mut sets = Vec::new();
|
|
||||||
let mut binds: Vec<String> = Vec::new();
|
|
||||||
|
|
||||||
if let Some(ref name) = payload.name {
|
|
||||||
sets.push("name = ?");
|
|
||||||
binds.push(name.clone());
|
|
||||||
}
|
|
||||||
if let Some(ref desc) = payload.description {
|
|
||||||
sets.push("description = ?");
|
|
||||||
binds.push(desc.clone());
|
|
||||||
}
|
|
||||||
if payload.is_active.is_some() {
|
|
||||||
sets.push("is_active = ?");
|
|
||||||
}
|
|
||||||
if payload.rate_limit_per_minute.is_some() {
|
|
||||||
sets.push("rate_limit_per_minute = ?");
|
|
||||||
}
|
|
||||||
|
|
||||||
if sets.is_empty() {
|
|
||||||
return Json(ApiResponse::error("No fields to update".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Always update updated_at
|
|
||||||
sets.push("updated_at = CURRENT_TIMESTAMP");
|
|
||||||
|
|
||||||
let sql = format!("UPDATE clients SET {} WHERE client_id = ?", sets.join(", "));
|
|
||||||
let mut query = sqlx::query(&sql);
|
|
||||||
|
|
||||||
// Bind in the same order as sets
|
|
||||||
for b in &binds {
|
|
||||||
query = query.bind(b);
|
|
||||||
}
|
|
||||||
if let Some(active) = payload.is_active {
|
|
||||||
query = query.bind(active);
|
|
||||||
}
|
|
||||||
if let Some(rate) = payload.rate_limit_per_minute {
|
|
||||||
query = query.bind(rate);
|
|
||||||
}
|
|
||||||
query = query.bind(&id);
|
|
||||||
|
|
||||||
match query.execute(pool).await {
|
|
||||||
Ok(result) => {
|
|
||||||
if result.rows_affected() == 0 {
|
|
||||||
return Json(ApiResponse::error(format!("Client '{}' not found", id)));
|
|
||||||
}
|
|
||||||
// Return the updated client
|
|
||||||
let row = sqlx::query(
|
|
||||||
r#"
|
|
||||||
SELECT client_id as id, name, description, is_active, rate_limit_per_minute,
|
|
||||||
created_at, total_requests, total_tokens, total_cost
|
|
||||||
FROM clients WHERE client_id = ?
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(&id)
|
|
||||||
.fetch_one(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match row {
|
|
||||||
Ok(row) => Json(ApiResponse::success(serde_json::json!({
|
|
||||||
"id": row.get::<String, _>("id"),
|
|
||||||
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
|
|
||||||
"description": row.get::<Option<String>, _>("description"),
|
|
||||||
"is_active": row.get::<bool, _>("is_active"),
|
|
||||||
"rate_limit_per_minute": row.get::<Option<i64>, _>("rate_limit_per_minute"),
|
|
||||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
|
||||||
"total_requests": row.get::<i64, _>("total_requests"),
|
|
||||||
"total_tokens": row.get::<i64, _>("total_tokens"),
|
|
||||||
"total_cost": row.get::<f64, _>("total_cost"),
|
|
||||||
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
|
|
||||||
}))),
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to fetch updated client: {}", e);
|
|
||||||
// Update succeeded but fetch failed — still report success
|
|
||||||
Json(ApiResponse::success(serde_json::json!({ "message": "Client updated" })))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to update client: {}", e);
|
|
||||||
Json(ApiResponse::error(format!("Failed to update client: {}", e)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_delete_client(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
headers: axum::http::HeaderMap,
|
|
||||||
Path(id): Path<String>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
|
||||||
Ok((session, new_token)) => (session, new_token),
|
|
||||||
Err(e) => return e,
|
|
||||||
};
|
|
||||||
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
// Don't allow deleting the default client
|
|
||||||
if id == "default" {
|
|
||||||
return Json(ApiResponse::error("Cannot delete default client".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = sqlx::query("DELETE FROM clients WHERE client_id = ?")
|
|
||||||
.bind(id)
|
|
||||||
.execute(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Client deleted" }))),
|
|
||||||
Err(e) => Json(ApiResponse::error(format!("Failed to delete client: {}", e))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_client_usage(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
Path(id): Path<String>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
// Get per-model breakdown for this client
|
|
||||||
let result = sqlx::query(
|
|
||||||
r#"
|
|
||||||
SELECT
|
|
||||||
model,
|
|
||||||
provider,
|
|
||||||
COUNT(*) as request_count,
|
|
||||||
SUM(prompt_tokens) as prompt_tokens,
|
|
||||||
SUM(completion_tokens) as completion_tokens,
|
|
||||||
SUM(total_tokens) as total_tokens,
|
|
||||||
SUM(cost) as total_cost,
|
|
||||||
AVG(duration_ms) as avg_duration_ms
|
|
||||||
FROM llm_requests
|
|
||||||
WHERE client_id = ?
|
|
||||||
GROUP BY model, provider
|
|
||||||
ORDER BY total_cost DESC
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(&id)
|
|
||||||
.fetch_all(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(rows) => {
|
|
||||||
let breakdown: Vec<serde_json::Value> = rows
|
|
||||||
.into_iter()
|
|
||||||
.map(|row| {
|
|
||||||
serde_json::json!({
|
|
||||||
"model": row.get::<String, _>("model"),
|
|
||||||
"provider": row.get::<String, _>("provider"),
|
|
||||||
"request_count": row.get::<i64, _>("request_count"),
|
|
||||||
"prompt_tokens": row.get::<i64, _>("prompt_tokens"),
|
|
||||||
"completion_tokens": row.get::<i64, _>("completion_tokens"),
|
|
||||||
"total_tokens": row.get::<i64, _>("total_tokens"),
|
|
||||||
"total_cost": row.get::<f64, _>("total_cost"),
|
|
||||||
"avg_duration_ms": row.get::<f64, _>("avg_duration_ms"),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!({
|
|
||||||
"client_id": id,
|
|
||||||
"breakdown": breakdown,
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to fetch client usage: {}", e);
|
|
||||||
Json(ApiResponse::error(format!("Failed to fetch client usage: {}", e)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Token management endpoints ──────────────────────────────────────
|
|
||||||
|
|
||||||
pub(super) async fn handle_get_client_tokens(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
Path(id): Path<String>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
let result = sqlx::query(
|
|
||||||
r#"
|
|
||||||
SELECT id, token, name, is_active, created_at, last_used_at
|
|
||||||
FROM client_tokens
|
|
||||||
WHERE client_id = ?
|
|
||||||
ORDER BY created_at DESC
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(&id)
|
|
||||||
.fetch_all(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(rows) => {
|
|
||||||
let tokens: Vec<serde_json::Value> = rows
|
|
||||||
.into_iter()
|
|
||||||
.map(|row| {
|
|
||||||
let token: String = row.get("token");
|
|
||||||
// Mask all but last 8 chars: sk-••••abcd1234
|
|
||||||
let masked = if token.len() > 8 {
|
|
||||||
format!("{}••••{}", &token[..3], &token[token.len() - 8..])
|
|
||||||
} else {
|
|
||||||
"••••".to_string()
|
|
||||||
};
|
|
||||||
serde_json::json!({
|
|
||||||
"id": row.get::<i64, _>("id"),
|
|
||||||
"token_masked": masked,
|
|
||||||
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "default".to_string()),
|
|
||||||
"is_active": row.get::<bool, _>("is_active"),
|
|
||||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
|
||||||
"last_used_at": row.get::<Option<chrono::DateTime<chrono::Utc>>, _>("last_used_at"),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!(tokens)))
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to fetch client tokens: {}", e);
|
|
||||||
Json(ApiResponse::error(format!("Failed to fetch client tokens: {}", e)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub(super) struct CreateTokenRequest {
|
|
||||||
pub(super) name: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_create_client_token(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
headers: axum::http::HeaderMap,
|
|
||||||
Path(id): Path<String>,
|
|
||||||
Json(payload): Json<CreateTokenRequest>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
|
||||||
Ok((session, new_token)) => (session, new_token),
|
|
||||||
Err(e) => return e,
|
|
||||||
};
|
|
||||||
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
// Verify client exists
|
|
||||||
let exists: Option<(i64,)> = sqlx::query_as("SELECT 1 as x FROM clients WHERE client_id = ?")
|
|
||||||
.bind(&id)
|
|
||||||
.fetch_optional(pool)
|
|
||||||
.await
|
|
||||||
.unwrap_or(None);
|
|
||||||
|
|
||||||
if exists.is_none() {
|
|
||||||
return Json(ApiResponse::error(format!("Client '{}' not found", id)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let token = generate_token();
|
|
||||||
let token_name = payload.name.unwrap_or_else(|| "default".to_string());
|
|
||||||
|
|
||||||
let result = sqlx::query(
|
|
||||||
"INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, ?) RETURNING id, created_at",
|
|
||||||
)
|
|
||||||
.bind(&id)
|
|
||||||
.bind(&token)
|
|
||||||
.bind(&token_name)
|
|
||||||
.fetch_one(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(row) => Json(ApiResponse::success(serde_json::json!({
|
|
||||||
"id": row.get::<i64, _>("id"),
|
|
||||||
"token": token,
|
|
||||||
"name": token_name,
|
|
||||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
|
||||||
}))),
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to create client token: {}", e);
|
|
||||||
Json(ApiResponse::error(format!("Failed to create token: {}", e)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_delete_client_token(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
headers: axum::http::HeaderMap,
|
|
||||||
Path((client_id, token_id)): Path<(String, i64)>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
|
||||||
Ok((session, new_token)) => (session, new_token),
|
|
||||||
Err(e) => return e,
|
|
||||||
};
|
|
||||||
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
let result = sqlx::query("DELETE FROM client_tokens WHERE id = ? AND client_id = ?")
|
|
||||||
.bind(token_id)
|
|
||||||
.bind(&client_id)
|
|
||||||
.execute(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(r) => {
|
|
||||||
if r.rows_affected() == 0 {
|
|
||||||
Json(ApiResponse::error("Token not found".to_string()))
|
|
||||||
} else {
|
|
||||||
Json(ApiResponse::success(serde_json::json!({ "message": "Token revoked" })))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to delete client token: {}", e);
|
|
||||||
Json(ApiResponse::error(format!("Failed to revoke token: {}", e)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,177 +0,0 @@
|
|||||||
// Dashboard module for LLM Proxy Gateway
|
|
||||||
|
|
||||||
mod auth;
|
|
||||||
mod clients;
|
|
||||||
mod models;
|
|
||||||
mod providers;
|
|
||||||
pub mod sessions;
|
|
||||||
mod system;
|
|
||||||
mod usage;
|
|
||||||
mod users;
|
|
||||||
mod websocket;
|
|
||||||
|
|
||||||
use axum::{
|
|
||||||
extract::{Request, State},
|
|
||||||
middleware::Next,
|
|
||||||
response::Response,
|
|
||||||
Router,
|
|
||||||
routing::{delete, get, post, put},
|
|
||||||
};
|
|
||||||
use axum::http::{header, HeaderValue};
|
|
||||||
use serde::Serialize;
|
|
||||||
use tower_http::{
|
|
||||||
limit::RequestBodyLimitLayer,
|
|
||||||
set_header::SetResponseHeaderLayer,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::state::AppState;
|
|
||||||
use sessions::SessionManager;
|
|
||||||
|
|
||||||
// Dashboard state
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct DashboardState {
|
|
||||||
app_state: AppState,
|
|
||||||
session_manager: SessionManager,
|
|
||||||
}
|
|
||||||
|
|
||||||
// API Response types
|
|
||||||
#[derive(Serialize)]
|
|
||||||
struct ApiResponse<T> {
|
|
||||||
success: bool,
|
|
||||||
data: Option<T>,
|
|
||||||
error: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> ApiResponse<T> {
|
|
||||||
fn success(data: T) -> Self {
|
|
||||||
Self {
|
|
||||||
success: true,
|
|
||||||
data: Some(data),
|
|
||||||
error: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn error(error: String) -> Self {
|
|
||||||
Self {
|
|
||||||
success: false,
|
|
||||||
data: None,
|
|
||||||
error: Some(error),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Rate limiting middleware for dashboard routes that extracts AppState from DashboardState.
|
|
||||||
async fn dashboard_rate_limit_middleware(
|
|
||||||
State(dashboard_state): State<DashboardState>,
|
|
||||||
request: Request,
|
|
||||||
next: Next,
|
|
||||||
) -> Result<Response, crate::errors::AppError> {
|
|
||||||
// Delegate to the existing rate limit middleware with AppState
|
|
||||||
crate::rate_limiting::middleware::rate_limit_middleware(
|
|
||||||
State(dashboard_state.app_state),
|
|
||||||
request,
|
|
||||||
next,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dashboard routes
|
|
||||||
pub fn router(state: AppState) -> Router {
|
|
||||||
let session_manager = SessionManager::new(24); // 24-hour session TTL
|
|
||||||
let dashboard_state = DashboardState {
|
|
||||||
app_state: state,
|
|
||||||
session_manager,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Security headers
|
|
||||||
let csp_header: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
|
||||||
header::CONTENT_SECURITY_POLICY,
|
|
||||||
"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws:;"
|
|
||||||
.parse()
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
let x_frame_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
|
||||||
header::X_FRAME_OPTIONS,
|
|
||||||
"DENY".parse().unwrap(),
|
|
||||||
);
|
|
||||||
let x_content_type_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
|
||||||
header::X_CONTENT_TYPE_OPTIONS,
|
|
||||||
"nosniff".parse().unwrap(),
|
|
||||||
);
|
|
||||||
let strict_transport_security: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
|
||||||
header::STRICT_TRANSPORT_SECURITY,
|
|
||||||
"max-age=31536000; includeSubDomains".parse().unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
Router::new()
|
|
||||||
// Static file serving
|
|
||||||
.fallback_service(tower_http::services::ServeDir::new("static"))
|
|
||||||
// WebSocket endpoint
|
|
||||||
.route("/ws", get(websocket::handle_websocket))
|
|
||||||
// API endpoints
|
|
||||||
.route("/api/auth/login", post(auth::handle_login))
|
|
||||||
.route("/api/auth/status", get(auth::handle_auth_status))
|
|
||||||
.route("/api/auth/logout", post(auth::handle_logout))
|
|
||||||
.route("/api/auth/change-password", post(auth::handle_change_password))
|
|
||||||
.route(
|
|
||||||
"/api/users",
|
|
||||||
get(users::handle_get_users).post(users::handle_create_user),
|
|
||||||
)
|
|
||||||
.route(
|
|
||||||
"/api/users/{id}",
|
|
||||||
put(users::handle_update_user).delete(users::handle_delete_user),
|
|
||||||
)
|
|
||||||
.route("/api/usage/summary", get(usage::handle_usage_summary))
|
|
||||||
.route("/api/usage/time-series", get(usage::handle_time_series))
|
|
||||||
.route("/api/usage/clients", get(usage::handle_clients_usage))
|
|
||||||
.route("/api/usage/providers", get(usage::handle_providers_usage))
|
|
||||||
.route("/api/usage/detailed", get(usage::handle_detailed_usage))
|
|
||||||
.route("/api/analytics/breakdown", get(usage::handle_analytics_breakdown))
|
|
||||||
.route("/api/models", get(models::handle_get_models))
|
|
||||||
.route("/api/models/{id}", put(models::handle_update_model))
|
|
||||||
.route(
|
|
||||||
"/api/clients",
|
|
||||||
get(clients::handle_get_clients).post(clients::handle_create_client),
|
|
||||||
)
|
|
||||||
.route(
|
|
||||||
"/api/clients/{id}",
|
|
||||||
get(clients::handle_get_client)
|
|
||||||
.put(clients::handle_update_client)
|
|
||||||
.delete(clients::handle_delete_client),
|
|
||||||
)
|
|
||||||
.route("/api/clients/{id}/usage", get(clients::handle_client_usage))
|
|
||||||
.route(
|
|
||||||
"/api/clients/{id}/tokens",
|
|
||||||
get(clients::handle_get_client_tokens).post(clients::handle_create_client_token),
|
|
||||||
)
|
|
||||||
.route(
|
|
||||||
"/api/clients/{id}/tokens/{token_id}",
|
|
||||||
delete(clients::handle_delete_client_token),
|
|
||||||
)
|
|
||||||
.route("/api/providers", get(providers::handle_get_providers))
|
|
||||||
.route(
|
|
||||||
"/api/providers/{name}",
|
|
||||||
get(providers::handle_get_provider).put(providers::handle_update_provider),
|
|
||||||
)
|
|
||||||
.route("/api/providers/{name}/test", post(providers::handle_test_provider))
|
|
||||||
.route("/api/system/health", get(system::handle_system_health))
|
|
||||||
.route("/api/system/metrics", get(system::handle_system_metrics))
|
|
||||||
.route("/api/system/logs", get(system::handle_system_logs))
|
|
||||||
.route("/api/system/backup", post(system::handle_system_backup))
|
|
||||||
.route(
|
|
||||||
"/api/system/settings",
|
|
||||||
get(system::handle_get_settings).post(system::handle_update_settings),
|
|
||||||
)
|
|
||||||
// Security layers
|
|
||||||
.layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit
|
|
||||||
.layer(csp_header)
|
|
||||||
.layer(x_frame_options)
|
|
||||||
.layer(x_content_type_options)
|
|
||||||
.layer(strict_transport_security)
|
|
||||||
// Rate limiting middleware
|
|
||||||
.layer(axum::middleware::from_fn_with_state(
|
|
||||||
dashboard_state.clone(),
|
|
||||||
dashboard_rate_limit_middleware,
|
|
||||||
))
|
|
||||||
.with_state(dashboard_state)
|
|
||||||
}
|
|
||||||
@@ -1,205 +0,0 @@
|
|||||||
use axum::{
|
|
||||||
extract::{Path, Query, State},
|
|
||||||
response::Json,
|
|
||||||
};
|
|
||||||
use serde::Deserialize;
|
|
||||||
use serde_json;
|
|
||||||
use sqlx::Row;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use super::{ApiResponse, DashboardState};
|
|
||||||
use crate::models::registry::{ModelFilter, ModelSortBy, SortOrder};
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub(super) struct UpdateModelRequest {
|
|
||||||
pub(super) enabled: bool,
|
|
||||||
pub(super) prompt_cost: Option<f64>,
|
|
||||||
pub(super) completion_cost: Option<f64>,
|
|
||||||
pub(super) mapping: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Query parameters for `GET /api/models`.
|
|
||||||
#[derive(Debug, Deserialize, Default)]
|
|
||||||
pub(super) struct ModelListParams {
|
|
||||||
/// Filter by provider ID.
|
|
||||||
pub provider: Option<String>,
|
|
||||||
/// Text search on model ID or name.
|
|
||||||
pub search: Option<String>,
|
|
||||||
/// Filter by input modality (e.g. "image").
|
|
||||||
pub modality: Option<String>,
|
|
||||||
/// Only models that support tool calling.
|
|
||||||
pub tool_call: Option<bool>,
|
|
||||||
/// Only models that support reasoning.
|
|
||||||
pub reasoning: Option<bool>,
|
|
||||||
/// Only models that have pricing data.
|
|
||||||
pub has_cost: Option<bool>,
|
|
||||||
/// Only models that have been used in requests.
|
|
||||||
pub used_only: Option<bool>,
|
|
||||||
/// Sort field (name, id, provider, context_limit, input_cost, output_cost).
|
|
||||||
pub sort_by: Option<ModelSortBy>,
|
|
||||||
/// Sort direction (asc, desc).
|
|
||||||
pub sort_order: Option<SortOrder>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_get_models(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
Query(params): Query<ModelListParams>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let registry = &state.app_state.model_registry;
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
// If used_only, fetch the set of models that appear in llm_requests
|
|
||||||
let used_models: Option<std::collections::HashSet<String>> =
|
|
||||||
if params.used_only.unwrap_or(false) {
|
|
||||||
match sqlx::query_scalar::<_, String>(
|
|
||||||
"SELECT DISTINCT model FROM llm_requests",
|
|
||||||
)
|
|
||||||
.fetch_all(pool)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(models) => Some(models.into_iter().collect()),
|
|
||||||
Err(_) => Some(std::collections::HashSet::new()),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
// Build filter from query params
|
|
||||||
let filter = ModelFilter {
|
|
||||||
provider: params.provider,
|
|
||||||
search: params.search,
|
|
||||||
modality: params.modality,
|
|
||||||
tool_call: params.tool_call,
|
|
||||||
reasoning: params.reasoning,
|
|
||||||
has_cost: params.has_cost,
|
|
||||||
};
|
|
||||||
let sort_by = params.sort_by.unwrap_or_default();
|
|
||||||
let sort_order = params.sort_order.unwrap_or_default();
|
|
||||||
|
|
||||||
// Get filtered and sorted model entries
|
|
||||||
let entries = registry.list_models(&filter, &sort_by, &sort_order);
|
|
||||||
|
|
||||||
// Load overrides from database
|
|
||||||
let db_models_result =
|
|
||||||
sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
|
|
||||||
.fetch_all(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let mut db_models = HashMap::new();
|
|
||||||
if let Ok(rows) = db_models_result {
|
|
||||||
for row in rows {
|
|
||||||
let id: String = row.get("id");
|
|
||||||
db_models.insert(id, row);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut models_json = Vec::new();
|
|
||||||
|
|
||||||
for entry in &entries {
|
|
||||||
let m_key = entry.model_key;
|
|
||||||
|
|
||||||
// Skip models not in the used set (when used_only is active)
|
|
||||||
if let Some(ref used) = used_models {
|
|
||||||
if !used.contains(m_key) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let m_meta = entry.metadata;
|
|
||||||
|
|
||||||
let mut enabled = true;
|
|
||||||
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
|
||||||
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
|
||||||
let cache_read_cost = m_meta.cost.as_ref().and_then(|c| c.cache_read);
|
|
||||||
let cache_write_cost = m_meta.cost.as_ref().and_then(|c| c.cache_write);
|
|
||||||
let mut mapping = None::<String>;
|
|
||||||
|
|
||||||
if let Some(row) = db_models.get(m_key) {
|
|
||||||
enabled = row.get("enabled");
|
|
||||||
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") {
|
|
||||||
prompt_cost = p;
|
|
||||||
}
|
|
||||||
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") {
|
|
||||||
completion_cost = c;
|
|
||||||
}
|
|
||||||
mapping = row.get("mapping");
|
|
||||||
}
|
|
||||||
|
|
||||||
models_json.push(serde_json::json!({
|
|
||||||
"id": m_key,
|
|
||||||
"provider": entry.provider_id,
|
|
||||||
"provider_name": entry.provider_name,
|
|
||||||
"name": m_meta.name,
|
|
||||||
"enabled": enabled,
|
|
||||||
"prompt_cost": prompt_cost,
|
|
||||||
"completion_cost": completion_cost,
|
|
||||||
"cache_read_cost": cache_read_cost,
|
|
||||||
"cache_write_cost": cache_write_cost,
|
|
||||||
"mapping": mapping,
|
|
||||||
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
|
|
||||||
"output_limit": m_meta.limit.as_ref().map(|l| l.output).unwrap_or(0),
|
|
||||||
"modalities": m_meta.modalities.as_ref().map(|m| serde_json::json!({
|
|
||||||
"input": m.input,
|
|
||||||
"output": m.output,
|
|
||||||
})),
|
|
||||||
"tool_call": m_meta.tool_call,
|
|
||||||
"reasoning": m_meta.reasoning,
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!(models_json)))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_update_model(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
headers: axum::http::HeaderMap,
|
|
||||||
Path(id): Path<String>,
|
|
||||||
Json(payload): Json<UpdateModelRequest>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
|
||||||
Ok((session, new_token)) => (session, new_token),
|
|
||||||
Err(e) => return e,
|
|
||||||
};
|
|
||||||
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
// Find provider_id for this model in registry
|
|
||||||
let provider_id = state
|
|
||||||
.app_state
|
|
||||||
.model_registry
|
|
||||||
.providers
|
|
||||||
.iter()
|
|
||||||
.find(|(_, p)| p.models.contains_key(&id))
|
|
||||||
.map(|(id, _)| id.clone())
|
|
||||||
.unwrap_or_else(|| "unknown".to_string());
|
|
||||||
|
|
||||||
let result = sqlx::query(
|
|
||||||
r#"
|
|
||||||
INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?)
|
|
||||||
ON CONFLICT(id) DO UPDATE SET
|
|
||||||
enabled = excluded.enabled,
|
|
||||||
prompt_cost_per_m = excluded.prompt_cost_per_m,
|
|
||||||
completion_cost_per_m = excluded.completion_cost_per_m,
|
|
||||||
mapping = excluded.mapping,
|
|
||||||
updated_at = CURRENT_TIMESTAMP
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(&id)
|
|
||||||
.bind(provider_id)
|
|
||||||
.bind(payload.enabled)
|
|
||||||
.bind(payload.prompt_cost)
|
|
||||||
.bind(payload.completion_cost)
|
|
||||||
.bind(payload.mapping)
|
|
||||||
.execute(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(_) => {
|
|
||||||
// Invalidate the in-memory cache so the proxy picks up the change immediately
|
|
||||||
state.app_state.model_config_cache.invalidate().await;
|
|
||||||
Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" })))
|
|
||||||
}
|
|
||||||
Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,420 +0,0 @@
|
|||||||
use axum::{
|
|
||||||
extract::{Path, State},
|
|
||||||
response::Json,
|
|
||||||
};
|
|
||||||
use serde::Deserialize;
|
|
||||||
use serde_json;
|
|
||||||
use sqlx::Row;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use tracing::warn;
|
|
||||||
|
|
||||||
use super::{ApiResponse, DashboardState};
|
|
||||||
use crate::utils::crypto;
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub(super) struct UpdateProviderRequest {
|
|
||||||
pub(super) enabled: bool,
|
|
||||||
pub(super) base_url: Option<String>,
|
|
||||||
pub(super) api_key: Option<String>,
|
|
||||||
pub(super) credit_balance: Option<f64>,
|
|
||||||
pub(super) low_credit_threshold: Option<f64>,
|
|
||||||
pub(super) billing_mode: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_get_providers(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let registry = &state.app_state.model_registry;
|
|
||||||
let config = &state.app_state.config;
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
// Load all overrides from database (including billing_mode)
|
|
||||||
let db_configs_result = sqlx::query(
|
|
||||||
"SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs",
|
|
||||||
)
|
|
||||||
.fetch_all(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let mut db_configs = HashMap::new();
|
|
||||||
if let Ok(rows) = db_configs_result {
|
|
||||||
for row in rows {
|
|
||||||
let id: String = row.get("id");
|
|
||||||
let enabled: bool = row.get("enabled");
|
|
||||||
let base_url: Option<String> = row.get("base_url");
|
|
||||||
let balance: f64 = row.get("credit_balance");
|
|
||||||
let threshold: f64 = row.get("low_credit_threshold");
|
|
||||||
let billing_mode: Option<String> = row.get("billing_mode");
|
|
||||||
db_configs.insert(id, (enabled, base_url, balance, threshold, billing_mode));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut providers_json = Vec::new();
|
|
||||||
|
|
||||||
// Define the list of providers we support
|
|
||||||
let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
|
|
||||||
|
|
||||||
for id in provider_ids {
|
|
||||||
// Get base config
|
|
||||||
let (mut enabled, mut base_url, display_name) = match id {
|
|
||||||
"openai" => (
|
|
||||||
config.providers.openai.enabled,
|
|
||||||
config.providers.openai.base_url.clone(),
|
|
||||||
"OpenAI",
|
|
||||||
),
|
|
||||||
"gemini" => (
|
|
||||||
config.providers.gemini.enabled,
|
|
||||||
config.providers.gemini.base_url.clone(),
|
|
||||||
"Google Gemini",
|
|
||||||
),
|
|
||||||
"deepseek" => (
|
|
||||||
config.providers.deepseek.enabled,
|
|
||||||
config.providers.deepseek.base_url.clone(),
|
|
||||||
"DeepSeek",
|
|
||||||
),
|
|
||||||
"grok" => (
|
|
||||||
config.providers.grok.enabled,
|
|
||||||
config.providers.grok.base_url.clone(),
|
|
||||||
"xAI Grok",
|
|
||||||
),
|
|
||||||
"ollama" => (
|
|
||||||
config.providers.ollama.enabled,
|
|
||||||
config.providers.ollama.base_url.clone(),
|
|
||||||
"Ollama",
|
|
||||||
),
|
|
||||||
_ => (false, "".to_string(), "Unknown"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut balance = 0.0;
|
|
||||||
let mut threshold = 5.0;
|
|
||||||
let mut billing_mode: Option<String> = None;
|
|
||||||
|
|
||||||
// Apply database overrides
|
|
||||||
if let Some((db_enabled, db_url, db_balance, db_threshold, db_billing)) = db_configs.get(id) {
|
|
||||||
enabled = *db_enabled;
|
|
||||||
if let Some(url) = db_url {
|
|
||||||
base_url = url.clone();
|
|
||||||
}
|
|
||||||
balance = *db_balance;
|
|
||||||
threshold = *db_threshold;
|
|
||||||
billing_mode = db_billing.clone();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find models for this provider in registry
|
|
||||||
// NOTE: registry provider IDs differ from internal IDs for some providers.
|
|
||||||
let registry_key = match id {
|
|
||||||
"gemini" => "google",
|
|
||||||
"grok" => "xai",
|
|
||||||
_ => id,
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut models = Vec::new();
|
|
||||||
if let Some(p_info) = registry.providers.get(registry_key) {
|
|
||||||
models = p_info.models.keys().cloned().collect();
|
|
||||||
} else if id == "ollama" {
|
|
||||||
models = config.providers.ollama.models.clone();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine status
|
|
||||||
let status = if !enabled {
|
|
||||||
"disabled"
|
|
||||||
} else {
|
|
||||||
// Check if it's actually initialized in the provider manager
|
|
||||||
if state.app_state.provider_manager.get_provider(id).await.is_some() {
|
|
||||||
// Check circuit breaker
|
|
||||||
if state
|
|
||||||
.app_state
|
|
||||||
.rate_limit_manager
|
|
||||||
.check_provider_request(id)
|
|
||||||
.await
|
|
||||||
.unwrap_or(true)
|
|
||||||
{
|
|
||||||
"online"
|
|
||||||
} else {
|
|
||||||
"degraded"
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
"error" // Enabled but failed to initialize (e.g. missing API key)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
providers_json.push(serde_json::json!({
|
|
||||||
"id": id,
|
|
||||||
"name": display_name,
|
|
||||||
"enabled": enabled,
|
|
||||||
"status": status,
|
|
||||||
"models": models,
|
|
||||||
"base_url": base_url,
|
|
||||||
"credit_balance": balance,
|
|
||||||
"low_credit_threshold": threshold,
|
|
||||||
"billing_mode": billing_mode,
|
|
||||||
"last_used": None::<String>,
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!(providers_json)))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_get_provider(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
Path(name): Path<String>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let registry = &state.app_state.model_registry;
|
|
||||||
let config = &state.app_state.config;
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
// Validate provider name
|
|
||||||
let (mut enabled, mut base_url, display_name) = match name.as_str() {
|
|
||||||
"openai" => (
|
|
||||||
config.providers.openai.enabled,
|
|
||||||
config.providers.openai.base_url.clone(),
|
|
||||||
"OpenAI",
|
|
||||||
),
|
|
||||||
"gemini" => (
|
|
||||||
config.providers.gemini.enabled,
|
|
||||||
config.providers.gemini.base_url.clone(),
|
|
||||||
"Google Gemini",
|
|
||||||
),
|
|
||||||
"deepseek" => (
|
|
||||||
config.providers.deepseek.enabled,
|
|
||||||
config.providers.deepseek.base_url.clone(),
|
|
||||||
"DeepSeek",
|
|
||||||
),
|
|
||||||
"grok" => (
|
|
||||||
config.providers.grok.enabled,
|
|
||||||
config.providers.grok.base_url.clone(),
|
|
||||||
"xAI Grok",
|
|
||||||
),
|
|
||||||
"ollama" => (
|
|
||||||
config.providers.ollama.enabled,
|
|
||||||
config.providers.ollama.base_url.clone(),
|
|
||||||
"Ollama",
|
|
||||||
),
|
|
||||||
_ => return Json(ApiResponse::error(format!("Unknown provider '{}'", name))),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut balance = 0.0;
|
|
||||||
let mut threshold = 5.0;
|
|
||||||
let mut billing_mode: Option<String> = None;
|
|
||||||
|
|
||||||
// Apply database overrides
|
|
||||||
let db_config = sqlx::query(
|
|
||||||
"SELECT enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs WHERE id = ?",
|
|
||||||
)
|
|
||||||
.bind(&name)
|
|
||||||
.fetch_optional(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
if let Ok(Some(row)) = db_config {
|
|
||||||
enabled = row.get::<bool, _>("enabled");
|
|
||||||
if let Some(url) = row.get::<Option<String>, _>("base_url") {
|
|
||||||
base_url = url;
|
|
||||||
}
|
|
||||||
balance = row.get::<f64, _>("credit_balance");
|
|
||||||
threshold = row.get::<f64, _>("low_credit_threshold");
|
|
||||||
billing_mode = row.get::<Option<String>, _>("billing_mode");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find models for this provider
|
|
||||||
// NOTE: registry provider IDs differ from internal IDs for some providers.
|
|
||||||
let registry_key = match name.as_str() {
|
|
||||||
"gemini" => "google",
|
|
||||||
"grok" => "xai",
|
|
||||||
_ => name.as_str(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut models = Vec::new();
|
|
||||||
if let Some(p_info) = registry.providers.get(registry_key) {
|
|
||||||
models = p_info.models.keys().cloned().collect();
|
|
||||||
} else if name == "ollama" {
|
|
||||||
models = config.providers.ollama.models.clone();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine status
|
|
||||||
let status = if !enabled {
|
|
||||||
"disabled"
|
|
||||||
} else if state.app_state.provider_manager.get_provider(&name).await.is_some() {
|
|
||||||
if state
|
|
||||||
.app_state
|
|
||||||
.rate_limit_manager
|
|
||||||
.check_provider_request(&name)
|
|
||||||
.await
|
|
||||||
.unwrap_or(true)
|
|
||||||
{
|
|
||||||
"online"
|
|
||||||
} else {
|
|
||||||
"degraded"
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
"error"
|
|
||||||
};
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!({
|
|
||||||
"id": name,
|
|
||||||
"name": display_name,
|
|
||||||
"enabled": enabled,
|
|
||||||
"status": status,
|
|
||||||
"models": models,
|
|
||||||
"base_url": base_url,
|
|
||||||
"credit_balance": balance,
|
|
||||||
"low_credit_threshold": threshold,
|
|
||||||
"billing_mode": billing_mode,
|
|
||||||
"last_used": None::<String>,
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_update_provider(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
headers: axum::http::HeaderMap,
|
|
||||||
Path(name): Path<String>,
|
|
||||||
Json(payload): Json<UpdateProviderRequest>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
|
||||||
Ok((session, new_token)) => (session, new_token),
|
|
||||||
Err(e) => return e,
|
|
||||||
};
|
|
||||||
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
// Prepare API key encryption if provided
|
|
||||||
let (api_key_to_store, api_key_encrypted_flag) = match &payload.api_key {
|
|
||||||
Some(key) if !key.is_empty() => {
|
|
||||||
match crypto::encrypt(key) {
|
|
||||||
Ok(encrypted) => (Some(encrypted), Some(true)),
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to encrypt API key for provider {}: {}", name, e);
|
|
||||||
return Json(ApiResponse::error(format!("Failed to encrypt API key: {}", e)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some(_) => {
|
|
||||||
// Empty string means clear the key
|
|
||||||
(None, Some(false))
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
// Keep existing key, we'll rely on COALESCE in SQL
|
|
||||||
(None, None)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Update or insert into database (include billing_mode and api_key_encrypted)
|
|
||||||
let result = sqlx::query(
|
|
||||||
r#"
|
|
||||||
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold, billing_mode)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
ON CONFLICT(id) DO UPDATE SET
|
|
||||||
enabled = excluded.enabled,
|
|
||||||
base_url = excluded.base_url,
|
|
||||||
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
|
|
||||||
api_key_encrypted = COALESCE(excluded.api_key_encrypted, provider_configs.api_key_encrypted),
|
|
||||||
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
|
|
||||||
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
|
|
||||||
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
|
|
||||||
updated_at = CURRENT_TIMESTAMP
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(&name)
|
|
||||||
.bind(name.to_uppercase())
|
|
||||||
.bind(payload.enabled)
|
|
||||||
.bind(&payload.base_url)
|
|
||||||
.bind(&api_key_to_store)
|
|
||||||
.bind(api_key_encrypted_flag)
|
|
||||||
.bind(payload.credit_balance)
|
|
||||||
.bind(payload.low_credit_threshold)
|
|
||||||
.bind(payload.billing_mode)
|
|
||||||
.execute(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(_) => {
|
|
||||||
// Re-initialize provider in manager
|
|
||||||
if let Err(e) = state
|
|
||||||
.app_state
|
|
||||||
.provider_manager
|
|
||||||
.initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
warn!("Failed to re-initialize provider {}: {}", name, e);
|
|
||||||
return Json(ApiResponse::error(format!(
|
|
||||||
"Provider settings saved but initialization failed: {}",
|
|
||||||
e
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
Json(ApiResponse::success(
|
|
||||||
serde_json::json!({ "message": "Provider updated and re-initialized" }),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to update provider config: {}", e);
|
|
||||||
Json(ApiResponse::error(format!("Failed to update provider: {}", e)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_test_provider(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
Path(name): Path<String>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
|
|
||||||
let provider = match state.app_state.provider_manager.get_provider(&name).await {
|
|
||||||
Some(p) => p,
|
|
||||||
None => {
|
|
||||||
return Json(ApiResponse::error(format!(
|
|
||||||
"Provider '{}' not found or not enabled",
|
|
||||||
name
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Pick a real model for this provider from the registry
|
|
||||||
// NOTE: registry provider IDs differ from internal IDs for some providers.
|
|
||||||
let registry_key = match name.as_str() {
|
|
||||||
"gemini" => "google",
|
|
||||||
"grok" => "xai",
|
|
||||||
_ => name.as_str(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let test_model = state
|
|
||||||
.app_state
|
|
||||||
.model_registry
|
|
||||||
.providers
|
|
||||||
.get(registry_key)
|
|
||||||
.and_then(|p| p.models.keys().next().cloned())
|
|
||||||
.unwrap_or_else(|| name.clone());
|
|
||||||
|
|
||||||
let test_request = crate::models::UnifiedRequest {
|
|
||||||
client_id: "system-test".to_string(),
|
|
||||||
model: test_model,
|
|
||||||
messages: vec![crate::models::UnifiedMessage {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: vec![crate::models::ContentPart::Text { text: "Hi".to_string() }],
|
|
||||||
reasoning_content: None,
|
|
||||||
tool_calls: None,
|
|
||||||
name: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
}],
|
|
||||||
temperature: None,
|
|
||||||
top_p: None,
|
|
||||||
top_k: None,
|
|
||||||
n: None,
|
|
||||||
stop: None,
|
|
||||||
max_tokens: Some(5),
|
|
||||||
presence_penalty: None,
|
|
||||||
frequency_penalty: None,
|
|
||||||
stream: false,
|
|
||||||
has_images: false,
|
|
||||||
tools: None,
|
|
||||||
tool_choice: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
match provider.chat_completion(test_request).await {
|
|
||||||
Ok(_) => {
|
|
||||||
let latency = start.elapsed().as_millis();
|
|
||||||
Json(ApiResponse::success(serde_json::json!({
|
|
||||||
"success": true,
|
|
||||||
"latency": latency,
|
|
||||||
"message": "Connection test successful"
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
Err(e) => Json(ApiResponse::error(format!("Provider test failed: {}", e))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,311 +0,0 @@
|
|||||||
use chrono::{DateTime, Duration, Utc};
|
|
||||||
use hmac::{Hmac, Mac};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use sha2::{Sha256, digest::generic_array::GenericArray};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::env;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use tokio::sync::RwLock;
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
|
|
||||||
|
|
||||||
const TOKEN_VERSION: &str = "v2";
|
|
||||||
const REFRESH_WINDOW_MINUTES: i64 = 15; // refresh if token expires within 15 minutes
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct Session {
|
|
||||||
pub username: String,
|
|
||||||
pub role: String,
|
|
||||||
pub created_at: DateTime<Utc>,
|
|
||||||
pub expires_at: DateTime<Utc>,
|
|
||||||
pub session_id: String, // unique identifier for the session (UUID)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct SessionManager {
|
|
||||||
sessions: Arc<RwLock<HashMap<String, Session>>>, // key = session_id
|
|
||||||
ttl_hours: i64,
|
|
||||||
secret: Vec<u8>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
struct SessionPayload {
|
|
||||||
session_id: String,
|
|
||||||
username: String,
|
|
||||||
role: String,
|
|
||||||
iat: i64, // issued at (Unix timestamp)
|
|
||||||
exp: i64, // expiry (Unix timestamp)
|
|
||||||
version: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SessionManager {
|
|
||||||
pub fn new(ttl_hours: i64) -> Self {
|
|
||||||
let secret = load_session_secret();
|
|
||||||
Self {
|
|
||||||
sessions: Arc::new(RwLock::new(HashMap::new())),
|
|
||||||
ttl_hours,
|
|
||||||
secret,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a new session and return a signed session token.
|
|
||||||
pub async fn create_session(&self, username: String, role: String) -> String {
|
|
||||||
let session_id = Uuid::new_v4().to_string();
|
|
||||||
let now = Utc::now();
|
|
||||||
let expires_at = now + Duration::hours(self.ttl_hours);
|
|
||||||
let session = Session {
|
|
||||||
username: username.clone(),
|
|
||||||
role: role.clone(),
|
|
||||||
created_at: now,
|
|
||||||
expires_at,
|
|
||||||
session_id: session_id.clone(),
|
|
||||||
};
|
|
||||||
// Store session by session_id
|
|
||||||
self.sessions.write().await.insert(session_id.clone(), session);
|
|
||||||
// Create signed token
|
|
||||||
self.create_signed_token(&session_id, &username, &role, now.timestamp(), expires_at.timestamp())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Validate a session token and return the session if valid and not expired.
|
|
||||||
/// If the token is within the refresh window, returns a new token as well.
|
|
||||||
pub async fn validate_session(&self, token: &str) -> Option<Session> {
|
|
||||||
self.validate_session_with_refresh(token).await.map(|(session, _)| session)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Validate a session token and return (session, optional new token if refreshed).
|
|
||||||
pub async fn validate_session_with_refresh(&self, token: &str) -> Option<(Session, Option<String>)> {
|
|
||||||
// Legacy token format (UUID)
|
|
||||||
if token.starts_with("session-") {
|
|
||||||
let sessions = self.sessions.read().await;
|
|
||||||
return sessions.get(token).and_then(|s| {
|
|
||||||
if s.expires_at > Utc::now() {
|
|
||||||
Some((s.clone(), None))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Signed token format
|
|
||||||
let payload = match verify_signed_token(token, &self.secret) {
|
|
||||||
Ok(p) => p,
|
|
||||||
Err(_) => return None,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check expiry
|
|
||||||
let now = Utc::now().timestamp();
|
|
||||||
if payload.exp <= now {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Look up session by session_id
|
|
||||||
let sessions = self.sessions.read().await;
|
|
||||||
let session = match sessions.get(&payload.session_id) {
|
|
||||||
Some(s) => s.clone(),
|
|
||||||
None => return None, // session revoked or not found
|
|
||||||
};
|
|
||||||
|
|
||||||
// Ensure session username/role matches (should always match)
|
|
||||||
if session.username != payload.username || session.role != payload.role {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if token is within refresh window (last REFRESH_WINDOW_MINUTES of validity)
|
|
||||||
let refresh_threshold = payload.exp - REFRESH_WINDOW_MINUTES * 60;
|
|
||||||
let new_token = if now >= refresh_threshold {
|
|
||||||
// Generate a new token with same session data but updated iat/exp?
|
|
||||||
// According to activity-based refresh, we should extend the session expiry.
|
|
||||||
// We'll extend from now by ttl_hours (or keep original expiry?).
|
|
||||||
// Let's extend from now by ttl_hours (sliding window).
|
|
||||||
let new_exp = Utc::now() + Duration::hours(self.ttl_hours);
|
|
||||||
// Update session expiry in store
|
|
||||||
drop(sessions); // release read lock before acquiring write lock
|
|
||||||
self.update_session_expiry(&payload.session_id, new_exp).await;
|
|
||||||
// Create new token with updated iat/exp
|
|
||||||
let new_token = self.create_signed_token(
|
|
||||||
&payload.session_id,
|
|
||||||
&payload.username,
|
|
||||||
&payload.role,
|
|
||||||
now,
|
|
||||||
new_exp.timestamp(),
|
|
||||||
);
|
|
||||||
Some(new_token)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
Some((session, new_token))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Revoke (delete) a session by token.
|
|
||||||
/// Supports both legacy tokens (token is key) and signed tokens (extract session_id).
|
|
||||||
pub async fn revoke_session(&self, token: &str) {
|
|
||||||
if token.starts_with("session-") {
|
|
||||||
self.sessions.write().await.remove(token);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// For signed token, try to extract session_id
|
|
||||||
if let Ok(payload) = verify_signed_token(token, &self.secret) {
|
|
||||||
self.sessions.write().await.remove(&payload.session_id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Remove all expired sessions from the store.
|
|
||||||
pub async fn cleanup_expired(&self) {
|
|
||||||
let now = Utc::now();
|
|
||||||
self.sessions.write().await.retain(|_, s| s.expires_at > now);
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- Private helpers ---
|
|
||||||
|
|
||||||
fn create_signed_token(&self, session_id: &str, username: &str, role: &str, iat: i64, exp: i64) -> String {
|
|
||||||
let payload = SessionPayload {
|
|
||||||
session_id: session_id.to_string(),
|
|
||||||
username: username.to_string(),
|
|
||||||
role: role.to_string(),
|
|
||||||
iat,
|
|
||||||
exp,
|
|
||||||
version: TOKEN_VERSION.to_string(),
|
|
||||||
};
|
|
||||||
sign_token(&payload, &self.secret)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn update_session_expiry(&self, session_id: &str, new_expires_at: DateTime<Utc>) {
|
|
||||||
let mut sessions = self.sessions.write().await;
|
|
||||||
if let Some(session) = sessions.get_mut(session_id) {
|
|
||||||
session.expires_at = new_expires_at;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Load session secret from environment variable SESSION_SECRET (hex or base64 encoded).
|
|
||||||
/// If not set, generates a random 32-byte secret and logs a warning.
|
|
||||||
fn load_session_secret() -> Vec<u8> {
|
|
||||||
let secret_str = env::var("SESSION_SECRET").unwrap_or_else(|_| {
|
|
||||||
// Also check LLM_PROXY__SESSION_SECRET for consistency with config prefix
|
|
||||||
env::var("LLM_PROXY__SESSION_SECRET").unwrap_or_else(|_| {
|
|
||||||
// Generate a random secret (32 bytes) and encode as hex
|
|
||||||
use rand::RngCore;
|
|
||||||
let mut bytes = [0u8; 32];
|
|
||||||
rand::rng().fill_bytes(&mut bytes);
|
|
||||||
let hex_secret = hex::encode(bytes);
|
|
||||||
tracing::warn!(
|
|
||||||
"SESSION_SECRET environment variable not set. Using a randomly generated secret. \
|
|
||||||
This will invalidate all sessions on restart. Set SESSION_SECRET to a fixed hex or base64 encoded 32-byte value."
|
|
||||||
);
|
|
||||||
hex_secret
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
// Decode hex or base64
|
|
||||||
hex::decode(&secret_str)
|
|
||||||
.or_else(|_| URL_SAFE.decode(&secret_str))
|
|
||||||
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&secret_str))
|
|
||||||
.unwrap_or_else(|_| {
|
|
||||||
panic!("SESSION_SECRET must be hex or base64 encoded (32 bytes)");
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sign a session payload and return a token string in format base64_url(payload).base64_url(signature).
|
|
||||||
fn sign_token(payload: &SessionPayload, secret: &[u8]) -> String {
|
|
||||||
let json = serde_json::to_vec(payload).expect("Failed to serialize payload");
|
|
||||||
let payload_b64 = URL_SAFE.encode(&json);
|
|
||||||
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
|
|
||||||
mac.update(&json);
|
|
||||||
let signature = mac.finalize().into_bytes();
|
|
||||||
let signature_b64 = URL_SAFE.encode(signature);
|
|
||||||
format!("{}.{}", payload_b64, signature_b64)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Verify a signed token and return the decoded payload if valid.
|
|
||||||
fn verify_signed_token(token: &str, secret: &[u8]) -> Result<SessionPayload, TokenError> {
|
|
||||||
let parts: Vec<&str> = token.split('.').collect();
|
|
||||||
if parts.len() != 2 {
|
|
||||||
return Err(TokenError::InvalidFormat);
|
|
||||||
}
|
|
||||||
let payload_b64 = parts[0];
|
|
||||||
let signature_b64 = parts[1];
|
|
||||||
|
|
||||||
let json = URL_SAFE.decode(payload_b64).map_err(|_| TokenError::InvalidFormat)?;
|
|
||||||
let signature = URL_SAFE.decode(signature_b64).map_err(|_| TokenError::InvalidFormat)?;
|
|
||||||
|
|
||||||
// Verify HMAC
|
|
||||||
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
|
|
||||||
mac.update(&json);
|
|
||||||
// Convert signature slice to GenericArray
|
|
||||||
let tag = GenericArray::from_slice(&signature);
|
|
||||||
mac.verify(tag).map_err(|_| TokenError::InvalidSignature)?;
|
|
||||||
|
|
||||||
// Deserialize payload
|
|
||||||
let payload: SessionPayload = serde_json::from_slice(&json).map_err(|_| TokenError::InvalidPayload)?;
|
|
||||||
Ok(payload)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
enum TokenError {
|
|
||||||
InvalidFormat,
|
|
||||||
InvalidSignature,
|
|
||||||
InvalidPayload,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use std::env;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_sign_and_verify_token() {
|
|
||||||
let secret = b"test-secret-must-be-32-bytes-long!";
|
|
||||||
let payload = SessionPayload {
|
|
||||||
session_id: "test-session".to_string(),
|
|
||||||
username: "testuser".to_string(),
|
|
||||||
role: "user".to_string(),
|
|
||||||
iat: 1000,
|
|
||||||
exp: 2000,
|
|
||||||
version: TOKEN_VERSION.to_string(),
|
|
||||||
};
|
|
||||||
let token = sign_token(&payload, secret);
|
|
||||||
let verified = verify_signed_token(&token, secret).unwrap();
|
|
||||||
assert_eq!(verified.session_id, payload.session_id);
|
|
||||||
assert_eq!(verified.username, payload.username);
|
|
||||||
assert_eq!(verified.role, payload.role);
|
|
||||||
assert_eq!(verified.iat, payload.iat);
|
|
||||||
assert_eq!(verified.exp, payload.exp);
|
|
||||||
assert_eq!(verified.version, payload.version);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_tampered_token() {
|
|
||||||
let secret = b"test-secret-must-be-32-bytes-long!";
|
|
||||||
let payload = SessionPayload {
|
|
||||||
session_id: "test-session".to_string(),
|
|
||||||
username: "testuser".to_string(),
|
|
||||||
role: "user".to_string(),
|
|
||||||
iat: 1000,
|
|
||||||
exp: 2000,
|
|
||||||
version: TOKEN_VERSION.to_string(),
|
|
||||||
};
|
|
||||||
let mut token = sign_token(&payload, secret);
|
|
||||||
// Tamper with payload part
|
|
||||||
let mut parts: Vec<&str> = token.split('.').collect();
|
|
||||||
let mut payload_bytes = URL_SAFE.decode(parts[0]).unwrap();
|
|
||||||
payload_bytes[0] ^= 0xFF; // flip some bits
|
|
||||||
let tampered_payload = URL_SAFE.encode(payload_bytes);
|
|
||||||
parts[0] = &tampered_payload;
|
|
||||||
token = parts.join(".");
|
|
||||||
assert!(verify_signed_token(&token, secret).is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_load_session_secret_from_env() {
|
|
||||||
unsafe {
|
|
||||||
env::set_var("SESSION_SECRET", hex::encode([0xAA; 32]));
|
|
||||||
}
|
|
||||||
let secret = load_session_secret();
|
|
||||||
assert_eq!(secret, vec![0xAA; 32]);
|
|
||||||
unsafe {
|
|
||||||
env::remove_var("SESSION_SECRET");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,367 +0,0 @@
|
|||||||
use axum::{extract::State, response::Json};
|
|
||||||
use chrono;
|
|
||||||
use serde_json;
|
|
||||||
use sqlx::Row;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use tracing::warn;
|
|
||||||
|
|
||||||
use super::{ApiResponse, DashboardState};
|
|
||||||
|
|
||||||
/// Read a value from /proc files, returning None on any failure.
|
|
||||||
fn read_proc_file(path: &str) -> Option<String> {
|
|
||||||
std::fs::read_to_string(path).ok()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_system_health(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let mut components = HashMap::new();
|
|
||||||
components.insert("api_server".to_string(), "online".to_string());
|
|
||||||
components.insert("database".to_string(), "online".to_string());
|
|
||||||
|
|
||||||
// Check provider health via circuit breakers
|
|
||||||
let provider_ids: Vec<String> = state
|
|
||||||
.app_state
|
|
||||||
.provider_manager
|
|
||||||
.get_all_providers()
|
|
||||||
.await
|
|
||||||
.iter()
|
|
||||||
.map(|p| p.name().to_string())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
for p_id in provider_ids {
|
|
||||||
if state
|
|
||||||
.app_state
|
|
||||||
.rate_limit_manager
|
|
||||||
.check_provider_request(&p_id)
|
|
||||||
.await
|
|
||||||
.unwrap_or(true)
|
|
||||||
{
|
|
||||||
components.insert(p_id, "online".to_string());
|
|
||||||
} else {
|
|
||||||
components.insert(p_id, "degraded".to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read real memory usage from /proc/self/status
|
|
||||||
let memory_mb = read_proc_file("/proc/self/status")
|
|
||||||
.and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string()))
|
|
||||||
.and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::<f64>().ok()))
|
|
||||||
.map(|kb| kb / 1024.0)
|
|
||||||
.unwrap_or(0.0);
|
|
||||||
|
|
||||||
// Get real database pool stats
|
|
||||||
let db_pool_size = state.app_state.db_pool.size();
|
|
||||||
let db_pool_idle = state.app_state.db_pool.num_idle();
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!({
|
|
||||||
"status": "healthy",
|
|
||||||
"timestamp": chrono::Utc::now().to_rfc3339(),
|
|
||||||
"components": components,
|
|
||||||
"metrics": {
|
|
||||||
"memory_usage_mb": (memory_mb * 10.0).round() / 10.0,
|
|
||||||
"db_connections_active": db_pool_size - db_pool_idle as u32,
|
|
||||||
"db_connections_idle": db_pool_idle,
|
|
||||||
}
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Real system metrics from /proc (Linux only).
|
|
||||||
pub(super) async fn handle_system_metrics(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
// --- CPU usage (aggregate across all cores) ---
|
|
||||||
// /proc/stat first line: cpu user nice system idle iowait irq softirq steal guest guest_nice
|
|
||||||
let cpu_percent = read_proc_file("/proc/stat")
|
|
||||||
.and_then(|s| {
|
|
||||||
let line = s.lines().find(|l| l.starts_with("cpu "))?.to_string();
|
|
||||||
let fields: Vec<u64> = line
|
|
||||||
.split_whitespace()
|
|
||||||
.skip(1)
|
|
||||||
.filter_map(|v| v.parse().ok())
|
|
||||||
.collect();
|
|
||||||
if fields.len() >= 4 {
|
|
||||||
let idle = fields[3];
|
|
||||||
let total: u64 = fields.iter().sum();
|
|
||||||
if total > 0 {
|
|
||||||
Some(((total - idle) as f64 / total as f64 * 100.0 * 10.0).round() / 10.0)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.unwrap_or(0.0);
|
|
||||||
|
|
||||||
// --- Memory (system-wide from /proc/meminfo) ---
|
|
||||||
let meminfo = read_proc_file("/proc/meminfo").unwrap_or_default();
|
|
||||||
let parse_meminfo = |key: &str| -> u64 {
|
|
||||||
meminfo
|
|
||||||
.lines()
|
|
||||||
.find(|l| l.starts_with(key))
|
|
||||||
.and_then(|l| l.split_whitespace().nth(1))
|
|
||||||
.and_then(|v| v.parse::<u64>().ok())
|
|
||||||
.unwrap_or(0)
|
|
||||||
};
|
|
||||||
let mem_total_kb = parse_meminfo("MemTotal:");
|
|
||||||
let mem_available_kb = parse_meminfo("MemAvailable:");
|
|
||||||
let mem_used_kb = mem_total_kb.saturating_sub(mem_available_kb);
|
|
||||||
let mem_total_mb = mem_total_kb as f64 / 1024.0;
|
|
||||||
let mem_used_mb = mem_used_kb as f64 / 1024.0;
|
|
||||||
let mem_percent = if mem_total_kb > 0 {
|
|
||||||
(mem_used_kb as f64 / mem_total_kb as f64 * 100.0 * 10.0).round() / 10.0
|
|
||||||
} else {
|
|
||||||
0.0
|
|
||||||
};
|
|
||||||
|
|
||||||
// --- Process-specific memory (VmRSS) ---
|
|
||||||
let process_rss_mb = read_proc_file("/proc/self/status")
|
|
||||||
.and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string()))
|
|
||||||
.and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::<f64>().ok()))
|
|
||||||
.map(|kb| (kb / 1024.0 * 10.0).round() / 10.0)
|
|
||||||
.unwrap_or(0.0);
|
|
||||||
|
|
||||||
// --- Disk usage of the data directory ---
|
|
||||||
let (disk_total_gb, disk_used_gb, disk_percent) = {
|
|
||||||
// statvfs via libc would be ideal; use df as a simple fallback
|
|
||||||
std::process::Command::new("df")
|
|
||||||
.args(["-BM", "--output=size,used,pcent", "."])
|
|
||||||
.output()
|
|
||||||
.ok()
|
|
||||||
.and_then(|o| {
|
|
||||||
let out = String::from_utf8_lossy(&o.stdout);
|
|
||||||
let line = out.lines().nth(1)?.to_string();
|
|
||||||
let parts: Vec<&str> = line.split_whitespace().collect();
|
|
||||||
if parts.len() >= 3 {
|
|
||||||
let total = parts[0].trim_end_matches('M').parse::<f64>().unwrap_or(0.0) / 1024.0;
|
|
||||||
let used = parts[1].trim_end_matches('M').parse::<f64>().unwrap_or(0.0) / 1024.0;
|
|
||||||
let pct = parts[2].trim_end_matches('%').parse::<f64>().unwrap_or(0.0);
|
|
||||||
Some(((total * 10.0).round() / 10.0, (used * 10.0).round() / 10.0, pct))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.unwrap_or((0.0, 0.0, 0.0))
|
|
||||||
};
|
|
||||||
|
|
||||||
// --- Uptime ---
|
|
||||||
let uptime_seconds = read_proc_file("/proc/uptime")
|
|
||||||
.and_then(|s| s.split_whitespace().next().and_then(|v| v.parse::<f64>().ok()))
|
|
||||||
.unwrap_or(0.0) as u64;
|
|
||||||
|
|
||||||
// --- Load average ---
|
|
||||||
let (load_1, load_5, load_15) = read_proc_file("/proc/loadavg")
|
|
||||||
.and_then(|s| {
|
|
||||||
let parts: Vec<&str> = s.split_whitespace().collect();
|
|
||||||
if parts.len() >= 3 {
|
|
||||||
Some((
|
|
||||||
parts[0].parse::<f64>().unwrap_or(0.0),
|
|
||||||
parts[1].parse::<f64>().unwrap_or(0.0),
|
|
||||||
parts[2].parse::<f64>().unwrap_or(0.0),
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.unwrap_or((0.0, 0.0, 0.0));
|
|
||||||
|
|
||||||
// --- Network (from /proc/net/dev, aggregate non-lo interfaces) ---
|
|
||||||
let (net_rx_bytes, net_tx_bytes) = read_proc_file("/proc/net/dev")
|
|
||||||
.map(|s| {
|
|
||||||
s.lines()
|
|
||||||
.skip(2) // skip header lines
|
|
||||||
.filter(|l| !l.trim().starts_with("lo:"))
|
|
||||||
.fold((0u64, 0u64), |(rx, tx), line| {
|
|
||||||
let parts: Vec<&str> = line.split_whitespace().collect();
|
|
||||||
if parts.len() >= 10 {
|
|
||||||
let r = parts[1].parse::<u64>().unwrap_or(0);
|
|
||||||
let t = parts[9].parse::<u64>().unwrap_or(0);
|
|
||||||
(rx + r, tx + t)
|
|
||||||
} else {
|
|
||||||
(rx, tx)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.unwrap_or((0, 0));
|
|
||||||
|
|
||||||
// --- Database pool ---
|
|
||||||
let db_pool_size = state.app_state.db_pool.size();
|
|
||||||
let db_pool_idle = state.app_state.db_pool.num_idle();
|
|
||||||
|
|
||||||
// --- Active WebSocket listeners ---
|
|
||||||
let ws_listeners = state.app_state.dashboard_tx.receiver_count();
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!({
|
|
||||||
"cpu": {
|
|
||||||
"usage_percent": cpu_percent,
|
|
||||||
"load_average": [load_1, load_5, load_15],
|
|
||||||
},
|
|
||||||
"memory": {
|
|
||||||
"total_mb": (mem_total_mb * 10.0).round() / 10.0,
|
|
||||||
"used_mb": (mem_used_mb * 10.0).round() / 10.0,
|
|
||||||
"usage_percent": mem_percent,
|
|
||||||
"process_rss_mb": process_rss_mb,
|
|
||||||
},
|
|
||||||
"disk": {
|
|
||||||
"total_gb": disk_total_gb,
|
|
||||||
"used_gb": disk_used_gb,
|
|
||||||
"usage_percent": disk_percent,
|
|
||||||
},
|
|
||||||
"network": {
|
|
||||||
"rx_bytes": net_rx_bytes,
|
|
||||||
"tx_bytes": net_tx_bytes,
|
|
||||||
},
|
|
||||||
"uptime_seconds": uptime_seconds,
|
|
||||||
"connections": {
|
|
||||||
"db_active": db_pool_size - db_pool_idle as u32,
|
|
||||||
"db_idle": db_pool_idle,
|
|
||||||
"websocket_listeners": ws_listeners,
|
|
||||||
},
|
|
||||||
"timestamp": chrono::Utc::now().to_rfc3339(),
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_system_logs(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
let result = sqlx::query(
|
|
||||||
r#"
|
|
||||||
SELECT
|
|
||||||
id,
|
|
||||||
timestamp,
|
|
||||||
client_id,
|
|
||||||
provider,
|
|
||||||
model,
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
total_tokens,
|
|
||||||
cost,
|
|
||||||
status,
|
|
||||||
error_message,
|
|
||||||
duration_ms
|
|
||||||
FROM llm_requests
|
|
||||||
ORDER BY timestamp DESC
|
|
||||||
LIMIT 100
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.fetch_all(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(rows) => {
|
|
||||||
let logs: Vec<serde_json::Value> = rows
|
|
||||||
.into_iter()
|
|
||||||
.map(|row| {
|
|
||||||
serde_json::json!({
|
|
||||||
"id": row.get::<i64, _>("id"),
|
|
||||||
"timestamp": row.get::<chrono::DateTime<chrono::Utc>, _>("timestamp"),
|
|
||||||
"client_id": row.get::<String, _>("client_id"),
|
|
||||||
"provider": row.get::<String, _>("provider"),
|
|
||||||
"model": row.get::<String, _>("model"),
|
|
||||||
"tokens": row.get::<i64, _>("total_tokens"),
|
|
||||||
"cost": row.get::<f64, _>("cost"),
|
|
||||||
"status": row.get::<String, _>("status"),
|
|
||||||
"error": row.get::<Option<String>, _>("error_message"),
|
|
||||||
"duration": row.get::<i64, _>("duration_ms"),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!(logs)))
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to fetch system logs: {}", e);
|
|
||||||
Json(ApiResponse::error("Failed to fetch system logs".to_string()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_system_backup(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
headers: axum::http::HeaderMap,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
|
||||||
Ok((session, new_token)) => (session, new_token),
|
|
||||||
Err(e) => return e,
|
|
||||||
};
|
|
||||||
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
let backup_id = format!("backup-{}", chrono::Utc::now().timestamp());
|
|
||||||
let backup_path = format!("data/{}.db", backup_id);
|
|
||||||
|
|
||||||
// Ensure the data directory exists
|
|
||||||
if let Err(e) = std::fs::create_dir_all("data") {
|
|
||||||
return Json(ApiResponse::error(format!("Failed to create backup directory: {}", e)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use SQLite VACUUM INTO for a consistent backup
|
|
||||||
let result = sqlx::query(&format!("VACUUM INTO '{}'", backup_path))
|
|
||||||
.execute(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(_) => {
|
|
||||||
// Get backup file size
|
|
||||||
let size_bytes = std::fs::metadata(&backup_path).map(|m| m.len()).unwrap_or(0);
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!({
|
|
||||||
"success": true,
|
|
||||||
"message": "Backup completed successfully",
|
|
||||||
"backup_id": backup_id,
|
|
||||||
"backup_path": backup_path,
|
|
||||||
"size_bytes": size_bytes,
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Database backup failed: {}", e);
|
|
||||||
Json(ApiResponse::error(format!("Backup failed: {}", e)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_get_settings(State(state): State<DashboardState>) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let registry = &state.app_state.model_registry;
|
|
||||||
let provider_count = registry.providers.len();
|
|
||||||
let model_count: usize = registry.providers.values().map(|p| p.models.len()).sum();
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!({
|
|
||||||
"server": {
|
|
||||||
"auth_tokens": state.app_state.auth_tokens.iter().map(|t| mask_token(t)).collect::<Vec<_>>(),
|
|
||||||
"version": env!("CARGO_PKG_VERSION"),
|
|
||||||
},
|
|
||||||
"registry": {
|
|
||||||
"provider_count": provider_count,
|
|
||||||
"model_count": model_count,
|
|
||||||
},
|
|
||||||
"database": {
|
|
||||||
"type": "SQLite",
|
|
||||||
}
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_update_settings(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
headers: axum::http::HeaderMap,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
|
|
||||||
Ok((session, new_token)) => (session, new_token),
|
|
||||||
Err(e) => return e,
|
|
||||||
};
|
|
||||||
|
|
||||||
Json(ApiResponse::error(
|
|
||||||
"Changing settings at runtime is not yet supported. Please update your config file and restart the server."
|
|
||||||
.to_string(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper functions
|
|
||||||
fn mask_token(token: &str) -> String {
|
|
||||||
if token.len() <= 8 {
|
|
||||||
return "*****".to_string();
|
|
||||||
}
|
|
||||||
|
|
||||||
let masked_len = token.len().min(12);
|
|
||||||
let visible_len = 4;
|
|
||||||
let mask_len = masked_len - visible_len;
|
|
||||||
|
|
||||||
format!("{}{}", "*".repeat(mask_len), &token[token.len() - visible_len..])
|
|
||||||
}
|
|
||||||
@@ -1,482 +0,0 @@
|
|||||||
use axum::{
|
|
||||||
extract::{Query, State},
|
|
||||||
response::Json,
|
|
||||||
};
|
|
||||||
use chrono;
|
|
||||||
use serde::Deserialize;
|
|
||||||
use serde_json;
|
|
||||||
use sqlx::Row;
|
|
||||||
use tracing::warn;
|
|
||||||
|
|
||||||
use super::{ApiResponse, DashboardState};
|
|
||||||
|
|
||||||
/// Query parameters for time-based filtering on usage endpoints.
|
|
||||||
#[derive(Debug, Deserialize, Default)]
|
|
||||||
pub(super) struct UsagePeriodFilter {
|
|
||||||
/// Preset period: "today", "24h", "7d", "30d", "all" (default: "all")
|
|
||||||
pub period: Option<String>,
|
|
||||||
/// Custom range start (ISO 8601, e.g. "2025-06-01T00:00:00Z")
|
|
||||||
pub from: Option<String>,
|
|
||||||
/// Custom range end (ISO 8601)
|
|
||||||
pub to: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UsagePeriodFilter {
|
|
||||||
/// Returns `(sql_fragment, bind_values)` for a WHERE clause.
|
|
||||||
/// The fragment is either empty (no filter) or " AND timestamp >= ? [AND timestamp <= ?]".
|
|
||||||
fn to_sql(&self) -> (String, Vec<String>) {
|
|
||||||
let period = self.period.as_deref().unwrap_or("all");
|
|
||||||
|
|
||||||
if period == "custom" {
|
|
||||||
let mut clause = String::new();
|
|
||||||
let mut binds = Vec::new();
|
|
||||||
if let Some(ref from) = self.from {
|
|
||||||
clause.push_str(" AND timestamp >= ?");
|
|
||||||
binds.push(from.clone());
|
|
||||||
}
|
|
||||||
if let Some(ref to) = self.to {
|
|
||||||
clause.push_str(" AND timestamp <= ?");
|
|
||||||
binds.push(to.clone());
|
|
||||||
}
|
|
||||||
return (clause, binds);
|
|
||||||
}
|
|
||||||
|
|
||||||
let now = chrono::Utc::now();
|
|
||||||
let cutoff = match period {
|
|
||||||
"today" => {
|
|
||||||
// Start of today (UTC)
|
|
||||||
let today = now.format("%Y-%m-%dT00:00:00Z").to_string();
|
|
||||||
Some(today)
|
|
||||||
}
|
|
||||||
"24h" => Some((now - chrono::Duration::hours(24)).to_rfc3339()),
|
|
||||||
"7d" => Some((now - chrono::Duration::days(7)).to_rfc3339()),
|
|
||||||
"30d" => Some((now - chrono::Duration::days(30)).to_rfc3339()),
|
|
||||||
_ => None, // "all" or unrecognized
|
|
||||||
};
|
|
||||||
|
|
||||||
match cutoff {
|
|
||||||
Some(ts) => (" AND timestamp >= ?".to_string(), vec![ts]),
|
|
||||||
None => (String::new(), vec![]),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Determine the time-series granularity label for grouping.
|
|
||||||
fn granularity(&self) -> &'static str {
|
|
||||||
match self.period.as_deref().unwrap_or("all") {
|
|
||||||
"today" | "24h" => "hour",
|
|
||||||
_ => "day",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_usage_summary(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
Query(filter): Query<UsagePeriodFilter>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
let (period_clause, period_binds) = filter.to_sql();
|
|
||||||
|
|
||||||
// Total stats (filtered by period)
|
|
||||||
let period_sql = format!(
|
|
||||||
r#"
|
|
||||||
SELECT
|
|
||||||
COUNT(*) as total_requests,
|
|
||||||
COALESCE(SUM(total_tokens), 0) as total_tokens,
|
|
||||||
COALESCE(SUM(cost), 0.0) as total_cost,
|
|
||||||
COUNT(DISTINCT client_id) as active_clients,
|
|
||||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read,
|
|
||||||
COALESCE(SUM(cache_write_tokens), 0) as total_cache_write
|
|
||||||
FROM llm_requests
|
|
||||||
WHERE 1=1 {}
|
|
||||||
"#,
|
|
||||||
period_clause
|
|
||||||
);
|
|
||||||
let mut q = sqlx::query(&period_sql);
|
|
||||||
for b in &period_binds {
|
|
||||||
q = q.bind(b);
|
|
||||||
}
|
|
||||||
let total_stats = q.fetch_one(pool);
|
|
||||||
|
|
||||||
// Today's stats
|
|
||||||
let today = chrono::Utc::now().format("%Y-%m-%d").to_string();
|
|
||||||
let today_stats = sqlx::query(
|
|
||||||
r#"
|
|
||||||
SELECT
|
|
||||||
COUNT(*) as today_requests,
|
|
||||||
COALESCE(SUM(total_tokens), 0) as today_tokens,
|
|
||||||
COALESCE(SUM(cost), 0.0) as today_cost
|
|
||||||
FROM llm_requests
|
|
||||||
WHERE strftime('%Y-%m-%d', timestamp) = ?
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(today)
|
|
||||||
.fetch_one(pool);
|
|
||||||
|
|
||||||
// Error stats
|
|
||||||
let error_stats = sqlx::query(
|
|
||||||
r#"
|
|
||||||
SELECT
|
|
||||||
COUNT(*) as total,
|
|
||||||
SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as errors
|
|
||||||
FROM llm_requests
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.fetch_one(pool);
|
|
||||||
|
|
||||||
// Average response time
|
|
||||||
let avg_response = sqlx::query(
|
|
||||||
r#"
|
|
||||||
SELECT COALESCE(AVG(duration_ms), 0.0) as avg_duration
|
|
||||||
FROM llm_requests
|
|
||||||
WHERE status = 'success'
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.fetch_one(pool);
|
|
||||||
|
|
||||||
match tokio::join!(total_stats, today_stats, error_stats, avg_response) {
|
|
||||||
(Ok(t), Ok(d), Ok(e), Ok(a)) => {
|
|
||||||
let total_requests: i64 = t.get("total_requests");
|
|
||||||
let total_tokens: i64 = t.get("total_tokens");
|
|
||||||
let total_cost: f64 = t.get("total_cost");
|
|
||||||
let active_clients: i64 = t.get("active_clients");
|
|
||||||
let total_cache_read: i64 = t.get("total_cache_read");
|
|
||||||
let total_cache_write: i64 = t.get("total_cache_write");
|
|
||||||
|
|
||||||
let today_requests: i64 = d.get("today_requests");
|
|
||||||
let today_cost: f64 = d.get("today_cost");
|
|
||||||
|
|
||||||
let total_count: i64 = e.get("total");
|
|
||||||
let error_count: i64 = e.get("errors");
|
|
||||||
let error_rate = if total_count > 0 {
|
|
||||||
(error_count as f64 / total_count as f64) * 100.0
|
|
||||||
} else {
|
|
||||||
0.0
|
|
||||||
};
|
|
||||||
|
|
||||||
let avg_response_time: f64 = a.get("avg_duration");
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!({
|
|
||||||
"total_requests": total_requests,
|
|
||||||
"total_tokens": total_tokens,
|
|
||||||
"total_cost": total_cost,
|
|
||||||
"active_clients": active_clients,
|
|
||||||
"today_requests": today_requests,
|
|
||||||
"today_cost": today_cost,
|
|
||||||
"error_rate": error_rate,
|
|
||||||
"avg_response_time": avg_response_time,
|
|
||||||
"total_cache_read_tokens": total_cache_read,
|
|
||||||
"total_cache_write_tokens": total_cache_write,
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
_ => Json(ApiResponse::error("Failed to fetch usage statistics".to_string())),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_time_series(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
Query(filter): Query<UsagePeriodFilter>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
let (period_clause, period_binds) = filter.to_sql();
|
|
||||||
let granularity = filter.granularity();
|
|
||||||
|
|
||||||
// Determine the strftime format and default lookback
|
|
||||||
let (strftime_fmt, _label_key, default_lookback) = match granularity {
|
|
||||||
"hour" => ("%H:00", "hour", chrono::Duration::hours(24)),
|
|
||||||
_ => ("%Y-%m-%d", "day", chrono::Duration::days(30)),
|
|
||||||
};
|
|
||||||
|
|
||||||
// If no period filter, apply a sensible default lookback
|
|
||||||
let (clause, binds) = if period_clause.is_empty() {
|
|
||||||
let cutoff = (chrono::Utc::now() - default_lookback).to_rfc3339();
|
|
||||||
(" AND timestamp >= ?".to_string(), vec![cutoff])
|
|
||||||
} else {
|
|
||||||
(period_clause, period_binds)
|
|
||||||
};
|
|
||||||
|
|
||||||
let sql = format!(
|
|
||||||
r#"
|
|
||||||
SELECT
|
|
||||||
strftime('{strftime_fmt}', timestamp) as bucket,
|
|
||||||
COUNT(*) as requests,
|
|
||||||
COALESCE(SUM(total_tokens), 0) as tokens,
|
|
||||||
COALESCE(SUM(cost), 0.0) as cost
|
|
||||||
FROM llm_requests
|
|
||||||
WHERE 1=1 {clause}
|
|
||||||
GROUP BY bucket
|
|
||||||
ORDER BY bucket
|
|
||||||
"#,
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut q = sqlx::query(&sql);
|
|
||||||
for b in &binds {
|
|
||||||
q = q.bind(b);
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = q.fetch_all(pool).await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(rows) => {
|
|
||||||
let mut series = Vec::new();
|
|
||||||
|
|
||||||
for row in rows {
|
|
||||||
let bucket: String = row.get("bucket");
|
|
||||||
let requests: i64 = row.get("requests");
|
|
||||||
let tokens: i64 = row.get("tokens");
|
|
||||||
let cost: f64 = row.get("cost");
|
|
||||||
|
|
||||||
series.push(serde_json::json!({
|
|
||||||
"time": bucket,
|
|
||||||
"requests": requests,
|
|
||||||
"tokens": tokens,
|
|
||||||
"cost": cost,
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!({
|
|
||||||
"series": series,
|
|
||||||
"period": filter.period.as_deref().unwrap_or("all"),
|
|
||||||
"granularity": granularity,
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to fetch time series data: {}", e);
|
|
||||||
Json(ApiResponse::error("Failed to fetch time series data".to_string()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_clients_usage(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
Query(filter): Query<UsagePeriodFilter>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
let (period_clause, period_binds) = filter.to_sql();
|
|
||||||
|
|
||||||
let sql = format!(
|
|
||||||
r#"
|
|
||||||
SELECT
|
|
||||||
client_id,
|
|
||||||
COUNT(*) as requests,
|
|
||||||
COALESCE(SUM(total_tokens), 0) as tokens,
|
|
||||||
COALESCE(SUM(cost), 0.0) as cost,
|
|
||||||
MAX(timestamp) as last_request
|
|
||||||
FROM llm_requests
|
|
||||||
WHERE 1=1 {}
|
|
||||||
GROUP BY client_id
|
|
||||||
ORDER BY requests DESC
|
|
||||||
"#,
|
|
||||||
period_clause
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut q = sqlx::query(&sql);
|
|
||||||
for b in &period_binds {
|
|
||||||
q = q.bind(b);
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = q.fetch_all(pool).await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(rows) => {
|
|
||||||
let mut client_usage = Vec::new();
|
|
||||||
|
|
||||||
for row in rows {
|
|
||||||
let client_id: String = row.get("client_id");
|
|
||||||
let requests: i64 = row.get("requests");
|
|
||||||
let tokens: i64 = row.get("tokens");
|
|
||||||
let cost: f64 = row.get("cost");
|
|
||||||
let last_request: Option<chrono::DateTime<chrono::Utc>> = row.get("last_request");
|
|
||||||
|
|
||||||
client_usage.push(serde_json::json!({
|
|
||||||
"client_id": client_id,
|
|
||||||
"client_name": client_id,
|
|
||||||
"requests": requests,
|
|
||||||
"tokens": tokens,
|
|
||||||
"cost": cost,
|
|
||||||
"last_request": last_request,
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!(client_usage)))
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to fetch client usage data: {}", e);
|
|
||||||
Json(ApiResponse::error("Failed to fetch client usage data".to_string()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_providers_usage(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
Query(filter): Query<UsagePeriodFilter>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
let (period_clause, period_binds) = filter.to_sql();
|
|
||||||
|
|
||||||
let sql = format!(
|
|
||||||
r#"
|
|
||||||
SELECT
|
|
||||||
provider,
|
|
||||||
COUNT(*) as requests,
|
|
||||||
COALESCE(SUM(total_tokens), 0) as tokens,
|
|
||||||
COALESCE(SUM(cost), 0.0) as cost,
|
|
||||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read,
|
|
||||||
COALESCE(SUM(cache_write_tokens), 0) as cache_write
|
|
||||||
FROM llm_requests
|
|
||||||
WHERE 1=1 {}
|
|
||||||
GROUP BY provider
|
|
||||||
ORDER BY requests DESC
|
|
||||||
"#,
|
|
||||||
period_clause
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut q = sqlx::query(&sql);
|
|
||||||
for b in &period_binds {
|
|
||||||
q = q.bind(b);
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = q.fetch_all(pool).await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(rows) => {
|
|
||||||
let mut provider_usage = Vec::new();
|
|
||||||
|
|
||||||
for row in rows {
|
|
||||||
let provider: String = row.get("provider");
|
|
||||||
let requests: i64 = row.get("requests");
|
|
||||||
let tokens: i64 = row.get("tokens");
|
|
||||||
let cost: f64 = row.get("cost");
|
|
||||||
let cache_read: i64 = row.get("cache_read");
|
|
||||||
let cache_write: i64 = row.get("cache_write");
|
|
||||||
|
|
||||||
provider_usage.push(serde_json::json!({
|
|
||||||
"provider": provider,
|
|
||||||
"requests": requests,
|
|
||||||
"tokens": tokens,
|
|
||||||
"cost": cost,
|
|
||||||
"cache_read_tokens": cache_read,
|
|
||||||
"cache_write_tokens": cache_write,
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!(provider_usage)))
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to fetch provider usage data: {}", e);
|
|
||||||
Json(ApiResponse::error("Failed to fetch provider usage data".to_string()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_detailed_usage(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
Query(filter): Query<UsagePeriodFilter>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
let (period_clause, period_binds) = filter.to_sql();
|
|
||||||
|
|
||||||
let sql = format!(
|
|
||||||
r#"
|
|
||||||
SELECT
|
|
||||||
strftime('%Y-%m-%d', timestamp) as date,
|
|
||||||
client_id,
|
|
||||||
provider,
|
|
||||||
model,
|
|
||||||
COUNT(*) as requests,
|
|
||||||
COALESCE(SUM(total_tokens), 0) as tokens,
|
|
||||||
COALESCE(SUM(cost), 0.0) as cost,
|
|
||||||
COALESCE(SUM(cache_read_tokens), 0) as cache_read,
|
|
||||||
COALESCE(SUM(cache_write_tokens), 0) as cache_write
|
|
||||||
FROM llm_requests
|
|
||||||
WHERE 1=1 {}
|
|
||||||
GROUP BY date, client_id, provider, model
|
|
||||||
ORDER BY date DESC
|
|
||||||
LIMIT 200
|
|
||||||
"#,
|
|
||||||
period_clause
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut q = sqlx::query(&sql);
|
|
||||||
for b in &period_binds {
|
|
||||||
q = q.bind(b);
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = q.fetch_all(pool).await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(rows) => {
|
|
||||||
let usage: Vec<serde_json::Value> = rows
|
|
||||||
.into_iter()
|
|
||||||
.map(|row| {
|
|
||||||
serde_json::json!({
|
|
||||||
"date": row.get::<String, _>("date"),
|
|
||||||
"client": row.get::<String, _>("client_id"),
|
|
||||||
"provider": row.get::<String, _>("provider"),
|
|
||||||
"model": row.get::<String, _>("model"),
|
|
||||||
"requests": row.get::<i64, _>("requests"),
|
|
||||||
"tokens": row.get::<i64, _>("tokens"),
|
|
||||||
"cost": row.get::<f64, _>("cost"),
|
|
||||||
"cache_read_tokens": row.get::<i64, _>("cache_read"),
|
|
||||||
"cache_write_tokens": row.get::<i64, _>("cache_write"),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!(usage)))
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to fetch detailed usage: {}", e);
|
|
||||||
Json(ApiResponse::error("Failed to fetch detailed usage".to_string()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_analytics_breakdown(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
Query(filter): Query<UsagePeriodFilter>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
let (period_clause, period_binds) = filter.to_sql();
|
|
||||||
|
|
||||||
// Model breakdown
|
|
||||||
let model_sql = format!(
|
|
||||||
"SELECT model as label, COUNT(*) as value FROM llm_requests WHERE 1=1 {} GROUP BY model ORDER BY value DESC",
|
|
||||||
period_clause
|
|
||||||
);
|
|
||||||
let mut mq = sqlx::query(&model_sql);
|
|
||||||
for b in &period_binds {
|
|
||||||
mq = mq.bind(b);
|
|
||||||
}
|
|
||||||
let models = mq.fetch_all(pool);
|
|
||||||
|
|
||||||
// Client breakdown
|
|
||||||
let client_sql = format!(
|
|
||||||
"SELECT client_id as label, COUNT(*) as value FROM llm_requests WHERE 1=1 {} GROUP BY client_id ORDER BY value DESC",
|
|
||||||
period_clause
|
|
||||||
);
|
|
||||||
let mut cq = sqlx::query(&client_sql);
|
|
||||||
for b in &period_binds {
|
|
||||||
cq = cq.bind(b);
|
|
||||||
}
|
|
||||||
let clients = cq.fetch_all(pool);
|
|
||||||
|
|
||||||
match tokio::join!(models, clients) {
|
|
||||||
(Ok(m_rows), Ok(c_rows)) => {
|
|
||||||
let model_breakdown: Vec<serde_json::Value> = m_rows
|
|
||||||
.into_iter()
|
|
||||||
.map(|r| serde_json::json!({ "label": r.get::<String, _>("label"), "value": r.get::<i64, _>("value") }))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let client_breakdown: Vec<serde_json::Value> = c_rows
|
|
||||||
.into_iter()
|
|
||||||
.map(|r| serde_json::json!({ "label": r.get::<String, _>("label"), "value": r.get::<i64, _>("value") }))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!({
|
|
||||||
"models": model_breakdown,
|
|
||||||
"clients": client_breakdown
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
_ => Json(ApiResponse::error("Failed to fetch analytics breakdown".to_string())),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,290 +0,0 @@
|
|||||||
use axum::{
|
|
||||||
extract::{Path, State},
|
|
||||||
response::Json,
|
|
||||||
};
|
|
||||||
use serde::Deserialize;
|
|
||||||
use sqlx::Row;
|
|
||||||
use tracing::warn;
|
|
||||||
|
|
||||||
use super::{ApiResponse, DashboardState, auth};
|
|
||||||
|
|
||||||
// ── User management endpoints (admin-only) ──────────────────────────
|
|
||||||
|
|
||||||
pub(super) async fn handle_get_users(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
headers: axum::http::HeaderMap,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let (_session, _) = match auth::require_admin(&state, &headers).await {
|
|
||||||
Ok((session, new_token)) => (session, new_token),
|
|
||||||
Err(e) => return e,
|
|
||||||
};
|
|
||||||
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
let result = sqlx::query(
|
|
||||||
"SELECT id, username, display_name, role, must_change_password, created_at FROM users ORDER BY created_at ASC",
|
|
||||||
)
|
|
||||||
.fetch_all(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(rows) => {
|
|
||||||
let users: Vec<serde_json::Value> = rows
|
|
||||||
.into_iter()
|
|
||||||
.map(|row| {
|
|
||||||
let username: String = row.get("username");
|
|
||||||
let display_name: Option<String> = row.get("display_name");
|
|
||||||
serde_json::json!({
|
|
||||||
"id": row.get::<i64, _>("id"),
|
|
||||||
"username": &username,
|
|
||||||
"display_name": display_name.as_deref().unwrap_or(&username),
|
|
||||||
"role": row.get::<String, _>("role"),
|
|
||||||
"must_change_password": row.get::<bool, _>("must_change_password"),
|
|
||||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
Json(ApiResponse::success(serde_json::json!(users)))
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to fetch users: {}", e);
|
|
||||||
Json(ApiResponse::error("Failed to fetch users".to_string()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub(super) struct CreateUserRequest {
|
|
||||||
pub(super) username: String,
|
|
||||||
pub(super) password: String,
|
|
||||||
pub(super) display_name: Option<String>,
|
|
||||||
pub(super) role: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_create_user(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
headers: axum::http::HeaderMap,
|
|
||||||
Json(payload): Json<CreateUserRequest>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let (_session, _) = match auth::require_admin(&state, &headers).await {
|
|
||||||
Ok((session, new_token)) => (session, new_token),
|
|
||||||
Err(e) => return e,
|
|
||||||
};
|
|
||||||
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
// Validate role
|
|
||||||
let role = payload.role.as_deref().unwrap_or("viewer");
|
|
||||||
if role != "admin" && role != "viewer" {
|
|
||||||
return Json(ApiResponse::error("Role must be 'admin' or 'viewer'".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate username
|
|
||||||
let username = payload.username.trim();
|
|
||||||
if username.is_empty() || username.len() > 64 {
|
|
||||||
return Json(ApiResponse::error("Username must be 1-64 characters".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate password
|
|
||||||
if payload.password.len() < 4 {
|
|
||||||
return Json(ApiResponse::error("Password must be at least 4 characters".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let password_hash = match bcrypt::hash(&payload.password, 12) {
|
|
||||||
Ok(h) => h,
|
|
||||||
Err(_) => return Json(ApiResponse::error("Failed to hash password".to_string())),
|
|
||||||
};
|
|
||||||
|
|
||||||
let result = sqlx::query(
|
|
||||||
r#"
|
|
||||||
INSERT INTO users (username, password_hash, display_name, role, must_change_password)
|
|
||||||
VALUES (?, ?, ?, ?, TRUE)
|
|
||||||
RETURNING id, username, display_name, role, must_change_password, created_at
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(username)
|
|
||||||
.bind(&password_hash)
|
|
||||||
.bind(&payload.display_name)
|
|
||||||
.bind(role)
|
|
||||||
.fetch_one(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(row) => {
|
|
||||||
let uname: String = row.get("username");
|
|
||||||
let display_name: Option<String> = row.get("display_name");
|
|
||||||
Json(ApiResponse::success(serde_json::json!({
|
|
||||||
"id": row.get::<i64, _>("id"),
|
|
||||||
"username": &uname,
|
|
||||||
"display_name": display_name.as_deref().unwrap_or(&uname),
|
|
||||||
"role": row.get::<String, _>("role"),
|
|
||||||
"must_change_password": row.get::<bool, _>("must_change_password"),
|
|
||||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
let msg = if e.to_string().contains("UNIQUE") {
|
|
||||||
format!("Username '{}' already exists", username)
|
|
||||||
} else {
|
|
||||||
format!("Failed to create user: {}", e)
|
|
||||||
};
|
|
||||||
warn!("Failed to create user: {}", e);
|
|
||||||
Json(ApiResponse::error(msg))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
pub(super) struct UpdateUserRequest {
|
|
||||||
pub(super) display_name: Option<String>,
|
|
||||||
pub(super) role: Option<String>,
|
|
||||||
pub(super) password: Option<String>,
|
|
||||||
pub(super) must_change_password: Option<bool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_update_user(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
headers: axum::http::HeaderMap,
|
|
||||||
Path(id): Path<i64>,
|
|
||||||
Json(payload): Json<UpdateUserRequest>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let (_session, _) = match auth::require_admin(&state, &headers).await {
|
|
||||||
Ok((session, new_token)) => (session, new_token),
|
|
||||||
Err(e) => return e,
|
|
||||||
};
|
|
||||||
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
// Validate role if provided
|
|
||||||
if let Some(ref role) = payload.role {
|
|
||||||
if role != "admin" && role != "viewer" {
|
|
||||||
return Json(ApiResponse::error("Role must be 'admin' or 'viewer'".to_string()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build dynamic update
|
|
||||||
let mut sets = Vec::new();
|
|
||||||
let mut string_binds: Vec<String> = Vec::new();
|
|
||||||
let mut has_password = false;
|
|
||||||
let mut has_must_change = false;
|
|
||||||
|
|
||||||
if let Some(ref display_name) = payload.display_name {
|
|
||||||
sets.push("display_name = ?");
|
|
||||||
string_binds.push(display_name.clone());
|
|
||||||
}
|
|
||||||
if let Some(ref role) = payload.role {
|
|
||||||
sets.push("role = ?");
|
|
||||||
string_binds.push(role.clone());
|
|
||||||
}
|
|
||||||
if let Some(ref password) = payload.password {
|
|
||||||
if password.len() < 4 {
|
|
||||||
return Json(ApiResponse::error("Password must be at least 4 characters".to_string()));
|
|
||||||
}
|
|
||||||
let hash = match bcrypt::hash(password, 12) {
|
|
||||||
Ok(h) => h,
|
|
||||||
Err(_) => return Json(ApiResponse::error("Failed to hash password".to_string())),
|
|
||||||
};
|
|
||||||
sets.push("password_hash = ?");
|
|
||||||
string_binds.push(hash);
|
|
||||||
has_password = true;
|
|
||||||
}
|
|
||||||
if let Some(mcp) = payload.must_change_password {
|
|
||||||
sets.push("must_change_password = ?");
|
|
||||||
has_must_change = true;
|
|
||||||
let _ = mcp; // used below in bind
|
|
||||||
}
|
|
||||||
|
|
||||||
if sets.is_empty() {
|
|
||||||
return Json(ApiResponse::error("No fields to update".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let sql = format!("UPDATE users SET {} WHERE id = ?", sets.join(", "));
|
|
||||||
let mut query = sqlx::query(&sql);
|
|
||||||
|
|
||||||
for b in &string_binds {
|
|
||||||
query = query.bind(b);
|
|
||||||
}
|
|
||||||
if has_must_change {
|
|
||||||
query = query.bind(payload.must_change_password.unwrap());
|
|
||||||
}
|
|
||||||
let _ = has_password; // consumed above via string_binds
|
|
||||||
query = query.bind(id);
|
|
||||||
|
|
||||||
match query.execute(pool).await {
|
|
||||||
Ok(result) => {
|
|
||||||
if result.rows_affected() == 0 {
|
|
||||||
return Json(ApiResponse::error("User not found".to_string()));
|
|
||||||
}
|
|
||||||
// Fetch updated user
|
|
||||||
let row = sqlx::query(
|
|
||||||
"SELECT id, username, display_name, role, must_change_password, created_at FROM users WHERE id = ?",
|
|
||||||
)
|
|
||||||
.bind(id)
|
|
||||||
.fetch_one(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match row {
|
|
||||||
Ok(row) => {
|
|
||||||
let uname: String = row.get("username");
|
|
||||||
let display_name: Option<String> = row.get("display_name");
|
|
||||||
Json(ApiResponse::success(serde_json::json!({
|
|
||||||
"id": row.get::<i64, _>("id"),
|
|
||||||
"username": &uname,
|
|
||||||
"display_name": display_name.as_deref().unwrap_or(&uname),
|
|
||||||
"role": row.get::<String, _>("role"),
|
|
||||||
"must_change_password": row.get::<bool, _>("must_change_password"),
|
|
||||||
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
Err(_) => Json(ApiResponse::success(serde_json::json!({ "message": "User updated" }))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to update user: {}", e);
|
|
||||||
Json(ApiResponse::error(format!("Failed to update user: {}", e)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_delete_user(
|
|
||||||
State(state): State<DashboardState>,
|
|
||||||
headers: axum::http::HeaderMap,
|
|
||||||
Path(id): Path<i64>,
|
|
||||||
) -> Json<ApiResponse<serde_json::Value>> {
|
|
||||||
let (session, _) = match auth::require_admin(&state, &headers).await {
|
|
||||||
Ok((session, new_token)) => (session, new_token),
|
|
||||||
Err(e) => return e,
|
|
||||||
};
|
|
||||||
|
|
||||||
let pool = &state.app_state.db_pool;
|
|
||||||
|
|
||||||
// Don't allow deleting yourself
|
|
||||||
let target_username: Option<String> =
|
|
||||||
sqlx::query_scalar::<_, String>("SELECT username FROM users WHERE id = ?")
|
|
||||||
.bind(id)
|
|
||||||
.fetch_optional(pool)
|
|
||||||
.await
|
|
||||||
.unwrap_or(None);
|
|
||||||
|
|
||||||
match target_username {
|
|
||||||
None => return Json(ApiResponse::error("User not found".to_string())),
|
|
||||||
Some(ref uname) if uname == &session.username => {
|
|
||||||
return Json(ApiResponse::error("Cannot delete your own account".to_string()));
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = sqlx::query("DELETE FROM users WHERE id = ?")
|
|
||||||
.bind(id)
|
|
||||||
.execute(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "User deleted" }))),
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to delete user: {}", e);
|
|
||||||
Json(ApiResponse::error(format!("Failed to delete user: {}", e)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
use axum::{
|
|
||||||
extract::{
|
|
||||||
State,
|
|
||||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
|
||||||
},
|
|
||||||
response::IntoResponse,
|
|
||||||
};
|
|
||||||
use serde_json;
|
|
||||||
use tracing::info;
|
|
||||||
|
|
||||||
use super::DashboardState;
|
|
||||||
|
|
||||||
// WebSocket handler
|
|
||||||
pub(super) async fn handle_websocket(ws: WebSocketUpgrade, State(state): State<DashboardState>) -> impl IntoResponse {
|
|
||||||
ws.on_upgrade(|socket| handle_websocket_connection(socket, state))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_websocket_connection(mut socket: WebSocket, state: DashboardState) {
|
|
||||||
info!("WebSocket connection established");
|
|
||||||
|
|
||||||
// Subscribe to events from the global bus
|
|
||||||
let mut rx = state.app_state.dashboard_tx.subscribe();
|
|
||||||
|
|
||||||
// Send initial connection message
|
|
||||||
let _ = socket
|
|
||||||
.send(Message::Text(
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "connected",
|
|
||||||
"message": "Connected to LLM Proxy Dashboard"
|
|
||||||
})
|
|
||||||
.to_string()
|
|
||||||
.into(),
|
|
||||||
))
|
|
||||||
.await;
|
|
||||||
|
|
||||||
// Handle incoming messages and broadcast events
|
|
||||||
loop {
|
|
||||||
tokio::select! {
|
|
||||||
// Receive broadcast events
|
|
||||||
Ok(event) = rx.recv() => {
|
|
||||||
let Ok(json_str) = serde_json::to_string(&event) else {
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
let message = Message::Text(json_str.into());
|
|
||||||
if socket.send(message).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Receive WebSocket messages
|
|
||||||
result = socket.recv() => {
|
|
||||||
match result {
|
|
||||||
Some(Ok(Message::Text(text))) => {
|
|
||||||
handle_websocket_message(&text, &state).await;
|
|
||||||
}
|
|
||||||
_ => break,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
info!("WebSocket connection closed");
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) async fn handle_websocket_message(text: &str, state: &DashboardState) {
|
|
||||||
// Parse and handle WebSocket messages
|
|
||||||
if let Ok(data) = serde_json::from_str::<serde_json::Value>(text)
|
|
||||||
&& data.get("type").and_then(|v| v.as_str()) == Some("ping")
|
|
||||||
{
|
|
||||||
let _ = state.app_state.dashboard_tx.send(serde_json::json!({
|
|
||||||
"type": "pong",
|
|
||||||
"payload": {}
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,261 +0,0 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool};
|
|
||||||
use std::str::FromStr;
|
|
||||||
use tracing::info;
|
|
||||||
|
|
||||||
use crate::config::DatabaseConfig;
|
|
||||||
|
|
||||||
pub type DbPool = SqlitePool;
|
|
||||||
|
|
||||||
pub async fn init(config: &DatabaseConfig) -> Result<DbPool> {
|
|
||||||
// Ensure the database directory exists
|
|
||||||
if let Some(parent) = config.path.parent()
|
|
||||||
&& !parent.as_os_str().is_empty()
|
|
||||||
{
|
|
||||||
tokio::fs::create_dir_all(parent).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let database_path = config.path.to_string_lossy().to_string();
|
|
||||||
info!("Connecting to database at {}", database_path);
|
|
||||||
|
|
||||||
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?
|
|
||||||
.create_if_missing(true)
|
|
||||||
.pragma("foreign_keys", "ON");
|
|
||||||
|
|
||||||
let pool = SqlitePool::connect_with(options).await?;
|
|
||||||
|
|
||||||
// Run migrations
|
|
||||||
run_migrations(&pool).await?;
|
|
||||||
info!("Database migrations completed");
|
|
||||||
|
|
||||||
Ok(pool)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn run_migrations(pool: &DbPool) -> Result<()> {
|
|
||||||
// Create clients table if it doesn't exist
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
CREATE TABLE IF NOT EXISTS clients (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
client_id TEXT UNIQUE NOT NULL,
|
|
||||||
name TEXT,
|
|
||||||
description TEXT,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
is_active BOOLEAN DEFAULT TRUE,
|
|
||||||
rate_limit_per_minute INTEGER DEFAULT 60,
|
|
||||||
total_requests INTEGER DEFAULT 0,
|
|
||||||
total_tokens INTEGER DEFAULT 0,
|
|
||||||
total_cost REAL DEFAULT 0.0
|
|
||||||
)
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Create llm_requests table if it doesn't exist
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
CREATE TABLE IF NOT EXISTS llm_requests (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
client_id TEXT,
|
|
||||||
provider TEXT,
|
|
||||||
model TEXT,
|
|
||||||
prompt_tokens INTEGER,
|
|
||||||
completion_tokens INTEGER,
|
|
||||||
total_tokens INTEGER,
|
|
||||||
cost REAL,
|
|
||||||
has_images BOOLEAN DEFAULT FALSE,
|
|
||||||
status TEXT DEFAULT 'success',
|
|
||||||
error_message TEXT,
|
|
||||||
duration_ms INTEGER,
|
|
||||||
request_body TEXT,
|
|
||||||
response_body TEXT,
|
|
||||||
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE SET NULL
|
|
||||||
)
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Create provider_configs table
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
CREATE TABLE IF NOT EXISTS provider_configs (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
display_name TEXT NOT NULL,
|
|
||||||
enabled BOOLEAN DEFAULT TRUE,
|
|
||||||
base_url TEXT,
|
|
||||||
api_key TEXT,
|
|
||||||
credit_balance REAL DEFAULT 0.0,
|
|
||||||
low_credit_threshold REAL DEFAULT 5.0,
|
|
||||||
billing_mode TEXT,
|
|
||||||
api_key_encrypted BOOLEAN DEFAULT FALSE,
|
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Create model_configs table
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
CREATE TABLE IF NOT EXISTS model_configs (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
provider_id TEXT NOT NULL,
|
|
||||||
display_name TEXT,
|
|
||||||
enabled BOOLEAN DEFAULT TRUE,
|
|
||||||
prompt_cost_per_m REAL,
|
|
||||||
completion_cost_per_m REAL,
|
|
||||||
mapping TEXT,
|
|
||||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE
|
|
||||||
)
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Create users table for dashboard access
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
CREATE TABLE IF NOT EXISTS users (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
username TEXT UNIQUE NOT NULL,
|
|
||||||
password_hash TEXT NOT NULL,
|
|
||||||
display_name TEXT,
|
|
||||||
role TEXT DEFAULT 'admin',
|
|
||||||
must_change_password BOOLEAN DEFAULT FALSE,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
||||||
)
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Create client_tokens table for DB-based token auth
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
CREATE TABLE IF NOT EXISTS client_tokens (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
client_id TEXT NOT NULL,
|
|
||||||
token TEXT NOT NULL UNIQUE,
|
|
||||||
name TEXT DEFAULT 'default',
|
|
||||||
is_active BOOLEAN DEFAULT TRUE,
|
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
last_used_at DATETIME,
|
|
||||||
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE CASCADE
|
|
||||||
)
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Add must_change_password column if it doesn't exist (migration for existing DBs)
|
|
||||||
let _ = sqlx::query("ALTER TABLE users ADD COLUMN must_change_password BOOLEAN DEFAULT FALSE")
|
|
||||||
.execute(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
// Add display_name column if it doesn't exist (migration for existing DBs)
|
|
||||||
let _ = sqlx::query("ALTER TABLE users ADD COLUMN display_name TEXT")
|
|
||||||
.execute(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
// Add cache token columns if they don't exist (migration for existing DBs)
|
|
||||||
let _ = sqlx::query("ALTER TABLE llm_requests ADD COLUMN cache_read_tokens INTEGER DEFAULT 0")
|
|
||||||
.execute(pool)
|
|
||||||
.await;
|
|
||||||
let _ = sqlx::query("ALTER TABLE llm_requests ADD COLUMN cache_write_tokens INTEGER DEFAULT 0")
|
|
||||||
.execute(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
// Add billing_mode column if it doesn't exist (migration for existing DBs)
|
|
||||||
let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN billing_mode TEXT")
|
|
||||||
.execute(pool)
|
|
||||||
.await;
|
|
||||||
// Add api_key_encrypted column if it doesn't exist (migration for existing DBs)
|
|
||||||
let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN api_key_encrypted BOOLEAN DEFAULT FALSE")
|
|
||||||
.execute(pool)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
// Insert default admin user if none exists (default password: admin)
|
|
||||||
let user_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users").fetch_one(pool).await?;
|
|
||||||
|
|
||||||
if user_count.0 == 0 {
|
|
||||||
// 'admin' hashed with default cost (12)
|
|
||||||
let default_admin_hash =
|
|
||||||
bcrypt::hash("admin", 12).map_err(|e| anyhow::anyhow!("Failed to hash default password: {}", e))?;
|
|
||||||
sqlx::query(
|
|
||||||
"INSERT INTO users (username, password_hash, role, must_change_password) VALUES ('admin', ?, 'admin', TRUE)"
|
|
||||||
)
|
|
||||||
.bind(default_admin_hash)
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
info!("Created default admin user with password 'admin' (must change on first login)");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create indices
|
|
||||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)")
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)")
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)")
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)")
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)")
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)")
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
sqlx::query("CREATE UNIQUE INDEX IF NOT EXISTS idx_client_tokens_token ON client_tokens(token)")
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_client_tokens_client_id ON client_tokens(client_id)")
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Composite indexes for performance
|
|
||||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp)")
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp)")
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
sqlx::query("CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id)")
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Insert default client if none exists
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
INSERT OR IGNORE INTO clients (client_id, name, description)
|
|
||||||
VALUES ('default', 'Default Client', 'Default client for anonymous requests')
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.execute(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn test_connection(pool: &DbPool) -> Result<()> {
|
|
||||||
sqlx::query("SELECT 1").execute(pool).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
use thiserror::Error;
|
|
||||||
|
|
||||||
#[derive(Error, Debug, Clone)]
|
|
||||||
pub enum AppError {
|
|
||||||
#[error("Authentication failed: {0}")]
|
|
||||||
AuthError(String),
|
|
||||||
|
|
||||||
#[error("Configuration error: {0}")]
|
|
||||||
ConfigError(String),
|
|
||||||
|
|
||||||
#[error("Database error: {0}")]
|
|
||||||
DatabaseError(String),
|
|
||||||
|
|
||||||
#[error("Provider error: {0}")]
|
|
||||||
ProviderError(String),
|
|
||||||
|
|
||||||
#[error("Validation error: {0}")]
|
|
||||||
ValidationError(String),
|
|
||||||
|
|
||||||
#[error("Multimodal processing error: {0}")]
|
|
||||||
MultimodalError(String),
|
|
||||||
|
|
||||||
#[error("Rate limit exceeded: {0}")]
|
|
||||||
RateLimitError(String),
|
|
||||||
|
|
||||||
#[error("Internal server error: {0}")]
|
|
||||||
InternalError(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<sqlx::Error> for AppError {
|
|
||||||
fn from(err: sqlx::Error) -> Self {
|
|
||||||
AppError::DatabaseError(err.to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<anyhow::Error> for AppError {
|
|
||||||
fn from(err: anyhow::Error) -> Self {
|
|
||||||
AppError::InternalError(err.to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl axum::response::IntoResponse for AppError {
|
|
||||||
fn into_response(self) -> axum::response::Response {
|
|
||||||
let status = match self {
|
|
||||||
AppError::AuthError(_) => axum::http::StatusCode::UNAUTHORIZED,
|
|
||||||
AppError::RateLimitError(_) => axum::http::StatusCode::TOO_MANY_REQUESTS,
|
|
||||||
AppError::ValidationError(_) => axum::http::StatusCode::BAD_REQUEST,
|
|
||||||
_ => axum::http::StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
};
|
|
||||||
|
|
||||||
let body = axum::Json(serde_json::json!({
|
|
||||||
"error": self.to_string(),
|
|
||||||
"type": format!("{:?}", self)
|
|
||||||
}));
|
|
||||||
|
|
||||||
(status, body).into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
328
src/lib.rs
328
src/lib.rs
@@ -1,328 +0,0 @@
|
|||||||
//! LLM Proxy Library
|
|
||||||
//!
|
|
||||||
//! This library provides the core functionality for the LLM proxy gateway,
|
|
||||||
//! including provider integration, token tracking, and API endpoints.
|
|
||||||
|
|
||||||
pub mod auth;
|
|
||||||
pub mod client;
|
|
||||||
pub mod config;
|
|
||||||
pub mod dashboard;
|
|
||||||
pub mod database;
|
|
||||||
pub mod errors;
|
|
||||||
pub mod logging;
|
|
||||||
pub mod models;
|
|
||||||
pub mod multimodal;
|
|
||||||
pub mod providers;
|
|
||||||
pub mod rate_limiting;
|
|
||||||
pub mod server;
|
|
||||||
pub mod state;
|
|
||||||
pub mod utils;
|
|
||||||
|
|
||||||
// Re-exports for convenience
|
|
||||||
pub use auth::{AuthenticatedClient, validate_token};
|
|
||||||
pub use config::{
|
|
||||||
AppConfig, DatabaseConfig, DeepSeekConfig, GeminiConfig, GrokConfig, ModelMappingConfig, ModelPricing,
|
|
||||||
OllamaConfig, OpenAIConfig, PricingConfig, ProviderConfig, ServerConfig,
|
|
||||||
};
|
|
||||||
pub use database::{DbPool, init as init_db, test_connection};
|
|
||||||
pub use errors::AppError;
|
|
||||||
pub use logging::{LoggingContext, RequestLog, RequestLogger};
|
|
||||||
pub use models::{
|
|
||||||
ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage,
|
|
||||||
ChatStreamChoice, ChatStreamDelta, ContentPart, ContentPartValue, FromOpenAI, ImageUrl, MessageContent,
|
|
||||||
OpenAIContentPart, OpenAIMessage, OpenAIRequest, ToOpenAI, UnifiedMessage, UnifiedRequest, Usage,
|
|
||||||
};
|
|
||||||
pub use providers::{Provider, ProviderManager, ProviderResponse, ProviderStreamChunk};
|
|
||||||
pub use server::router;
|
|
||||||
pub use state::AppState;
|
|
||||||
|
|
||||||
/// Test utilities for integration testing
|
|
||||||
#[cfg(test)]
|
|
||||||
pub mod test_utils {
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState, utils::crypto, database::run_migrations};
|
|
||||||
use sqlx::sqlite::SqlitePool;
|
|
||||||
|
|
||||||
/// Create a test application state
|
|
||||||
pub async fn create_test_state() -> AppState {
|
|
||||||
// Create in-memory database
|
|
||||||
let pool = SqlitePool::connect("sqlite::memory:")
|
|
||||||
.await
|
|
||||||
.expect("Failed to create test database");
|
|
||||||
|
|
||||||
// Run migrations on the pool
|
|
||||||
run_migrations(&pool).await.expect("Failed to run migrations");
|
|
||||||
|
|
||||||
let rate_limit_manager = RateLimitManager::new(
|
|
||||||
crate::rate_limiting::RateLimiterConfig::default(),
|
|
||||||
crate::rate_limiting::CircuitBreakerConfig::default(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let client_manager = Arc::new(ClientManager::new(pool.clone()));
|
|
||||||
|
|
||||||
// Create provider manager
|
|
||||||
let provider_manager = ProviderManager::new();
|
|
||||||
|
|
||||||
let model_registry = crate::models::registry::ModelRegistry {
|
|
||||||
providers: std::collections::HashMap::new(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let (dashboard_tx, _) = tokio::sync::broadcast::channel::<serde_json::Value>(100);
|
|
||||||
|
|
||||||
let config = Arc::new(crate::config::AppConfig {
|
|
||||||
server: crate::config::ServerConfig {
|
|
||||||
port: 8080,
|
|
||||||
host: "127.0.0.1".to_string(),
|
|
||||||
auth_tokens: vec![],
|
|
||||||
},
|
|
||||||
database: crate::config::DatabaseConfig {
|
|
||||||
path: std::path::PathBuf::from(":memory:"),
|
|
||||||
max_connections: 5,
|
|
||||||
},
|
|
||||||
providers: crate::config::ProviderConfig {
|
|
||||||
openai: crate::config::OpenAIConfig {
|
|
||||||
api_key_env: "OPENAI_API_KEY".to_string(),
|
|
||||||
base_url: "".to_string(),
|
|
||||||
default_model: "".to_string(),
|
|
||||||
enabled: true,
|
|
||||||
},
|
|
||||||
gemini: crate::config::GeminiConfig {
|
|
||||||
api_key_env: "GEMINI_API_KEY".to_string(),
|
|
||||||
base_url: "".to_string(),
|
|
||||||
default_model: "".to_string(),
|
|
||||||
enabled: true,
|
|
||||||
},
|
|
||||||
deepseek: crate::config::DeepSeekConfig {
|
|
||||||
api_key_env: "DEEPSEEK_API_KEY".to_string(),
|
|
||||||
base_url: "".to_string(),
|
|
||||||
default_model: "".to_string(),
|
|
||||||
enabled: true,
|
|
||||||
},
|
|
||||||
grok: crate::config::GrokConfig {
|
|
||||||
api_key_env: "GROK_API_KEY".to_string(),
|
|
||||||
base_url: "".to_string(),
|
|
||||||
default_model: "".to_string(),
|
|
||||||
enabled: true,
|
|
||||||
},
|
|
||||||
ollama: crate::config::OllamaConfig {
|
|
||||||
base_url: "".to_string(),
|
|
||||||
enabled: true,
|
|
||||||
models: vec![],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
model_mapping: crate::config::ModelMappingConfig { patterns: vec![] },
|
|
||||||
pricing: crate::config::PricingConfig {
|
|
||||||
openai: vec![],
|
|
||||||
gemini: vec![],
|
|
||||||
deepseek: vec![],
|
|
||||||
grok: vec![],
|
|
||||||
ollama: vec![],
|
|
||||||
},
|
|
||||||
config_path: None,
|
|
||||||
encryption_key: "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f".to_string(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Initialize encryption with the test key
|
|
||||||
crypto::init_with_key(&config.encryption_key).expect("failed to initialize crypto");
|
|
||||||
|
|
||||||
AppState::new(
|
|
||||||
config,
|
|
||||||
provider_manager,
|
|
||||||
pool,
|
|
||||||
rate_limit_manager,
|
|
||||||
model_registry,
|
|
||||||
vec![], // auth_tokens
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a test HTTP client
|
|
||||||
pub fn create_test_client() -> reqwest::Client {
|
|
||||||
reqwest::Client::builder()
|
|
||||||
.timeout(std::time::Duration::from_secs(30))
|
|
||||||
.build()
|
|
||||||
.expect("Failed to create test HTTP client")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod integration_tests {
|
|
||||||
use super::test_utils::*;
|
|
||||||
use crate::{
|
|
||||||
models::{ChatCompletionRequest, ChatMessage},
|
|
||||||
server::router,
|
|
||||||
utils::crypto,
|
|
||||||
};
|
|
||||||
use axum::{
|
|
||||||
body::Body,
|
|
||||||
http::{Request, StatusCode},
|
|
||||||
};
|
|
||||||
use mockito::Server;
|
|
||||||
use serde_json::json;
|
|
||||||
use sqlx::Row;
|
|
||||||
use tower::util::ServiceExt;
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_encrypted_provider_key_integration() {
|
|
||||||
// Step 1: Setup test database and state
|
|
||||||
let state = create_test_state().await;
|
|
||||||
let pool = state.db_pool.clone();
|
|
||||||
|
|
||||||
// Step 2: Insert provider with encrypted API key
|
|
||||||
let test_api_key = "test-openai-key-12345";
|
|
||||||
let encrypted_key = crypto::encrypt(test_api_key).expect("Failed to encrypt test key");
|
|
||||||
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind("openai")
|
|
||||||
.bind("OpenAI")
|
|
||||||
.bind(true)
|
|
||||||
.bind("http://localhost:1234") // Mock server URL
|
|
||||||
.bind(&encrypted_key)
|
|
||||||
.bind(true) // api_key_encrypted flag
|
|
||||||
.bind(100.0)
|
|
||||||
.bind(5.0)
|
|
||||||
.execute(&pool)
|
|
||||||
.await
|
|
||||||
.expect("Failed to update provider URL");
|
|
||||||
|
|
||||||
// Re-initialize provider with new URL
|
|
||||||
state
|
|
||||||
.provider_manager
|
|
||||||
.initialize_provider("openai", &state.config, &pool)
|
|
||||||
.await
|
|
||||||
.expect("Failed to re-initialize provider");
|
|
||||||
|
|
||||||
// Step 4: Mock OpenAI API server
|
|
||||||
let mut server = Server::new_async().await;
|
|
||||||
let mock = server
|
|
||||||
.mock("POST", "/chat/completions")
|
|
||||||
.match_header("authorization", format!("Bearer {}", test_api_key).as_str())
|
|
||||||
.with_status(200)
|
|
||||||
.with_header("content-type", "application/json")
|
|
||||||
.with_body(
|
|
||||||
json!({
|
|
||||||
"id": "chatcmpl-test",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1234567890,
|
|
||||||
"model": "gpt-3.5-turbo",
|
|
||||||
"choices": [{
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "Hello, world!"
|
|
||||||
},
|
|
||||||
"finish_reason": "stop"
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 10,
|
|
||||||
"completion_tokens": 5,
|
|
||||||
"total_tokens": 15
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.to_string(),
|
|
||||||
)
|
|
||||||
.create_async()
|
|
||||||
.await;
|
|
||||||
|
|
||||||
// Update provider base URL to use mock server
|
|
||||||
sqlx::query("UPDATE provider_configs SET base_url = ? WHERE id = 'openai'")
|
|
||||||
.bind(&server.url())
|
|
||||||
.execute(&pool)
|
|
||||||
.await
|
|
||||||
.expect("Failed to update provider URL");
|
|
||||||
|
|
||||||
// Re-initialize provider with new URL
|
|
||||||
state
|
|
||||||
.provider_manager
|
|
||||||
.initialize_provider("openai", &state.config, &pool)
|
|
||||||
.await
|
|
||||||
.expect("Failed to re-initialize provider");
|
|
||||||
|
|
||||||
// Step 5: Create test router and make request
|
|
||||||
let app = router(state);
|
|
||||||
|
|
||||||
let request_body = ChatCompletionRequest {
|
|
||||||
model: "gpt-3.5-turbo".to_string(),
|
|
||||||
messages: vec![ChatMessage {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: crate::models::MessageContent::Text {
|
|
||||||
content: "Hello".to_string(),
|
|
||||||
},
|
|
||||||
reasoning_content: None,
|
|
||||||
tool_calls: None,
|
|
||||||
name: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
}],
|
|
||||||
temperature: None,
|
|
||||||
top_p: None,
|
|
||||||
top_k: None,
|
|
||||||
n: None,
|
|
||||||
stop: None,
|
|
||||||
max_tokens: Some(100),
|
|
||||||
presence_penalty: None,
|
|
||||||
frequency_penalty: None,
|
|
||||||
stream: Some(false),
|
|
||||||
tools: None,
|
|
||||||
tool_choice: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let request = Request::builder()
|
|
||||||
.method("POST")
|
|
||||||
.uri("/v1/chat/completions")
|
|
||||||
.header("content-type", "application/json")
|
|
||||||
.header("authorization", "Bearer test-token")
|
|
||||||
.body(Body::from(serde_json::to_string(&request_body).unwrap()))
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// Step 6: Execute request through proxy
|
|
||||||
let response = app
|
|
||||||
.oneshot(request)
|
|
||||||
.await
|
|
||||||
.expect("Failed to execute request");
|
|
||||||
|
|
||||||
let status = response.status();
|
|
||||||
println!("Response status: {}", status);
|
|
||||||
|
|
||||||
if status != StatusCode::OK {
|
|
||||||
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
|
|
||||||
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
|
|
||||||
println!("Response body: {}", body_str);
|
|
||||||
panic!("Response status is not OK: {}", status);
|
|
||||||
}
|
|
||||||
|
|
||||||
assert_eq!(status, StatusCode::OK);
|
|
||||||
|
|
||||||
// Verify the mock was called
|
|
||||||
mock.assert_async().await;
|
|
||||||
|
|
||||||
// Give the async logging task time to complete
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
|
||||||
|
|
||||||
// Step 7: Verify usage was logged in database
|
|
||||||
let log_row = sqlx::query("SELECT * FROM llm_requests WHERE client_id = 'client_test-tok' ORDER BY id DESC LIMIT 1")
|
|
||||||
.fetch_one(&pool)
|
|
||||||
.await
|
|
||||||
.expect("Request log not found");
|
|
||||||
|
|
||||||
assert_eq!(log_row.get::<String, _>("provider"), "openai");
|
|
||||||
assert_eq!(log_row.get::<String, _>("model"), "gpt-3.5-turbo");
|
|
||||||
assert_eq!(log_row.get::<i64, _>("prompt_tokens"), 10);
|
|
||||||
assert_eq!(log_row.get::<i64, _>("completion_tokens"), 5);
|
|
||||||
assert_eq!(log_row.get::<i64, _>("total_tokens"), 15);
|
|
||||||
assert_eq!(log_row.get::<String, _>("status"), "success");
|
|
||||||
|
|
||||||
// Verify client usage was updated
|
|
||||||
let client_row = sqlx::query("SELECT total_requests, total_tokens, total_cost FROM clients WHERE client_id = 'client_test-tok'")
|
|
||||||
.fetch_one(&pool)
|
|
||||||
.await
|
|
||||||
.expect("Client not found");
|
|
||||||
|
|
||||||
assert_eq!(client_row.get::<i64, _>("total_requests"), 1);
|
|
||||||
assert_eq!(client_row.get::<i64, _>("total_tokens"), 15);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,240 +0,0 @@
|
|||||||
use chrono::{DateTime, Utc};
|
|
||||||
use serde::Serialize;
|
|
||||||
use sqlx::SqlitePool;
|
|
||||||
use tokio::sync::broadcast;
|
|
||||||
use tracing::{info, warn};
|
|
||||||
|
|
||||||
use crate::errors::AppError;
|
|
||||||
|
|
||||||
/// Request log entry for database storage
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
|
||||||
pub struct RequestLog {
|
|
||||||
pub timestamp: DateTime<Utc>,
|
|
||||||
pub client_id: String,
|
|
||||||
pub provider: String,
|
|
||||||
pub model: String,
|
|
||||||
pub prompt_tokens: u32,
|
|
||||||
pub completion_tokens: u32,
|
|
||||||
pub total_tokens: u32,
|
|
||||||
pub cache_read_tokens: u32,
|
|
||||||
pub cache_write_tokens: u32,
|
|
||||||
pub cost: f64,
|
|
||||||
pub has_images: bool,
|
|
||||||
pub status: String, // "success", "error"
|
|
||||||
pub error_message: Option<String>,
|
|
||||||
pub duration_ms: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Database operations for request logging
|
|
||||||
pub struct RequestLogger {
|
|
||||||
db_pool: SqlitePool,
|
|
||||||
dashboard_tx: broadcast::Sender<serde_json::Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RequestLogger {
|
|
||||||
pub fn new(db_pool: SqlitePool, dashboard_tx: broadcast::Sender<serde_json::Value>) -> Self {
|
|
||||||
Self { db_pool, dashboard_tx }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Log a request to the database (async, spawns a task)
|
|
||||||
pub fn log_request(&self, log: RequestLog) {
|
|
||||||
let pool = self.db_pool.clone();
|
|
||||||
let tx = self.dashboard_tx.clone();
|
|
||||||
|
|
||||||
// Spawn async task to log without blocking response
|
|
||||||
tokio::spawn(async move {
|
|
||||||
// Broadcast to dashboard
|
|
||||||
let broadcast_result = tx.send(serde_json::json!({
|
|
||||||
"type": "request",
|
|
||||||
"channel": "requests",
|
|
||||||
"payload": log
|
|
||||||
}));
|
|
||||||
match broadcast_result {
|
|
||||||
Ok(receivers) => info!("Broadcast request log to {} dashboard listeners", receivers),
|
|
||||||
Err(_) => {} // No active WebSocket clients — expected when dashboard isn't open
|
|
||||||
}
|
|
||||||
|
|
||||||
match Self::insert_log(&pool, log).await {
|
|
||||||
Ok(()) => info!("Request logged to database successfully"),
|
|
||||||
Err(e) => warn!("Failed to log request to database: {}", e),
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Insert a log entry into the database
|
|
||||||
async fn insert_log(pool: &SqlitePool, log: RequestLog) -> Result<(), sqlx::Error> {
|
|
||||||
let mut tx = pool.begin().await?;
|
|
||||||
|
|
||||||
// Ensure the client row exists (FK constraint requires it)
|
|
||||||
sqlx::query(
|
|
||||||
"INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')",
|
|
||||||
)
|
|
||||||
.bind(&log.client_id)
|
|
||||||
.bind(&log.client_id)
|
|
||||||
.execute(&mut *tx)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
INSERT INTO llm_requests
|
|
||||||
(timestamp, client_id, provider, model, prompt_tokens, completion_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status, error_message, duration_ms, request_body, response_body)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(log.timestamp)
|
|
||||||
.bind(&log.client_id)
|
|
||||||
.bind(&log.provider)
|
|
||||||
.bind(&log.model)
|
|
||||||
.bind(log.prompt_tokens as i64)
|
|
||||||
.bind(log.completion_tokens as i64)
|
|
||||||
.bind(log.total_tokens as i64)
|
|
||||||
.bind(log.cache_read_tokens as i64)
|
|
||||||
.bind(log.cache_write_tokens as i64)
|
|
||||||
.bind(log.cost)
|
|
||||||
.bind(log.has_images)
|
|
||||||
.bind(&log.status)
|
|
||||||
.bind(log.error_message)
|
|
||||||
.bind(log.duration_ms as i64)
|
|
||||||
.bind(None::<String>) // request_body - optional, not stored to save disk space
|
|
||||||
.bind(None::<String>) // response_body - optional, not stored to save disk space
|
|
||||||
.execute(&mut *tx)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Update client usage statistics
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
UPDATE clients SET
|
|
||||||
total_requests = total_requests + 1,
|
|
||||||
total_tokens = total_tokens + ?,
|
|
||||||
total_cost = total_cost + ?,
|
|
||||||
updated_at = CURRENT_TIMESTAMP
|
|
||||||
WHERE client_id = ?
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(log.total_tokens as i64)
|
|
||||||
.bind(log.cost)
|
|
||||||
.bind(&log.client_id)
|
|
||||||
.execute(&mut *tx)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Deduct from provider balance if successful.
|
|
||||||
// Providers configured with billing_mode = 'postpaid' will not have their
|
|
||||||
// credit_balance decremented. Use a conditional UPDATE so we don't need
|
|
||||||
// a prior SELECT and avoid extra round-trips.
|
|
||||||
if log.cost > 0.0 {
|
|
||||||
sqlx::query(
|
|
||||||
"UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')",
|
|
||||||
)
|
|
||||||
.bind(log.cost)
|
|
||||||
.bind(&log.provider)
|
|
||||||
.execute(&mut *tx)
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
tx.commit().await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// /// Middleware to log LLM API requests
|
|
||||||
// /// TODO: Implement proper middleware that can extract response body details
|
|
||||||
// pub async fn request_logging_middleware(
|
|
||||||
// // Extract the authenticated client (if available)
|
|
||||||
// auth_result: Result<AuthenticatedClient, AppError>,
|
|
||||||
// request: Request,
|
|
||||||
// next: Next,
|
|
||||||
// ) -> Response {
|
|
||||||
// let start_time = std::time::Instant::now();
|
|
||||||
//
|
|
||||||
// // Extract client_id from auth or use "unknown"
|
|
||||||
// let client_id = match auth_result {
|
|
||||||
// Ok(auth) => auth.client_id,
|
|
||||||
// Err(_) => "unknown".to_string(),
|
|
||||||
// };
|
|
||||||
//
|
|
||||||
// // Try to extract request details
|
|
||||||
// let (request_parts, request_body) = request.into_parts();
|
|
||||||
//
|
|
||||||
// // Clone request parts for logging
|
|
||||||
// let path = request_parts.uri.path().to_string();
|
|
||||||
//
|
|
||||||
// // Check if this is a chat completion request
|
|
||||||
// let is_chat_completion = path == "/v1/chat/completions";
|
|
||||||
//
|
|
||||||
// // Reconstruct request for downstream handlers
|
|
||||||
// let request = Request::from_parts(request_parts, request_body);
|
|
||||||
//
|
|
||||||
// // Process request and get response
|
|
||||||
// let response = next.run(request).await;
|
|
||||||
//
|
|
||||||
// // Calculate duration
|
|
||||||
// let duration = start_time.elapsed();
|
|
||||||
// let duration_ms = duration.as_millis() as u64;
|
|
||||||
//
|
|
||||||
// // Log basic request info
|
|
||||||
// info!(
|
|
||||||
// "Request from {} to {} - Status: {} - Duration: {}ms",
|
|
||||||
// client_id,
|
|
||||||
// path,
|
|
||||||
// response.status().as_u16(),
|
|
||||||
// duration_ms
|
|
||||||
// );
|
|
||||||
//
|
|
||||||
// // TODO: Extract more details from request/response for logging
|
|
||||||
// // For now, we'll need to modify the server handler to pass additional context
|
|
||||||
//
|
|
||||||
// response
|
|
||||||
// }
|
|
||||||
|
|
||||||
/// Context for request logging that can be passed through extensions
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct LoggingContext {
|
|
||||||
pub client_id: String,
|
|
||||||
pub provider_name: String,
|
|
||||||
pub model: String,
|
|
||||||
pub prompt_tokens: u32,
|
|
||||||
pub completion_tokens: u32,
|
|
||||||
pub total_tokens: u32,
|
|
||||||
pub cost: f64,
|
|
||||||
pub has_images: bool,
|
|
||||||
pub error: Option<AppError>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LoggingContext {
|
|
||||||
pub fn new(client_id: String, provider_name: String, model: String) -> Self {
|
|
||||||
Self {
|
|
||||||
client_id,
|
|
||||||
provider_name,
|
|
||||||
model,
|
|
||||||
prompt_tokens: 0,
|
|
||||||
completion_tokens: 0,
|
|
||||||
total_tokens: 0,
|
|
||||||
cost: 0.0,
|
|
||||||
has_images: false,
|
|
||||||
error: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_token_counts(mut self, prompt_tokens: u32, completion_tokens: u32) -> Self {
|
|
||||||
self.prompt_tokens = prompt_tokens;
|
|
||||||
self.completion_tokens = completion_tokens;
|
|
||||||
self.total_tokens = prompt_tokens + completion_tokens;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_cost(mut self, cost: f64) -> Self {
|
|
||||||
self.cost = cost;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_images(mut self, has_images: bool) -> Self {
|
|
||||||
self.has_images = has_images;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_error(mut self, error: AppError) -> Self {
|
|
||||||
self.error = Some(error);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
96
src/main.rs
96
src/main.rs
@@ -1,96 +0,0 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use axum::{Router, routing::get};
|
|
||||||
use std::net::SocketAddr;
|
|
||||||
use tracing::{error, info};
|
|
||||||
|
|
||||||
use llm_proxy::{
|
|
||||||
config::AppConfig,
|
|
||||||
dashboard, database,
|
|
||||||
providers::ProviderManager,
|
|
||||||
rate_limiting::{CircuitBreakerConfig, RateLimitManager, RateLimiterConfig},
|
|
||||||
server,
|
|
||||||
state::AppState,
|
|
||||||
utils::crypto,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[tokio::main]
|
|
||||||
async fn main() -> Result<()> {
|
|
||||||
// Initialize tracing (logging)
|
|
||||||
tracing_subscriber::fmt()
|
|
||||||
.with_max_level(tracing::Level::INFO)
|
|
||||||
.with_target(false)
|
|
||||||
.init();
|
|
||||||
|
|
||||||
info!("Starting LLM Proxy Gateway v{}", env!("CARGO_PKG_VERSION"));
|
|
||||||
|
|
||||||
// Load configuration
|
|
||||||
let config = AppConfig::load().await?;
|
|
||||||
info!("Configuration loaded from {:?}", config.config_path);
|
|
||||||
|
|
||||||
// Initialize encryption
|
|
||||||
crypto::init_with_key(&config.encryption_key)?;
|
|
||||||
info!("Encryption initialized");
|
|
||||||
|
|
||||||
// Initialize database connection pool
|
|
||||||
let db_pool = database::init(&config.database).await?;
|
|
||||||
info!("Database initialized at {:?}", config.database.path);
|
|
||||||
|
|
||||||
// Initialize provider manager with configured providers
|
|
||||||
let provider_manager = ProviderManager::new();
|
|
||||||
|
|
||||||
// Initialize all supported providers (they handle their own enabled check)
|
|
||||||
let supported_providers = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
|
|
||||||
for name in supported_providers {
|
|
||||||
if let Err(e) = provider_manager.initialize_provider(name, &config, &db_pool).await {
|
|
||||||
error!("Failed to initialize provider {}: {}", name, e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create rate limit manager
|
|
||||||
let rate_limit_manager = RateLimitManager::new(RateLimiterConfig::default(), CircuitBreakerConfig::default());
|
|
||||||
|
|
||||||
// Fetch model registry from models.dev
|
|
||||||
let model_registry = match llm_proxy::utils::registry::fetch_registry().await {
|
|
||||||
Ok(registry) => registry,
|
|
||||||
Err(e) => {
|
|
||||||
error!("Failed to fetch model registry: {}. Using empty registry.", e);
|
|
||||||
llm_proxy::models::registry::ModelRegistry {
|
|
||||||
providers: std::collections::HashMap::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create application state
|
|
||||||
let state = AppState::new(
|
|
||||||
config.clone(),
|
|
||||||
provider_manager,
|
|
||||||
db_pool,
|
|
||||||
rate_limit_manager,
|
|
||||||
model_registry,
|
|
||||||
config.server.auth_tokens.clone(),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Initialize model config cache and start background refresh (every 30s)
|
|
||||||
state.model_config_cache.refresh().await;
|
|
||||||
state.model_config_cache.clone().start_refresh_task(30);
|
|
||||||
info!("Model config cache initialized");
|
|
||||||
|
|
||||||
// Create application router
|
|
||||||
let app = Router::new()
|
|
||||||
.route("/health", get(health_check))
|
|
||||||
.merge(server::router(state.clone()))
|
|
||||||
.merge(dashboard::router(state.clone()));
|
|
||||||
|
|
||||||
// Start server
|
|
||||||
let addr = SocketAddr::from(([0, 0, 0, 0], config.server.port));
|
|
||||||
info!("Server listening on http://{}", addr);
|
|
||||||
|
|
||||||
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
|
||||||
axum::serve(listener, app).await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn health_check() -> &'static str {
|
|
||||||
"OK"
|
|
||||||
}
|
|
||||||
@@ -1,377 +0,0 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
pub mod registry;
|
|
||||||
|
|
||||||
// ========== OpenAI-compatible Request/Response Structs ==========
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ChatCompletionRequest {
|
|
||||||
pub model: String,
|
|
||||||
pub messages: Vec<ChatMessage>,
|
|
||||||
#[serde(default)]
|
|
||||||
pub temperature: Option<f64>,
|
|
||||||
#[serde(default)]
|
|
||||||
pub top_p: Option<f64>,
|
|
||||||
#[serde(default)]
|
|
||||||
pub top_k: Option<u32>,
|
|
||||||
#[serde(default)]
|
|
||||||
pub n: Option<u32>,
|
|
||||||
#[serde(default)]
|
|
||||||
pub stop: Option<Value>, // Can be string or array of strings
|
|
||||||
#[serde(default)]
|
|
||||||
pub max_tokens: Option<u32>,
|
|
||||||
#[serde(default)]
|
|
||||||
pub presence_penalty: Option<f64>,
|
|
||||||
#[serde(default)]
|
|
||||||
pub frequency_penalty: Option<f64>,
|
|
||||||
#[serde(default)]
|
|
||||||
pub stream: Option<bool>,
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tools: Option<Vec<Tool>>,
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tool_choice: Option<ToolChoice>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ChatMessage {
|
|
||||||
pub role: String, // "system", "user", "assistant", "tool"
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub content: MessageContent,
|
|
||||||
#[serde(alias = "reasoning", alias = "thought", skip_serializing_if = "Option::is_none")]
|
|
||||||
pub reasoning_content: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tool_calls: Option<Vec<ToolCall>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub name: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tool_call_id: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum MessageContent {
|
|
||||||
Text { content: String },
|
|
||||||
Parts { content: Vec<ContentPartValue> },
|
|
||||||
None, // Handle cases where content might be null but reasoning is present
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
#[serde(tag = "type", rename_all = "snake_case")]
|
|
||||||
pub enum ContentPartValue {
|
|
||||||
Text { text: String },
|
|
||||||
ImageUrl { image_url: ImageUrl },
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ImageUrl {
|
|
||||||
pub url: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub detail: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ========== Tool-Calling Types ==========
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct Tool {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub tool_type: String,
|
|
||||||
pub function: FunctionDef,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct FunctionDef {
|
|
||||||
pub name: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub description: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub parameters: Option<Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum ToolChoice {
|
|
||||||
Mode(String), // "auto", "none", "required"
|
|
||||||
Specific(ToolChoiceSpecific),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ToolChoiceSpecific {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub choice_type: String,
|
|
||||||
pub function: ToolChoiceFunction,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ToolChoiceFunction {
|
|
||||||
pub name: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ToolCall {
|
|
||||||
pub id: String,
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub call_type: String,
|
|
||||||
pub function: FunctionCall,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct FunctionCall {
|
|
||||||
pub name: String,
|
|
||||||
pub arguments: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ToolCallDelta {
|
|
||||||
pub index: u32,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub id: Option<String>,
|
|
||||||
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
|
||||||
pub call_type: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub function: Option<FunctionCallDelta>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct FunctionCallDelta {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub name: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub arguments: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ========== OpenAI-compatible Response Structs ==========
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ChatCompletionResponse {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String,
|
|
||||||
pub created: u64,
|
|
||||||
pub model: String,
|
|
||||||
pub choices: Vec<ChatChoice>,
|
|
||||||
pub usage: Option<Usage>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ChatChoice {
|
|
||||||
pub index: u32,
|
|
||||||
pub message: ChatMessage,
|
|
||||||
pub finish_reason: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct Usage {
|
|
||||||
pub prompt_tokens: u32,
|
|
||||||
pub completion_tokens: u32,
|
|
||||||
pub total_tokens: u32,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub cache_read_tokens: Option<u32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub cache_write_tokens: Option<u32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ========== Streaming Response Structs ==========
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ChatCompletionStreamResponse {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String,
|
|
||||||
pub created: u64,
|
|
||||||
pub model: String,
|
|
||||||
pub choices: Vec<ChatStreamChoice>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ChatStreamChoice {
|
|
||||||
pub index: u32,
|
|
||||||
pub delta: ChatStreamDelta,
|
|
||||||
pub finish_reason: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ChatStreamDelta {
|
|
||||||
pub role: Option<String>,
|
|
||||||
pub content: Option<String>,
|
|
||||||
#[serde(alias = "reasoning", alias = "thought", skip_serializing_if = "Option::is_none")]
|
|
||||||
pub reasoning_content: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ========== Unified Request Format (for internal use) ==========
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct UnifiedRequest {
|
|
||||||
pub client_id: String,
|
|
||||||
pub model: String,
|
|
||||||
pub messages: Vec<UnifiedMessage>,
|
|
||||||
pub temperature: Option<f64>,
|
|
||||||
pub top_p: Option<f64>,
|
|
||||||
pub top_k: Option<u32>,
|
|
||||||
pub n: Option<u32>,
|
|
||||||
pub stop: Option<Vec<String>>,
|
|
||||||
pub max_tokens: Option<u32>,
|
|
||||||
pub presence_penalty: Option<f64>,
|
|
||||||
pub frequency_penalty: Option<f64>,
|
|
||||||
pub stream: bool,
|
|
||||||
pub has_images: bool,
|
|
||||||
pub tools: Option<Vec<Tool>>,
|
|
||||||
pub tool_choice: Option<ToolChoice>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct UnifiedMessage {
|
|
||||||
pub role: String,
|
|
||||||
pub content: Vec<ContentPart>,
|
|
||||||
pub reasoning_content: Option<String>,
|
|
||||||
pub tool_calls: Option<Vec<ToolCall>>,
|
|
||||||
pub name: Option<String>,
|
|
||||||
pub tool_call_id: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum ContentPart {
|
|
||||||
Text { text: String },
|
|
||||||
Image(crate::multimodal::ImageInput),
|
|
||||||
}
|
|
||||||
|
|
||||||
// ========== Provider-specific Structs ==========
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
|
||||||
pub struct OpenAIRequest {
|
|
||||||
pub model: String,
|
|
||||||
pub messages: Vec<OpenAIMessage>,
|
|
||||||
pub temperature: Option<f64>,
|
|
||||||
pub max_tokens: Option<u32>,
|
|
||||||
pub stream: Option<bool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
|
||||||
pub struct OpenAIMessage {
|
|
||||||
pub role: String,
|
|
||||||
pub content: Vec<OpenAIContentPart>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
|
||||||
#[serde(tag = "type", rename_all = "snake_case")]
|
|
||||||
pub enum OpenAIContentPart {
|
|
||||||
Text { text: String },
|
|
||||||
ImageUrl { image_url: ImageUrl },
|
|
||||||
}
|
|
||||||
|
|
||||||
// Note: ImageUrl struct is defined earlier in the file
|
|
||||||
|
|
||||||
// ========== Conversion Traits ==========
|
|
||||||
|
|
||||||
pub trait ToOpenAI {
|
|
||||||
fn to_openai(&self) -> Result<OpenAIRequest, anyhow::Error>;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait FromOpenAI {
|
|
||||||
fn from_openai(request: &OpenAIRequest) -> Result<Self, anyhow::Error>
|
|
||||||
where
|
|
||||||
Self: Sized;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UnifiedRequest {
|
|
||||||
/// Hydrate all image content by fetching URLs and converting to base64/bytes
|
|
||||||
pub async fn hydrate_images(&mut self) -> anyhow::Result<()> {
|
|
||||||
if !self.has_images {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
for msg in &mut self.messages {
|
|
||||||
for part in &mut msg.content {
|
|
||||||
if let ContentPart::Image(image_input) = part {
|
|
||||||
// Pre-fetch and validate if it's a URL
|
|
||||||
if let crate::multimodal::ImageInput::Url(_url) = image_input {
|
|
||||||
let (base64_data, mime_type) = image_input.to_base64().await?;
|
|
||||||
*image_input = crate::multimodal::ImageInput::Base64 {
|
|
||||||
data: base64_data,
|
|
||||||
mime_type,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
|
|
||||||
type Error = anyhow::Error;
|
|
||||||
|
|
||||||
fn try_from(req: ChatCompletionRequest) -> Result<Self, Self::Error> {
|
|
||||||
let mut has_images = false;
|
|
||||||
|
|
||||||
// Convert OpenAI-compatible request to unified format
|
|
||||||
let messages = req
|
|
||||||
.messages
|
|
||||||
.into_iter()
|
|
||||||
.map(|msg| {
|
|
||||||
let (content, _images_in_message) = match msg.content {
|
|
||||||
MessageContent::Text { content } => (vec![ContentPart::Text { text: content }], false),
|
|
||||||
MessageContent::Parts { content } => {
|
|
||||||
let mut unified_content = Vec::new();
|
|
||||||
let mut has_images_in_msg = false;
|
|
||||||
|
|
||||||
for part in content {
|
|
||||||
match part {
|
|
||||||
ContentPartValue::Text { text } => {
|
|
||||||
unified_content.push(ContentPart::Text { text });
|
|
||||||
}
|
|
||||||
ContentPartValue::ImageUrl { image_url } => {
|
|
||||||
has_images_in_msg = true;
|
|
||||||
has_images = true;
|
|
||||||
unified_content.push(ContentPart::Image(crate::multimodal::ImageInput::from_url(
|
|
||||||
image_url.url,
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
(unified_content, has_images_in_msg)
|
|
||||||
}
|
|
||||||
MessageContent::None => (vec![], false),
|
|
||||||
};
|
|
||||||
|
|
||||||
UnifiedMessage {
|
|
||||||
role: msg.role,
|
|
||||||
content,
|
|
||||||
reasoning_content: msg.reasoning_content,
|
|
||||||
tool_calls: msg.tool_calls,
|
|
||||||
name: msg.name,
|
|
||||||
tool_call_id: msg.tool_call_id,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let stop = match req.stop {
|
|
||||||
Some(Value::String(s)) => Some(vec![s]),
|
|
||||||
Some(Value::Array(a)) => Some(
|
|
||||||
a.into_iter()
|
|
||||||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
|
||||||
.collect(),
|
|
||||||
),
|
|
||||||
_ => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(UnifiedRequest {
|
|
||||||
client_id: String::new(), // Will be populated by auth middleware
|
|
||||||
model: req.model,
|
|
||||||
messages,
|
|
||||||
temperature: req.temperature,
|
|
||||||
top_p: req.top_p,
|
|
||||||
top_k: req.top_k,
|
|
||||||
n: req.n,
|
|
||||||
stop,
|
|
||||||
max_tokens: req.max_tokens,
|
|
||||||
presence_penalty: req.presence_penalty,
|
|
||||||
frequency_penalty: req.frequency_penalty,
|
|
||||||
stream: req.stream.unwrap_or(false),
|
|
||||||
has_images,
|
|
||||||
tools: req.tools,
|
|
||||||
tool_choice: req.tool_choice,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,219 +0,0 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ModelRegistry {
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub providers: HashMap<String, ProviderInfo>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ProviderInfo {
|
|
||||||
pub id: String,
|
|
||||||
pub name: String,
|
|
||||||
pub models: HashMap<String, ModelMetadata>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ModelMetadata {
|
|
||||||
pub id: String,
|
|
||||||
pub name: String,
|
|
||||||
pub cost: Option<ModelCost>,
|
|
||||||
pub limit: Option<ModelLimit>,
|
|
||||||
pub modalities: Option<ModelModalities>,
|
|
||||||
pub tool_call: Option<bool>,
|
|
||||||
pub reasoning: Option<bool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ModelCost {
|
|
||||||
pub input: f64,
|
|
||||||
pub output: f64,
|
|
||||||
pub cache_read: Option<f64>,
|
|
||||||
pub cache_write: Option<f64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ModelLimit {
|
|
||||||
pub context: u32,
|
|
||||||
pub output: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct ModelModalities {
|
|
||||||
pub input: Vec<String>,
|
|
||||||
pub output: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A model entry paired with its provider ID, returned by listing/filtering methods.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct ModelEntry<'a> {
|
|
||||||
pub model_key: &'a str,
|
|
||||||
pub provider_id: &'a str,
|
|
||||||
pub provider_name: &'a str,
|
|
||||||
pub metadata: &'a ModelMetadata,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Filter criteria for listing models. All fields are optional; `None` means no filter.
|
|
||||||
#[derive(Debug, Default, Clone, Deserialize)]
|
|
||||||
pub struct ModelFilter {
|
|
||||||
/// Filter by provider ID (exact match).
|
|
||||||
pub provider: Option<String>,
|
|
||||||
/// Text search on model ID or name (case-insensitive substring).
|
|
||||||
pub search: Option<String>,
|
|
||||||
/// Filter by input modality (e.g. "image", "text").
|
|
||||||
pub modality: Option<String>,
|
|
||||||
/// Only models that support tool calling.
|
|
||||||
pub tool_call: Option<bool>,
|
|
||||||
/// Only models that support reasoning.
|
|
||||||
pub reasoning: Option<bool>,
|
|
||||||
/// Only models that have pricing data.
|
|
||||||
pub has_cost: Option<bool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sort field for model listings.
|
|
||||||
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum ModelSortBy {
|
|
||||||
#[default]
|
|
||||||
Name,
|
|
||||||
Id,
|
|
||||||
Provider,
|
|
||||||
ContextLimit,
|
|
||||||
InputCost,
|
|
||||||
OutputCost,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sort direction.
|
|
||||||
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum SortOrder {
|
|
||||||
#[default]
|
|
||||||
Asc,
|
|
||||||
Desc,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ModelRegistry {
|
|
||||||
/// Find a model by its ID (searching across all providers)
|
|
||||||
pub fn find_model(&self, model_id: &str) -> Option<&ModelMetadata> {
|
|
||||||
// First try exact match if the key in models map matches the ID
|
|
||||||
for provider in self.providers.values() {
|
|
||||||
if let Some(model) = provider.models.get(model_id) {
|
|
||||||
return Some(model);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try searching for the model ID inside the metadata if the key was different
|
|
||||||
for provider in self.providers.values() {
|
|
||||||
for model in provider.models.values() {
|
|
||||||
if model.id == model_id {
|
|
||||||
return Some(model);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
/// List all models with optional filtering and sorting.
|
|
||||||
pub fn list_models(
|
|
||||||
&self,
|
|
||||||
filter: &ModelFilter,
|
|
||||||
sort_by: &ModelSortBy,
|
|
||||||
sort_order: &SortOrder,
|
|
||||||
) -> Vec<ModelEntry<'_>> {
|
|
||||||
let mut entries: Vec<ModelEntry<'_>> = Vec::new();
|
|
||||||
|
|
||||||
for (p_id, p_info) in &self.providers {
|
|
||||||
// Provider filter
|
|
||||||
if let Some(ref prov) = filter.provider {
|
|
||||||
if p_id != prov {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (m_key, m_meta) in &p_info.models {
|
|
||||||
// Text search filter
|
|
||||||
if let Some(ref search) = filter.search {
|
|
||||||
let search_lower = search.to_lowercase();
|
|
||||||
if !m_meta.id.to_lowercase().contains(&search_lower)
|
|
||||||
&& !m_meta.name.to_lowercase().contains(&search_lower)
|
|
||||||
&& !m_key.to_lowercase().contains(&search_lower)
|
|
||||||
{
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Modality filter
|
|
||||||
if let Some(ref modality) = filter.modality {
|
|
||||||
let has_modality = m_meta
|
|
||||||
.modalities
|
|
||||||
.as_ref()
|
|
||||||
.is_some_and(|m| m.input.iter().any(|i| i.eq_ignore_ascii_case(modality)));
|
|
||||||
if !has_modality {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tool call filter
|
|
||||||
if let Some(tc) = filter.tool_call {
|
|
||||||
if m_meta.tool_call.unwrap_or(false) != tc {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reasoning filter
|
|
||||||
if let Some(r) = filter.reasoning {
|
|
||||||
if m_meta.reasoning.unwrap_or(false) != r {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Has cost filter
|
|
||||||
if let Some(hc) = filter.has_cost {
|
|
||||||
if hc != m_meta.cost.is_some() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
entries.push(ModelEntry {
|
|
||||||
model_key: m_key,
|
|
||||||
provider_id: p_id,
|
|
||||||
provider_name: &p_info.name,
|
|
||||||
metadata: m_meta,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort
|
|
||||||
entries.sort_by(|a, b| {
|
|
||||||
let cmp = match sort_by {
|
|
||||||
ModelSortBy::Name => a.metadata.name.to_lowercase().cmp(&b.metadata.name.to_lowercase()),
|
|
||||||
ModelSortBy::Id => a.model_key.cmp(b.model_key),
|
|
||||||
ModelSortBy::Provider => a.provider_id.cmp(b.provider_id),
|
|
||||||
ModelSortBy::ContextLimit => {
|
|
||||||
let a_ctx = a.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0);
|
|
||||||
let b_ctx = b.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0);
|
|
||||||
a_ctx.cmp(&b_ctx)
|
|
||||||
}
|
|
||||||
ModelSortBy::InputCost => {
|
|
||||||
let a_cost = a.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
|
||||||
let b_cost = b.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
|
|
||||||
a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
|
|
||||||
}
|
|
||||||
ModelSortBy::OutputCost => {
|
|
||||||
let a_cost = a.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
|
||||||
let b_cost = b.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
|
|
||||||
a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
match sort_order {
|
|
||||||
SortOrder::Asc => cmp,
|
|
||||||
SortOrder::Desc => cmp.reverse(),
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
entries
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,299 +0,0 @@
|
|||||||
//! Multimodal support for image processing and conversion
|
|
||||||
//!
|
|
||||||
//! This module handles:
|
|
||||||
//! 1. Image format detection and conversion
|
|
||||||
//! 2. Base64 encoding/decoding
|
|
||||||
//! 3. URL fetching for images
|
|
||||||
//! 4. Provider-specific image format conversion
|
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
|
||||||
use base64::{Engine as _, engine::general_purpose};
|
|
||||||
use std::sync::LazyLock;
|
|
||||||
use tracing::{info, warn};
|
|
||||||
|
|
||||||
/// Shared HTTP client for image fetching — avoids creating a new TCP+TLS
|
|
||||||
/// connection for every image URL.
|
|
||||||
static IMAGE_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
|
||||||
reqwest::Client::builder()
|
|
||||||
.connect_timeout(std::time::Duration::from_secs(5))
|
|
||||||
.timeout(std::time::Duration::from_secs(30))
|
|
||||||
.pool_idle_timeout(std::time::Duration::from_secs(60))
|
|
||||||
.build()
|
|
||||||
.expect("Failed to build image HTTP client")
|
|
||||||
});
|
|
||||||
|
|
||||||
/// Supported image formats for multimodal input
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum ImageInput {
|
|
||||||
/// Base64-encoded image data with MIME type
|
|
||||||
Base64 { data: String, mime_type: String },
|
|
||||||
/// URL to fetch image from
|
|
||||||
Url(String),
|
|
||||||
/// Raw bytes with MIME type
|
|
||||||
Bytes { data: Vec<u8>, mime_type: String },
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ImageInput {
|
|
||||||
/// Create ImageInput from base64 string
|
|
||||||
pub fn from_base64(data: String, mime_type: String) -> Self {
|
|
||||||
Self::Base64 { data, mime_type }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create ImageInput from URL
|
|
||||||
pub fn from_url(url: String) -> Self {
|
|
||||||
Self::Url(url)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create ImageInput from raw bytes
|
|
||||||
pub fn from_bytes(data: Vec<u8>, mime_type: String) -> Self {
|
|
||||||
Self::Bytes { data, mime_type }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get MIME type if available
|
|
||||||
pub fn mime_type(&self) -> Option<&str> {
|
|
||||||
match self {
|
|
||||||
Self::Base64 { mime_type, .. } => Some(mime_type),
|
|
||||||
Self::Bytes { mime_type, .. } => Some(mime_type),
|
|
||||||
Self::Url(_) => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert to base64 if not already
|
|
||||||
pub async fn to_base64(&self) -> Result<(String, String)> {
|
|
||||||
match self {
|
|
||||||
Self::Base64 { data, mime_type } => Ok((data.clone(), mime_type.clone())),
|
|
||||||
Self::Bytes { data, mime_type } => {
|
|
||||||
let base64_data = general_purpose::STANDARD.encode(data);
|
|
||||||
Ok((base64_data, mime_type.clone()))
|
|
||||||
}
|
|
||||||
Self::Url(url) => {
|
|
||||||
// Fetch image from URL using shared client
|
|
||||||
info!("Fetching image from URL: {}", url);
|
|
||||||
let response = IMAGE_CLIENT
|
|
||||||
.get(url)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.context("Failed to fetch image from URL")?;
|
|
||||||
|
|
||||||
if !response.status().is_success() {
|
|
||||||
anyhow::bail!("Failed to fetch image: HTTP {}", response.status());
|
|
||||||
}
|
|
||||||
|
|
||||||
let mime_type = response
|
|
||||||
.headers()
|
|
||||||
.get(reqwest::header::CONTENT_TYPE)
|
|
||||||
.and_then(|h| h.to_str().ok())
|
|
||||||
.unwrap_or("image/jpeg")
|
|
||||||
.to_string();
|
|
||||||
|
|
||||||
let bytes = response.bytes().await.context("Failed to read image bytes")?;
|
|
||||||
|
|
||||||
let base64_data = general_purpose::STANDARD.encode(&bytes);
|
|
||||||
Ok((base64_data, mime_type))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get image dimensions (width, height)
|
|
||||||
pub async fn get_dimensions(&self) -> Result<(u32, u32)> {
|
|
||||||
let bytes = match self {
|
|
||||||
Self::Base64 { data, .. } => general_purpose::STANDARD
|
|
||||||
.decode(data)
|
|
||||||
.context("Failed to decode base64")?,
|
|
||||||
Self::Bytes { data, .. } => data.clone(),
|
|
||||||
Self::Url(_) => {
|
|
||||||
let (base64_data, _) = self.to_base64().await?;
|
|
||||||
general_purpose::STANDARD
|
|
||||||
.decode(&base64_data)
|
|
||||||
.context("Failed to decode base64")?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let img = image::load_from_memory(&bytes).context("Failed to load image from bytes")?;
|
|
||||||
Ok((img.width(), img.height()))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Validate image size and format
|
|
||||||
pub async fn validate(&self, max_size_mb: f64) -> Result<()> {
|
|
||||||
let (width, height) = self.get_dimensions().await?;
|
|
||||||
|
|
||||||
// Check dimensions
|
|
||||||
if width > 4096 || height > 4096 {
|
|
||||||
warn!("Image dimensions too large: {}x{}", width, height);
|
|
||||||
// Continue anyway, but log warning
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check file size
|
|
||||||
let size_bytes = match self {
|
|
||||||
Self::Base64 { data, .. } => {
|
|
||||||
// Base64 size is ~4/3 of original
|
|
||||||
(data.len() as f64 * 0.75) as usize
|
|
||||||
}
|
|
||||||
Self::Bytes { data, .. } => data.len(),
|
|
||||||
Self::Url(_) => {
|
|
||||||
// For URLs, we'd need to fetch to check size
|
|
||||||
// Skip size check for URLs for now
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let size_mb = size_bytes as f64 / (1024.0 * 1024.0);
|
|
||||||
if size_mb > max_size_mb {
|
|
||||||
anyhow::bail!("Image too large: {:.2}MB > {:.2}MB limit", size_mb, max_size_mb);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Provider-specific image format conversion
|
|
||||||
pub struct ImageConverter;
|
|
||||||
|
|
||||||
impl ImageConverter {
|
|
||||||
/// Convert image to OpenAI-compatible format
|
|
||||||
pub async fn to_openai_format(image: &ImageInput) -> Result<serde_json::Value> {
|
|
||||||
let (base64_data, mime_type) = image.to_base64().await?;
|
|
||||||
|
|
||||||
// OpenAI expects data URL format: "data:image/jpeg;base64,{data}"
|
|
||||||
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
|
|
||||||
|
|
||||||
Ok(serde_json::json!({
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": data_url,
|
|
||||||
"detail": "auto" // Can be "low", "high", or "auto"
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert image to Gemini-compatible format
|
|
||||||
pub async fn to_gemini_format(image: &ImageInput) -> Result<serde_json::Value> {
|
|
||||||
let (base64_data, mime_type) = image.to_base64().await?;
|
|
||||||
|
|
||||||
// Gemini expects inline data format
|
|
||||||
Ok(serde_json::json!({
|
|
||||||
"inline_data": {
|
|
||||||
"mime_type": mime_type,
|
|
||||||
"data": base64_data
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert image to DeepSeek-compatible format
|
|
||||||
pub async fn to_deepseek_format(image: &ImageInput) -> Result<serde_json::Value> {
|
|
||||||
// DeepSeek uses OpenAI-compatible format for vision models
|
|
||||||
Self::to_openai_format(image).await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Detect if a model supports multimodal input
|
|
||||||
pub fn model_supports_multimodal(model: &str) -> bool {
|
|
||||||
// OpenAI vision models
|
|
||||||
if (model.starts_with("gpt-4") && (model.contains("vision") || model.contains("-v") || model.contains("4o")))
|
|
||||||
|| model.starts_with("o1-")
|
|
||||||
|| model.starts_with("o3-")
|
|
||||||
{
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gemini vision models
|
|
||||||
if model.starts_with("gemini") {
|
|
||||||
// Most Gemini models support vision
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeepSeek vision models
|
|
||||||
if model.starts_with("deepseek-vl") {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parse OpenAI-compatible multimodal message content
|
|
||||||
pub fn parse_openai_content(content: &serde_json::Value) -> Result<Vec<(String, Option<ImageInput>)>> {
|
|
||||||
let mut parts = Vec::new();
|
|
||||||
|
|
||||||
if let Some(content_str) = content.as_str() {
|
|
||||||
// Simple text content
|
|
||||||
parts.push((content_str.to_string(), None));
|
|
||||||
} else if let Some(content_array) = content.as_array() {
|
|
||||||
// Array of content parts (text and/or images)
|
|
||||||
for part in content_array {
|
|
||||||
if let Some(part_obj) = part.as_object()
|
|
||||||
&& let Some(part_type) = part_obj.get("type").and_then(|t| t.as_str())
|
|
||||||
{
|
|
||||||
match part_type {
|
|
||||||
"text" => {
|
|
||||||
if let Some(text) = part_obj.get("text").and_then(|t| t.as_str()) {
|
|
||||||
parts.push((text.to_string(), None));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"image_url" => {
|
|
||||||
if let Some(image_url_obj) = part_obj.get("image_url").and_then(|o| o.as_object())
|
|
||||||
&& let Some(url) = image_url_obj.get("url").and_then(|u| u.as_str())
|
|
||||||
{
|
|
||||||
if url.starts_with("data:") {
|
|
||||||
// Parse data URL
|
|
||||||
if let Some((mime_type, data)) = parse_data_url(url) {
|
|
||||||
let image_input = ImageInput::from_base64(data, mime_type);
|
|
||||||
parts.push(("".to_string(), Some(image_input)));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Regular URL
|
|
||||||
let image_input = ImageInput::from_url(url.to_string());
|
|
||||||
parts.push(("".to_string(), Some(image_input)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
warn!("Unknown content part type: {}", part_type);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(parts)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parse data URL (data:image/jpeg;base64,{data})
|
|
||||||
fn parse_data_url(data_url: &str) -> Option<(String, String)> {
|
|
||||||
if !data_url.starts_with("data:") {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let parts: Vec<&str> = data_url[5..].split(";base64,").collect();
|
|
||||||
if parts.len() != 2 {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mime_type = parts[0].to_string();
|
|
||||||
let data = parts[1].to_string();
|
|
||||||
|
|
||||||
Some((mime_type, data))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_parse_data_url() {
|
|
||||||
let test_url = "data:image/jpeg;base64,SGVsbG8gV29ybGQ="; // "Hello World" in base64
|
|
||||||
let (mime_type, data) = parse_data_url(test_url).unwrap();
|
|
||||||
|
|
||||||
assert_eq!(mime_type, "image/jpeg");
|
|
||||||
assert_eq!(data, "SGVsbG8gV29ybGQ=");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_model_supports_multimodal() {
|
|
||||||
assert!(ImageConverter::model_supports_multimodal("gpt-4-vision-preview"));
|
|
||||||
assert!(ImageConverter::model_supports_multimodal("gpt-4o"));
|
|
||||||
assert!(ImageConverter::model_supports_multimodal("gemini-pro-vision"));
|
|
||||||
assert!(ImageConverter::model_supports_multimodal("gemini-pro"));
|
|
||||||
assert!(!ImageConverter::model_supports_multimodal("gpt-3.5-turbo"));
|
|
||||||
assert!(!ImageConverter::model_supports_multimodal("claude-3-opus"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,251 +0,0 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use futures::stream::BoxStream;
|
|
||||||
use futures::StreamExt;
|
|
||||||
|
|
||||||
use super::helpers;
|
|
||||||
use super::{ProviderResponse, ProviderStreamChunk};
|
|
||||||
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
|
||||||
|
|
||||||
pub struct DeepSeekProvider {
|
|
||||||
client: reqwest::Client,
|
|
||||||
config: crate::config::DeepSeekConfig,
|
|
||||||
api_key: String,
|
|
||||||
pricing: Vec<crate::config::ModelPricing>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DeepSeekProvider {
|
|
||||||
pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result<Self> {
|
|
||||||
let api_key = app_config.get_api_key("deepseek")?;
|
|
||||||
Self::new_with_key(config, app_config, api_key)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new_with_key(
|
|
||||||
config: &crate::config::DeepSeekConfig,
|
|
||||||
app_config: &AppConfig,
|
|
||||||
api_key: String,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let client = reqwest::Client::builder()
|
|
||||||
.connect_timeout(std::time::Duration::from_secs(5))
|
|
||||||
.timeout(std::time::Duration::from_secs(300))
|
|
||||||
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
|
||||||
.pool_max_idle_per_host(4)
|
|
||||||
.tcp_keepalive(std::time::Duration::from_secs(30))
|
|
||||||
.build()?;
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
client,
|
|
||||||
config: config.clone(),
|
|
||||||
api_key,
|
|
||||||
pricing: app_config.pricing.deepseek.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl super::Provider for DeepSeekProvider {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"deepseek"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supports_model(&self, model: &str) -> bool {
|
|
||||||
model.starts_with("deepseek-") || model.contains("deepseek")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supports_multimodal(&self) -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
|
||||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
|
||||||
let mut body = helpers::build_openai_body(&request, messages_json, false);
|
|
||||||
|
|
||||||
// Sanitize and fix for deepseek-reasoner (R1)
|
|
||||||
if request.model == "deepseek-reasoner" {
|
|
||||||
if let Some(obj) = body.as_object_mut() {
|
|
||||||
// Remove unsupported parameters
|
|
||||||
obj.remove("temperature");
|
|
||||||
obj.remove("top_p");
|
|
||||||
obj.remove("presence_penalty");
|
|
||||||
obj.remove("frequency_penalty");
|
|
||||||
obj.remove("logit_bias");
|
|
||||||
obj.remove("logprobs");
|
|
||||||
obj.remove("top_logprobs");
|
|
||||||
|
|
||||||
// ENSURE: EVERY assistant message must have reasoning_content and valid content
|
|
||||||
if let Some(messages) = obj.get_mut("messages").and_then(|m| m.as_array_mut()) {
|
|
||||||
for m in messages {
|
|
||||||
if m["role"].as_str() == Some("assistant") {
|
|
||||||
// DeepSeek R1 requires reasoning_content for consistency in history.
|
|
||||||
if m.get("reasoning_content").is_none() || m["reasoning_content"].is_null() {
|
|
||||||
m["reasoning_content"] = serde_json::json!(" ");
|
|
||||||
}
|
|
||||||
// DeepSeek R1 often requires content to be a string, not null/array
|
|
||||||
if m.get("content").is_none() || m["content"].is_null() || m["content"].is_array() {
|
|
||||||
m["content"] = serde_json::json!("");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = self
|
|
||||||
.client
|
|
||||||
.post(format!("{}/chat/completions", self.config.base_url))
|
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
||||||
.json(&body)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
||||||
|
|
||||||
if !response.status().is_success() {
|
|
||||||
let status = response.status();
|
|
||||||
let error_text = response.text().await.unwrap_or_default();
|
|
||||||
tracing::error!("DeepSeek API error ({}): {}", status, error_text);
|
|
||||||
tracing::error!("Offending DeepSeek Request Body: {}", serde_json::to_string(&body).unwrap_or_default());
|
|
||||||
return Err(AppError::ProviderError(format!("DeepSeek API error ({}): {}", status, error_text)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let resp_json: serde_json::Value = response
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
||||||
|
|
||||||
helpers::parse_openai_response(&resp_json, request.model)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
|
||||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calculate_cost(
|
|
||||||
&self,
|
|
||||||
model: &str,
|
|
||||||
prompt_tokens: u32,
|
|
||||||
completion_tokens: u32,
|
|
||||||
cache_read_tokens: u32,
|
|
||||||
cache_write_tokens: u32,
|
|
||||||
registry: &crate::models::registry::ModelRegistry,
|
|
||||||
) -> f64 {
|
|
||||||
helpers::calculate_cost_with_registry(
|
|
||||||
model,
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
cache_read_tokens,
|
|
||||||
cache_write_tokens,
|
|
||||||
registry,
|
|
||||||
&self.pricing,
|
|
||||||
0.14,
|
|
||||||
0.28,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn chat_completion_stream(
|
|
||||||
&self,
|
|
||||||
request: UnifiedRequest,
|
|
||||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
|
||||||
// DeepSeek doesn't support images in streaming, use text-only
|
|
||||||
let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?;
|
|
||||||
let mut body = helpers::build_openai_body(&request, messages_json, true);
|
|
||||||
|
|
||||||
// Sanitize and fix for deepseek-reasoner (R1)
|
|
||||||
if request.model == "deepseek-reasoner" {
|
|
||||||
if let Some(obj) = body.as_object_mut() {
|
|
||||||
obj.remove("stream_options");
|
|
||||||
obj.remove("temperature");
|
|
||||||
obj.remove("top_p");
|
|
||||||
obj.remove("presence_penalty");
|
|
||||||
obj.remove("frequency_penalty");
|
|
||||||
obj.remove("logit_bias");
|
|
||||||
obj.remove("logprobs");
|
|
||||||
obj.remove("top_logprobs");
|
|
||||||
|
|
||||||
// ENSURE: EVERY assistant message must have reasoning_content and valid content
|
|
||||||
if let Some(messages) = obj.get_mut("messages").and_then(|m| m.as_array_mut()) {
|
|
||||||
for m in messages {
|
|
||||||
if m["role"].as_str() == Some("assistant") {
|
|
||||||
// DeepSeek R1 requires reasoning_content for consistency in history.
|
|
||||||
if m.get("reasoning_content").is_none() || m["reasoning_content"].is_null() {
|
|
||||||
m["reasoning_content"] = serde_json::json!(" ");
|
|
||||||
}
|
|
||||||
// DeepSeek R1 often requires content to be a string, not null/array
|
|
||||||
if m.get("content").is_none() || m["content"].is_null() || m["content"].is_array() {
|
|
||||||
m["content"] = serde_json::json!("");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// For standard deepseek-chat, keep it clean
|
|
||||||
if let Some(obj) = body.as_object_mut() {
|
|
||||||
obj.remove("stream_options");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let url = format!("{}/chat/completions", self.config.base_url);
|
|
||||||
let api_key = self.api_key.clone();
|
|
||||||
let probe_client = self.client.clone();
|
|
||||||
let probe_body = body.clone();
|
|
||||||
let model = request.model.clone();
|
|
||||||
|
|
||||||
let es = reqwest_eventsource::EventSource::new(
|
|
||||||
self.client
|
|
||||||
.post(&url)
|
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
||||||
.json(&body),
|
|
||||||
)
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
|
||||||
|
|
||||||
let stream = async_stream::try_stream! {
|
|
||||||
let mut es = es;
|
|
||||||
while let Some(event) = es.next().await {
|
|
||||||
match event {
|
|
||||||
Ok(reqwest_eventsource::Event::Message(msg)) => {
|
|
||||||
if msg.data == "[DONE]" {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let chunk: serde_json::Value = serde_json::from_str(&msg.data)
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
|
||||||
|
|
||||||
if let Some(p_chunk) = helpers::parse_openai_stream_chunk(&chunk, &model, None) {
|
|
||||||
yield p_chunk?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(_) => continue,
|
|
||||||
Err(e) => {
|
|
||||||
// Attempt to probe for the actual error body
|
|
||||||
let probe_resp = probe_client
|
|
||||||
.post(&url)
|
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
|
||||||
.json(&probe_body)
|
|
||||||
.send()
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match probe_resp {
|
|
||||||
Ok(r) if !r.status().is_success() => {
|
|
||||||
let status = r.status();
|
|
||||||
let error_body = r.text().await.unwrap_or_default();
|
|
||||||
tracing::error!("DeepSeek Stream Error Probe ({}): {}", status, error_body);
|
|
||||||
// Log the offending request body at ERROR level so it shows up in standard logs
|
|
||||||
tracing::error!("Offending DeepSeek Request Body: {}", serde_json::to_string(&probe_body).unwrap_or_default());
|
|
||||||
Err(AppError::ProviderError(format!("DeepSeek API error ({}): {}", status, error_body)))?;
|
|
||||||
}
|
|
||||||
Ok(_) => {
|
|
||||||
Err(AppError::ProviderError(format!("Stream error (probe returned 200): {}", e)))?;
|
|
||||||
}
|
|
||||||
Err(probe_err) => {
|
|
||||||
tracing::error!("DeepSeek Stream Error Probe failed: {}", probe_err);
|
|
||||||
Err(AppError::ProviderError(format!("Stream error (probe failed: {}): {}", probe_err, e)))?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Box::pin(stream))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,123 +0,0 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use futures::stream::BoxStream;
|
|
||||||
|
|
||||||
use super::helpers;
|
|
||||||
use super::{ProviderResponse, ProviderStreamChunk};
|
|
||||||
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
|
||||||
|
|
||||||
pub struct GrokProvider {
|
|
||||||
client: reqwest::Client,
|
|
||||||
config: crate::config::GrokConfig,
|
|
||||||
api_key: String,
|
|
||||||
pricing: Vec<crate::config::ModelPricing>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GrokProvider {
|
|
||||||
pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result<Self> {
|
|
||||||
let api_key = app_config.get_api_key("grok")?;
|
|
||||||
Self::new_with_key(config, app_config, api_key)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
|
||||||
let client = reqwest::Client::builder()
|
|
||||||
.connect_timeout(std::time::Duration::from_secs(5))
|
|
||||||
.timeout(std::time::Duration::from_secs(300))
|
|
||||||
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
|
||||||
.pool_max_idle_per_host(4)
|
|
||||||
.tcp_keepalive(std::time::Duration::from_secs(30))
|
|
||||||
.build()?;
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
client,
|
|
||||||
config: config.clone(),
|
|
||||||
api_key,
|
|
||||||
pricing: app_config.pricing.grok.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl super::Provider for GrokProvider {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"grok"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supports_model(&self, model: &str) -> bool {
|
|
||||||
model.starts_with("grok-")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supports_multimodal(&self) -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
|
||||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
|
||||||
let body = helpers::build_openai_body(&request, messages_json, false);
|
|
||||||
|
|
||||||
let response = self
|
|
||||||
.client
|
|
||||||
.post(format!("{}/chat/completions", self.config.base_url))
|
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
||||||
.json(&body)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
||||||
|
|
||||||
if !response.status().is_success() {
|
|
||||||
let error_text = response.text().await.unwrap_or_default();
|
|
||||||
return Err(AppError::ProviderError(format!("Grok API error: {}", error_text)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let resp_json: serde_json::Value = response
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
||||||
|
|
||||||
helpers::parse_openai_response(&resp_json, request.model)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
|
||||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calculate_cost(
|
|
||||||
&self,
|
|
||||||
model: &str,
|
|
||||||
prompt_tokens: u32,
|
|
||||||
completion_tokens: u32,
|
|
||||||
cache_read_tokens: u32,
|
|
||||||
cache_write_tokens: u32,
|
|
||||||
registry: &crate::models::registry::ModelRegistry,
|
|
||||||
) -> f64 {
|
|
||||||
helpers::calculate_cost_with_registry(
|
|
||||||
model,
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
cache_read_tokens,
|
|
||||||
cache_write_tokens,
|
|
||||||
registry,
|
|
||||||
&self.pricing,
|
|
||||||
5.0,
|
|
||||||
15.0,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn chat_completion_stream(
|
|
||||||
&self,
|
|
||||||
request: UnifiedRequest,
|
|
||||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
|
||||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
|
||||||
let body = helpers::build_openai_body(&request, messages_json, true);
|
|
||||||
|
|
||||||
let es = reqwest_eventsource::EventSource::new(
|
|
||||||
self.client
|
|
||||||
.post(format!("{}/chat/completions", self.config.base_url))
|
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
||||||
.json(&body),
|
|
||||||
)
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
|
||||||
|
|
||||||
Ok(helpers::create_openai_stream(es, request.model, None))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,445 +0,0 @@
|
|||||||
use super::{ProviderResponse, ProviderStreamChunk, StreamUsage};
|
|
||||||
use crate::errors::AppError;
|
|
||||||
use crate::models::{ContentPart, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest};
|
|
||||||
use futures::stream::{BoxStream, StreamExt};
|
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
/// Convert messages to OpenAI-compatible JSON, resolving images asynchronously.
|
|
||||||
///
|
|
||||||
/// This avoids the deadlock caused by `futures::executor::block_on` inside a
|
|
||||||
/// Tokio async context. All image base64 conversions are awaited properly.
|
|
||||||
/// Handles tool-calling messages: assistant messages with tool_calls, and
|
|
||||||
/// tool-role messages with tool_call_id/name.
|
|
||||||
pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result<Vec<serde_json::Value>, AppError> {
|
|
||||||
let mut result = Vec::new();
|
|
||||||
for m in messages {
|
|
||||||
// Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." }
|
|
||||||
if m.role == "tool" {
|
|
||||||
let text_content = m
|
|
||||||
.content
|
|
||||||
.first()
|
|
||||||
.map(|p| match p {
|
|
||||||
ContentPart::Text { text } => text.clone(),
|
|
||||||
ContentPart::Image(_) => "[Image]".to_string(),
|
|
||||||
})
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
let mut msg = serde_json::json!({
|
|
||||||
"role": "tool",
|
|
||||||
"content": text_content
|
|
||||||
});
|
|
||||||
if let Some(tool_call_id) = &m.tool_call_id {
|
|
||||||
// OpenAI and others have a 40-char limit for tool_call_id.
|
|
||||||
// Gemini signatures (56 chars) must be shortened for compatibility.
|
|
||||||
let id = if tool_call_id.len() > 40 {
|
|
||||||
&tool_call_id[..40]
|
|
||||||
} else {
|
|
||||||
tool_call_id
|
|
||||||
};
|
|
||||||
msg["tool_call_id"] = serde_json::json!(id);
|
|
||||||
}
|
|
||||||
if let Some(name) = &m.name {
|
|
||||||
msg["name"] = serde_json::json!(name);
|
|
||||||
}
|
|
||||||
result.push(msg);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build content parts for non-tool messages
|
|
||||||
let mut parts = Vec::new();
|
|
||||||
for p in &m.content {
|
|
||||||
match p {
|
|
||||||
ContentPart::Text { text } => {
|
|
||||||
parts.push(serde_json::json!({ "type": "text", "text": text }));
|
|
||||||
}
|
|
||||||
ContentPart::Image(image_input) => {
|
|
||||||
let (base64_data, mime_type) = image_input
|
|
||||||
.to_base64()
|
|
||||||
.await
|
|
||||||
.map_err(|e| AppError::MultimodalError(e.to_string()))?;
|
|
||||||
parts.push(serde_json::json!({
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut msg = serde_json::json!({ "role": m.role });
|
|
||||||
|
|
||||||
// Include reasoning_content if present (DeepSeek R1/reasoner requires this in history)
|
|
||||||
if let Some(reasoning) = &m.reasoning_content {
|
|
||||||
msg["reasoning_content"] = serde_json::json!(reasoning);
|
|
||||||
}
|
|
||||||
|
|
||||||
// For assistant messages with tool_calls, content can be empty string
|
|
||||||
if let Some(tool_calls) = &m.tool_calls {
|
|
||||||
// Sanitize tool call IDs for OpenAI compatibility (max 40 chars)
|
|
||||||
let sanitized_calls: Vec<_> = tool_calls.iter().map(|tc| {
|
|
||||||
let mut sanitized = tc.clone();
|
|
||||||
if sanitized.id.len() > 40 {
|
|
||||||
sanitized.id = sanitized.id[..40].to_string();
|
|
||||||
}
|
|
||||||
sanitized
|
|
||||||
}).collect();
|
|
||||||
|
|
||||||
if parts.is_empty() {
|
|
||||||
msg["content"] = serde_json::json!("");
|
|
||||||
} else {
|
|
||||||
msg["content"] = serde_json::json!(parts);
|
|
||||||
}
|
|
||||||
msg["tool_calls"] = serde_json::json!(sanitized_calls);
|
|
||||||
} else {
|
|
||||||
msg["content"] = serde_json::json!(parts);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(name) = &m.name {
|
|
||||||
msg["name"] = serde_json::json!(name);
|
|
||||||
}
|
|
||||||
|
|
||||||
result.push(msg);
|
|
||||||
}
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert messages to OpenAI-compatible JSON, but replace images with a
|
|
||||||
/// text placeholder "[Image]". Useful for providers that don't support
|
|
||||||
/// multimodal in streaming mode or at all.
|
|
||||||
///
|
|
||||||
/// Handles tool-calling messages identically to `messages_to_openai_json`:
|
|
||||||
/// assistant messages with `tool_calls`, and tool-role messages with
|
|
||||||
/// `tool_call_id`/`name`.
|
|
||||||
pub async fn messages_to_openai_json_text_only(
|
|
||||||
messages: &[UnifiedMessage],
|
|
||||||
) -> Result<Vec<serde_json::Value>, AppError> {
|
|
||||||
let mut result = Vec::new();
|
|
||||||
for m in messages {
|
|
||||||
// Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." }
|
|
||||||
if m.role == "tool" {
|
|
||||||
let text_content = m
|
|
||||||
.content
|
|
||||||
.first()
|
|
||||||
.map(|p| match p {
|
|
||||||
ContentPart::Text { text } => text.clone(),
|
|
||||||
ContentPart::Image(_) => "[Image]".to_string(),
|
|
||||||
})
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
let mut msg = serde_json::json!({
|
|
||||||
"role": "tool",
|
|
||||||
"content": text_content
|
|
||||||
});
|
|
||||||
if let Some(tool_call_id) = &m.tool_call_id {
|
|
||||||
// OpenAI and others have a 40-char limit for tool_call_id.
|
|
||||||
let id = if tool_call_id.len() > 40 {
|
|
||||||
&tool_call_id[..40]
|
|
||||||
} else {
|
|
||||||
tool_call_id
|
|
||||||
};
|
|
||||||
msg["tool_call_id"] = serde_json::json!(id);
|
|
||||||
}
|
|
||||||
if let Some(name) = &m.name {
|
|
||||||
msg["name"] = serde_json::json!(name);
|
|
||||||
}
|
|
||||||
result.push(msg);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build content parts for non-tool messages (images become "[Image]" text)
|
|
||||||
let mut parts = Vec::new();
|
|
||||||
for p in &m.content {
|
|
||||||
match p {
|
|
||||||
ContentPart::Text { text } => {
|
|
||||||
parts.push(serde_json::json!({ "type": "text", "text": text }));
|
|
||||||
}
|
|
||||||
ContentPart::Image(_) => {
|
|
||||||
parts.push(serde_json::json!({ "type": "text", "text": "[Image]" }));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut msg = serde_json::json!({ "role": m.role });
|
|
||||||
|
|
||||||
// Include reasoning_content if present (DeepSeek R1/reasoner requires this in history)
|
|
||||||
if let Some(reasoning) = &m.reasoning_content {
|
|
||||||
msg["reasoning_content"] = serde_json::json!(reasoning);
|
|
||||||
}
|
|
||||||
|
|
||||||
// For assistant messages with tool_calls, content can be empty string
|
|
||||||
if let Some(tool_calls) = &m.tool_calls {
|
|
||||||
// Sanitize tool call IDs for OpenAI compatibility (max 40 chars)
|
|
||||||
let sanitized_calls: Vec<_> = tool_calls.iter().map(|tc| {
|
|
||||||
let mut sanitized = tc.clone();
|
|
||||||
if sanitized.id.len() > 40 {
|
|
||||||
sanitized.id = sanitized.id[..40].to_string();
|
|
||||||
}
|
|
||||||
sanitized
|
|
||||||
}).collect();
|
|
||||||
|
|
||||||
if parts.is_empty() {
|
|
||||||
msg["content"] = serde_json::json!("");
|
|
||||||
} else {
|
|
||||||
msg["content"] = serde_json::json!(parts);
|
|
||||||
}
|
|
||||||
msg["tool_calls"] = serde_json::json!(sanitized_calls);
|
|
||||||
} else {
|
|
||||||
msg["content"] = serde_json::json!(parts);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(name) = &m.name {
|
|
||||||
msg["name"] = serde_json::json!(name);
|
|
||||||
}
|
|
||||||
|
|
||||||
result.push(msg);
|
|
||||||
}
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build an OpenAI-compatible request body from a UnifiedRequest and pre-converted messages.
|
|
||||||
/// Includes tools and tool_choice when present.
|
|
||||||
/// When streaming, adds `stream_options.include_usage: true` so providers report
|
|
||||||
/// token counts in the final SSE chunk.
|
|
||||||
pub fn build_openai_body(
|
|
||||||
request: &UnifiedRequest,
|
|
||||||
messages_json: Vec<serde_json::Value>,
|
|
||||||
stream: bool,
|
|
||||||
) -> serde_json::Value {
|
|
||||||
let mut body = serde_json::json!({
|
|
||||||
"model": request.model,
|
|
||||||
"messages": messages_json,
|
|
||||||
"stream": stream,
|
|
||||||
});
|
|
||||||
|
|
||||||
if stream {
|
|
||||||
body["stream_options"] = serde_json::json!({ "include_usage": true });
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(temp) = request.temperature {
|
|
||||||
body["temperature"] = serde_json::json!(temp);
|
|
||||||
}
|
|
||||||
if let Some(max_tokens) = request.max_tokens {
|
|
||||||
body["max_tokens"] = serde_json::json!(max_tokens);
|
|
||||||
}
|
|
||||||
if let Some(tools) = &request.tools {
|
|
||||||
body["tools"] = serde_json::json!(tools);
|
|
||||||
}
|
|
||||||
if let Some(tool_choice) = &request.tool_choice {
|
|
||||||
body["tool_choice"] = serde_json::json!(tool_choice);
|
|
||||||
}
|
|
||||||
|
|
||||||
body
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parse an OpenAI-compatible chat completion response JSON into a ProviderResponse.
|
|
||||||
/// Extracts tool_calls from the message when present.
|
|
||||||
/// Extracts cache token counts from:
|
|
||||||
/// - OpenAI/Grok: `usage.prompt_tokens_details.cached_tokens`
|
|
||||||
/// - DeepSeek: `usage.prompt_cache_hit_tokens` / `usage.prompt_cache_miss_tokens`
|
|
||||||
pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<ProviderResponse, AppError> {
|
|
||||||
let choice = resp_json["choices"]
|
|
||||||
.get(0)
|
|
||||||
.ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
|
|
||||||
let message = &choice["message"];
|
|
||||||
|
|
||||||
let content = message["content"].as_str().unwrap_or_default().to_string();
|
|
||||||
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
|
|
||||||
|
|
||||||
// Parse tool_calls from the response message
|
|
||||||
let tool_calls: Option<Vec<ToolCall>> = message
|
|
||||||
.get("tool_calls")
|
|
||||||
.and_then(|tc| serde_json::from_value(tc.clone()).ok());
|
|
||||||
|
|
||||||
let usage = &resp_json["usage"];
|
|
||||||
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
|
||||||
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
|
||||||
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
|
|
||||||
|
|
||||||
// Extract cache tokens — try OpenAI/Grok format first, then DeepSeek format
|
|
||||||
let cache_read_tokens = usage["prompt_tokens_details"]["cached_tokens"]
|
|
||||||
.as_u64()
|
|
||||||
// DeepSeek uses a different field name
|
|
||||||
.or_else(|| usage["prompt_cache_hit_tokens"].as_u64())
|
|
||||||
.unwrap_or(0) as u32;
|
|
||||||
|
|
||||||
// DeepSeek reports cache_write as prompt_cache_miss_tokens (tokens written to cache for future use).
|
|
||||||
// OpenAI doesn't report cache_write in this location, but may in the future.
|
|
||||||
let cache_write_tokens = usage["prompt_cache_miss_tokens"].as_u64().unwrap_or(0) as u32;
|
|
||||||
|
|
||||||
Ok(ProviderResponse {
|
|
||||||
content,
|
|
||||||
reasoning_content,
|
|
||||||
tool_calls,
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
total_tokens,
|
|
||||||
cache_read_tokens,
|
|
||||||
cache_write_tokens,
|
|
||||||
model,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Parse a single OpenAI-compatible stream chunk into a ProviderStreamChunk.
|
|
||||||
/// Returns None if the chunk should be skipped (e.g. promptFeedback).
|
|
||||||
pub fn parse_openai_stream_chunk(
|
|
||||||
chunk: &Value,
|
|
||||||
model: &str,
|
|
||||||
reasoning_field: Option<&'static str>,
|
|
||||||
) -> Option<Result<ProviderStreamChunk, AppError>> {
|
|
||||||
// Parse usage from the final chunk (sent when stream_options.include_usage is true).
|
|
||||||
// This chunk may have an empty `choices` array.
|
|
||||||
let stream_usage = chunk.get("usage").and_then(|u| {
|
|
||||||
if u.is_null() {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
let prompt_tokens = u["prompt_tokens"].as_u64().unwrap_or(0) as u32;
|
|
||||||
let completion_tokens = u["completion_tokens"].as_u64().unwrap_or(0) as u32;
|
|
||||||
let total_tokens = u["total_tokens"].as_u64().unwrap_or(0) as u32;
|
|
||||||
|
|
||||||
let cache_read_tokens = u["prompt_tokens_details"]["cached_tokens"]
|
|
||||||
.as_u64()
|
|
||||||
.or_else(|| u["prompt_cache_hit_tokens"].as_u64())
|
|
||||||
.unwrap_or(0) as u32;
|
|
||||||
|
|
||||||
let cache_write_tokens = u["prompt_cache_miss_tokens"]
|
|
||||||
.as_u64()
|
|
||||||
.unwrap_or(0) as u32;
|
|
||||||
|
|
||||||
Some(StreamUsage {
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
total_tokens,
|
|
||||||
cache_read_tokens,
|
|
||||||
cache_write_tokens,
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(choice) = chunk["choices"].get(0) {
|
|
||||||
let delta = &choice["delta"];
|
|
||||||
let content = delta["content"].as_str().unwrap_or_default().to_string();
|
|
||||||
let reasoning_content = delta["reasoning_content"]
|
|
||||||
.as_str()
|
|
||||||
.or_else(|| reasoning_field.and_then(|f| delta[f].as_str()))
|
|
||||||
.map(|s| s.to_string());
|
|
||||||
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
|
|
||||||
|
|
||||||
// Parse tool_calls deltas from the stream chunk
|
|
||||||
let tool_calls: Option<Vec<ToolCallDelta>> = delta
|
|
||||||
.get("tool_calls")
|
|
||||||
.and_then(|tc| serde_json::from_value(tc.clone()).ok());
|
|
||||||
|
|
||||||
Some(Ok(ProviderStreamChunk {
|
|
||||||
content,
|
|
||||||
reasoning_content,
|
|
||||||
finish_reason,
|
|
||||||
tool_calls,
|
|
||||||
model: model.to_string(),
|
|
||||||
usage: stream_usage,
|
|
||||||
}))
|
|
||||||
} else if stream_usage.is_some() {
|
|
||||||
// Final usage-only chunk (empty choices array) — yield it so
|
|
||||||
// AggregatingStream can capture the real token counts.
|
|
||||||
Some(Ok(ProviderStreamChunk {
|
|
||||||
content: String::new(),
|
|
||||||
reasoning_content: None,
|
|
||||||
finish_reason: None,
|
|
||||||
tool_calls: None,
|
|
||||||
model: model.to_string(),
|
|
||||||
usage: stream_usage,
|
|
||||||
}))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create an SSE stream that parses OpenAI-compatible streaming chunks.
|
|
||||||
///
|
|
||||||
/// The optional `reasoning_field` allows overriding the field name for
|
|
||||||
/// reasoning content (e.g., "thought" for Ollama).
|
|
||||||
/// Parses tool_calls deltas from streaming chunks when present.
|
|
||||||
/// When `stream_options.include_usage: true` was sent, the provider sends a
|
|
||||||
/// final chunk with `usage` data — this is parsed into `StreamUsage` and
|
|
||||||
/// attached to the yielded `ProviderStreamChunk`.
|
|
||||||
pub fn create_openai_stream(
|
|
||||||
es: reqwest_eventsource::EventSource,
|
|
||||||
model: String,
|
|
||||||
reasoning_field: Option<&'static str>,
|
|
||||||
) -> BoxStream<'static, Result<ProviderStreamChunk, AppError>> {
|
|
||||||
use reqwest_eventsource::Event;
|
|
||||||
|
|
||||||
let stream = async_stream::try_stream! {
|
|
||||||
let mut es = es;
|
|
||||||
while let Some(event) = es.next().await {
|
|
||||||
match event {
|
|
||||||
Ok(Event::Message(msg)) => {
|
|
||||||
if msg.data == "[DONE]" {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let chunk: Value = serde_json::from_str(&msg.data)
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
|
||||||
|
|
||||||
if let Some(p_chunk) = parse_openai_stream_chunk(&chunk, &model, reasoning_field) {
|
|
||||||
yield p_chunk?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(_) => continue,
|
|
||||||
Err(e) => {
|
|
||||||
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Box::pin(stream)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Calculate cost using the model registry first, then falling back to provider pricing config.
|
|
||||||
///
|
|
||||||
/// When the registry provides `cache_read` / `cache_write` rates, the formula is:
|
|
||||||
/// (prompt_tokens - cache_read_tokens) * input_rate
|
|
||||||
/// + cache_read_tokens * cache_read_rate
|
|
||||||
/// + cache_write_tokens * cache_write_rate (if applicable)
|
|
||||||
/// + completion_tokens * output_rate
|
|
||||||
///
|
|
||||||
/// All rates are per-token (the registry stores per-million-token rates).
|
|
||||||
pub fn calculate_cost_with_registry(
|
|
||||||
model: &str,
|
|
||||||
prompt_tokens: u32,
|
|
||||||
completion_tokens: u32,
|
|
||||||
cache_read_tokens: u32,
|
|
||||||
cache_write_tokens: u32,
|
|
||||||
registry: &crate::models::registry::ModelRegistry,
|
|
||||||
pricing: &[crate::config::ModelPricing],
|
|
||||||
default_prompt_rate: f64,
|
|
||||||
default_completion_rate: f64,
|
|
||||||
) -> f64 {
|
|
||||||
if let Some(metadata) = registry.find_model(model)
|
|
||||||
&& let Some(cost) = &metadata.cost
|
|
||||||
{
|
|
||||||
let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens);
|
|
||||||
let mut total = (non_cached_prompt as f64 * cost.input / 1_000_000.0)
|
|
||||||
+ (completion_tokens as f64 * cost.output / 1_000_000.0);
|
|
||||||
|
|
||||||
if let Some(cache_read_rate) = cost.cache_read {
|
|
||||||
total += cache_read_tokens as f64 * cache_read_rate / 1_000_000.0;
|
|
||||||
} else {
|
|
||||||
// No cache_read rate — charge cached tokens at full input rate
|
|
||||||
total += cache_read_tokens as f64 * cost.input / 1_000_000.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(cache_write_rate) = cost.cache_write {
|
|
||||||
total += cache_write_tokens as f64 * cache_write_rate / 1_000_000.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return total;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback: no registry entry — use provider pricing config (no cache awareness)
|
|
||||||
let (prompt_rate, completion_rate) = pricing
|
|
||||||
.iter()
|
|
||||||
.find(|p| model.contains(&p.model))
|
|
||||||
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
|
|
||||||
.unwrap_or((default_prompt_rate, default_completion_rate));
|
|
||||||
|
|
||||||
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
|
|
||||||
}
|
|
||||||
@@ -1,363 +0,0 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use futures::stream::BoxStream;
|
|
||||||
use sqlx::Row;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use crate::errors::AppError;
|
|
||||||
use crate::models::UnifiedRequest;
|
|
||||||
|
|
||||||
|
|
||||||
pub mod deepseek;
|
|
||||||
pub mod gemini;
|
|
||||||
pub mod grok;
|
|
||||||
pub mod helpers;
|
|
||||||
pub mod ollama;
|
|
||||||
pub mod openai;
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
pub trait Provider: Send + Sync {
|
|
||||||
/// Get provider name (e.g., "openai", "gemini")
|
|
||||||
fn name(&self) -> &str;
|
|
||||||
|
|
||||||
/// Check if provider supports a specific model
|
|
||||||
fn supports_model(&self, model: &str) -> bool;
|
|
||||||
|
|
||||||
/// Check if provider supports multimodal (images, etc.)
|
|
||||||
fn supports_multimodal(&self) -> bool;
|
|
||||||
|
|
||||||
/// Process a chat completion request
|
|
||||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError>;
|
|
||||||
|
|
||||||
/// Process a chat request using provider-specific "responses" style endpoint
|
|
||||||
/// Default implementation falls back to `chat_completion` for providers
|
|
||||||
/// that do not implement a dedicated responses endpoint.
|
|
||||||
async fn chat_responses(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
|
||||||
self.chat_completion(request).await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Process a streaming chat completion request
|
|
||||||
async fn chat_completion_stream(
|
|
||||||
&self,
|
|
||||||
request: UnifiedRequest,
|
|
||||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError>;
|
|
||||||
|
|
||||||
/// Process a streaming chat request using provider-specific "responses" style endpoint
|
|
||||||
/// Default implementation falls back to `chat_completion_stream` for providers
|
|
||||||
/// that do not implement a dedicated responses endpoint.
|
|
||||||
async fn chat_responses_stream(
|
|
||||||
&self,
|
|
||||||
request: UnifiedRequest,
|
|
||||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
|
||||||
self.chat_completion_stream(request).await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Estimate token count for a request (for cost calculation)
|
|
||||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32>;
|
|
||||||
|
|
||||||
/// Calculate cost based on token usage and model using the registry.
|
|
||||||
/// `cache_read_tokens` / `cache_write_tokens` allow cache-aware pricing
|
|
||||||
/// when the registry provides `cache_read` / `cache_write` rates.
|
|
||||||
fn calculate_cost(
|
|
||||||
&self,
|
|
||||||
model: &str,
|
|
||||||
prompt_tokens: u32,
|
|
||||||
completion_tokens: u32,
|
|
||||||
cache_read_tokens: u32,
|
|
||||||
cache_write_tokens: u32,
|
|
||||||
registry: &crate::models::registry::ModelRegistry,
|
|
||||||
) -> f64;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct ProviderResponse {
|
|
||||||
pub content: String,
|
|
||||||
pub reasoning_content: Option<String>,
|
|
||||||
pub tool_calls: Option<Vec<crate::models::ToolCall>>,
|
|
||||||
pub prompt_tokens: u32,
|
|
||||||
pub completion_tokens: u32,
|
|
||||||
pub total_tokens: u32,
|
|
||||||
pub cache_read_tokens: u32,
|
|
||||||
pub cache_write_tokens: u32,
|
|
||||||
pub model: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Usage data from the final streaming chunk (when providers report real token counts).
|
|
||||||
#[derive(Debug, Clone, Default)]
|
|
||||||
pub struct StreamUsage {
|
|
||||||
pub prompt_tokens: u32,
|
|
||||||
pub completion_tokens: u32,
|
|
||||||
pub total_tokens: u32,
|
|
||||||
pub cache_read_tokens: u32,
|
|
||||||
pub cache_write_tokens: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct ProviderStreamChunk {
|
|
||||||
pub content: String,
|
|
||||||
pub reasoning_content: Option<String>,
|
|
||||||
pub finish_reason: Option<String>,
|
|
||||||
pub tool_calls: Option<Vec<crate::models::ToolCallDelta>>,
|
|
||||||
pub model: String,
|
|
||||||
/// Populated only on the final chunk when providers report usage (e.g. stream_options.include_usage).
|
|
||||||
pub usage: Option<StreamUsage>,
|
|
||||||
}
|
|
||||||
|
|
||||||
use tokio::sync::RwLock;
|
|
||||||
|
|
||||||
use crate::config::AppConfig;
|
|
||||||
use crate::providers::{
|
|
||||||
deepseek::DeepSeekProvider, gemini::GeminiProvider, grok::GrokProvider, ollama::OllamaProvider,
|
|
||||||
openai::OpenAIProvider,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct ProviderManager {
|
|
||||||
providers: Arc<RwLock<Vec<Arc<dyn Provider>>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for ProviderManager {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderManager {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
providers: Arc::new(RwLock::new(Vec::new())),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Initialize a provider by name using config and database overrides
|
|
||||||
pub async fn initialize_provider(
|
|
||||||
&self,
|
|
||||||
name: &str,
|
|
||||||
app_config: &AppConfig,
|
|
||||||
db_pool: &crate::database::DbPool,
|
|
||||||
) -> Result<()> {
|
|
||||||
// Load override from database
|
|
||||||
let db_config = sqlx::query("SELECT enabled, base_url, api_key, api_key_encrypted FROM provider_configs WHERE id = ?")
|
|
||||||
.bind(name)
|
|
||||||
.fetch_optional(db_pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let (enabled, base_url, api_key) = if let Some(row) = db_config {
|
|
||||||
let enabled = row.get::<bool, _>("enabled");
|
|
||||||
let base_url = row.get::<Option<String>, _>("base_url");
|
|
||||||
let api_key_encrypted = row.get::<bool, _>("api_key_encrypted");
|
|
||||||
let api_key = row.get::<Option<String>, _>("api_key");
|
|
||||||
// Decrypt API key if encrypted
|
|
||||||
let api_key = match (api_key, api_key_encrypted) {
|
|
||||||
(Some(key), true) => {
|
|
||||||
match crate::utils::crypto::decrypt(&key) {
|
|
||||||
Ok(decrypted) => Some(decrypted),
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("Failed to decrypt API key for provider {}: {}", name, e);
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(Some(key), false) => {
|
|
||||||
// Plaintext key - optionally encrypt and update database (lazy migration)
|
|
||||||
// For now, just use plaintext
|
|
||||||
Some(key)
|
|
||||||
}
|
|
||||||
(None, _) => None,
|
|
||||||
};
|
|
||||||
(enabled, base_url, api_key)
|
|
||||||
} else {
|
|
||||||
// No database override, use defaults from AppConfig
|
|
||||||
match name {
|
|
||||||
"openai" => (
|
|
||||||
app_config.providers.openai.enabled,
|
|
||||||
Some(app_config.providers.openai.base_url.clone()),
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
"gemini" => (
|
|
||||||
app_config.providers.gemini.enabled,
|
|
||||||
Some(app_config.providers.gemini.base_url.clone()),
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
"deepseek" => (
|
|
||||||
app_config.providers.deepseek.enabled,
|
|
||||||
Some(app_config.providers.deepseek.base_url.clone()),
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
"grok" => (
|
|
||||||
app_config.providers.grok.enabled,
|
|
||||||
Some(app_config.providers.grok.base_url.clone()),
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
"ollama" => (
|
|
||||||
app_config.providers.ollama.enabled,
|
|
||||||
Some(app_config.providers.ollama.base_url.clone()),
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
_ => (false, None, None),
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if !enabled {
|
|
||||||
self.remove_provider(name).await;
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create provider instance with merged config
|
|
||||||
let provider: Arc<dyn Provider> = match name {
|
|
||||||
"openai" => {
|
|
||||||
let mut cfg = app_config.providers.openai.clone();
|
|
||||||
if let Some(url) = base_url {
|
|
||||||
cfg.base_url = url;
|
|
||||||
}
|
|
||||||
// Handle API key override if present
|
|
||||||
let p = if let Some(key) = api_key {
|
|
||||||
// We need a way to create a provider with an explicit key
|
|
||||||
// Let's modify the providers to allow this
|
|
||||||
OpenAIProvider::new_with_key(&cfg, app_config, key)?
|
|
||||||
} else {
|
|
||||||
OpenAIProvider::new(&cfg, app_config)?
|
|
||||||
};
|
|
||||||
Arc::new(p)
|
|
||||||
}
|
|
||||||
"ollama" => {
|
|
||||||
let mut cfg = app_config.providers.ollama.clone();
|
|
||||||
if let Some(url) = base_url {
|
|
||||||
cfg.base_url = url;
|
|
||||||
}
|
|
||||||
Arc::new(OllamaProvider::new(&cfg, app_config)?)
|
|
||||||
}
|
|
||||||
"gemini" => {
|
|
||||||
let mut cfg = app_config.providers.gemini.clone();
|
|
||||||
if let Some(url) = base_url {
|
|
||||||
cfg.base_url = url;
|
|
||||||
}
|
|
||||||
let p = if let Some(key) = api_key {
|
|
||||||
GeminiProvider::new_with_key(&cfg, app_config, key)?
|
|
||||||
} else {
|
|
||||||
GeminiProvider::new(&cfg, app_config)?
|
|
||||||
};
|
|
||||||
Arc::new(p)
|
|
||||||
}
|
|
||||||
"deepseek" => {
|
|
||||||
let mut cfg = app_config.providers.deepseek.clone();
|
|
||||||
if let Some(url) = base_url {
|
|
||||||
cfg.base_url = url;
|
|
||||||
}
|
|
||||||
let p = if let Some(key) = api_key {
|
|
||||||
DeepSeekProvider::new_with_key(&cfg, app_config, key)?
|
|
||||||
} else {
|
|
||||||
DeepSeekProvider::new(&cfg, app_config)?
|
|
||||||
};
|
|
||||||
Arc::new(p)
|
|
||||||
}
|
|
||||||
"grok" => {
|
|
||||||
let mut cfg = app_config.providers.grok.clone();
|
|
||||||
if let Some(url) = base_url {
|
|
||||||
cfg.base_url = url;
|
|
||||||
}
|
|
||||||
let p = if let Some(key) = api_key {
|
|
||||||
GrokProvider::new_with_key(&cfg, app_config, key)?
|
|
||||||
} else {
|
|
||||||
GrokProvider::new(&cfg, app_config)?
|
|
||||||
};
|
|
||||||
Arc::new(p)
|
|
||||||
}
|
|
||||||
_ => return Err(anyhow::anyhow!("Unknown provider: {}", name)),
|
|
||||||
};
|
|
||||||
|
|
||||||
self.add_provider(provider).await;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn add_provider(&self, provider: Arc<dyn Provider>) {
|
|
||||||
let mut providers = self.providers.write().await;
|
|
||||||
// If provider with same name exists, replace it
|
|
||||||
if let Some(index) = providers.iter().position(|p| p.name() == provider.name()) {
|
|
||||||
providers[index] = provider;
|
|
||||||
} else {
|
|
||||||
providers.push(provider);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn remove_provider(&self, name: &str) {
|
|
||||||
let mut providers = self.providers.write().await;
|
|
||||||
providers.retain(|p| p.name() != name);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_provider_for_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
|
|
||||||
let providers = self.providers.read().await;
|
|
||||||
providers.iter().find(|p| p.supports_model(model)).map(Arc::clone)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
|
|
||||||
let providers = self.providers.read().await;
|
|
||||||
providers.iter().find(|p| p.name() == name).map(Arc::clone)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_all_providers(&self) -> Vec<Arc<dyn Provider>> {
|
|
||||||
let providers = self.providers.read().await;
|
|
||||||
providers.clone()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create placeholder provider implementations
|
|
||||||
pub mod placeholder {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
pub struct PlaceholderProvider {
|
|
||||||
name: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl PlaceholderProvider {
|
|
||||||
pub fn new(name: &str) -> Self {
|
|
||||||
Self { name: name.to_string() }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Provider for PlaceholderProvider {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
&self.name
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supports_model(&self, _model: &str) -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supports_multimodal(&self) -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn chat_completion_stream(
|
|
||||||
&self,
|
|
||||||
_request: UnifiedRequest,
|
|
||||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
|
||||||
Err(AppError::ProviderError(
|
|
||||||
"Streaming not supported for placeholder provider".to_string(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn chat_completion(&self, _request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
|
||||||
Err(AppError::ProviderError(format!(
|
|
||||||
"Provider {} not implemented",
|
|
||||||
self.name
|
|
||||||
)))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result<u32> {
|
|
||||||
Ok(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calculate_cost(
|
|
||||||
&self,
|
|
||||||
_model: &str,
|
|
||||||
_prompt_tokens: u32,
|
|
||||||
_completion_tokens: u32,
|
|
||||||
_cache_read_tokens: u32,
|
|
||||||
_cache_write_tokens: u32,
|
|
||||||
_registry: &crate::models::registry::ModelRegistry,
|
|
||||||
) -> f64 {
|
|
||||||
0.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,140 +0,0 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use futures::stream::BoxStream;
|
|
||||||
|
|
||||||
use super::helpers;
|
|
||||||
use super::{ProviderResponse, ProviderStreamChunk};
|
|
||||||
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
|
||||||
|
|
||||||
pub struct OllamaProvider {
|
|
||||||
client: reqwest::Client,
|
|
||||||
config: crate::config::OllamaConfig,
|
|
||||||
pricing: Vec<crate::config::ModelPricing>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OllamaProvider {
|
|
||||||
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
|
|
||||||
let client = reqwest::Client::builder()
|
|
||||||
.connect_timeout(std::time::Duration::from_secs(5))
|
|
||||||
.timeout(std::time::Duration::from_secs(300))
|
|
||||||
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
|
||||||
.pool_max_idle_per_host(4)
|
|
||||||
.tcp_keepalive(std::time::Duration::from_secs(30))
|
|
||||||
.build()?;
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
client,
|
|
||||||
config: config.clone(),
|
|
||||||
pricing: app_config.pricing.ollama.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl super::Provider for OllamaProvider {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"ollama"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supports_model(&self, model: &str) -> bool {
|
|
||||||
self.config.models.iter().any(|m| m == model) || model.starts_with("ollama/")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supports_multimodal(&self) -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn chat_completion(&self, mut request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
|
||||||
// Strip "ollama/" prefix if present for the API call
|
|
||||||
let api_model = request
|
|
||||||
.model
|
|
||||||
.strip_prefix("ollama/")
|
|
||||||
.unwrap_or(&request.model)
|
|
||||||
.to_string();
|
|
||||||
let original_model = request.model.clone();
|
|
||||||
request.model = api_model;
|
|
||||||
|
|
||||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
|
||||||
let body = helpers::build_openai_body(&request, messages_json, false);
|
|
||||||
|
|
||||||
let response = self
|
|
||||||
.client
|
|
||||||
.post(format!("{}/chat/completions", self.config.base_url))
|
|
||||||
.json(&body)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
||||||
|
|
||||||
if !response.status().is_success() {
|
|
||||||
let error_text = response.text().await.unwrap_or_default();
|
|
||||||
return Err(AppError::ProviderError(format!("Ollama API error: {}", error_text)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let resp_json: serde_json::Value = response
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
||||||
|
|
||||||
// Ollama also supports "thought" as an alias for reasoning_content
|
|
||||||
let mut result = helpers::parse_openai_response(&resp_json, original_model)?;
|
|
||||||
if result.reasoning_content.is_none() {
|
|
||||||
result.reasoning_content = resp_json["choices"]
|
|
||||||
.get(0)
|
|
||||||
.and_then(|c| c["message"]["thought"].as_str())
|
|
||||||
.map(|s| s.to_string());
|
|
||||||
}
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
|
||||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calculate_cost(
|
|
||||||
&self,
|
|
||||||
model: &str,
|
|
||||||
prompt_tokens: u32,
|
|
||||||
completion_tokens: u32,
|
|
||||||
cache_read_tokens: u32,
|
|
||||||
cache_write_tokens: u32,
|
|
||||||
registry: &crate::models::registry::ModelRegistry,
|
|
||||||
) -> f64 {
|
|
||||||
helpers::calculate_cost_with_registry(
|
|
||||||
model,
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
cache_read_tokens,
|
|
||||||
cache_write_tokens,
|
|
||||||
registry,
|
|
||||||
&self.pricing,
|
|
||||||
0.0,
|
|
||||||
0.0,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn chat_completion_stream(
|
|
||||||
&self,
|
|
||||||
mut request: UnifiedRequest,
|
|
||||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
|
||||||
let api_model = request
|
|
||||||
.model
|
|
||||||
.strip_prefix("ollama/")
|
|
||||||
.unwrap_or(&request.model)
|
|
||||||
.to_string();
|
|
||||||
let original_model = request.model.clone();
|
|
||||||
request.model = api_model;
|
|
||||||
|
|
||||||
let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?;
|
|
||||||
let body = helpers::build_openai_body(&request, messages_json, true);
|
|
||||||
|
|
||||||
let es = reqwest_eventsource::EventSource::new(
|
|
||||||
self.client
|
|
||||||
.post(format!("{}/chat/completions", self.config.base_url))
|
|
||||||
.json(&body),
|
|
||||||
)
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
|
||||||
|
|
||||||
// Ollama uses "thought" as an alternative field for reasoning content
|
|
||||||
Ok(helpers::create_openai_stream(es, original_model, Some("thought")))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,509 +0,0 @@
|
|||||||
use anyhow::Result;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use futures::stream::BoxStream;
|
|
||||||
use futures::StreamExt;
|
|
||||||
|
|
||||||
use super::helpers;
|
|
||||||
use super::{ProviderResponse, ProviderStreamChunk};
|
|
||||||
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
|
|
||||||
|
|
||||||
pub struct OpenAIProvider {
|
|
||||||
client: reqwest::Client,
|
|
||||||
config: crate::config::OpenAIConfig,
|
|
||||||
api_key: String,
|
|
||||||
pricing: Vec<crate::config::ModelPricing>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl OpenAIProvider {
|
|
||||||
pub fn new(config: &crate::config::OpenAIConfig, app_config: &AppConfig) -> Result<Self> {
|
|
||||||
let api_key = app_config.get_api_key("openai")?;
|
|
||||||
Self::new_with_key(config, app_config, api_key)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
|
|
||||||
let client = reqwest::Client::builder()
|
|
||||||
.connect_timeout(std::time::Duration::from_secs(5))
|
|
||||||
.timeout(std::time::Duration::from_secs(300))
|
|
||||||
.pool_idle_timeout(std::time::Duration::from_secs(90))
|
|
||||||
.pool_max_idle_per_host(4)
|
|
||||||
.tcp_keepalive(std::time::Duration::from_secs(30))
|
|
||||||
.build()?;
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
client,
|
|
||||||
config: config.clone(),
|
|
||||||
api_key,
|
|
||||||
pricing: app_config.pricing.openai.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl super::Provider for OpenAIProvider {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"openai"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supports_model(&self, model: &str) -> bool {
|
|
||||||
model.starts_with("gpt-") ||
|
|
||||||
model.starts_with("o1-") ||
|
|
||||||
model.starts_with("o2-") ||
|
|
||||||
model.starts_with("o3-") ||
|
|
||||||
model.starts_with("o4-") ||
|
|
||||||
model.starts_with("o5-") ||
|
|
||||||
model.contains("gpt-5")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn supports_multimodal(&self) -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
|
||||||
// Allow proactive routing to Responses API based on heuristic
|
|
||||||
let model_lc = request.model.to_lowercase();
|
|
||||||
if model_lc.contains("gpt-5") || model_lc.contains("codex") {
|
|
||||||
return self.chat_responses(request).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
|
||||||
let mut body = helpers::build_openai_body(&request, messages_json, false);
|
|
||||||
|
|
||||||
// Transition: Newer OpenAI models (o1, o3, gpt-5) require max_completion_tokens
|
|
||||||
// instead of the legacy max_tokens parameter.
|
|
||||||
if request.model.starts_with("o1-") || request.model.starts_with("o3-") || request.model.contains("gpt-5") {
|
|
||||||
if let Some(max_tokens) = body.as_object_mut().and_then(|obj| obj.remove("max_tokens")) {
|
|
||||||
body["max_completion_tokens"] = max_tokens;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = self
|
|
||||||
.client
|
|
||||||
.post(format!("{}/chat/completions", self.config.base_url))
|
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
||||||
.json(&body)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
||||||
|
|
||||||
if !response.status().is_success() {
|
|
||||||
let status = response.status();
|
|
||||||
let error_text = response.text().await.unwrap_or_default();
|
|
||||||
|
|
||||||
// Read error body to diagnose. If the model requires the Responses
|
|
||||||
// API (v1/responses), retry against that endpoint.
|
|
||||||
if error_text.to_lowercase().contains("v1/responses") || error_text.to_lowercase().contains("only supported in v1/responses") {
|
|
||||||
// Build a simple `input` string by concatenating message parts.
|
|
||||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
|
||||||
let mut inputs: Vec<String> = Vec::new();
|
|
||||||
for m in &messages_json {
|
|
||||||
let role = m["role"].as_str().unwrap_or("");
|
|
||||||
let parts = m.get("content").and_then(|c| c.as_array()).cloned().unwrap_or_default();
|
|
||||||
let mut text_parts = Vec::new();
|
|
||||||
for p in parts {
|
|
||||||
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
|
|
||||||
text_parts.push(t.to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
inputs.push(format!("{}: {}", role, text_parts.join("")));
|
|
||||||
}
|
|
||||||
let input_text = inputs.join("\n");
|
|
||||||
|
|
||||||
let resp = self
|
|
||||||
.client
|
|
||||||
.post(format!("{}/responses", self.config.base_url))
|
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
||||||
.json(&serde_json::json!({ "model": request.model, "input": input_text }))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
||||||
|
|
||||||
if !resp.status().is_success() {
|
|
||||||
let err = resp.text().await.unwrap_or_default();
|
|
||||||
return Err(AppError::ProviderError(format!("OpenAI Responses API error: {}", err)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let resp_json: serde_json::Value = resp.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
||||||
// Try to normalize: if it's chat-style, use existing parser
|
|
||||||
if resp_json.get("choices").is_some() {
|
|
||||||
return helpers::parse_openai_response(&resp_json, request.model);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Responses API: try to extract text from `output` or `candidates`
|
|
||||||
let mut content_text = String::new();
|
|
||||||
if let Some(output) = resp_json.get("output").and_then(|o| o.as_array()) {
|
|
||||||
if let Some(first) = output.get(0) {
|
|
||||||
if let Some(contents) = first.get("content").and_then(|c| c.as_array()) {
|
|
||||||
for item in contents {
|
|
||||||
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
|
|
||||||
if !content_text.is_empty() { content_text.push_str("\n"); }
|
|
||||||
content_text.push_str(text);
|
|
||||||
} else if let Some(parts) = item.get("parts").and_then(|p| p.as_array()) {
|
|
||||||
for p in parts {
|
|
||||||
if let Some(t) = p.as_str() {
|
|
||||||
if !content_text.is_empty() { content_text.push_str("\n"); }
|
|
||||||
content_text.push_str(t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if content_text.is_empty() {
|
|
||||||
if let Some(cands) = resp_json.get("candidates").and_then(|c| c.as_array()) {
|
|
||||||
if let Some(c0) = cands.get(0) {
|
|
||||||
if let Some(content) = c0.get("content") {
|
|
||||||
if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) {
|
|
||||||
for p in parts {
|
|
||||||
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
|
|
||||||
if !content_text.is_empty() { content_text.push_str("\n"); }
|
|
||||||
content_text.push_str(t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let prompt_tokens = resp_json.get("usage").and_then(|u| u.get("prompt_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
|
||||||
let completion_tokens = resp_json.get("usage").and_then(|u| u.get("completion_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
|
||||||
let total_tokens = resp_json.get("usage").and_then(|u| u.get("total_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
|
||||||
|
|
||||||
return Ok(ProviderResponse {
|
|
||||||
content: content_text,
|
|
||||||
reasoning_content: None,
|
|
||||||
tool_calls: None,
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
total_tokens,
|
|
||||||
cache_read_tokens: 0,
|
|
||||||
cache_write_tokens: 0,
|
|
||||||
model: request.model,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
tracing::error!("OpenAI API error ({}): {}", status, error_text);
|
|
||||||
return Err(AppError::ProviderError(format!("OpenAI API error ({}): {}", status, error_text)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let resp_json: serde_json::Value = response
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
||||||
|
|
||||||
helpers::parse_openai_response(&resp_json, request.model)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn chat_responses(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
|
|
||||||
// Build a simple `input` string by concatenating message parts.
|
|
||||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
|
||||||
let mut inputs: Vec<String> = Vec::new();
|
|
||||||
for m in &messages_json {
|
|
||||||
let role = m["role"].as_str().unwrap_or("");
|
|
||||||
let parts = m.get("content").and_then(|c| c.as_array()).cloned().unwrap_or_default();
|
|
||||||
let mut text_parts = Vec::new();
|
|
||||||
for p in parts {
|
|
||||||
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
|
|
||||||
text_parts.push(t.to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
inputs.push(format!("{}: {}", role, text_parts.join("")));
|
|
||||||
}
|
|
||||||
let input_text = inputs.join("\n");
|
|
||||||
|
|
||||||
let resp = self
|
|
||||||
.client
|
|
||||||
.post(format!("{}/responses", self.config.base_url))
|
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
||||||
.json(&serde_json::json!({ "model": request.model, "input": input_text }))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
||||||
|
|
||||||
if !resp.status().is_success() {
|
|
||||||
let err = resp.text().await.unwrap_or_default();
|
|
||||||
return Err(AppError::ProviderError(format!("OpenAI Responses API error: {}", err)));
|
|
||||||
}
|
|
||||||
|
|
||||||
let resp_json: serde_json::Value = resp.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
|
|
||||||
|
|
||||||
// Normalize Responses API output into ProviderResponse
|
|
||||||
let mut content_text = String::new();
|
|
||||||
if let Some(output) = resp_json.get("output").and_then(|o| o.as_array()) {
|
|
||||||
if let Some(first) = output.get(0) {
|
|
||||||
if let Some(contents) = first.get("content").and_then(|c| c.as_array()) {
|
|
||||||
for item in contents {
|
|
||||||
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
|
|
||||||
if !content_text.is_empty() { content_text.push_str("\n"); }
|
|
||||||
content_text.push_str(text);
|
|
||||||
} else if let Some(parts) = item.get("parts").and_then(|p| p.as_array()) {
|
|
||||||
for p in parts {
|
|
||||||
if let Some(t) = p.as_str() {
|
|
||||||
if !content_text.is_empty() { content_text.push_str("\n"); }
|
|
||||||
content_text.push_str(t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if content_text.is_empty() {
|
|
||||||
if let Some(cands) = resp_json.get("candidates").and_then(|c| c.as_array()) {
|
|
||||||
if let Some(c0) = cands.get(0) {
|
|
||||||
if let Some(content) = c0.get("content") {
|
|
||||||
if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) {
|
|
||||||
for p in parts {
|
|
||||||
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
|
|
||||||
if !content_text.is_empty() { content_text.push_str("\n"); }
|
|
||||||
content_text.push_str(t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let prompt_tokens = resp_json.get("usage").and_then(|u| u.get("prompt_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
|
||||||
let completion_tokens = resp_json.get("usage").and_then(|u| u.get("completion_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
|
||||||
let total_tokens = resp_json.get("usage").and_then(|u| u.get("total_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
|
|
||||||
|
|
||||||
Ok(ProviderResponse {
|
|
||||||
content: content_text,
|
|
||||||
reasoning_content: None,
|
|
||||||
tool_calls: None,
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
total_tokens,
|
|
||||||
cache_read_tokens: 0,
|
|
||||||
cache_write_tokens: 0,
|
|
||||||
model: request.model,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
|
|
||||||
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calculate_cost(
|
|
||||||
&self,
|
|
||||||
model: &str,
|
|
||||||
prompt_tokens: u32,
|
|
||||||
completion_tokens: u32,
|
|
||||||
cache_read_tokens: u32,
|
|
||||||
cache_write_tokens: u32,
|
|
||||||
registry: &crate::models::registry::ModelRegistry,
|
|
||||||
) -> f64 {
|
|
||||||
helpers::calculate_cost_with_registry(
|
|
||||||
model,
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
cache_read_tokens,
|
|
||||||
cache_write_tokens,
|
|
||||||
registry,
|
|
||||||
&self.pricing,
|
|
||||||
0.15,
|
|
||||||
0.60,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn chat_completion_stream(
|
|
||||||
&self,
|
|
||||||
request: UnifiedRequest,
|
|
||||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
|
||||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
|
||||||
let mut body = helpers::build_openai_body(&request, messages_json, true);
|
|
||||||
|
|
||||||
// Standard OpenAI cleanup
|
|
||||||
if let Some(obj) = body.as_object_mut() {
|
|
||||||
obj.remove("stream_options");
|
|
||||||
|
|
||||||
// Transition: Newer OpenAI models (o1, o3, gpt-5) require max_completion_tokens
|
|
||||||
if request.model.starts_with("o1-") || request.model.starts_with("o3-") || request.model.contains("gpt-5") {
|
|
||||||
if let Some(max_tokens) = obj.remove("max_tokens") {
|
|
||||||
obj.insert("max_completion_tokens".to_string(), max_tokens);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let url = format!("{}/chat/completions", self.config.base_url);
|
|
||||||
let api_key = self.api_key.clone();
|
|
||||||
let probe_client = self.client.clone();
|
|
||||||
let probe_body = body.clone();
|
|
||||||
let model = request.model.clone();
|
|
||||||
|
|
||||||
let es = reqwest_eventsource::EventSource::new(
|
|
||||||
self.client
|
|
||||||
.post(&url)
|
|
||||||
.header("Authorization", format!("Bearer {}", self.api_key))
|
|
||||||
.json(&body),
|
|
||||||
)
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
|
|
||||||
|
|
||||||
let stream = async_stream::try_stream! {
|
|
||||||
let mut es = es;
|
|
||||||
while let Some(event) = es.next().await {
|
|
||||||
match event {
|
|
||||||
Ok(reqwest_eventsource::Event::Message(msg)) => {
|
|
||||||
if msg.data == "[DONE]" {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let chunk: serde_json::Value = serde_json::from_str(&msg.data)
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
|
|
||||||
|
|
||||||
if let Some(p_chunk) = helpers::parse_openai_stream_chunk(&chunk, &model, None) {
|
|
||||||
yield p_chunk?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(_) => continue,
|
|
||||||
Err(e) => {
|
|
||||||
// Attempt to probe for the actual error body
|
|
||||||
let probe_resp = probe_client
|
|
||||||
.post(&url)
|
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
|
||||||
.json(&probe_body)
|
|
||||||
.send()
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match probe_resp {
|
|
||||||
Ok(r) if !r.status().is_success() => {
|
|
||||||
let status = r.status();
|
|
||||||
let error_body = r.text().await.unwrap_or_default();
|
|
||||||
tracing::error!("OpenAI Stream Error Probe ({}): {}", status, error_body);
|
|
||||||
tracing::debug!("Offending OpenAI Request Body: {}", serde_json::to_string(&probe_body).unwrap_or_default());
|
|
||||||
Err(AppError::ProviderError(format!("OpenAI API error ({}): {}", status, error_body)))?;
|
|
||||||
}
|
|
||||||
Ok(_) => {
|
|
||||||
// Probe returned success? This is unexpected if the original stream failed.
|
|
||||||
Err(AppError::ProviderError(format!("Stream error (probe returned 200): {}", e)))?;
|
|
||||||
}
|
|
||||||
Err(probe_err) => {
|
|
||||||
// Probe itself failed
|
|
||||||
tracing::error!("OpenAI Stream Error Probe failed: {}", probe_err);
|
|
||||||
Err(AppError::ProviderError(format!("Stream error (probe failed: {}): {}", probe_err, e)))?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Box::pin(stream))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn chat_responses_stream(
|
|
||||||
&self,
|
|
||||||
request: UnifiedRequest,
|
|
||||||
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
|
||||||
// Build a simple `input` string by concatenating message parts.
|
|
||||||
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
|
|
||||||
let mut inputs: Vec<String> = Vec::new();
|
|
||||||
for m in &messages_json {
|
|
||||||
let role = m["role"].as_str().unwrap_or("");
|
|
||||||
let parts = m.get("content").and_then(|c| c.as_array()).cloned().unwrap_or_default();
|
|
||||||
let mut text_parts = Vec::new();
|
|
||||||
for p in parts {
|
|
||||||
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
|
|
||||||
text_parts.push(t.to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
inputs.push(format!("{}: {}", role, text_parts.join("")));
|
|
||||||
}
|
|
||||||
let input_text = inputs.join("\n");
|
|
||||||
|
|
||||||
let body = serde_json::json!({
|
|
||||||
"model": request.model,
|
|
||||||
"input": input_text,
|
|
||||||
"stream": true
|
|
||||||
});
|
|
||||||
|
|
||||||
let url = format!("{}/responses", self.config.base_url);
|
|
||||||
let api_key = self.api_key.clone();
|
|
||||||
let model = request.model.clone();
|
|
||||||
|
|
||||||
let es = reqwest_eventsource::EventSource::new(
|
|
||||||
self.client
|
|
||||||
.post(&url)
|
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
|
||||||
.json(&body),
|
|
||||||
)
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource for Responses API: {}", e)))?;
|
|
||||||
|
|
||||||
let stream = async_stream::try_stream! {
|
|
||||||
let mut es = es;
|
|
||||||
while let Some(event) = es.next().await {
|
|
||||||
match event {
|
|
||||||
Ok(reqwest_eventsource::Event::Message(msg)) => {
|
|
||||||
if msg.data == "[DONE]" {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
let chunk: serde_json::Value = serde_json::from_str(&msg.data)
|
|
||||||
.map_err(|e| AppError::ProviderError(format!("Failed to parse Responses stream chunk: {}", e)))?;
|
|
||||||
|
|
||||||
// Try standard OpenAI parsing first
|
|
||||||
if let Some(p_chunk) = helpers::parse_openai_stream_chunk(&chunk, &model, None) {
|
|
||||||
yield p_chunk?;
|
|
||||||
} else {
|
|
||||||
// Responses API specific parsing for streaming
|
|
||||||
// Often it follows a similar structure to the non-streaming response but in chunks
|
|
||||||
let mut content = String::new();
|
|
||||||
|
|
||||||
// Check for output[0].content[0].text (similar to non-stream)
|
|
||||||
if let Some(output) = chunk.get("output").and_then(|o| o.as_array()) {
|
|
||||||
if let Some(first) = output.get(0) {
|
|
||||||
if let Some(contents) = first.get("content").and_then(|c| c.as_array()) {
|
|
||||||
for item in contents {
|
|
||||||
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
|
|
||||||
content.push_str(text);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for candidates[0].content.parts[0].text (Gemini-like, which OpenAI sometimes uses for v1/responses)
|
|
||||||
if content.is_empty() {
|
|
||||||
if let Some(cands) = chunk.get("candidates").and_then(|c| c.as_array()) {
|
|
||||||
if let Some(c0) = cands.get(0) {
|
|
||||||
if let Some(content_obj) = c0.get("content") {
|
|
||||||
if let Some(parts) = content_obj.get("parts").and_then(|p| p.as_array()) {
|
|
||||||
for p in parts {
|
|
||||||
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
|
|
||||||
content.push_str(t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !content.is_empty() {
|
|
||||||
yield ProviderStreamChunk {
|
|
||||||
content,
|
|
||||||
reasoning_content: None,
|
|
||||||
finish_reason: None,
|
|
||||||
tool_calls: None,
|
|
||||||
model: model.clone(),
|
|
||||||
usage: None,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(_) => continue,
|
|
||||||
Err(e) => {
|
|
||||||
Err(AppError::ProviderError(format!("Responses stream error: {}", e)))?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Box::pin(stream))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,351 +0,0 @@
|
|||||||
//! Rate limiting and circuit breaking for LLM proxy
|
|
||||||
//!
|
|
||||||
//! This module provides:
|
|
||||||
//! 1. Per-client rate limiting using governor crate
|
|
||||||
//! 2. Provider circuit breaking to handle API failures
|
|
||||||
//! 3. Global rate limiting for overall system protection
|
|
||||||
|
|
||||||
use anyhow::Result;
|
|
||||||
use governor::{Quota, RateLimiter, DefaultDirectRateLimiter};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::num::NonZeroU32;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use tokio::sync::RwLock;
|
|
||||||
use tracing::{info, warn};
|
|
||||||
|
|
||||||
type GovRateLimiter = DefaultDirectRateLimiter;
|
|
||||||
|
|
||||||
/// Rate limiter configuration
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct RateLimiterConfig {
|
|
||||||
/// Requests per minute per client
|
|
||||||
pub requests_per_minute: u32,
|
|
||||||
/// Burst size (maximum burst capacity)
|
|
||||||
pub burst_size: u32,
|
|
||||||
/// Global requests per minute (across all clients)
|
|
||||||
pub global_requests_per_minute: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for RateLimiterConfig {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
requests_per_minute: 60, // 1 request per second per client
|
|
||||||
burst_size: 10, // Allow bursts of up to 10 requests
|
|
||||||
global_requests_per_minute: 600, // 10 requests per second globally
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Circuit breaker state
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
|
||||||
pub enum CircuitState {
|
|
||||||
Closed, // Normal operation
|
|
||||||
Open, // Circuit is open, requests fail fast
|
|
||||||
HalfOpen, // Testing if service has recovered
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Circuit breaker configuration
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct CircuitBreakerConfig {
|
|
||||||
/// Failure threshold to open circuit
|
|
||||||
pub failure_threshold: u32,
|
|
||||||
/// Time window for failure counting (seconds)
|
|
||||||
pub failure_window_secs: u64,
|
|
||||||
/// Time to wait before trying half-open state (seconds)
|
|
||||||
pub reset_timeout_secs: u64,
|
|
||||||
/// Success threshold to close circuit
|
|
||||||
pub success_threshold: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for CircuitBreakerConfig {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
failure_threshold: 5, // 5 failures
|
|
||||||
failure_window_secs: 60, // within 60 seconds
|
|
||||||
reset_timeout_secs: 30, // wait 30 seconds before half-open
|
|
||||||
success_threshold: 3, // 3 successes to close circuit
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/// Circuit breaker for a provider
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct ProviderCircuitBreaker {
|
|
||||||
state: CircuitState,
|
|
||||||
failure_count: u32,
|
|
||||||
success_count: u32,
|
|
||||||
last_failure_time: Option<std::time::Instant>,
|
|
||||||
last_state_change: std::time::Instant,
|
|
||||||
config: CircuitBreakerConfig,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProviderCircuitBreaker {
|
|
||||||
pub fn new(config: CircuitBreakerConfig) -> Self {
|
|
||||||
Self {
|
|
||||||
state: CircuitState::Closed,
|
|
||||||
failure_count: 0,
|
|
||||||
success_count: 0,
|
|
||||||
last_failure_time: None,
|
|
||||||
last_state_change: std::time::Instant::now(),
|
|
||||||
config,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if request is allowed
|
|
||||||
pub fn allow_request(&mut self) -> bool {
|
|
||||||
match self.state {
|
|
||||||
CircuitState::Closed => true,
|
|
||||||
CircuitState::Open => {
|
|
||||||
// Check if reset timeout has passed
|
|
||||||
let elapsed = self.last_state_change.elapsed();
|
|
||||||
if elapsed.as_secs() >= self.config.reset_timeout_secs {
|
|
||||||
self.state = CircuitState::HalfOpen;
|
|
||||||
self.last_state_change = std::time::Instant::now();
|
|
||||||
info!("Circuit breaker transitioning to half-open state");
|
|
||||||
true
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
CircuitState::HalfOpen => true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Record a successful request
|
|
||||||
pub fn record_success(&mut self) {
|
|
||||||
match self.state {
|
|
||||||
CircuitState::Closed => {
|
|
||||||
// Reset failure count on success
|
|
||||||
self.failure_count = 0;
|
|
||||||
self.last_failure_time = None;
|
|
||||||
}
|
|
||||||
CircuitState::HalfOpen => {
|
|
||||||
self.success_count += 1;
|
|
||||||
if self.success_count >= self.config.success_threshold {
|
|
||||||
self.state = CircuitState::Closed;
|
|
||||||
self.success_count = 0;
|
|
||||||
self.failure_count = 0;
|
|
||||||
self.last_state_change = std::time::Instant::now();
|
|
||||||
info!("Circuit breaker closed after successful requests");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
CircuitState::Open => {
|
|
||||||
// Should not happen, but handle gracefully
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Record a failed request
|
|
||||||
pub fn record_failure(&mut self) {
|
|
||||||
let now = std::time::Instant::now();
|
|
||||||
|
|
||||||
// Check if failure window has expired
|
|
||||||
if let Some(last_failure) = self.last_failure_time
|
|
||||||
&& now.duration_since(last_failure).as_secs() > self.config.failure_window_secs
|
|
||||||
{
|
|
||||||
// Reset failure count if window expired
|
|
||||||
self.failure_count = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
self.failure_count += 1;
|
|
||||||
self.last_failure_time = Some(now);
|
|
||||||
|
|
||||||
if self.failure_count >= self.config.failure_threshold && self.state == CircuitState::Closed {
|
|
||||||
self.state = CircuitState::Open;
|
|
||||||
self.last_state_change = now;
|
|
||||||
warn!("Circuit breaker opened due to {} failures", self.failure_count);
|
|
||||||
} else if self.state == CircuitState::HalfOpen {
|
|
||||||
// Failure in half-open state, go back to open
|
|
||||||
self.state = CircuitState::Open;
|
|
||||||
self.success_count = 0;
|
|
||||||
self.last_state_change = now;
|
|
||||||
warn!("Circuit breaker re-opened after failure in half-open state");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get current state
|
|
||||||
pub fn state(&self) -> CircuitState {
|
|
||||||
self.state
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Rate limiting and circuit breaking manager
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct RateLimitManager {
|
|
||||||
client_buckets: Arc<RwLock<HashMap<String, GovRateLimiter>>>,
|
|
||||||
global_bucket: Arc<GovRateLimiter>,
|
|
||||||
circuit_breakers: Arc<RwLock<HashMap<String, ProviderCircuitBreaker>>>,
|
|
||||||
config: RateLimiterConfig,
|
|
||||||
circuit_config: CircuitBreakerConfig,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RateLimitManager {
|
|
||||||
pub fn new(config: RateLimiterConfig, circuit_config: CircuitBreakerConfig) -> Self {
|
|
||||||
// Create global rate limiter quota
|
|
||||||
let global_quota = Quota::per_minute(
|
|
||||||
NonZeroU32::new(config.global_requests_per_minute).expect("global_requests_per_minute must be positive")
|
|
||||||
)
|
|
||||||
.allow_burst(NonZeroU32::new(config.burst_size).expect("burst_size must be positive"));
|
|
||||||
let global_bucket = RateLimiter::direct(global_quota);
|
|
||||||
|
|
||||||
Self {
|
|
||||||
client_buckets: Arc::new(RwLock::new(HashMap::new())),
|
|
||||||
global_bucket: Arc::new(global_bucket),
|
|
||||||
circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
|
|
||||||
config,
|
|
||||||
circuit_config,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if a client request is allowed
|
|
||||||
pub async fn check_client_request(&self, client_id: &str) -> Result<bool> {
|
|
||||||
// Check global rate limit first (1 token per request)
|
|
||||||
if self.global_bucket.check().is_err() {
|
|
||||||
warn!("Global rate limit exceeded");
|
|
||||||
return Ok(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check client-specific rate limit
|
|
||||||
let mut buckets = self.client_buckets.write().await;
|
|
||||||
let bucket = buckets.entry(client_id.to_string()).or_insert_with(|| {
|
|
||||||
let quota = Quota::per_minute(
|
|
||||||
NonZeroU32::new(self.config.requests_per_minute).expect("requests_per_minute must be positive")
|
|
||||||
)
|
|
||||||
.allow_burst(NonZeroU32::new(self.config.burst_size).expect("burst_size must be positive"));
|
|
||||||
RateLimiter::direct(quota)
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(bucket.check().is_ok())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if provider requests are allowed (circuit breaker)
|
|
||||||
pub async fn check_provider_request(&self, provider_name: &str) -> Result<bool> {
|
|
||||||
let mut breakers = self.circuit_breakers.write().await;
|
|
||||||
let breaker = breakers
|
|
||||||
.entry(provider_name.to_string())
|
|
||||||
.or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone()));
|
|
||||||
|
|
||||||
Ok(breaker.allow_request())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Record provider success
|
|
||||||
pub async fn record_provider_success(&self, provider_name: &str) {
|
|
||||||
let mut breakers = self.circuit_breakers.write().await;
|
|
||||||
if let Some(breaker) = breakers.get_mut(provider_name) {
|
|
||||||
breaker.record_success();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Record provider failure
|
|
||||||
pub async fn record_provider_failure(&self, provider_name: &str) {
|
|
||||||
let mut breakers = self.circuit_breakers.write().await;
|
|
||||||
let breaker = breakers
|
|
||||||
.entry(provider_name.to_string())
|
|
||||||
.or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone()));
|
|
||||||
|
|
||||||
breaker.record_failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get provider circuit state
|
|
||||||
pub async fn get_provider_state(&self, provider_name: &str) -> CircuitState {
|
|
||||||
let breakers = self.circuit_breakers.read().await;
|
|
||||||
breakers
|
|
||||||
.get(provider_name)
|
|
||||||
.map(|b| b.state())
|
|
||||||
.unwrap_or(CircuitState::Closed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Axum middleware for rate limiting
|
|
||||||
pub mod middleware {
|
|
||||||
use super::*;
|
|
||||||
use crate::errors::AppError;
|
|
||||||
use crate::state::AppState;
|
|
||||||
use crate::auth::AuthInfo;
|
|
||||||
use axum::{
|
|
||||||
extract::{Request, State},
|
|
||||||
middleware::Next,
|
|
||||||
response::Response,
|
|
||||||
};
|
|
||||||
use sqlx;
|
|
||||||
|
|
||||||
/// Rate limiting middleware
|
|
||||||
pub async fn rate_limit_middleware(
|
|
||||||
State(state): State<AppState>,
|
|
||||||
mut request: Request,
|
|
||||||
next: Next,
|
|
||||||
) -> Result<Response, AppError> {
|
|
||||||
// Extract token synchronously from headers (avoids holding &Request across await)
|
|
||||||
let token = extract_bearer_token(&request);
|
|
||||||
|
|
||||||
// Resolve client_id and populate AuthInfo: DB token lookup, then prefix fallback
|
|
||||||
let auth_info = resolve_auth_info(token, &state).await;
|
|
||||||
let client_id = auth_info.client_id.clone();
|
|
||||||
|
|
||||||
// Check rate limits
|
|
||||||
if !state.rate_limit_manager.check_client_request(&client_id).await? {
|
|
||||||
return Err(AppError::RateLimitError("Rate limit exceeded".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store AuthInfo in request extensions for extractors and downstream handlers
|
|
||||||
request.extensions_mut().insert(auth_info);
|
|
||||||
|
|
||||||
Ok(next.run(request).await)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Synchronously extract bearer token from request headers
|
|
||||||
fn extract_bearer_token(request: &Request) -> Option<String> {
|
|
||||||
request.headers().get("Authorization")
|
|
||||||
.and_then(|v| v.to_str().ok())
|
|
||||||
.and_then(|s| s.strip_prefix("Bearer "))
|
|
||||||
.map(|t| t.to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Resolve auth info: try DB token first, then fall back to token-prefix derivation
|
|
||||||
async fn resolve_auth_info(token: Option<String>, state: &AppState) -> AuthInfo {
|
|
||||||
if let Some(token) = token {
|
|
||||||
// Try DB token lookup first
|
|
||||||
match sqlx::query_scalar::<_, String>(
|
|
||||||
"UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = TRUE RETURNING client_id",
|
|
||||||
)
|
|
||||||
.bind(&token)
|
|
||||||
.fetch_optional(&state.db_pool)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(Some(cid)) => {
|
|
||||||
return AuthInfo {
|
|
||||||
token,
|
|
||||||
client_id: cid,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("DB error during token lookup: {}", e);
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback to token-prefix derivation (env tokens / permissive mode)
|
|
||||||
let client_id = format!("client_{}", &token[..8.min(token.len())]);
|
|
||||||
return AuthInfo { token, client_id };
|
|
||||||
}
|
|
||||||
|
|
||||||
// No token — anonymous
|
|
||||||
AuthInfo {
|
|
||||||
token: String::new(),
|
|
||||||
client_id: "anonymous".to_string(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Circuit breaker middleware for provider requests
|
|
||||||
pub async fn circuit_breaker_middleware(provider_name: &str, state: &AppState) -> Result<(), AppError> {
|
|
||||||
if !state.rate_limit_manager.check_provider_request(provider_name).await? {
|
|
||||||
return Err(AppError::ProviderError(format!(
|
|
||||||
"Provider {} is currently unavailable (circuit breaker open)",
|
|
||||||
provider_name
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,456 +0,0 @@
|
|||||||
use axum::{
|
|
||||||
Json, Router,
|
|
||||||
extract::State,
|
|
||||||
response::IntoResponse,
|
|
||||||
response::sse::{Event, Sse},
|
|
||||||
routing::{get, post},
|
|
||||||
};
|
|
||||||
use axum::http::{header, HeaderValue};
|
|
||||||
use tower_http::{
|
|
||||||
limit::RequestBodyLimitLayer,
|
|
||||||
set_header::SetResponseHeaderLayer,
|
|
||||||
};
|
|
||||||
|
|
||||||
use futures::StreamExt;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use uuid::Uuid;
|
|
||||||
use tracing::{info, warn};
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
auth::AuthenticatedClient,
|
|
||||||
errors::AppError,
|
|
||||||
models::{
|
|
||||||
ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage,
|
|
||||||
ChatStreamChoice, ChatStreamDelta, Usage,
|
|
||||||
},
|
|
||||||
rate_limiting,
|
|
||||||
state::AppState,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub fn router(state: AppState) -> Router {
|
|
||||||
// Security headers
|
|
||||||
let csp_header: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
|
||||||
header::CONTENT_SECURITY_POLICY,
|
|
||||||
"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws:;"
|
|
||||||
.parse()
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
let x_frame_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
|
||||||
header::X_FRAME_OPTIONS,
|
|
||||||
"DENY".parse().unwrap(),
|
|
||||||
);
|
|
||||||
let x_content_type_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
|
||||||
header::X_CONTENT_TYPE_OPTIONS,
|
|
||||||
"nosniff".parse().unwrap(),
|
|
||||||
);
|
|
||||||
let strict_transport_security: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
|
|
||||||
header::STRICT_TRANSPORT_SECURITY,
|
|
||||||
"max-age=31536000; includeSubDomains".parse().unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
Router::new()
|
|
||||||
.route("/v1/chat/completions", post(chat_completions))
|
|
||||||
.route("/v1/models", get(list_models))
|
|
||||||
.layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit
|
|
||||||
.layer(csp_header)
|
|
||||||
.layer(x_frame_options)
|
|
||||||
.layer(x_content_type_options)
|
|
||||||
.layer(strict_transport_security)
|
|
||||||
.layer(axum::middleware::from_fn_with_state(
|
|
||||||
state.clone(),
|
|
||||||
rate_limiting::middleware::rate_limit_middleware,
|
|
||||||
))
|
|
||||||
.with_state(state)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/// GET /v1/models — OpenAI-compatible model listing.
|
|
||||||
/// Returns all models from enabled providers so clients like Open WebUI can
|
|
||||||
/// discover which models are available through the proxy.
|
|
||||||
async fn list_models(
|
|
||||||
State(state): State<AppState>,
|
|
||||||
_auth: AuthenticatedClient,
|
|
||||||
) -> Result<Json<serde_json::Value>, AppError> {
|
|
||||||
let registry = &state.model_registry;
|
|
||||||
let providers = state.provider_manager.get_all_providers().await;
|
|
||||||
|
|
||||||
let mut models = Vec::new();
|
|
||||||
|
|
||||||
for provider in &providers {
|
|
||||||
let provider_name = provider.name();
|
|
||||||
|
|
||||||
// Map internal provider names to registry provider IDs
|
|
||||||
let registry_key = match provider_name {
|
|
||||||
"gemini" => "google",
|
|
||||||
"grok" => "xai",
|
|
||||||
_ => provider_name,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Find this provider's models in the registry
|
|
||||||
if let Some(provider_info) = registry.providers.get(registry_key) {
|
|
||||||
for (model_id, meta) in &provider_info.models {
|
|
||||||
// Skip disabled models via the config cache
|
|
||||||
if let Some(cfg) = state.model_config_cache.get(model_id).await {
|
|
||||||
if !cfg.enabled {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
models.push(serde_json::json!({
|
|
||||||
"id": model_id,
|
|
||||||
"object": "model",
|
|
||||||
"created": 0,
|
|
||||||
"owned_by": provider_name,
|
|
||||||
"name": meta.name,
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// For Ollama, models are configured in the TOML, not the registry
|
|
||||||
if provider_name == "ollama" {
|
|
||||||
for model_id in &state.config.providers.ollama.models {
|
|
||||||
models.push(serde_json::json!({
|
|
||||||
"id": model_id,
|
|
||||||
"object": "model",
|
|
||||||
"created": 0,
|
|
||||||
"owned_by": "ollama",
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Json(serde_json::json!({
|
|
||||||
"object": "list",
|
|
||||||
"data": models
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_model_cost(
|
|
||||||
model: &str,
|
|
||||||
prompt_tokens: u32,
|
|
||||||
completion_tokens: u32,
|
|
||||||
cache_read_tokens: u32,
|
|
||||||
cache_write_tokens: u32,
|
|
||||||
provider: &Arc<dyn crate::providers::Provider>,
|
|
||||||
state: &AppState,
|
|
||||||
) -> f64 {
|
|
||||||
// Check in-memory cache for cost overrides (no SQLite hit)
|
|
||||||
if let Some(cached) = state.model_config_cache.get(model).await {
|
|
||||||
if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) {
|
|
||||||
// Manual overrides don't have cache-specific rates, so use simple formula
|
|
||||||
return (prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback to provider's registry-based calculation (cache-aware)
|
|
||||||
provider.calculate_cost(model, prompt_tokens, completion_tokens, cache_read_tokens, cache_write_tokens, &state.model_registry)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn chat_completions(
|
|
||||||
State(state): State<AppState>,
|
|
||||||
auth: AuthenticatedClient,
|
|
||||||
Json(mut request): Json<ChatCompletionRequest>,
|
|
||||||
) -> Result<axum::response::Response, AppError> {
|
|
||||||
let client_id = auth.client_id.clone();
|
|
||||||
let token = auth.token.clone();
|
|
||||||
|
|
||||||
// Verify token if env tokens are configured
|
|
||||||
if !state.auth_tokens.is_empty() && !state.auth_tokens.contains(&token) {
|
|
||||||
// If not in env tokens, check if it was a DB token (client_id wouldn't be client_XXXX prefix)
|
|
||||||
if client_id.starts_with("client_") {
|
|
||||||
return Err(AppError::AuthError("Invalid authentication token".to_string()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let start_time = std::time::Instant::now();
|
|
||||||
let model = request.model.clone();
|
|
||||||
|
|
||||||
info!("Chat completion request from client {} for model {}", client_id, model);
|
|
||||||
|
|
||||||
// Check if model is enabled via in-memory cache (no SQLite hit)
|
|
||||||
let cached_config = state.model_config_cache.get(&model).await;
|
|
||||||
|
|
||||||
let (model_enabled, model_mapping) = match cached_config {
|
|
||||||
Some(cfg) => (cfg.enabled, cfg.mapping),
|
|
||||||
None => (true, None),
|
|
||||||
};
|
|
||||||
|
|
||||||
if !model_enabled {
|
|
||||||
return Err(AppError::ValidationError(format!(
|
|
||||||
"Model {} is currently disabled",
|
|
||||||
model
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply mapping if present
|
|
||||||
if let Some(target_model) = model_mapping {
|
|
||||||
info!("Mapping model {} to {}", model, target_model);
|
|
||||||
request.model = target_model;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find appropriate provider for the model
|
|
||||||
let provider = state
|
|
||||||
.provider_manager
|
|
||||||
.get_provider_for_model(&request.model)
|
|
||||||
.await
|
|
||||||
.ok_or_else(|| AppError::ProviderError(format!("No provider found for model: {}", request.model)))?;
|
|
||||||
|
|
||||||
let provider_name = provider.name().to_string();
|
|
||||||
|
|
||||||
// Check circuit breaker for this provider
|
|
||||||
rate_limiting::middleware::circuit_breaker_middleware(&provider_name, &state).await?;
|
|
||||||
|
|
||||||
// Convert to unified request format
|
|
||||||
let mut unified_request =
|
|
||||||
crate::models::UnifiedRequest::try_from(request).map_err(|e| AppError::ValidationError(e.to_string()))?;
|
|
||||||
|
|
||||||
// Set client_id from authentication
|
|
||||||
unified_request.client_id = client_id.clone();
|
|
||||||
|
|
||||||
// Hydrate images if present
|
|
||||||
if unified_request.has_images {
|
|
||||||
unified_request
|
|
||||||
.hydrate_images()
|
|
||||||
.await
|
|
||||||
.map_err(|e| AppError::ValidationError(format!("Failed to process images: {}", e)))?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let has_images = unified_request.has_images;
|
|
||||||
|
|
||||||
// Measure proxy overhead (time spent before sending to upstream provider)
|
|
||||||
let proxy_overhead = start_time.elapsed();
|
|
||||||
|
|
||||||
// Check if streaming is requested
|
|
||||||
if unified_request.stream {
|
|
||||||
// Estimate prompt tokens for logging later
|
|
||||||
let prompt_tokens = crate::utils::tokens::estimate_request_tokens(&model, &unified_request);
|
|
||||||
|
|
||||||
// Handle streaming response
|
|
||||||
// Allow provider-specific routing for streaming too
|
|
||||||
let use_responses = provider.name() == "openai"
|
|
||||||
&& crate::utils::registry::model_prefers_responses(&state.model_registry, &unified_request.model);
|
|
||||||
|
|
||||||
let stream_result = if use_responses {
|
|
||||||
provider.chat_responses_stream(unified_request).await
|
|
||||||
} else {
|
|
||||||
provider.chat_completion_stream(unified_request).await
|
|
||||||
};
|
|
||||||
|
|
||||||
match stream_result {
|
|
||||||
Ok(stream) => {
|
|
||||||
// Record provider success
|
|
||||||
state.rate_limit_manager.record_provider_success(&provider_name).await;
|
|
||||||
|
|
||||||
info!(
|
|
||||||
"Streaming started for {} (proxy overhead: {}ms)",
|
|
||||||
model,
|
|
||||||
proxy_overhead.as_millis()
|
|
||||||
);
|
|
||||||
|
|
||||||
// Wrap with AggregatingStream for token counting and database logging
|
|
||||||
let aggregating_stream = crate::utils::streaming::AggregatingStream::new(
|
|
||||||
stream,
|
|
||||||
crate::utils::streaming::StreamConfig {
|
|
||||||
client_id: client_id.clone(),
|
|
||||||
provider: provider.clone(),
|
|
||||||
model: model.clone(),
|
|
||||||
prompt_tokens,
|
|
||||||
has_images,
|
|
||||||
logger: state.request_logger.clone(),
|
|
||||||
model_registry: state.model_registry.clone(),
|
|
||||||
model_config_cache: state.model_config_cache.clone(),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
// Create SSE stream - simpler approach that works
|
|
||||||
let stream_id = format!("chatcmpl-{}", Uuid::new_v4());
|
|
||||||
let stream_created = chrono::Utc::now().timestamp() as u64;
|
|
||||||
let stream_id_sse = stream_id.clone();
|
|
||||||
|
|
||||||
// Build stream that yields events wrapped in Result
|
|
||||||
let stream = async_stream::stream! {
|
|
||||||
let mut aggregator = Box::pin(aggregating_stream);
|
|
||||||
let mut first_chunk = true;
|
|
||||||
|
|
||||||
while let Some(chunk_result) = aggregator.next().await {
|
|
||||||
match chunk_result {
|
|
||||||
Ok(chunk) => {
|
|
||||||
let role = if first_chunk {
|
|
||||||
first_chunk = false;
|
|
||||||
Some("assistant".to_string())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let response = ChatCompletionStreamResponse {
|
|
||||||
id: stream_id_sse.clone(),
|
|
||||||
object: "chat.completion.chunk".to_string(),
|
|
||||||
created: stream_created,
|
|
||||||
model: chunk.model.clone(),
|
|
||||||
choices: vec![ChatStreamChoice {
|
|
||||||
index: 0,
|
|
||||||
delta: ChatStreamDelta {
|
|
||||||
role,
|
|
||||||
content: Some(chunk.content),
|
|
||||||
reasoning_content: chunk.reasoning_content,
|
|
||||||
tool_calls: chunk.tool_calls,
|
|
||||||
},
|
|
||||||
finish_reason: chunk.finish_reason,
|
|
||||||
}],
|
|
||||||
};
|
|
||||||
|
|
||||||
// Use axum's Event directly, wrap in Ok
|
|
||||||
match Event::default().json_data(response) {
|
|
||||||
Ok(event) => yield Ok::<_, crate::errors::AppError>(event),
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to serialize SSE: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Stream error: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Yield [DONE] at the end
|
|
||||||
yield Ok::<_, crate::errors::AppError>(Event::default().data("[DONE]"));
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(Sse::new(stream).into_response())
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
// Record provider failure
|
|
||||||
state.rate_limit_manager.record_provider_failure(&provider_name).await;
|
|
||||||
|
|
||||||
// Log failed request
|
|
||||||
let duration = start_time.elapsed();
|
|
||||||
warn!("Streaming request failed after {:?}: {}", duration, e);
|
|
||||||
|
|
||||||
Err(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Handle non-streaming response
|
|
||||||
// Allow provider-specific routing: for OpenAI, some models prefer the
|
|
||||||
// Responses API (/v1/responses). Use the model registry heuristic to
|
|
||||||
// choose chat_responses vs chat_completion automatically.
|
|
||||||
let use_responses = provider.name() == "openai"
|
|
||||||
&& crate::utils::registry::model_prefers_responses(&state.model_registry, &unified_request.model);
|
|
||||||
|
|
||||||
let result = if use_responses {
|
|
||||||
provider.chat_responses(unified_request).await
|
|
||||||
} else {
|
|
||||||
provider.chat_completion(unified_request).await
|
|
||||||
};
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(response) => {
|
|
||||||
// Record provider success
|
|
||||||
state.rate_limit_manager.record_provider_success(&provider_name).await;
|
|
||||||
|
|
||||||
let duration = start_time.elapsed();
|
|
||||||
let cost = get_model_cost(
|
|
||||||
&response.model,
|
|
||||||
response.prompt_tokens,
|
|
||||||
response.completion_tokens,
|
|
||||||
response.cache_read_tokens,
|
|
||||||
response.cache_write_tokens,
|
|
||||||
&provider,
|
|
||||||
&state,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
// Log request to database
|
|
||||||
state.request_logger.log_request(crate::logging::RequestLog {
|
|
||||||
timestamp: chrono::Utc::now(),
|
|
||||||
client_id: client_id.clone(),
|
|
||||||
provider: provider_name.clone(),
|
|
||||||
model: response.model.clone(),
|
|
||||||
prompt_tokens: response.prompt_tokens,
|
|
||||||
completion_tokens: response.completion_tokens,
|
|
||||||
total_tokens: response.total_tokens,
|
|
||||||
cache_read_tokens: response.cache_read_tokens,
|
|
||||||
cache_write_tokens: response.cache_write_tokens,
|
|
||||||
cost,
|
|
||||||
has_images,
|
|
||||||
status: "success".to_string(),
|
|
||||||
error_message: None,
|
|
||||||
duration_ms: duration.as_millis() as u64,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Convert ProviderResponse to ChatCompletionResponse
|
|
||||||
let finish_reason = if response.tool_calls.is_some() {
|
|
||||||
"tool_calls".to_string()
|
|
||||||
} else {
|
|
||||||
"stop".to_string()
|
|
||||||
};
|
|
||||||
|
|
||||||
let chat_response = ChatCompletionResponse {
|
|
||||||
id: format!("chatcmpl-{}", Uuid::new_v4()),
|
|
||||||
object: "chat.completion".to_string(),
|
|
||||||
created: chrono::Utc::now().timestamp() as u64,
|
|
||||||
model: response.model,
|
|
||||||
choices: vec![ChatChoice {
|
|
||||||
index: 0,
|
|
||||||
message: ChatMessage {
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: crate::models::MessageContent::Text {
|
|
||||||
content: response.content,
|
|
||||||
},
|
|
||||||
reasoning_content: response.reasoning_content,
|
|
||||||
tool_calls: response.tool_calls,
|
|
||||||
name: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
|
||||||
finish_reason: Some(finish_reason),
|
|
||||||
}],
|
|
||||||
usage: Some(Usage {
|
|
||||||
prompt_tokens: response.prompt_tokens,
|
|
||||||
completion_tokens: response.completion_tokens,
|
|
||||||
total_tokens: response.total_tokens,
|
|
||||||
cache_read_tokens: if response.cache_read_tokens > 0 { Some(response.cache_read_tokens) } else { None },
|
|
||||||
cache_write_tokens: if response.cache_write_tokens > 0 { Some(response.cache_write_tokens) } else { None },
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Log successful request with proxy overhead breakdown
|
|
||||||
let upstream_ms = duration.as_millis() as u64 - proxy_overhead.as_millis() as u64;
|
|
||||||
info!(
|
|
||||||
"Request completed in {:?} (proxy: {}ms, upstream: {}ms)",
|
|
||||||
duration,
|
|
||||||
proxy_overhead.as_millis(),
|
|
||||||
upstream_ms
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(Json(chat_response).into_response())
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
// Record provider failure
|
|
||||||
state.rate_limit_manager.record_provider_failure(&provider_name).await;
|
|
||||||
|
|
||||||
// Log failed request to database
|
|
||||||
let duration = start_time.elapsed();
|
|
||||||
state.request_logger.log_request(crate::logging::RequestLog {
|
|
||||||
timestamp: chrono::Utc::now(),
|
|
||||||
client_id: client_id.clone(),
|
|
||||||
provider: provider_name.clone(),
|
|
||||||
model: model.clone(),
|
|
||||||
prompt_tokens: 0,
|
|
||||||
completion_tokens: 0,
|
|
||||||
total_tokens: 0,
|
|
||||||
cache_read_tokens: 0,
|
|
||||||
cache_write_tokens: 0,
|
|
||||||
cost: 0.0,
|
|
||||||
has_images: false,
|
|
||||||
status: "error".to_string(),
|
|
||||||
error_message: Some(e.to_string()),
|
|
||||||
duration_ms: duration.as_millis() as u64,
|
|
||||||
});
|
|
||||||
|
|
||||||
warn!("Request failed after {:?}: {}", duration, e);
|
|
||||||
|
|
||||||
Err(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
129
src/state/mod.rs
129
src/state/mod.rs
@@ -1,129 +0,0 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use tokio::sync::{broadcast, RwLock};
|
|
||||||
use tracing::warn;
|
|
||||||
|
|
||||||
use crate::{
|
|
||||||
client::ClientManager, config::AppConfig, database::DbPool, logging::RequestLogger,
|
|
||||||
models::registry::ModelRegistry, providers::ProviderManager, rate_limiting::RateLimitManager,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Cached model configuration entry
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct CachedModelConfig {
|
|
||||||
pub enabled: bool,
|
|
||||||
pub mapping: Option<String>,
|
|
||||||
pub prompt_cost_per_m: Option<f64>,
|
|
||||||
pub completion_cost_per_m: Option<f64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// In-memory cache for model_configs table.
|
|
||||||
/// Refreshes periodically to avoid hitting SQLite on every request.
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct ModelConfigCache {
|
|
||||||
cache: Arc<RwLock<HashMap<String, CachedModelConfig>>>,
|
|
||||||
db_pool: DbPool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ModelConfigCache {
|
|
||||||
pub fn new(db_pool: DbPool) -> Self {
|
|
||||||
Self {
|
|
||||||
cache: Arc::new(RwLock::new(HashMap::new())),
|
|
||||||
db_pool,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Load all model configs from the database into cache
|
|
||||||
pub async fn refresh(&self) {
|
|
||||||
match sqlx::query_as::<_, (String, bool, Option<String>, Option<f64>, Option<f64>)>(
|
|
||||||
"SELECT id, enabled, mapping, prompt_cost_per_m, completion_cost_per_m FROM model_configs",
|
|
||||||
)
|
|
||||||
.fetch_all(&self.db_pool)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(rows) => {
|
|
||||||
let mut map = HashMap::with_capacity(rows.len());
|
|
||||||
for (id, enabled, mapping, prompt_cost, completion_cost) in rows {
|
|
||||||
map.insert(
|
|
||||||
id,
|
|
||||||
CachedModelConfig {
|
|
||||||
enabled,
|
|
||||||
mapping,
|
|
||||||
prompt_cost_per_m: prompt_cost,
|
|
||||||
completion_cost_per_m: completion_cost,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
*self.cache.write().await = map;
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to refresh model config cache: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a cached model config. Returns None if not in cache (model is unconfigured).
|
|
||||||
pub async fn get(&self, model: &str) -> Option<CachedModelConfig> {
|
|
||||||
self.cache.read().await.get(model).cloned()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invalidate cache — call this after dashboard writes to model_configs
|
|
||||||
pub async fn invalidate(&self) {
|
|
||||||
self.refresh().await;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Start a background task that refreshes the cache every `interval` seconds
|
|
||||||
pub fn start_refresh_task(self, interval_secs: u64) {
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
|
|
||||||
loop {
|
|
||||||
interval.tick().await;
|
|
||||||
self.refresh().await;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Shared application state
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct AppState {
|
|
||||||
pub config: Arc<AppConfig>,
|
|
||||||
pub provider_manager: ProviderManager,
|
|
||||||
pub db_pool: DbPool,
|
|
||||||
pub rate_limit_manager: Arc<RateLimitManager>,
|
|
||||||
pub client_manager: Arc<ClientManager>,
|
|
||||||
pub request_logger: Arc<RequestLogger>,
|
|
||||||
pub model_registry: Arc<ModelRegistry>,
|
|
||||||
pub model_config_cache: ModelConfigCache,
|
|
||||||
pub dashboard_tx: broadcast::Sender<serde_json::Value>,
|
|
||||||
pub auth_tokens: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AppState {
|
|
||||||
pub fn new(
|
|
||||||
config: Arc<AppConfig>,
|
|
||||||
provider_manager: ProviderManager,
|
|
||||||
db_pool: DbPool,
|
|
||||||
rate_limit_manager: RateLimitManager,
|
|
||||||
model_registry: ModelRegistry,
|
|
||||||
auth_tokens: Vec<String>,
|
|
||||||
) -> Self {
|
|
||||||
let client_manager = Arc::new(ClientManager::new(db_pool.clone()));
|
|
||||||
let (dashboard_tx, _) = broadcast::channel(100);
|
|
||||||
let request_logger = Arc::new(RequestLogger::new(db_pool.clone(), dashboard_tx.clone()));
|
|
||||||
let model_config_cache = ModelConfigCache::new(db_pool.clone());
|
|
||||||
|
|
||||||
Self {
|
|
||||||
config,
|
|
||||||
provider_manager,
|
|
||||||
db_pool,
|
|
||||||
rate_limit_manager: Arc::new(rate_limit_manager),
|
|
||||||
client_manager,
|
|
||||||
request_logger,
|
|
||||||
model_registry: Arc::new(model_registry),
|
|
||||||
model_config_cache,
|
|
||||||
dashboard_tx,
|
|
||||||
auth_tokens,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,171 +0,0 @@
|
|||||||
use aes_gcm::{
|
|
||||||
aead::{Aead, AeadCore, KeyInit, OsRng},
|
|
||||||
Aes256Gcm, Key, Nonce,
|
|
||||||
};
|
|
||||||
use anyhow::{anyhow, Context, Result};
|
|
||||||
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
|
|
||||||
use std::env;
|
|
||||||
use std::sync::OnceLock;
|
|
||||||
|
|
||||||
static RAW_KEY: OnceLock<[u8; 32]> = OnceLock::new();
|
|
||||||
|
|
||||||
/// Initialize the encryption key from a hex or base64 encoded string.
|
|
||||||
/// Must be called before any encryption/decryption operations.
|
|
||||||
/// Returns error if the key is invalid or already initialized with a different key.
|
|
||||||
pub fn init_with_key(key_str: &str) -> Result<()> {
|
|
||||||
let key_bytes = hex::decode(key_str)
|
|
||||||
.or_else(|_| BASE64.decode(key_str))
|
|
||||||
.context("Encryption key must be hex or base64 encoded")?;
|
|
||||||
if key_bytes.len() != 32 {
|
|
||||||
anyhow::bail!(
|
|
||||||
"Encryption key must be 32 bytes (256 bits), got {} bytes",
|
|
||||||
key_bytes.len()
|
|
||||||
);
|
|
||||||
}
|
|
||||||
let key_array: [u8; 32] = key_bytes.try_into().unwrap(); // safe due to length check
|
|
||||||
// Check if already initialized with same key
|
|
||||||
if let Some(existing) = RAW_KEY.get() {
|
|
||||||
if existing == &key_array {
|
|
||||||
// Same key already initialized, okay
|
|
||||||
return Ok(());
|
|
||||||
} else {
|
|
||||||
anyhow::bail!("Encryption key already initialized with a different key");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Store raw key bytes
|
|
||||||
RAW_KEY
|
|
||||||
.set(key_array)
|
|
||||||
.map_err(|_| anyhow::anyhow!("Encryption key already initialized"))?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Initialize the encryption key from the environment variable `LLM_PROXY__ENCRYPTION_KEY`.
|
|
||||||
/// Must be called before any encryption/decryption operations.
|
|
||||||
/// Panics if the environment variable is missing or invalid.
|
|
||||||
pub fn init_from_env() -> Result<()> {
|
|
||||||
let key_str =
|
|
||||||
env::var("LLM_PROXY__ENCRYPTION_KEY").context("LLM_PROXY__ENCRYPTION_KEY environment variable not set")?;
|
|
||||||
init_with_key(&key_str)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the encryption key bytes, panicking if not initialized.
|
|
||||||
fn get_key() -> &'static [u8; 32] {
|
|
||||||
RAW_KEY
|
|
||||||
.get()
|
|
||||||
.expect("Encryption key not initialized. Call crypto::init_with_key() or crypto::init_from_env() first.")
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Encrypt a plaintext string and return a base64-encoded ciphertext (nonce || ciphertext || tag).
|
|
||||||
pub fn encrypt(plaintext: &str) -> Result<String> {
|
|
||||||
let key = Key::<Aes256Gcm>::from_slice(get_key());
|
|
||||||
let cipher = Aes256Gcm::new(key);
|
|
||||||
let nonce = Aes256Gcm::generate_nonce(&mut OsRng); // 12 bytes
|
|
||||||
let ciphertext = cipher
|
|
||||||
.encrypt(&nonce, plaintext.as_bytes())
|
|
||||||
.map_err(|e| anyhow!("Encryption failed: {}", e))?;
|
|
||||||
// Combine nonce and ciphertext (ciphertext already includes tag)
|
|
||||||
let mut combined = Vec::with_capacity(nonce.len() + ciphertext.len());
|
|
||||||
combined.extend_from_slice(&nonce);
|
|
||||||
combined.extend_from_slice(&ciphertext);
|
|
||||||
Ok(BASE64.encode(combined))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Decrypt a base64-encoded ciphertext (nonce || ciphertext || tag) to a plaintext string.
|
|
||||||
pub fn decrypt(ciphertext_b64: &str) -> Result<String> {
|
|
||||||
let key = Key::<Aes256Gcm>::from_slice(get_key());
|
|
||||||
let cipher = Aes256Gcm::new(key);
|
|
||||||
let combined = BASE64.decode(ciphertext_b64).context("Invalid base64 ciphertext")?;
|
|
||||||
if combined.len() < 12 {
|
|
||||||
anyhow::bail!("Ciphertext too short");
|
|
||||||
}
|
|
||||||
let (nonce_bytes, ciphertext_and_tag) = combined.split_at(12);
|
|
||||||
let nonce = Nonce::from_slice(nonce_bytes);
|
|
||||||
let plaintext_bytes = cipher
|
|
||||||
.decrypt(nonce, ciphertext_and_tag)
|
|
||||||
.map_err(|e| anyhow!("Decryption failed (invalid key or corrupted ciphertext): {}", e))?;
|
|
||||||
String::from_utf8(plaintext_bytes).context("Decrypted bytes are not valid UTF-8")
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
const TEST_KEY_HEX: &str = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f";
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_encrypt_decrypt() {
|
|
||||||
init_with_key(TEST_KEY_HEX).unwrap();
|
|
||||||
let plaintext = "super secret api key";
|
|
||||||
let ciphertext = encrypt(plaintext).unwrap();
|
|
||||||
assert_ne!(ciphertext, plaintext);
|
|
||||||
let decrypted = decrypt(&ciphertext).unwrap();
|
|
||||||
assert_eq!(decrypted, plaintext);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_different_inputs_produce_different_ciphertexts() {
|
|
||||||
init_with_key(TEST_KEY_HEX).unwrap();
|
|
||||||
let plaintext = "same";
|
|
||||||
let cipher1 = encrypt(plaintext).unwrap();
|
|
||||||
let cipher2 = encrypt(plaintext).unwrap();
|
|
||||||
assert_ne!(cipher1, cipher2, "Nonce should make ciphertexts differ");
|
|
||||||
assert_eq!(decrypt(&cipher1).unwrap(), plaintext);
|
|
||||||
assert_eq!(decrypt(&cipher2).unwrap(), plaintext);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_invalid_key_length() {
|
|
||||||
let result = init_with_key("tooshort");
|
|
||||||
assert!(result.is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_init_from_env() {
|
|
||||||
unsafe { std::env::set_var("LLM_PROXY__ENCRYPTION_KEY", TEST_KEY_HEX) };
|
|
||||||
let result = init_from_env();
|
|
||||||
assert!(result.is_ok());
|
|
||||||
// Ensure encryption works
|
|
||||||
let ciphertext = encrypt("test").unwrap();
|
|
||||||
let decrypted = decrypt(&ciphertext).unwrap();
|
|
||||||
assert_eq!(decrypted, "test");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_missing_env_key() {
|
|
||||||
unsafe { std::env::remove_var("LLM_PROXY__ENCRYPTION_KEY") };
|
|
||||||
let result = init_from_env();
|
|
||||||
assert!(result.is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_key_hex_and_base64() {
|
|
||||||
// Hex key works
|
|
||||||
init_with_key(TEST_KEY_HEX).unwrap();
|
|
||||||
// Base64 key (same bytes encoded as base64)
|
|
||||||
let base64_key = BASE64.encode(hex::decode(TEST_KEY_HEX).unwrap());
|
|
||||||
// Re-initialization with same key (different encoding) is allowed
|
|
||||||
let result = init_with_key(&base64_key);
|
|
||||||
assert!(result.is_ok());
|
|
||||||
// Encryption should still work
|
|
||||||
let ciphertext = encrypt("test").unwrap();
|
|
||||||
let decrypted = decrypt(&ciphertext).unwrap();
|
|
||||||
assert_eq!(decrypted, "test");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
#[ignore] // conflicts with global state from other tests
|
|
||||||
fn test_already_initialized() {
|
|
||||||
init_with_key(TEST_KEY_HEX).unwrap();
|
|
||||||
let result = init_with_key(TEST_KEY_HEX);
|
|
||||||
assert!(result.is_ok()); // same key allowed
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
#[ignore] // conflicts with global state from other tests
|
|
||||||
fn test_already_initialized_different_key() {
|
|
||||||
init_with_key(TEST_KEY_HEX).unwrap();
|
|
||||||
let different_key = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e20";
|
|
||||||
let result = init_with_key(different_key);
|
|
||||||
assert!(result.is_err());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
pub mod crypto;
|
|
||||||
pub mod registry;
|
|
||||||
pub mod streaming;
|
|
||||||
pub mod tokens;
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
use crate::models::registry::ModelRegistry;
|
|
||||||
use anyhow::Result;
|
|
||||||
use tracing::info;
|
|
||||||
|
|
||||||
const MODELS_DEV_URL: &str = "https://models.dev/api.json";
|
|
||||||
|
|
||||||
pub async fn fetch_registry() -> Result<ModelRegistry> {
|
|
||||||
info!("Fetching model registry from {}", MODELS_DEV_URL);
|
|
||||||
|
|
||||||
let client = reqwest::Client::builder()
|
|
||||||
.timeout(std::time::Duration::from_secs(10))
|
|
||||||
.build()?;
|
|
||||||
|
|
||||||
let response = client.get(MODELS_DEV_URL).send().await?;
|
|
||||||
|
|
||||||
if !response.status().is_success() {
|
|
||||||
return Err(anyhow::anyhow!("Failed to fetch registry: HTTP {}", response.status()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let registry: ModelRegistry = response.json().await?;
|
|
||||||
info!("Successfully loaded model registry");
|
|
||||||
|
|
||||||
Ok(registry)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Heuristic: decide whether a model should be routed to OpenAI Responses API
|
|
||||||
/// instead of the legacy chat/completions endpoint.
|
|
||||||
///
|
|
||||||
/// Currently this uses simple patterns (codex, gpt-5 series) and also checks
|
|
||||||
/// the loaded registry metadata name for the substring "codex" as a hint.
|
|
||||||
pub fn model_prefers_responses(registry: &ModelRegistry, model: &str) -> bool {
|
|
||||||
let model_lc = model.to_lowercase();
|
|
||||||
|
|
||||||
if model_lc.contains("codex") {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if model_lc.starts_with("gpt-5") || model_lc.contains("gpt-5.") {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(meta) = registry.find_model(model) {
|
|
||||||
if meta.name.to_lowercase().contains("codex") {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
false
|
|
||||||
}
|
|
||||||
@@ -1,323 +0,0 @@
|
|||||||
|
|
||||||
use crate::errors::AppError;
|
|
||||||
use crate::logging::{RequestLog, RequestLogger};
|
|
||||||
use crate::models::ToolCall;
|
|
||||||
use crate::providers::{Provider, ProviderStreamChunk, StreamUsage};
|
|
||||||
use crate::state::ModelConfigCache;
|
|
||||||
use crate::utils::tokens::estimate_completion_tokens;
|
|
||||||
use futures::stream::Stream;
|
|
||||||
use std::pin::Pin;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::task::{Context, Poll};
|
|
||||||
|
|
||||||
/// Configuration for creating an AggregatingStream.
|
|
||||||
pub struct StreamConfig {
|
|
||||||
pub client_id: String,
|
|
||||||
pub provider: Arc<dyn Provider>,
|
|
||||||
pub model: String,
|
|
||||||
pub prompt_tokens: u32,
|
|
||||||
pub has_images: bool,
|
|
||||||
pub logger: Arc<RequestLogger>,
|
|
||||||
pub model_registry: Arc<crate::models::registry::ModelRegistry>,
|
|
||||||
pub model_config_cache: ModelConfigCache,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct AggregatingStream<S> {
|
|
||||||
inner: S,
|
|
||||||
client_id: String,
|
|
||||||
provider: Arc<dyn Provider>,
|
|
||||||
model: String,
|
|
||||||
prompt_tokens: u32,
|
|
||||||
has_images: bool,
|
|
||||||
accumulated_content: String,
|
|
||||||
accumulated_reasoning: String,
|
|
||||||
accumulated_tool_calls: Vec<ToolCall>,
|
|
||||||
/// Real usage data from the provider's final stream chunk (when available).
|
|
||||||
real_usage: Option<StreamUsage>,
|
|
||||||
logger: Arc<RequestLogger>,
|
|
||||||
model_registry: Arc<crate::models::registry::ModelRegistry>,
|
|
||||||
model_config_cache: ModelConfigCache,
|
|
||||||
start_time: std::time::Instant,
|
|
||||||
has_logged: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> AggregatingStream<S>
|
|
||||||
where
|
|
||||||
S: Stream<Item = Result<ProviderStreamChunk, AppError>> + Unpin,
|
|
||||||
{
|
|
||||||
pub fn new(inner: S, config: StreamConfig) -> Self {
|
|
||||||
Self {
|
|
||||||
inner,
|
|
||||||
client_id: config.client_id,
|
|
||||||
provider: config.provider,
|
|
||||||
model: config.model,
|
|
||||||
prompt_tokens: config.prompt_tokens,
|
|
||||||
has_images: config.has_images,
|
|
||||||
accumulated_content: String::new(),
|
|
||||||
accumulated_reasoning: String::new(),
|
|
||||||
accumulated_tool_calls: Vec::new(),
|
|
||||||
real_usage: None,
|
|
||||||
logger: config.logger,
|
|
||||||
model_registry: config.model_registry,
|
|
||||||
model_config_cache: config.model_config_cache,
|
|
||||||
start_time: std::time::Instant::now(),
|
|
||||||
has_logged: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn finalize(&mut self) {
|
|
||||||
if self.has_logged {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
self.has_logged = true;
|
|
||||||
|
|
||||||
let duration = self.start_time.elapsed();
|
|
||||||
let client_id = self.client_id.clone();
|
|
||||||
let provider_name = self.provider.name().to_string();
|
|
||||||
let model = self.model.clone();
|
|
||||||
let logger = self.logger.clone();
|
|
||||||
let provider = self.provider.clone();
|
|
||||||
let estimated_prompt_tokens = self.prompt_tokens;
|
|
||||||
let has_images = self.has_images;
|
|
||||||
let registry = self.model_registry.clone();
|
|
||||||
let config_cache = self.model_config_cache.clone();
|
|
||||||
let real_usage = self.real_usage.take();
|
|
||||||
|
|
||||||
// Estimate completion tokens (including reasoning if present)
|
|
||||||
let estimated_content_tokens = estimate_completion_tokens(&self.accumulated_content, &model);
|
|
||||||
let estimated_reasoning_tokens = if !self.accumulated_reasoning.is_empty() {
|
|
||||||
estimate_completion_tokens(&self.accumulated_reasoning, &model)
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
|
|
||||||
let estimated_completion = estimated_content_tokens + estimated_reasoning_tokens;
|
|
||||||
|
|
||||||
// Spawn a background task to log the completion
|
|
||||||
tokio::spawn(async move {
|
|
||||||
// Use real usage from the provider when available, otherwise fall back to estimates
|
|
||||||
let (prompt_tokens, completion_tokens, total_tokens, cache_read_tokens, cache_write_tokens) =
|
|
||||||
if let Some(usage) = &real_usage {
|
|
||||||
(
|
|
||||||
usage.prompt_tokens,
|
|
||||||
usage.completion_tokens,
|
|
||||||
usage.total_tokens,
|
|
||||||
usage.cache_read_tokens,
|
|
||||||
usage.cache_write_tokens,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
(
|
|
||||||
estimated_prompt_tokens,
|
|
||||||
estimated_completion,
|
|
||||||
estimated_prompt_tokens + estimated_completion,
|
|
||||||
0u32,
|
|
||||||
0u32,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check in-memory cache for cost overrides (no SQLite hit)
|
|
||||||
let cost = if let Some(cached) = config_cache.get(&model).await {
|
|
||||||
if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) {
|
|
||||||
// Cost override doesn't have cache-aware pricing, use simple formula
|
|
||||||
(prompt_tokens as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0)
|
|
||||||
} else {
|
|
||||||
provider.calculate_cost(
|
|
||||||
&model,
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
cache_read_tokens,
|
|
||||||
cache_write_tokens,
|
|
||||||
®istry,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
provider.calculate_cost(
|
|
||||||
&model,
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
cache_read_tokens,
|
|
||||||
cache_write_tokens,
|
|
||||||
®istry,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
// Log to database
|
|
||||||
logger.log_request(RequestLog {
|
|
||||||
timestamp: chrono::Utc::now(),
|
|
||||||
client_id: client_id.clone(),
|
|
||||||
provider: provider_name,
|
|
||||||
model,
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
total_tokens,
|
|
||||||
cache_read_tokens,
|
|
||||||
cache_write_tokens,
|
|
||||||
cost,
|
|
||||||
has_images,
|
|
||||||
status: "success".to_string(),
|
|
||||||
error_message: None,
|
|
||||||
duration_ms: duration.as_millis() as u64,
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S> Stream for AggregatingStream<S>
|
|
||||||
where
|
|
||||||
S: Stream<Item = Result<ProviderStreamChunk, AppError>> + Unpin,
|
|
||||||
{
|
|
||||||
type Item = Result<ProviderStreamChunk, AppError>;
|
|
||||||
|
|
||||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
|
||||||
let result = Pin::new(&mut self.inner).poll_next(cx);
|
|
||||||
|
|
||||||
match &result {
|
|
||||||
Poll::Ready(Some(Ok(chunk))) => {
|
|
||||||
self.accumulated_content.push_str(&chunk.content);
|
|
||||||
if let Some(reasoning) = &chunk.reasoning_content {
|
|
||||||
self.accumulated_reasoning.push_str(reasoning);
|
|
||||||
}
|
|
||||||
// Capture real usage from the provider when present (typically on the final chunk)
|
|
||||||
if let Some(usage) = &chunk.usage {
|
|
||||||
self.real_usage = Some(usage.clone());
|
|
||||||
}
|
|
||||||
// Accumulate tool call deltas into complete tool calls
|
|
||||||
if let Some(deltas) = &chunk.tool_calls {
|
|
||||||
for delta in deltas {
|
|
||||||
let idx = delta.index as usize;
|
|
||||||
// Grow the accumulated_tool_calls vec if needed
|
|
||||||
while self.accumulated_tool_calls.len() <= idx {
|
|
||||||
self.accumulated_tool_calls.push(ToolCall {
|
|
||||||
id: String::new(),
|
|
||||||
call_type: "function".to_string(),
|
|
||||||
function: crate::models::FunctionCall {
|
|
||||||
name: String::new(),
|
|
||||||
arguments: String::new(),
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
let tc = &mut self.accumulated_tool_calls[idx];
|
|
||||||
if let Some(id) = &delta.id {
|
|
||||||
tc.id.clone_from(id);
|
|
||||||
}
|
|
||||||
if let Some(ct) = &delta.call_type {
|
|
||||||
tc.call_type.clone_from(ct);
|
|
||||||
}
|
|
||||||
if let Some(f) = &delta.function {
|
|
||||||
if let Some(name) = &f.name {
|
|
||||||
tc.function.name.push_str(name);
|
|
||||||
}
|
|
||||||
if let Some(args) = &f.arguments {
|
|
||||||
tc.function.arguments.push_str(args);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Poll::Ready(Some(Err(_))) => {
|
|
||||||
// If there's an error, we might still want to log what we got so far?
|
|
||||||
// For now, just finalize if we have content
|
|
||||||
if !self.accumulated_content.is_empty() {
|
|
||||||
self.finalize();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Poll::Ready(None) => {
|
|
||||||
self.finalize();
|
|
||||||
}
|
|
||||||
Poll::Pending => {}
|
|
||||||
}
|
|
||||||
|
|
||||||
result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use anyhow::Result;
|
|
||||||
use futures::stream::{self, StreamExt};
|
|
||||||
|
|
||||||
// Simple mock provider for testing
|
|
||||||
struct MockProvider;
|
|
||||||
#[async_trait::async_trait]
|
|
||||||
impl Provider for MockProvider {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"mock"
|
|
||||||
}
|
|
||||||
fn supports_model(&self, _model: &str) -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
fn supports_multimodal(&self) -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
async fn chat_completion(
|
|
||||||
&self,
|
|
||||||
_req: crate::models::UnifiedRequest,
|
|
||||||
) -> Result<crate::providers::ProviderResponse, AppError> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
async fn chat_completion_stream(
|
|
||||||
&self,
|
|
||||||
_req: crate::models::UnifiedRequest,
|
|
||||||
) -> Result<futures::stream::BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
|
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
fn estimate_tokens(&self, _req: &crate::models::UnifiedRequest) -> Result<u32> {
|
|
||||||
Ok(10)
|
|
||||||
}
|
|
||||||
fn calculate_cost(&self, _model: &str, _p: u32, _c: u32, _cr: u32, _cw: u32, _r: &crate::models::registry::ModelRegistry) -> f64 {
|
|
||||||
0.05
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_aggregating_stream() {
|
|
||||||
let chunks = vec![
|
|
||||||
Ok(ProviderStreamChunk {
|
|
||||||
content: "Hello".to_string(),
|
|
||||||
reasoning_content: None,
|
|
||||||
finish_reason: None,
|
|
||||||
tool_calls: None,
|
|
||||||
model: "test".to_string(),
|
|
||||||
usage: None,
|
|
||||||
}),
|
|
||||||
Ok(ProviderStreamChunk {
|
|
||||||
content: " World".to_string(),
|
|
||||||
reasoning_content: None,
|
|
||||||
finish_reason: Some("stop".to_string()),
|
|
||||||
tool_calls: None,
|
|
||||||
model: "test".to_string(),
|
|
||||||
usage: None,
|
|
||||||
}),
|
|
||||||
];
|
|
||||||
let inner_stream = stream::iter(chunks);
|
|
||||||
|
|
||||||
let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
|
|
||||||
let (dashboard_tx, _) = tokio::sync::broadcast::channel(16);
|
|
||||||
let logger = Arc::new(RequestLogger::new(pool.clone(), dashboard_tx));
|
|
||||||
let registry = Arc::new(crate::models::registry::ModelRegistry {
|
|
||||||
providers: std::collections::HashMap::new(),
|
|
||||||
});
|
|
||||||
|
|
||||||
let mut agg_stream = AggregatingStream::new(
|
|
||||||
inner_stream,
|
|
||||||
StreamConfig {
|
|
||||||
client_id: "client_1".to_string(),
|
|
||||||
provider: Arc::new(MockProvider),
|
|
||||||
model: "test".to_string(),
|
|
||||||
prompt_tokens: 10,
|
|
||||||
has_images: false,
|
|
||||||
logger,
|
|
||||||
model_registry: registry,
|
|
||||||
model_config_cache: ModelConfigCache::new(pool.clone()),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
while let Some(item) = agg_stream.next().await {
|
|
||||||
assert!(item.is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
assert_eq!(agg_stream.accumulated_content, "Hello World");
|
|
||||||
assert!(agg_stream.has_logged);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
use crate::models::UnifiedRequest;
|
|
||||||
use tiktoken_rs::get_bpe_from_model;
|
|
||||||
|
|
||||||
/// Count tokens for a given model and text
|
|
||||||
pub fn count_tokens(model: &str, text: &str) -> u32 {
|
|
||||||
// If we can't get the bpe for the model, fallback to a safe default (cl100k_base for GPT-4/o1)
|
|
||||||
let bpe = get_bpe_from_model(model)
|
|
||||||
.unwrap_or_else(|_| tiktoken_rs::cl100k_base().expect("Failed to get cl100k_base encoding"));
|
|
||||||
|
|
||||||
bpe.encode_with_special_tokens(text).len() as u32
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Estimate tokens for a unified request.
|
|
||||||
/// Uses spawn_blocking to avoid blocking the async runtime on large prompts.
|
|
||||||
pub fn estimate_request_tokens(model: &str, request: &UnifiedRequest) -> u32 {
|
|
||||||
let mut total_text = String::new();
|
|
||||||
let msg_count = request.messages.len();
|
|
||||||
|
|
||||||
// Base tokens per message for OpenAI (approximate)
|
|
||||||
let tokens_per_message: u32 = 3;
|
|
||||||
|
|
||||||
for msg in &request.messages {
|
|
||||||
for part in &msg.content {
|
|
||||||
match part {
|
|
||||||
crate::models::ContentPart::Text { text } => {
|
|
||||||
total_text.push_str(text);
|
|
||||||
total_text.push('\n');
|
|
||||||
}
|
|
||||||
crate::models::ContentPart::Image { .. } => {
|
|
||||||
// Vision models usually have a fixed cost or calculation based on size
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Quick heuristic for small inputs (< 1KB) — avoid spawn_blocking overhead
|
|
||||||
if total_text.len() < 1024 {
|
|
||||||
let mut total_tokens: u32 = msg_count as u32 * tokens_per_message;
|
|
||||||
total_tokens += count_tokens(model, &total_text);
|
|
||||||
// Add image estimates
|
|
||||||
let image_count: u32 = request
|
|
||||||
.messages
|
|
||||||
.iter()
|
|
||||||
.flat_map(|m| m.content.iter())
|
|
||||||
.filter(|p| matches!(p, crate::models::ContentPart::Image { .. }))
|
|
||||||
.count() as u32;
|
|
||||||
total_tokens += image_count * 1000;
|
|
||||||
total_tokens += 3; // assistant reply header
|
|
||||||
return total_tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
// For large inputs, use the fast heuristic (chars / 4) to avoid blocking
|
|
||||||
// the async runtime. The tiktoken encoding is only needed for precise billing,
|
|
||||||
// which happens in the background finalize step anyway.
|
|
||||||
let estimated_text_tokens = (total_text.len() as u32) / 4;
|
|
||||||
let image_count: u32 = request
|
|
||||||
.messages
|
|
||||||
.iter()
|
|
||||||
.flat_map(|m| m.content.iter())
|
|
||||||
.filter(|p| matches!(p, crate::models::ContentPart::Image { .. }))
|
|
||||||
.count() as u32;
|
|
||||||
|
|
||||||
(msg_count as u32 * tokens_per_message) + estimated_text_tokens + (image_count * 1000) + 3
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Estimate tokens for completion text
|
|
||||||
pub fn estimate_completion_tokens(text: &str, model: &str) -> u32 {
|
|
||||||
count_tokens(model, text)
|
|
||||||
}
|
|
||||||
@@ -61,9 +61,9 @@
|
|||||||
--text-white: var(--fg0);
|
--text-white: var(--fg0);
|
||||||
|
|
||||||
/* Borders */
|
/* Borders */
|
||||||
--border-color: var(--bg2);
|
--border-color: var(--bg3);
|
||||||
--border-radius: 8px;
|
--border-radius: 0px;
|
||||||
--border-radius-sm: 4px;
|
--border-radius-sm: 0px;
|
||||||
|
|
||||||
/* Spacing System */
|
/* Spacing System */
|
||||||
--spacing-xs: 0.25rem;
|
--spacing-xs: 0.25rem;
|
||||||
@@ -72,15 +72,15 @@
|
|||||||
--spacing-lg: 1.5rem;
|
--spacing-lg: 1.5rem;
|
||||||
--spacing-xl: 2rem;
|
--spacing-xl: 2rem;
|
||||||
|
|
||||||
/* Shadows */
|
/* Shadows - Retro Block Style */
|
||||||
--shadow-sm: 0 1px 2px 0 rgba(0, 0, 0, 0.2);
|
--shadow-sm: 2px 2px 0px rgba(0, 0, 0, 0.4);
|
||||||
--shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.3);
|
--shadow: 4px 4px 0px rgba(0, 0, 0, 0.5);
|
||||||
--shadow-md: 0 10px 15px -3px rgba(0, 0, 0, 0.4);
|
--shadow-md: 6px 6px 0px rgba(0, 0, 0, 0.6);
|
||||||
--shadow-lg: 0 20px 25px -5px rgba(0, 0, 0, 0.5);
|
--shadow-lg: 8px 8px 0px rgba(0, 0, 0, 0.7);
|
||||||
}
|
}
|
||||||
|
|
||||||
body {
|
body {
|
||||||
font-family: 'Inter', -apple-system, sans-serif;
|
font-family: 'JetBrains Mono', 'Fira Code', 'Courier New', monospace;
|
||||||
background-color: var(--bg-primary);
|
background-color: var(--bg-primary);
|
||||||
color: var(--text-primary);
|
color: var(--text-primary);
|
||||||
line-height: 1.6;
|
line-height: 1.6;
|
||||||
@@ -105,12 +105,12 @@ body {
|
|||||||
|
|
||||||
.login-card {
|
.login-card {
|
||||||
background: var(--bg1);
|
background: var(--bg1);
|
||||||
border-radius: 24px;
|
border-radius: var(--border-radius);
|
||||||
padding: 4rem 2.5rem 3rem;
|
padding: 4rem 2.5rem 3rem;
|
||||||
width: 100%;
|
width: 100%;
|
||||||
max-width: 440px;
|
max-width: 440px;
|
||||||
box-shadow: var(--shadow-lg);
|
box-shadow: var(--shadow-lg);
|
||||||
border: 1px solid var(--bg2);
|
border: 2px solid var(--bg3);
|
||||||
text-align: center;
|
text-align: center;
|
||||||
animation: slideUp 0.6s cubic-bezier(0.34, 1.56, 0.64, 1);
|
animation: slideUp 0.6s cubic-bezier(0.34, 1.56, 0.64, 1);
|
||||||
position: relative;
|
position: relative;
|
||||||
@@ -148,22 +148,54 @@ body {
|
|||||||
width: 80px;
|
width: 80px;
|
||||||
height: 80px;
|
height: 80px;
|
||||||
margin: 0 auto 1.25rem;
|
margin: 0 auto 1.25rem;
|
||||||
border-radius: 16px;
|
|
||||||
background: var(--bg2);
|
|
||||||
display: flex;
|
display: flex;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
justify-content: center;
|
justify-content: center;
|
||||||
color: var(--orange);
|
background: rgba(254, 128, 25, 0.15);
|
||||||
font-size: 2rem;
|
color: var(--primary);
|
||||||
|
border-radius: 12px;
|
||||||
|
font-size: 2.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* GopherGate Logo Icon */
|
||||||
|
.logo-icon-container {
|
||||||
|
width: 60px;
|
||||||
|
height: 60px;
|
||||||
|
background: var(--blue-light);
|
||||||
|
border-radius: 12px;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
box-shadow: var(--shadow);
|
box-shadow: var(--shadow);
|
||||||
|
border: 2px solid var(--fg1);
|
||||||
|
margin: 0 auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.logo-icon-container.small {
|
||||||
|
width: 32px;
|
||||||
|
height: 32px;
|
||||||
|
border-radius: 6px;
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.logo-icon-text {
|
||||||
|
font-family: 'JetBrains Mono', monospace;
|
||||||
|
font-weight: 700;
|
||||||
|
color: var(--bg0);
|
||||||
|
font-size: 1.8rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.logo-icon-container.small .logo-icon-text {
|
||||||
|
font-size: 1rem;
|
||||||
}
|
}
|
||||||
|
|
||||||
.login-header h1 {
|
.login-header h1 {
|
||||||
font-size: 1.75rem;
|
font-size: 2rem;
|
||||||
font-weight: 800;
|
font-weight: 800;
|
||||||
color: var(--fg0);
|
color: var(--primary-light);
|
||||||
margin-bottom: 0.5rem;
|
margin-bottom: 0.5rem;
|
||||||
letter-spacing: -0.025em;
|
letter-spacing: -0.025em;
|
||||||
|
text-transform: uppercase;
|
||||||
}
|
}
|
||||||
|
|
||||||
.login-subtitle {
|
.login-subtitle {
|
||||||
@@ -191,7 +223,7 @@ body {
|
|||||||
color: var(--fg3);
|
color: var(--fg3);
|
||||||
pointer-events: none;
|
pointer-events: none;
|
||||||
transition: all 0.25s ease;
|
transition: all 0.25s ease;
|
||||||
background: var(--bg1);
|
background: transparent;
|
||||||
padding: 0 0.375rem;
|
padding: 0 0.375rem;
|
||||||
z-index: 2;
|
z-index: 2;
|
||||||
font-weight: 500;
|
font-weight: 500;
|
||||||
@@ -202,30 +234,32 @@ body {
|
|||||||
|
|
||||||
.form-group input:focus ~ label,
|
.form-group input:focus ~ label,
|
||||||
.form-group input:not(:placeholder-shown) ~ label {
|
.form-group input:not(:placeholder-shown) ~ label {
|
||||||
top: -0.625rem;
|
top: 0;
|
||||||
left: 0.875rem;
|
left: 0.875rem;
|
||||||
font-size: 0.7rem;
|
font-size: 0.75rem;
|
||||||
color: var(--orange);
|
color: var(--orange);
|
||||||
font-weight: 600;
|
font-weight: 600;
|
||||||
transform: translateY(0);
|
transform: translateY(-50%);
|
||||||
|
background: linear-gradient(180deg, var(--bg1) 50%, var(--bg0) 50%);
|
||||||
}
|
}
|
||||||
|
|
||||||
.form-group input {
|
.form-group input {
|
||||||
padding: 1rem 1.25rem;
|
padding: 1rem 1.25rem;
|
||||||
background: var(--bg0);
|
background: var(--bg0);
|
||||||
border: 2px solid var(--bg3);
|
border: 2px solid var(--bg3);
|
||||||
border-radius: 12px;
|
border-radius: var(--border-radius);
|
||||||
|
font-family: inherit;
|
||||||
font-size: 1rem;
|
font-size: 1rem;
|
||||||
color: var(--fg1);
|
color: var(--fg1);
|
||||||
transition: all 0.3s;
|
transition: all 0.2s;
|
||||||
width: 100%;
|
width: 100%;
|
||||||
box-sizing: border-box;
|
box-sizing: border-box;
|
||||||
}
|
}
|
||||||
|
|
||||||
.form-group input:focus {
|
.form-group input:focus {
|
||||||
border-color: var(--orange);
|
border-color: var(--orange);
|
||||||
box-shadow: 0 0 0 4px rgba(214, 93, 14, 0.2);
|
|
||||||
outline: none;
|
outline: none;
|
||||||
|
box-shadow: 4px 4px 0px rgba(214, 93, 14, 0.4);
|
||||||
}
|
}
|
||||||
|
|
||||||
.login-btn {
|
.login-btn {
|
||||||
@@ -295,6 +329,25 @@ body {
|
|||||||
font-size: 1.125rem;
|
font-size: 1.125rem;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Badges */
|
||||||
|
.badge {
|
||||||
|
display: inline-block;
|
||||||
|
padding: 0.25rem 0.5rem;
|
||||||
|
font-size: 0.75rem;
|
||||||
|
font-weight: 600;
|
||||||
|
line-height: 1;
|
||||||
|
text-align: center;
|
||||||
|
white-space: nowrap;
|
||||||
|
vertical-align: baseline;
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.badge-success { background-color: rgba(152, 151, 26, 0.15); color: var(--green-light); border: 1px solid var(--green); }
|
||||||
|
.badge-info { background-color: rgba(69, 133, 136, 0.15); color: var(--blue-light); border: 1px solid var(--blue); }
|
||||||
|
.badge-warning { background-color: rgba(215, 153, 33, 0.15); color: var(--yellow-light); border: 1px solid var(--yellow); }
|
||||||
|
.badge-danger { background-color: rgba(204, 36, 29, 0.15); color: var(--red-light); border: 1px solid var(--red); }
|
||||||
|
.badge-client { background-color: var(--bg2); color: var(--fg1); border: 1px solid var(--bg3); padding: 2px 6px; font-size: 0.7rem; text-transform: uppercase; }
|
||||||
|
|
||||||
/* Responsive Login */
|
/* Responsive Login */
|
||||||
@media (max-width: 480px) {
|
@media (max-width: 480px) {
|
||||||
.login-card {
|
.login-card {
|
||||||
@@ -373,11 +426,15 @@ body {
|
|||||||
}
|
}
|
||||||
|
|
||||||
.sidebar.collapsed .logo {
|
.sidebar.collapsed .logo {
|
||||||
|
display: flex;
|
||||||
|
}
|
||||||
|
|
||||||
|
.sidebar.collapsed .logo span {
|
||||||
display: none;
|
display: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
.sidebar.collapsed .sidebar-toggle {
|
.sidebar.collapsed .sidebar-toggle {
|
||||||
opacity: 1;
|
margin-left: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.logo {
|
.logo {
|
||||||
@@ -392,6 +449,7 @@ body {
|
|||||||
white-space: nowrap;
|
white-space: nowrap;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
.sidebar-logo {
|
.sidebar-logo {
|
||||||
width: 32px;
|
width: 32px;
|
||||||
height: 32px;
|
height: 32px;
|
||||||
@@ -586,17 +644,48 @@ body {
|
|||||||
|
|
||||||
/* Main Content Area */
|
/* Main Content Area */
|
||||||
.main-content {
|
.main-content {
|
||||||
margin-left: 260px;
|
padding-left: 260px;
|
||||||
flex: 1;
|
flex: 1;
|
||||||
min-height: 100vh;
|
min-height: 100vh;
|
||||||
transition: all 0.3s;
|
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
background-color: var(--bg-primary);
|
background-color: var(--bg-primary);
|
||||||
}
|
}
|
||||||
|
|
||||||
.sidebar.collapsed ~ .main-content {
|
.sidebar.collapsed + .main-content {
|
||||||
margin-left: 80px;
|
padding-left: 80px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.top-bar {
|
||||||
|
height: 70px;
|
||||||
|
background: var(--bg0);
|
||||||
|
border-bottom: 1px solid var(--bg2);
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: space-between;
|
||||||
|
padding: 0 var(--spacing-xl);
|
||||||
|
position: sticky;
|
||||||
|
top: 0;
|
||||||
|
z-index: 100;
|
||||||
|
}
|
||||||
|
|
||||||
|
.top-bar .page-title h2 {
|
||||||
|
font-size: 1.25rem;
|
||||||
|
font-weight: 700;
|
||||||
|
color: var(--fg0);
|
||||||
|
}
|
||||||
|
|
||||||
|
.top-bar-actions {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: var(--spacing-lg);
|
||||||
|
}
|
||||||
|
|
||||||
|
.content-body {
|
||||||
|
padding: var(--spacing-xl);
|
||||||
|
flex: 1;
|
||||||
|
position: relative;
|
||||||
}
|
}
|
||||||
|
|
||||||
.top-nav {
|
.top-nav {
|
||||||
@@ -732,11 +821,11 @@ body {
|
|||||||
.stat-change.positive { color: var(--green-light); }
|
.stat-change.positive { color: var(--green-light); }
|
||||||
.stat-change.negative { color: var(--red-light); }
|
.stat-change.negative { color: var(--red-light); }
|
||||||
|
|
||||||
/* Generic Cards */
|
/* Cards */
|
||||||
.card {
|
.card {
|
||||||
background: var(--bg1);
|
background: var(--bg1);
|
||||||
border-radius: var(--border-radius);
|
border-radius: var(--border-radius);
|
||||||
border: 1px solid var(--bg2);
|
border: 1px solid var(--bg3);
|
||||||
box-shadow: var(--shadow-sm);
|
box-shadow: var(--shadow-sm);
|
||||||
margin-bottom: 1.5rem;
|
margin-bottom: 1.5rem;
|
||||||
display: flex;
|
display: flex;
|
||||||
@@ -749,6 +838,15 @@ body {
|
|||||||
display: flex;
|
display: flex;
|
||||||
justify-content: space-between;
|
justify-content: space-between;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
gap: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.card-actions {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.5rem;
|
||||||
|
flex-wrap: wrap;
|
||||||
}
|
}
|
||||||
|
|
||||||
.card-title {
|
.card-title {
|
||||||
@@ -817,25 +915,26 @@ body {
|
|||||||
/* Badges */
|
/* Badges */
|
||||||
.status-badge {
|
.status-badge {
|
||||||
padding: 0.25rem 0.75rem;
|
padding: 0.25rem 0.75rem;
|
||||||
border-radius: 9999px;
|
border-radius: var(--border-radius);
|
||||||
font-size: 0.7rem;
|
font-size: 0.7rem;
|
||||||
font-weight: 700;
|
font-weight: 700;
|
||||||
text-transform: uppercase;
|
text-transform: uppercase;
|
||||||
display: inline-flex;
|
display: inline-flex;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
gap: 0.375rem;
|
gap: 0.375rem;
|
||||||
|
border: 1px solid transparent;
|
||||||
}
|
}
|
||||||
|
|
||||||
.status-badge.online, .status-badge.success { background: rgba(184, 187, 38, 0.2); color: var(--green-light); }
|
.status-badge.online, .status-badge.success { background: rgba(184, 187, 38, 0.2); color: var(--green-light); border-color: rgba(184, 187, 38, 0.4); }
|
||||||
.status-badge.offline, .status-badge.danger { background: rgba(251, 73, 52, 0.2); color: var(--red-light); }
|
.status-badge.offline, .status-badge.danger { background: rgba(251, 73, 52, 0.2); color: var(--red-light); border-color: rgba(251, 73, 52, 0.4); }
|
||||||
.status-badge.warning { background: rgba(250, 189, 47, 0.2); color: var(--yellow-light); }
|
.status-badge.warning { background: rgba(250, 189, 47, 0.2); color: var(--yellow-light); border-color: rgba(250, 189, 47, 0.4); }
|
||||||
|
|
||||||
.badge-client {
|
.badge-client {
|
||||||
background: var(--bg2);
|
background: var(--bg2);
|
||||||
color: var(--blue-light);
|
color: var(--blue-light);
|
||||||
padding: 2px 8px;
|
padding: 2px 8px;
|
||||||
border-radius: 6px;
|
border-radius: var(--border-radius);
|
||||||
font-family: monospace;
|
font-family: inherit;
|
||||||
font-size: 0.85rem;
|
font-size: 0.85rem;
|
||||||
border: 1px solid var(--bg3);
|
border: 1px solid var(--bg3);
|
||||||
}
|
}
|
||||||
@@ -889,7 +988,7 @@ body {
|
|||||||
width: 100%;
|
width: 100%;
|
||||||
background: var(--bg0);
|
background: var(--bg0);
|
||||||
border: 1px solid var(--bg3);
|
border: 1px solid var(--bg3);
|
||||||
border-radius: 8px;
|
border-radius: var(--border-radius);
|
||||||
padding: 0.75rem;
|
padding: 0.75rem;
|
||||||
font-family: inherit;
|
font-family: inherit;
|
||||||
font-size: 0.875rem;
|
font-size: 0.875rem;
|
||||||
@@ -900,7 +999,7 @@ body {
|
|||||||
.form-control input:focus, .form-control textarea:focus, .form-control select:focus {
|
.form-control input:focus, .form-control textarea:focus, .form-control select:focus {
|
||||||
outline: none;
|
outline: none;
|
||||||
border-color: var(--orange);
|
border-color: var(--orange);
|
||||||
box-shadow: 0 0 0 2px rgba(214, 93, 14, 0.2);
|
box-shadow: 2px 2px 0px rgba(214, 93, 14, 0.4);
|
||||||
}
|
}
|
||||||
|
|
||||||
.btn {
|
.btn {
|
||||||
@@ -908,21 +1007,27 @@ body {
|
|||||||
align-items: center;
|
align-items: center;
|
||||||
gap: 0.5rem;
|
gap: 0.5rem;
|
||||||
padding: 0.625rem 1.25rem;
|
padding: 0.625rem 1.25rem;
|
||||||
border-radius: 8px;
|
border-radius: var(--border-radius);
|
||||||
font-weight: 600;
|
font-weight: 600;
|
||||||
font-size: 0.875rem;
|
font-size: 0.875rem;
|
||||||
cursor: pointer;
|
cursor: pointer;
|
||||||
transition: all 0.2s;
|
transition: all 0.1s;
|
||||||
border: 1px solid transparent;
|
border: 1px solid transparent;
|
||||||
|
text-transform: uppercase;
|
||||||
}
|
}
|
||||||
|
|
||||||
.btn-primary { background: var(--orange); color: var(--bg0); }
|
.btn:active {
|
||||||
|
transform: translate(2px, 2px);
|
||||||
|
box-shadow: none !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
.btn-primary { background: var(--orange); color: var(--bg0); box-shadow: 2px 2px 0px var(--bg4); }
|
||||||
.btn-primary:hover { background: var(--orange-light); }
|
.btn-primary:hover { background: var(--orange-light); }
|
||||||
|
|
||||||
.btn-secondary { background: var(--bg2); border-color: var(--bg3); color: var(--fg1); }
|
.btn-secondary { background: var(--bg2); border-color: var(--bg3); color: var(--fg1); box-shadow: 2px 2px 0px var(--bg0); }
|
||||||
.btn-secondary:hover { background: var(--bg3); color: var(--fg0); }
|
.btn-secondary:hover { background: var(--bg3); color: var(--fg0); }
|
||||||
|
|
||||||
.btn-danger { background: var(--red); color: var(--fg0); }
|
.btn-danger { background: var(--red); color: var(--fg0); box-shadow: 2px 2px 0px var(--bg4); }
|
||||||
.btn-danger:hover { background: var(--red-light); }
|
.btn-danger:hover { background: var(--red-light); }
|
||||||
|
|
||||||
/* Small inline action buttons (edit, delete, copy) */
|
/* Small inline action buttons (edit, delete, copy) */
|
||||||
@@ -981,13 +1086,13 @@ body {
|
|||||||
|
|
||||||
.modal-content {
|
.modal-content {
|
||||||
background: var(--bg1);
|
background: var(--bg1);
|
||||||
border-radius: 16px;
|
border-radius: var(--border-radius);
|
||||||
width: 90%;
|
width: 90%;
|
||||||
max-width: 500px;
|
max-width: 500px;
|
||||||
box-shadow: var(--shadow-lg);
|
box-shadow: var(--shadow-lg);
|
||||||
border: 1px solid var(--bg3);
|
border: 2px solid var(--bg3);
|
||||||
transform: translateY(20px);
|
transform: translateY(20px);
|
||||||
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
|
transition: all 0.2s;
|
||||||
}
|
}
|
||||||
|
|
||||||
.modal.active .modal-content {
|
.modal.active .modal-content {
|
||||||
@@ -1029,6 +1134,53 @@ body {
|
|||||||
gap: 0.75rem;
|
gap: 0.75rem;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Connection Status Indicator */
|
||||||
|
.status-indicator {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.75rem;
|
||||||
|
padding: 0.5rem 0.875rem;
|
||||||
|
background: var(--bg1);
|
||||||
|
border: 1px solid var(--bg3);
|
||||||
|
border-radius: 6px;
|
||||||
|
font-size: 0.8rem;
|
||||||
|
font-weight: 600;
|
||||||
|
color: var(--fg3);
|
||||||
|
transition: all 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-dot {
|
||||||
|
width: 8px;
|
||||||
|
height: 8px;
|
||||||
|
border-radius: 50%;
|
||||||
|
background: var(--fg4);
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-dot.connected {
|
||||||
|
background: var(--green-light);
|
||||||
|
box-shadow: 0 0 0 0 rgba(184, 187, 38, 0.4);
|
||||||
|
animation: status-pulse 2s infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-dot.disconnected {
|
||||||
|
background: var(--red-light);
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-dot.connecting {
|
||||||
|
background: var(--yellow-light);
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-dot.error {
|
||||||
|
background: var(--red);
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes status-pulse {
|
||||||
|
0% { box-shadow: 0 0 0 0 rgba(184, 187, 38, 0.4); }
|
||||||
|
70% { box-shadow: 0 0 0 6px rgba(184, 187, 38, 0); }
|
||||||
|
100% { box-shadow: 0 0 0 0 rgba(184, 187, 38, 0); }
|
||||||
|
}
|
||||||
|
|
||||||
/* WebSocket Dot Pulse */
|
/* WebSocket Dot Pulse */
|
||||||
@keyframes ws-pulse {
|
@keyframes ws-pulse {
|
||||||
0% { box-shadow: 0 0 0 0 rgba(184, 187, 38, 0.4); }
|
0% { box-shadow: 0 0 0 0 rgba(184, 187, 38, 0.4); }
|
||||||
|
|||||||
BIN
static/favicon.ico
Normal file
BIN
static/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1002 B |
@@ -3,50 +3,38 @@
|
|||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8">
|
<meta charset="UTF-8">
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
<title>LLM Proxy Gateway - Admin Dashboard</title>
|
<title>GopherGate - Admin Dashboard</title>
|
||||||
<link rel="stylesheet" href="/css/dashboard.css?v=7">
|
<link rel="stylesheet" href="/css/dashboard.css?v=11">
|
||||||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
|
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
|
||||||
<link rel="icon" href="img/logo-icon.png" type="image/png" sizes="any">
|
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||||
<link rel="apple-touch-icon" href="img/logo-icon.png">
|
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||||
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet">
|
<link href="https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;700&family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
|
||||||
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
|
|
||||||
<script src="https://cdn.jsdelivr.net/npm/luxon@3.4.4/build/global/luxon.min.js"></script>
|
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body class="gruvbox-dark">
|
||||||
<!-- Login Screen -->
|
<!-- Auth Page -->
|
||||||
<div id="login-screen" class="login-container">
|
<div id="auth-page" class="login-container">
|
||||||
<div class="login-card">
|
<div class="login-card">
|
||||||
<div class="login-header">
|
<div class="login-header">
|
||||||
<img src="img/logo-full.png" alt="LLM Proxy Logo" class="login-logo" onerror="this.style.display='none'; this.nextElementSibling.style.display='block';">
|
<div class="logo-icon-container">
|
||||||
<i class="fas fa-robot login-logo-fallback" style="display: none;"></i>
|
<span class="logo-icon-text">GG</span>
|
||||||
<h1>LLM Proxy Gateway</h1>
|
</div>
|
||||||
<p class="login-subtitle">Admin Dashboard</p>
|
<h1>GopherGate</h1>
|
||||||
|
<p class="login-subtitle">Secure LLM Gateway & Management</p>
|
||||||
</div>
|
</div>
|
||||||
<form id="login-form" class="login-form">
|
<form id="login-form">
|
||||||
<div class="form-group">
|
<div class="form-control">
|
||||||
<input type="text" id="username" name="username" placeholder=" " required>
|
<label for="username">Username</label>
|
||||||
<label for="username">
|
<input type="text" id="username" name="username" required autocomplete="username">
|
||||||
<i class="fas fa-user"></i> Username
|
|
||||||
</label>
|
|
||||||
</div>
|
</div>
|
||||||
<div class="form-group">
|
<div class="form-control">
|
||||||
<input type="password" id="password" name="password" placeholder=" " required>
|
<label for="password">Password</label>
|
||||||
<label for="password">
|
<input type="password" id="password" name="password" required autocomplete="current-password">
|
||||||
<i class="fas fa-lock"></i> Password
|
|
||||||
</label>
|
|
||||||
</div>
|
|
||||||
<div class="form-group">
|
|
||||||
<button type="submit" class="login-btn">
|
|
||||||
<i class="fas fa-sign-in-alt"></i> Sign In
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
<div class="login-footer">
|
|
||||||
<p>Default: <code>admin</code> / <code>admin</code> (change in Settings > Security)</p>
|
|
||||||
</div>
|
</div>
|
||||||
|
<button type="submit" id="login-btn" class="btn btn-primary btn-block">Sign In</button>
|
||||||
</form>
|
</form>
|
||||||
<div id="login-error" class="error-message" style="display: none;">
|
<div id="login-error" class="error-message" style="display: none;">
|
||||||
<i class="fas fa-exclamation-circle"></i>
|
<i class="fas fa-exclamation-circle"></i>
|
||||||
<span>Invalid credentials. Please try again.</span>
|
<span></span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -57,9 +45,10 @@
|
|||||||
<nav class="sidebar">
|
<nav class="sidebar">
|
||||||
<div class="sidebar-header">
|
<div class="sidebar-header">
|
||||||
<div class="logo">
|
<div class="logo">
|
||||||
<img src="img/logo-icon.png" alt="LLM Proxy" class="sidebar-logo" onerror="this.style.display='none'; this.nextElementSibling.style.display='inline-block';">
|
<div class="logo-icon-container small">
|
||||||
<i class="fas fa-shield-alt logo-fallback" style="display: none;"></i>
|
<span class="logo-icon-text">GG</span>
|
||||||
<span>LLM Proxy</span>
|
</div>
|
||||||
|
<span>GopherGate</span>
|
||||||
</div>
|
</div>
|
||||||
<button class="sidebar-toggle" id="sidebar-toggle">
|
<button class="sidebar-toggle" id="sidebar-toggle">
|
||||||
<i class="fas fa-bars"></i>
|
<i class="fas fa-bars"></i>
|
||||||
@@ -69,68 +58,74 @@
|
|||||||
<div class="sidebar-menu">
|
<div class="sidebar-menu">
|
||||||
<div class="menu-section">
|
<div class="menu-section">
|
||||||
<h3 class="menu-title">MAIN</h3>
|
<h3 class="menu-title">MAIN</h3>
|
||||||
<a href="#overview" class="menu-item active" data-page="overview" data-tooltip="Dashboard Overview">
|
<ul class="menu-list">
|
||||||
<i class="fas fa-th-large"></i>
|
<li class="menu-item active" data-page="overview">
|
||||||
<span>Overview</span>
|
<i class="fas fa-th-large"></i>
|
||||||
</a>
|
<span>Overview</span>
|
||||||
<a href="#analytics" class="menu-item" data-page="analytics" data-tooltip="Usage Analytics">
|
</li>
|
||||||
<i class="fas fa-chart-line"></i>
|
<li class="menu-item" data-page="analytics">
|
||||||
<span>Analytics</span>
|
<i class="fas fa-chart-bar"></i>
|
||||||
</a>
|
<span>Analytics</span>
|
||||||
<a href="#costs" class="menu-item" data-page="costs" data-tooltip="Cost Tracking">
|
</li>
|
||||||
<i class="fas fa-dollar-sign"></i>
|
<li class="menu-item" data-page="costs">
|
||||||
<span>Cost Management</span>
|
<i class="fas fa-dollar-sign"></i>
|
||||||
</a>
|
<span>Costs & Billing</span>
|
||||||
|
</li>
|
||||||
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="menu-section">
|
<div class="menu-section">
|
||||||
<h3 class="menu-title">MANAGEMENT</h3>
|
<h3 class="menu-title">MANAGEMENT</h3>
|
||||||
<a href="#clients" class="menu-item" data-page="clients" data-tooltip="API Clients">
|
<ul class="menu-list">
|
||||||
<i class="fas fa-users"></i>
|
<li class="menu-item" data-page="clients">
|
||||||
<span>Client Management</span>
|
<i class="fas fa-users"></i>
|
||||||
</a>
|
<span>Clients</span>
|
||||||
<a href="#providers" class="menu-item" data-page="providers" data-tooltip="Model Providers">
|
</li>
|
||||||
<i class="fas fa-server"></i>
|
<li class="menu-item" data-page="providers">
|
||||||
<span>Providers</span>
|
<i class="fas fa-server"></i>
|
||||||
</a>
|
<span>Providers</span>
|
||||||
<a href="#models" class="menu-item" data-page="models" data-tooltip="Manage Models">
|
</li>
|
||||||
<i class="fas fa-cube"></i>
|
<li class="menu-item" data-page="models">
|
||||||
<span>Models</span>
|
<i class="fas fa-brain"></i>
|
||||||
</a>
|
<span>Models</span>
|
||||||
<a href="#monitoring" class="menu-item" data-page="monitoring" data-tooltip="Live Monitoring">
|
</li>
|
||||||
<i class="fas fa-heartbeat"></i>
|
</ul>
|
||||||
<span>Real-time Monitoring</span>
|
|
||||||
</a>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="menu-section">
|
<div class="menu-section">
|
||||||
<h3 class="menu-title">SYSTEM</h3>
|
<h3 class="menu-title">SYSTEM</h3>
|
||||||
<a href="#users" class="menu-item admin-only" data-page="users" data-tooltip="User Accounts">
|
<ul class="menu-list">
|
||||||
<i class="fas fa-user-shield"></i>
|
<li class="menu-item" data-page="monitoring">
|
||||||
<span>User Management</span>
|
<i class="fas fa-activity"></i>
|
||||||
</a>
|
<span>Live Monitoring</span>
|
||||||
<a href="#settings" class="menu-item admin-only" data-page="settings" data-tooltip="System Settings">
|
</li>
|
||||||
<i class="fas fa-cog"></i>
|
<li class="menu-item" data-page="logs">
|
||||||
<span>Settings</span>
|
<i class="fas fa-list-alt"></i>
|
||||||
</a>
|
<span>Logs</span>
|
||||||
<a href="#logs" class="menu-item" data-page="logs" data-tooltip="System Logs">
|
</li>
|
||||||
<i class="fas fa-list-alt"></i>
|
<li class="menu-item" data-page="users">
|
||||||
<span>System Logs</span>
|
<i class="fas fa-user-shield"></i>
|
||||||
</a>
|
<span>Admin Users</span>
|
||||||
|
</li>
|
||||||
|
<li class="menu-item" data-page="settings">
|
||||||
|
<i class="fas fa-cog"></i>
|
||||||
|
<span>Settings</span>
|
||||||
|
</li>
|
||||||
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div class="sidebar-footer">
|
<div class="sidebar-footer">
|
||||||
<div class="user-info">
|
<div class="user-info">
|
||||||
<div class="user-avatar">
|
<div class="user-avatar">
|
||||||
<i class="fas fa-user-circle"></i>
|
<i class="fas fa-user"></i>
|
||||||
</div>
|
</div>
|
||||||
<div class="user-details">
|
<div class="user-details">
|
||||||
<span class="user-name">Loading...</span>
|
<div class="user-name" id="display-username">Admin</div>
|
||||||
<span class="user-role">...</span>
|
<div class="user-role" id="display-role">Administrator</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<button class="logout-btn" id="logout-btn" title="Logout">
|
<button id="logout-btn" class="btn-icon" title="Logout">
|
||||||
<i class="fas fa-sign-out-alt"></i>
|
<i class="fas fa-sign-out-alt"></i>
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
@@ -138,43 +133,40 @@
|
|||||||
|
|
||||||
<!-- Main Content -->
|
<!-- Main Content -->
|
||||||
<main class="main-content">
|
<main class="main-content">
|
||||||
<!-- Top Navigation -->
|
<header class="top-bar">
|
||||||
<header class="top-nav">
|
<div class="page-title">
|
||||||
<div class="nav-left">
|
<h2 id="current-page-title">Overview</h2>
|
||||||
<h1 class="page-title" id="page-title">Dashboard Overview</h1>
|
|
||||||
</div>
|
</div>
|
||||||
<div class="nav-right">
|
<div class="top-bar-actions">
|
||||||
<div class="nav-item" id="ws-status-nav" title="WebSocket Connection Status">
|
<div id="connection-status" class="status-indicator">
|
||||||
<div class="ws-dot"></div>
|
<span class="status-dot"></span>
|
||||||
<span class="ws-text">Connecting...</span>
|
<span class="status-text">Disconnected</span>
|
||||||
</div>
|
</div>
|
||||||
<div class="nav-item" title="Refresh Current Page">
|
<div class="theme-toggle" id="theme-toggle">
|
||||||
<i class="fas fa-sync-alt" id="refresh-btn"></i>
|
<i class="fas fa-moon"></i>
|
||||||
</div>
|
|
||||||
<div class="nav-item">
|
|
||||||
<span id="current-time">Loading...</span>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</header>
|
</header>
|
||||||
|
|
||||||
<!-- Page Content -->
|
<div id="page-content" class="content-body">
|
||||||
<div class="page-content" id="page-content">
|
<!-- Content will be loaded dynamically -->
|
||||||
<!-- Dynamic content container -->
|
<div class="loader-container">
|
||||||
</div>
|
<div class="loader"></div>
|
||||||
|
</div>
|
||||||
<!-- Global Spinner -->
|
|
||||||
<div class="spinner-container">
|
|
||||||
<div class="spinner"></div>
|
|
||||||
</div>
|
</div>
|
||||||
</main>
|
</main>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<!-- Scripts (cache-busted with version query params) -->
|
<!-- Scripts -->
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/luxon@3.3.0/build/global/luxon.min.js"></script>
|
||||||
<script src="/js/api.js?v=7"></script>
|
<script src="/js/api.js?v=7"></script>
|
||||||
<script src="/js/auth.js?v=7"></script>
|
<script src="/js/auth.js?v=7"></script>
|
||||||
<script src="/js/dashboard.js?v=7"></script>
|
|
||||||
<script src="/js/websocket.js?v=7"></script>
|
|
||||||
<script src="/js/charts.js?v=7"></script>
|
<script src="/js/charts.js?v=7"></script>
|
||||||
|
<script src="/js/websocket.js?v=7"></script>
|
||||||
|
<script src="/js/dashboard.js?v=7"></script>
|
||||||
|
|
||||||
|
<!-- Page Modules -->
|
||||||
<script src="/js/pages/overview.js?v=7"></script>
|
<script src="/js/pages/overview.js?v=7"></script>
|
||||||
<script src="/js/pages/analytics.js?v=7"></script>
|
<script src="/js/pages/analytics.js?v=7"></script>
|
||||||
<script src="/js/pages/costs.js?v=7"></script>
|
<script src="/js/pages/costs.js?v=7"></script>
|
||||||
@@ -186,4 +178,4 @@
|
|||||||
<script src="/js/pages/logs.js?v=7"></script>
|
<script src="/js/pages/logs.js?v=7"></script>
|
||||||
<script src="/js/pages/users.js?v=7"></script>
|
<script src="/js/pages/users.js?v=7"></script>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
|
|||||||
@@ -32,6 +32,17 @@ class ApiClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!response.ok || !result.success) {
|
if (!response.ok || !result.success) {
|
||||||
|
// Handle authentication errors (session expired, server restarted, etc.)
|
||||||
|
if (response.status === 401 ||
|
||||||
|
result.error === 'Session expired or invalid' ||
|
||||||
|
result.error === 'Not authenticated' ||
|
||||||
|
result.error === 'Admin access required') {
|
||||||
|
|
||||||
|
if (window.authManager) {
|
||||||
|
// Try to logout to clear local state and show login screen
|
||||||
|
window.authManager.logout();
|
||||||
|
}
|
||||||
|
}
|
||||||
throw new Error(result.error || `HTTP error! status: ${response.status}`);
|
throw new Error(result.error || `HTTP error! status: ${response.status}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// Authentication Module for LLM Proxy Dashboard
|
// Authentication Module for GopherGate Dashboard
|
||||||
|
|
||||||
class AuthManager {
|
class AuthManager {
|
||||||
constructor() {
|
constructor() {
|
||||||
@@ -58,7 +58,7 @@ class AuthManager {
|
|||||||
|
|
||||||
async login(username, password) {
|
async login(username, password) {
|
||||||
const errorElement = document.getElementById('login-error');
|
const errorElement = document.getElementById('login-error');
|
||||||
const loginBtn = document.querySelector('.login-btn');
|
const loginBtn = document.getElementById('login-btn');
|
||||||
|
|
||||||
try {
|
try {
|
||||||
loginBtn.innerHTML = '<i class="fas fa-spinner fa-spin"></i> Authenticating...';
|
loginBtn.innerHTML = '<i class="fas fa-spinner fa-spin"></i> Authenticating...';
|
||||||
@@ -124,7 +124,7 @@ class AuthManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
showLogin() {
|
showLogin() {
|
||||||
const loginScreen = document.getElementById('login-screen');
|
const loginScreen = document.getElementById('auth-page');
|
||||||
const dashboard = document.getElementById('dashboard');
|
const dashboard = document.getElementById('dashboard');
|
||||||
|
|
||||||
if (loginScreen) loginScreen.style.display = 'flex';
|
if (loginScreen) loginScreen.style.display = 'flex';
|
||||||
@@ -139,7 +139,7 @@ class AuthManager {
|
|||||||
if (errorElement) errorElement.style.display = 'none';
|
if (errorElement) errorElement.style.display = 'none';
|
||||||
|
|
||||||
// Reset button
|
// Reset button
|
||||||
const loginBtn = document.querySelector('.login-btn');
|
const loginBtn = document.getElementById('login-btn');
|
||||||
if (loginBtn) {
|
if (loginBtn) {
|
||||||
loginBtn.innerHTML = '<i class="fas fa-sign-in-alt"></i> Sign In';
|
loginBtn.innerHTML = '<i class="fas fa-sign-in-alt"></i> Sign In';
|
||||||
loginBtn.disabled = false;
|
loginBtn.disabled = false;
|
||||||
@@ -147,7 +147,7 @@ class AuthManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
showDashboard() {
|
showDashboard() {
|
||||||
const loginScreen = document.getElementById('login-screen');
|
const loginScreen = document.getElementById('auth-page');
|
||||||
const dashboard = document.getElementById('dashboard');
|
const dashboard = document.getElementById('dashboard');
|
||||||
|
|
||||||
if (loginScreen) loginScreen.style.display = 'none';
|
if (loginScreen) loginScreen.style.display = 'none';
|
||||||
@@ -167,7 +167,7 @@ class AuthManager {
|
|||||||
const userRoleElement = document.querySelector('.user-role');
|
const userRoleElement = document.querySelector('.user-role');
|
||||||
|
|
||||||
if (userNameElement && this.user) {
|
if (userNameElement && this.user) {
|
||||||
userNameElement.textContent = this.user.name || this.user.username || 'User';
|
userNameElement.textContent = this.user.display_name || this.user.username || 'User';
|
||||||
}
|
}
|
||||||
|
|
||||||
if (userRoleElement && this.user) {
|
if (userRoleElement && this.user) {
|
||||||
|
|||||||
@@ -285,7 +285,30 @@ class Dashboard {
|
|||||||
<p class="card-subtitle">Manage model availability and custom pricing</p>
|
<p class="card-subtitle">Manage model availability and custom pricing</p>
|
||||||
</div>
|
</div>
|
||||||
<div class="card-actions">
|
<div class="card-actions">
|
||||||
<input type="text" id="model-search" placeholder="Search models..." class="form-control" style="margin-bottom: 0; padding: 4px 8px; width: 250px;">
|
<select id="model-provider-filter" class="form-control" style="margin-bottom: 0; padding: 4px 8px; width: auto;">
|
||||||
|
<option value="">All Providers</option>
|
||||||
|
<option value="openai">OpenAI</option>
|
||||||
|
<option value="anthropic">Anthropic / Gemini</option>
|
||||||
|
<option value="google">Google</option>
|
||||||
|
<option value="deepseek">DeepSeek</option>
|
||||||
|
<option value="xai">xAI</option>
|
||||||
|
<option value="meta">Meta</option>
|
||||||
|
<option value="cohere">Cohere</option>
|
||||||
|
<option value="mistral">Mistral</option>
|
||||||
|
<option value="other">Other</option>
|
||||||
|
</select>
|
||||||
|
<select id="model-modality-filter" class="form-control" style="margin-bottom: 0; padding: 4px 8px; width: auto;">
|
||||||
|
<option value="">All Modalities</option>
|
||||||
|
<option value="text">Text</option>
|
||||||
|
<option value="image">Vision/Image</option>
|
||||||
|
<option value="audio">Audio</option>
|
||||||
|
</select>
|
||||||
|
<select id="model-capability-filter" class="form-control" style="margin-bottom: 0; padding: 4px 8px; width: auto;">
|
||||||
|
<option value="">All Capabilities</option>
|
||||||
|
<option value="tool_call">Tool Calling</option>
|
||||||
|
<option value="reasoning">Reasoning</option>
|
||||||
|
</select>
|
||||||
|
<input type="text" id="model-search" placeholder="Search models..." class="form-control" style="margin-bottom: 0; padding: 4px 8px; width: 200px;">
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="table-container">
|
<div class="table-container">
|
||||||
|
|||||||
@@ -38,6 +38,24 @@ class LogsPage {
|
|||||||
const statusClass = log.status === 'success' ? 'success' : 'danger';
|
const statusClass = log.status === 'success' ? 'success' : 'danger';
|
||||||
const timestamp = luxon.DateTime.fromISO(log.timestamp).toFormat('yyyy-MM-dd HH:mm:ss');
|
const timestamp = luxon.DateTime.fromISO(log.timestamp).toFormat('yyyy-MM-dd HH:mm:ss');
|
||||||
|
|
||||||
|
let tokenDetails = `${log.tokens} total tokens`;
|
||||||
|
if (log.status === 'success') {
|
||||||
|
const parts = [];
|
||||||
|
parts.push(`${log.prompt_tokens} in`);
|
||||||
|
|
||||||
|
let completionStr = `${log.completion_tokens} out`;
|
||||||
|
if (log.reasoning_tokens > 0) {
|
||||||
|
completionStr += ` (${log.reasoning_tokens} reasoning)`;
|
||||||
|
}
|
||||||
|
parts.push(completionStr);
|
||||||
|
|
||||||
|
if (log.cache_read_tokens > 0) {
|
||||||
|
parts.push(`${log.cache_read_tokens} cache-hit`);
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenDetails = parts.join(', ');
|
||||||
|
}
|
||||||
|
|
||||||
return `
|
return `
|
||||||
<tr class="log-row">
|
<tr class="log-row">
|
||||||
<td class="whitespace-nowrap">${timestamp}</td>
|
<td class="whitespace-nowrap">${timestamp}</td>
|
||||||
@@ -55,7 +73,7 @@ class LogsPage {
|
|||||||
<td>
|
<td>
|
||||||
<div class="log-message-container">
|
<div class="log-message-container">
|
||||||
<code class="log-model">${log.model}</code>
|
<code class="log-model">${log.model}</code>
|
||||||
<span class="log-tokens">${log.tokens} tokens</span>
|
<span class="log-tokens" title="${log.tokens} total tokens">${tokenDetails}</span>
|
||||||
<span class="log-duration">${log.duration}ms</span>
|
<span class="log-duration">${log.duration}ms</span>
|
||||||
${log.error ? `<div class="log-error-msg">${log.error}</div>` : ''}
|
${log.error ? `<div class="log-error-msg">${log.error}</div>` : ''}
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -31,13 +31,58 @@ class ModelsPage {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const searchInput = document.getElementById('model-search');
|
||||||
|
const providerFilter = document.getElementById('model-provider-filter');
|
||||||
|
const modalityFilter = document.getElementById('model-modality-filter');
|
||||||
|
const capabilityFilter = document.getElementById('model-capability-filter');
|
||||||
|
|
||||||
|
const q = searchInput ? searchInput.value.toLowerCase() : '';
|
||||||
|
const providerVal = providerFilter ? providerFilter.value : '';
|
||||||
|
const modalityVal = modalityFilter ? modalityFilter.value : '';
|
||||||
|
const capabilityVal = capabilityFilter ? capabilityFilter.value : '';
|
||||||
|
|
||||||
|
// Apply filters non-destructively
|
||||||
|
let filteredModels = this.models.filter(m => {
|
||||||
|
// Text search
|
||||||
|
if (q && !(m.id.toLowerCase().includes(q) || m.name.toLowerCase().includes(q) || m.provider.toLowerCase().includes(q))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider filter
|
||||||
|
if (providerVal) {
|
||||||
|
if (providerVal === 'other') {
|
||||||
|
const known = ['openai', 'anthropic', 'google', 'deepseek', 'xai', 'meta', 'cohere', 'mistral'];
|
||||||
|
if (known.includes(m.provider.toLowerCase())) return false;
|
||||||
|
} else if (m.provider.toLowerCase() !== providerVal) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modality filter
|
||||||
|
if (modalityVal) {
|
||||||
|
const mods = m.modalities && m.modalities.input ? m.modalities.input.map(x => x.toLowerCase()) : [];
|
||||||
|
if (!mods.includes(modalityVal)) return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capability filter
|
||||||
|
if (capabilityVal === 'tool_call' && !m.tool_call) return false;
|
||||||
|
if (capabilityVal === 'reasoning' && !m.reasoning) return false;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (filteredModels.length === 0) {
|
||||||
|
tableBody.innerHTML = '<tr><td colspan="7" class="text-center">No models match the selected filters</td></tr>';
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Sort by provider then name
|
// Sort by provider then name
|
||||||
this.models.sort((a, b) => {
|
filteredModels.sort((a, b) => {
|
||||||
if (a.provider !== b.provider) return a.provider.localeCompare(b.provider);
|
if (a.provider !== b.provider) return a.provider.localeCompare(b.provider);
|
||||||
return a.name.localeCompare(b.name);
|
return a.name.localeCompare(b.name);
|
||||||
});
|
});
|
||||||
|
|
||||||
tableBody.innerHTML = this.models.map(model => {
|
tableBody.innerHTML = filteredModels.map(model => {
|
||||||
const statusClass = model.enabled ? 'success' : 'secondary';
|
const statusClass = model.enabled ? 'success' : 'secondary';
|
||||||
const statusIcon = model.enabled ? 'check-circle' : 'ban';
|
const statusIcon = model.enabled ? 'check-circle' : 'ban';
|
||||||
|
|
||||||
@@ -99,6 +144,14 @@ class ModelsPage {
|
|||||||
<input type="number" id="model-completion-cost" value="${model.completion_cost}" step="0.01">
|
<input type="number" id="model-completion-cost" value="${model.completion_cost}" step="0.01">
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
<div class="form-control">
|
||||||
|
<label for="model-cache-read-cost">Cache Read Cost (per 1M tokens)</label>
|
||||||
|
<input type="number" id="model-cache-read-cost" value="${model.cache_read_cost || 0}" step="0.01">
|
||||||
|
</div>
|
||||||
|
<div class="form-control">
|
||||||
|
<label for="model-cache-write-cost">Cache Write Cost (per 1M tokens)</label>
|
||||||
|
<input type="number" id="model-cache-write-cost" value="${model.cache_write_cost || 0}" step="0.01">
|
||||||
|
</div>
|
||||||
<div class="form-control">
|
<div class="form-control">
|
||||||
<label for="model-mapping">Internal Mapping (Optional)</label>
|
<label for="model-mapping">Internal Mapping (Optional)</label>
|
||||||
<input type="text" id="model-mapping" value="${model.mapping || ''}" placeholder="e.g. gpt-4o-2024-05-13">
|
<input type="text" id="model-mapping" value="${model.mapping || ''}" placeholder="e.g. gpt-4o-2024-05-13">
|
||||||
@@ -118,6 +171,8 @@ class ModelsPage {
|
|||||||
const enabled = modal.querySelector('#model-enabled').checked;
|
const enabled = modal.querySelector('#model-enabled').checked;
|
||||||
const promptCost = parseFloat(modal.querySelector('#model-prompt-cost').value);
|
const promptCost = parseFloat(modal.querySelector('#model-prompt-cost').value);
|
||||||
const completionCost = parseFloat(modal.querySelector('#model-completion-cost').value);
|
const completionCost = parseFloat(modal.querySelector('#model-completion-cost').value);
|
||||||
|
const cacheReadCost = parseFloat(modal.querySelector('#model-cache-read-cost').value);
|
||||||
|
const cacheWriteCost = parseFloat(modal.querySelector('#model-cache-write-cost').value);
|
||||||
const mapping = modal.querySelector('#model-mapping').value;
|
const mapping = modal.querySelector('#model-mapping').value;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@@ -125,6 +180,8 @@ class ModelsPage {
|
|||||||
enabled,
|
enabled,
|
||||||
prompt_cost: promptCost,
|
prompt_cost: promptCost,
|
||||||
completion_cost: completionCost,
|
completion_cost: completionCost,
|
||||||
|
cache_read_cost: isNaN(cacheReadCost) ? null : cacheReadCost,
|
||||||
|
cache_write_cost: isNaN(cacheWriteCost) ? null : cacheWriteCost,
|
||||||
mapping: mapping || null
|
mapping: mapping || null
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -138,27 +195,18 @@ class ModelsPage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
setupEventListeners() {
|
setupEventListeners() {
|
||||||
const searchInput = document.getElementById('model-search');
|
const attachFilter = (id) => {
|
||||||
if (searchInput) {
|
const el = document.getElementById(id);
|
||||||
searchInput.oninput = (e) => this.filterModels(e.target.value);
|
if (el) {
|
||||||
}
|
el.addEventListener('input', () => this.renderModelsTable());
|
||||||
}
|
el.addEventListener('change', () => this.renderModelsTable());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
filterModels(query) {
|
attachFilter('model-search');
|
||||||
if (!query) {
|
attachFilter('model-provider-filter');
|
||||||
this.renderModelsTable();
|
attachFilter('model-modality-filter');
|
||||||
return;
|
attachFilter('model-capability-filter');
|
||||||
}
|
|
||||||
|
|
||||||
const q = query.toLowerCase();
|
|
||||||
const originalModels = this.models;
|
|
||||||
this.models = this.models.filter(m =>
|
|
||||||
m.id.toLowerCase().includes(q) ||
|
|
||||||
m.name.toLowerCase().includes(q) ||
|
|
||||||
m.provider.toLowerCase().includes(q)
|
|
||||||
);
|
|
||||||
this.renderModelsTable();
|
|
||||||
this.models = originalModels;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -492,7 +492,7 @@ class MonitoringPage {
|
|||||||
simulateRequest() {
|
simulateRequest() {
|
||||||
const clients = ['client-1', 'client-2', 'client-3', 'client-4', 'client-5'];
|
const clients = ['client-1', 'client-2', 'client-3', 'client-4', 'client-5'];
|
||||||
const providers = ['OpenAI', 'Gemini', 'DeepSeek', 'Grok'];
|
const providers = ['OpenAI', 'Gemini', 'DeepSeek', 'Grok'];
|
||||||
const models = ['gpt-4', 'gpt-3.5-turbo', 'gemini-pro', 'deepseek-chat', 'grok-beta'];
|
const models = ['gpt-4o', 'gpt-4o-mini', 'gemini-2.0-flash', 'deepseek-chat', 'grok-4-1-fast-non-reasoning'];
|
||||||
const statuses = ['success', 'success', 'success', 'error', 'warning']; // Mostly success
|
const statuses = ['success', 'success', 'success', 'error', 'warning']; // Mostly success
|
||||||
|
|
||||||
const request = {
|
const request = {
|
||||||
|
|||||||
@@ -248,21 +248,19 @@ class WebSocketManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
updateStatus(status) {
|
updateStatus(status) {
|
||||||
const statusElement = document.getElementById('ws-status-nav');
|
const statusElement = document.getElementById('connection-status');
|
||||||
if (!statusElement) return;
|
if (!statusElement) return;
|
||||||
|
|
||||||
const dot = statusElement.querySelector('.ws-dot');
|
const dot = statusElement.querySelector('.status-dot');
|
||||||
const text = statusElement.querySelector('.ws-text');
|
const text = statusElement.querySelector('.status-text');
|
||||||
|
|
||||||
if (!dot || !text) return;
|
if (!dot || !text) return;
|
||||||
|
|
||||||
// Remove all status classes
|
// Remove all status classes
|
||||||
dot.classList.remove('connected', 'disconnected');
|
dot.classList.remove('connected', 'disconnected', 'error', 'connecting');
|
||||||
statusElement.classList.remove('connected', 'disconnected');
|
|
||||||
|
|
||||||
// Add new status class
|
// Add new status class
|
||||||
dot.classList.add(status);
|
dot.classList.add(status);
|
||||||
statusElement.classList.add(status);
|
|
||||||
|
|
||||||
// Update text
|
// Update text
|
||||||
const statusText = {
|
const statusText = {
|
||||||
|
|||||||
@@ -1,38 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Test script for LLM Proxy Dashboard
|
|
||||||
|
|
||||||
echo "Building LLM Proxy Gateway..."
|
|
||||||
cargo build --release
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "Starting server in background..."
|
|
||||||
./target/release/llm-proxy &
|
|
||||||
SERVER_PID=$!
|
|
||||||
|
|
||||||
# Wait for server to start
|
|
||||||
sleep 3
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "Testing dashboard endpoints..."
|
|
||||||
|
|
||||||
# Test health endpoint
|
|
||||||
echo "1. Testing health endpoint:"
|
|
||||||
curl -s http://localhost:8080/health
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "2. Testing dashboard static files:"
|
|
||||||
curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "3. Testing API endpoints:"
|
|
||||||
curl -s http://localhost:8080/api/auth/status | jq . 2>/dev/null || echo "JSON response received"
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "Dashboard should be available at: http://localhost:8080"
|
|
||||||
echo "Default login: admin / admin"
|
|
||||||
echo ""
|
|
||||||
echo "Press Ctrl+C to stop the server"
|
|
||||||
|
|
||||||
# Keep script running
|
|
||||||
wait $SERVER_PID
|
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Test script for LLM Proxy Gateway
|
|
||||||
|
|
||||||
echo "Building LLM Proxy Gateway..."
|
|
||||||
cargo build --release
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Build failed!"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "Build successful!"
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "Project Structure Summary:"
|
|
||||||
echo "=========================="
|
|
||||||
echo "Core Components:"
|
|
||||||
echo " - main.rs: Application entry point with server setup"
|
|
||||||
echo " - config/: Configuration management"
|
|
||||||
echo " - server/: API route handlers"
|
|
||||||
echo " - auth/: Bearer token authentication"
|
|
||||||
echo " - database/: SQLite database setup"
|
|
||||||
echo " - models/: Data structures (OpenAI-compatible)"
|
|
||||||
echo " - providers/: LLM provider implementations (OpenAI, Gemini, DeepSeek, Grok)"
|
|
||||||
echo " - errors/: Custom error types"
|
|
||||||
echo " - dashboard/: Admin dashboard with WebSocket support"
|
|
||||||
echo " - logging/: Request logging middleware"
|
|
||||||
echo " - state/: Shared application state"
|
|
||||||
echo " - multimodal/: Image processing support (basic structure)"
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "Key Features Implemented:"
|
|
||||||
echo "=========================="
|
|
||||||
echo "✓ OpenAI-compatible API endpoint (/v1/chat/completions)"
|
|
||||||
echo "✓ Bearer token authentication"
|
|
||||||
echo "✓ SQLite database for request tracking"
|
|
||||||
echo "✓ Request logging with token/cost calculation"
|
|
||||||
echo "✓ Provider abstraction layer"
|
|
||||||
echo "✓ Admin dashboard with real-time monitoring"
|
|
||||||
echo "✓ WebSocket support for live updates"
|
|
||||||
echo "✓ Configuration management (config.toml, .env, env vars)"
|
|
||||||
echo "✓ Multimodal support structure (images)"
|
|
||||||
echo "✓ Error handling with proper HTTP status codes"
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "Next Steps Needed:"
|
|
||||||
echo "=================="
|
|
||||||
echo "1. Add API keys to .env file:"
|
|
||||||
echo " OPENAI_API_KEY=your_key_here"
|
|
||||||
echo " GEMINI_API_KEY=your_key_here"
|
|
||||||
echo " DEEPSEEK_API_KEY=your_key_here"
|
|
||||||
echo " GROK_API_KEY=your_key_here (optional)"
|
|
||||||
echo ""
|
|
||||||
echo "2. Create config.toml for custom configuration (optional)"
|
|
||||||
echo ""
|
|
||||||
echo "3. Run the server:"
|
|
||||||
echo " cargo run"
|
|
||||||
echo ""
|
|
||||||
echo "4. Access dashboard at: http://localhost:8080"
|
|
||||||
echo ""
|
|
||||||
echo "5. Test API with curl:"
|
|
||||||
echo " curl -X POST http://localhost:8080/v1/chat/completions \\"
|
|
||||||
echo " -H 'Authorization: Bearer your_token' \\"
|
|
||||||
echo " -H 'Content-Type: application/json' \\"
|
|
||||||
echo " -d '{\"model\": \"gpt-4\", \"messages\": [{\"role\": \"user\", \"content\": \"Hello\"}]}'"
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "Deployment Notes:"
|
|
||||||
echo "================="
|
|
||||||
echo "Memory: Designed for 512MB RAM (LXC container)"
|
|
||||||
echo "Database: SQLite (./data/llm_proxy.db)"
|
|
||||||
echo "Port: 8080 (configurable)"
|
|
||||||
echo "Authentication: Single Bearer token (configurable)"
|
|
||||||
echo "Providers: OpenAI, Gemini, DeepSeek, Grok (disabled by default)"
|
|
||||||
14
timeline.mmd
14
timeline.mmd
@@ -1,14 +0,0 @@
|
|||||||
gantt
|
|
||||||
title LLM Proxy Project Timeline
|
|
||||||
dateFormat YYYY-MM-DD
|
|
||||||
section Frontend
|
|
||||||
Standardize Escaping (users.js) :a1, 2026-03-06, 1d
|
|
||||||
section Backend Cleanup
|
|
||||||
Remove Unused Imports :b1, 2026-03-06, 1d
|
|
||||||
section HMAC Migration
|
|
||||||
Architecture Design :c1, 2026-03-07, 1d
|
|
||||||
Backend Implementation :c2, after c1, 2d
|
|
||||||
Session Refresh Logic :c3, after c2, 1d
|
|
||||||
section Testing
|
|
||||||
Integration Test (Encrypted Keys) :d1, 2026-03-09, 2d
|
|
||||||
HMAC Verification Tests :d2, after c3, 1d
|
|
||||||
Reference in New Issue
Block a user