Compare commits

..

1 Commits

Author SHA1 Message Date
6b10d4249c feat: migrate backend from rust to go
Some checks failed
CI / Check (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Formatting (push) Has been cancelled
CI / Test (push) Has been cancelled
CI / Release Build (push) Has been cancelled
This commit replaces the Axum/Rust backend with a Gin/Go implementation. The original Rust code has been archived in the 'rust' branch.
2026-03-19 10:30:05 -04:00
69 changed files with 3460 additions and 15096 deletions

View File

@@ -1,31 +1,28 @@
# LLM Proxy Gateway Environment Variables
# Copy to .env and fill in your API keys
# OpenAI
# MANDATORY: Encryption key for sessions and stored API keys
# Must be a 32-byte hex or base64 encoded string
# Example (hex): LLM_PROXY__ENCRYPTION_KEY=0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef
LLM_PROXY__ENCRYPTION_KEY=your_secure_32_byte_key_here
# LLM Provider API Keys (Standard Environment Variables)
OPENAI_API_KEY=your_openai_api_key_here
# Google Gemini
GEMINI_API_KEY=your_gemini_api_key_here
# DeepSeek
DEEPSEEK_API_KEY=your_deepseek_api_key_here
# xAI Grok (not yet available)
GROK_API_KEY=your_grok_api_key_here
# Ollama (local server)
# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://your-ollama-host:11434/v1
# Provider Overrides (Optional)
# LLM_PROXY__PROVIDERS__OPENAI__BASE_URL=https://api.openai.com/v1
# LLM_PROXY__PROVIDERS__GEMINI__ENABLED=true
# LLM_PROXY__PROVIDERS__OLLAMA__BASE_URL=http://localhost:11434/v1
# LLM_PROXY__PROVIDERS__OLLAMA__ENABLED=true
# LLM_PROXY__PROVIDERS__OLLAMA__MODELS=llama3,mistral,llava
# Authentication tokens (comma-separated list)
LLM_PROXY__SERVER__AUTH_TOKENS=your_bearer_token_here,another_token
# Server port (optional)
# Server Configuration
LLM_PROXY__SERVER__PORT=8080
LLM_PROXY__SERVER__HOST=0.0.0.0
# Database path (optional)
# Database Configuration
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
# Session secret for HMAC-signed tokens (hex or base64 encoded, 32 bytes)
SESSION_SECRET=your_session_secret_here_32_bytes
LLM_PROXY__DATABASE__MAX_CONNECTIONS=10

61
BACKEND_ARCHITECTURE.md Normal file
View File

@@ -0,0 +1,61 @@
# Backend Architecture (Go)
The LLM Proxy backend is implemented in Go, focusing on high performance, clear concurrency patterns, and maintainability.
## Core Technologies
- **Runtime:** Go 1.22+
- **Web Framework:** [Gin Gonic](https://github.com/gin-gonic/gin) - Fast and lightweight HTTP routing.
- **Database:** [sqlx](https://github.com/jmoiron/sqlx) - Lightweight wrapper for standard `database/sql`.
- **SQLite Driver:** [modernc.org/sqlite](https://modernc.org/sqlite) - CGO-free SQLite implementation for ease of cross-compilation.
- **Config:** [Viper](https://github.com/spf13/viper) - Robust configuration management supporting environment variables and files.
## Project Structure
```text
├── cmd/
│ └── llm-proxy/ # Entry point (main.go)
├── internal/
│ ├── config/ # Configuration loading and validation
│ ├── db/ # Database schema, migrations, and models
│ ├── middleware/ # Auth and logging middleware
│ ├── models/ # Unified request/response structs
│ ├── providers/ # LLM provider implementations (OpenAI, Gemini, etc.)
│ ├── server/ # HTTP server, dashboard handlers, and WebSocket hub
│ └── utils/ # Common utilities (multimodal, etc.)
└── static/ # Frontend assets (served by the backend)
```
## Key Components
### 1. Provider Interface (`internal/providers/provider.go`)
Standardized interface for all LLM backends:
```go
type Provider interface {
Name() string
ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error)
ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error)
}
```
### 2. Asynchronous Logging (`internal/server/logging.go`)
Uses a buffered channel and background worker to log every request to SQLite without blocking the client response. It also broadcasts logs to the WebSocket hub for real-time dashboard updates.
### 3. Session Management (`internal/server/sessions.go`)
Implements HMAC-SHA256 signed tokens for dashboard authentication. Sessions are stored in-memory with configurable TTL.
### 4. WebSocket Hub (`internal/server/websocket.go`)
A centralized hub for managing WebSocket connections, allowing real-time broadcast of system events and request logs to the dashboard.
## Concurrency Model
Go's goroutines and channels are used extensively:
- **Streaming:** Each streaming request uses a goroutine to read and parse the provider's response, feeding chunks into a channel.
- **Logging:** A single background worker processes the `logChan` to perform database writes.
- **WebSocket:** The `Hub` runs in a dedicated goroutine, handling registration and broadcasting.
## Security
- **Encryption Key:** A mandatory 32-byte key is used for both session signing and encryption of sensitive data in the database.
- **Auth Middleware:** Verifies client API keys against the database before proxying requests to LLM providers.
- **Bcrypt:** Passwords for dashboard users are hashed using Bcrypt with a work factor of 12.

4139
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,75 +0,0 @@
[package]
name = "llm-proxy"
version = "0.1.0"
edition = "2024"
rust-version = "1.87"
description = "Unified LLM proxy gateway supporting OpenAI, Gemini, DeepSeek, and Grok with token tracking and cost calculation"
authors = ["newkirk"]
license = "MIT OR Apache-2.0"
repository = ""
[dependencies]
# ========== Web Framework & Async Runtime ==========
axum = { version = "0.8", features = ["macros", "ws"] }
tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "net", "time", "signal", "fs"] }
tower = "0.5"
tower-http = { version = "0.6", features = ["trace", "cors", "compression-gzip", "fs", "set-header", "limit"] }
governor = "0.7"
# ========== HTTP Clients ==========
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
tiktoken-rs = "0.9"
# ========== Database & ORM ==========
sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite", "macros", "migrate", "chrono"] }
# ========== Authentication & Middleware ==========
axum-extra = { version = "0.12", features = ["typed-header"] }
headers = "0.4"
# ========== Configuration Management ==========
config = "0.13"
dotenvy = "0.15"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
toml = "0.8"
# ========== Logging & Monitoring ==========
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
# ========== Multimodal & Image Processing ==========
base64 = "0.21"
image = { version = "0.25", default-features = false, features = ["jpeg", "png", "webp"] }
mime = "0.3"
# ========== Error Handling & Utilities ==========
anyhow = "1.0"
thiserror = "1.0"
bcrypt = "0.15"
aes-gcm = "0.10"
hmac = "0.12"
sha2 = "0.10"
chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1.0", features = ["v4", "serde"] }
futures = "0.3"
async-trait = "0.1"
async-stream = "0.3"
reqwest-eventsource = "0.6"
rand = "0.9"
hex = "0.4"
[dev-dependencies]
tokio-test = "0.4"
mockito = "1.0"
tempfile = "3.10"
assert_cmd = "2.0"
insta = "1.39"
anyhow = "1.0"
[profile.release]
opt-level = 3
lto = true
codegen-units = 1
strip = true
panic = "abort"

View File

@@ -1,35 +1,34 @@
# ── Build stage ──────────────────────────────────────────────
FROM rust:1-bookworm AS builder
# Build stage
FROM golang:1.22-alpine AS builder
WORKDIR /app
# Cache dependency build
COPY Cargo.toml Cargo.lock ./
RUN mkdir src && echo 'fn main() {}' > src/main.rs && \
cargo build --release && \
rm -rf src
# Copy go mod and sum files
COPY go.mod go.sum ./
RUN go mod download
# Build the actual binary
COPY src/ src/
RUN touch src/main.rs && cargo build --release
# Copy the source code
COPY . .
# ── Runtime stage ────────────────────────────────────────────
FROM debian:bookworm-slim
# Build the application
RUN CGO_ENABLED=0 GOOS=linux go build -o llm-proxy ./cmd/llm-proxy
RUN apt-get update && \
apt-get install -y --no-install-recommends ca-certificates && \
rm -rf /var/lib/apt/lists/*
# Final stage
FROM alpine:latest
RUN apk --no-cache add ca-certificates tzdata
WORKDIR /app
COPY --from=builder /app/target/release/llm-proxy /app/llm-proxy
COPY static/ /app/static/
# Copy the binary from the builder stage
COPY --from=builder /app/llm-proxy .
COPY --from=builder /app/static ./static
# Default config location
VOLUME ["/app/config", "/app/data"]
# Create data directory
RUN mkdir -p /app/data
# Expose port
EXPOSE 8080
ENV RUST_LOG=info
ENTRYPOINT ["/app/llm-proxy"]
# Run the application
CMD ["./llm-proxy"]

View File

@@ -1,114 +1,108 @@
# LLM Proxy Gateway
A unified, high-performance LLM proxy gateway built in Rust. It provides a single OpenAI-compatible API to access multiple providers (OpenAI, Gemini, DeepSeek, Grok, Ollama) with built-in token tracking, real-time cost calculation, multi-user authentication, and a management dashboard.
A unified, high-performance LLM proxy gateway built in Go. It provides a single OpenAI-compatible API to access multiple providers (OpenAI, Gemini, DeepSeek, Grok, Ollama) with built-in token tracking, real-time cost calculation, multi-user authentication, and a management dashboard.
## Features
- **Unified API:** OpenAI-compatible `/v1/chat/completions` and `/v1/models` endpoints.
- **Multi-Provider Support:**
- **OpenAI:** GPT-4o, GPT-4o Mini, o1, o3 reasoning models.
- **Google Gemini:** Gemini 2.0 Flash, Pro, and vision models.
- **DeepSeek:** DeepSeek Chat and Reasoner models.
- **Google Gemini:** Gemini 2.0 Flash, Pro, and vision models (with native CoT support).
- **DeepSeek:** DeepSeek Chat and Reasoner (R1) models.
- **xAI Grok:** Grok-beta models.
- **Ollama:** Local LLMs running on your network.
- **Observability & Tracking:**
- **Real-time Costing:** Fetches live pricing and context specs from `models.dev` on startup.
- **Token Counting:** Precise estimation using `tiktoken-rs`.
- **Database Logging:** Every request logged to SQLite for historical analysis.
- **Streaming Support:** Full SSE (Server-Sent Events) with `[DONE]` termination for client compatibility.
- **Asynchronous Logging:** Non-blocking request logging to SQLite using background workers.
- **Token Counting:** Precise estimation and tracking of prompt, completion, and reasoning tokens.
- **Database Persistence:** Every request logged to SQLite for historical analysis and dashboard analytics.
- **Streaming Support:** Full SSE (Server-Sent Events) support for all providers.
- **Multimodal (Vision):** Image processing (Base64 and remote URLs) across compatible providers.
- **Multi-User Access Control:**
- **Admin Role:** Full access to all dashboard features, user management, and system configuration.
- **Viewer Role:** Read-only access to usage analytics, costs, and monitoring.
- **Client API Keys:** Create and manage multiple client tokens for external integrations.
- **Reliability:**
- **Circuit Breaking:** Automatically protects when providers are down.
- **Rate Limiting:** Per-client and global rate limits.
- **Cache-Aware Costing:** Tracks cache hit/miss tokens for accurate billing.
- **Circuit Breaking:** Automatically protects when providers are down (coming soon).
- **Rate Limiting:** Per-client and global rate limits (coming soon).
## Security
LLM Proxy is designed with security in mind:
- **HMAC Session Tokens:** Management dashboard sessions are secured using HMAC-SHA256 signed tokens.
- **Encrypted Provider Keys:** Sensitive LLM provider API keys are stored encrypted (AES-256-GCM) in the database.
- **Session Refresh:** Activity-based session extension prevents session hijacking while maintaining user convenience.
- **XSS Prevention:** Standardized frontend escaping using `window.api.escapeHtml`.
- **Signed Session Tokens:** Management dashboard sessions are secured using HMAC-SHA256 signed tokens.
- **Encrypted Storage:** Support for encrypted provider API keys in the database.
- **Auth Middleware:** Secure client authentication via database-backed API keys.
**Note:** You must define a `SESSION_SECRET` in your `.env` file for secure session signing.
**Note:** You must define an `LLM_PROXY__ENCRYPTION_KEY` in your `.env` file for secure session signing and encryption.
## Tech Stack
- **Runtime:** Rust with Tokio.
- **Web Framework:** Axum.
- **Database:** SQLx with SQLite.
- **Frontend:** Vanilla JS/CSS with Chart.js for visualizations.
- **Runtime:** Go 1.22+
- **Web Framework:** Gin Gonic
- **Database:** sqlx with SQLite (CGO-free via `modernc.org/sqlite`)
- **Frontend:** Vanilla JS/CSS with Chart.js for visualizations
## Getting Started
### Prerequisites
- Rust (1.80+)
- SQLite3
- Go (1.22+)
- SQLite3 (optional, driver is built-in)
- Docker (optional, for containerized deployment)
### Quick Start
1. Clone and build:
```bash
git clone ssh://git.dustin.coffee:2222/hobokenchicken/llm-proxy.git
git clone <repository-url>
cd llm-proxy
cargo build --release
go build -o llm-proxy ./cmd/llm-proxy
```
2. Configure environment:
```bash
cp .env.example .env
# Edit .env and add your API keys:
# SESSION_SECRET=... (Generate a strong random secret)
# Edit .env and add your configuration:
# LLM_PROXY__ENCRYPTION_KEY=... (32-byte hex or base64 string)
# OPENAI_API_KEY=sk-...
# GEMINI_API_KEY=AIza...
```
3. Run the proxy:
```bash
cargo run --release
./llm-proxy
```
The server starts on `http://localhost:8080` by default.
The server starts on `http://0.0.0.0:8080` by default.
### Deployment (Docker)
A multi-stage `Dockerfile` is provided for efficient deployment:
```bash
# Build the container
docker build -t llm-proxy .
# Run the container
docker run -p 8080:8080 \
-e SESSION_SECRET=your-secure-secret \
-e LLM_PROXY__ENCRYPTION_KEY=your-secure-key \
-v ./data:/app/data \
llm-proxy
```
## Management Dashboard
Access the dashboard at `http://localhost:8080`. The dashboard architecture has been refactored into modular sub-components for better maintainability:
Access the dashboard at `http://localhost:8080`.
- **Auth (`/api/auth`):** Login, session management, and password changes.
- **Usage (`/api/usage`):** Summary stats, time-series analytics, and provider breakdown.
- **Clients (`/api/clients`):** API key management and per-client usage tracking.
- **Providers (`/api/providers`):** Provider configuration, status monitoring, and connection testing.
- **System (`/api/system`):** Health metrics, live logs, database backups, and global settings.
- **Auth:** Login, session management, and status tracking.
- **Usage:** Summary stats, time-series analytics, and provider breakdown.
- **Clients:** API key management and per-client usage tracking.
- **Providers:** Provider configuration and status monitoring.
- **Users:** Admin-only user management for dashboard access.
- **Monitoring:** Live request stream via WebSocket.
### Default Credentials
- **Username:** `admin`
- **Password:** `admin123`
Change the admin password in the dashboard after first login!
- **Password:** `admin` (You will be prompted to change this or should change it manually in the dashboard)
## API Usage
@@ -131,4 +125,4 @@ response = client.chat.completions.create(
## License
MIT OR Apache-2.0
MIT

View File

@@ -1,58 +0,0 @@
# LLM Proxy Rust Backend Code Review Report
## Executive Summary
This code review examines the `llm-proxy` Rust backend, focusing on architectural soundness, performance characteristics, and adherence to Rust idioms. The codebase demonstrates solid engineering with well-structured modular design, comprehensive error handling, and thoughtful async patterns. However, several areas require attention for production readiness, particularly around thread safety, memory efficiency, and error recovery.
## 1. Core Proxy Logic Review
### Strengths
- Clean provider abstraction (`Provider` trait).
- Streaming support with `AggregatingStream` for token counting.
- Model mapping and caching system.
### Issues Found
- **Provider Manager Thread Safety Risk:** O(n) lookups using `Vec` with `RwLock`. Use `DashMap` instead.
- **Streaming Memory Inefficiency:** Accumulates complete response content in memory.
- **Model Registry Cache Invalidation:** No strategy when config changes via dashboard.
## 2. State Management Review
- **Token Bucket Algorithm Flaw:** Custom implementation lacks thread-safe refill sync. Use `governor` crate.
- **Broadcast Channel Unbounded Growth Risk:** Fixed-size (100) may drop messages.
- **Database Connection Pool Contention:** SQLite connections shared without differentiation.
## 3. Error Handling Review
- **Error Recovery Missing:** No circuit breaker or retry logic for provider calls.
- **Stream Error Logging Gap:** Stream errors swallowed without logging partial usage.
## 4. Rust Idioms and Performance
- **Unnecessary String Cloning:** Frequent cloning in authentication hot paths.
- **JSON Parsing Inefficiency:** Multiple passes with `serde_json::Value`. Use typed structs.
- **Missing `#[derive(Copy)]`:** For small enums like `CircuitState`.
## 5. Async Performance
- **Blocking Calls:** Token estimation may block async runtime.
- **Missing Connection Timeouts:** Only overall timeout, no separate read/write timeouts.
- **Unbounded Task Spawn:** For client usage updates under load.
## 6. Security Considerations
- **Token Leakage:** Redact tokens in `Debug` and `Display` impls.
- **No Request Size Limits:** Vulnerable to memory exhaustion.
## 7. Testability
- **Mocking Difficulty:** Tight coupling to concrete provider implementations.
- **Missing Integration Tests:** No E2E tests for streaming.
## 8. Summary of Critical Actions
### High Priority
1. Replace custom token bucket with `governor` crate.
2. Fix provider lookup O(n) scaling issue.
3. Implement proper error recovery with retries.
4. Add request size limits and timeout configurations.
### Medium Priority
1. Reduce string cloning in hot paths.
2. Implement cache invalidation for model configs.
3. Add connection pooling separation.
4. Improve streaming memory efficiency.
---
*Review conducted by: Senior Principal Engineer (Code Reviewer)*

48
TODO.md Normal file
View File

@@ -0,0 +1,48 @@
# Migration TODO List
## Completed Tasks
- [x] Initial Go project setup
- [x] Database schema & migrations
- [x] Configuration loader (Viper)
- [x] Auth Middleware
- [x] Basic Provider implementations (OpenAI, Gemini, DeepSeek, Grok)
- [x] Streaming Support (SSE & Gemini custom streaming)
- [x] Move Rust files to `rust_backup`
- [x] Enhanced `helpers.go` for Multimodal & Tool Calling (OpenAI compatible)
- [x] Enhanced `server.go` for robust request conversion
- [x] Dashboard Management APIs (Clients, Tokens, Users, Providers)
- [x] Dashboard Analytics & Usage Summary
- [x] WebSocket for real-time dashboard updates
## Feature Parity Checklist (High Priority)
### OpenAI Provider
- [x] Tool Calling
- [x] Multimodal (Images) support
- [ ] Reasoning Content (CoT) support for `o1`, `o3` (need to ensure it's parsed in responses)
- [ ] Support for `/v1/responses` API (required for some gpt-5/o1 models)
### Gemini Provider
- [x] Tool Calling (mapping to Gemini format)
- [x] Multimodal (Images) support
- [x] Reasoning/Thought support
- [x] Handle Tool Response role in unified format
### DeepSeek Provider
- [x] Reasoning Content (CoT) support
- [x] Parameter sanitization for `deepseek-reasoner`
- [x] Tool Calling support
### Grok Provider
- [x] Tool Calling support
- [x] Multimodal support
## Infrastructure & Middleware
- [ ] Implement Request Logging to SQLite (asynchronous)
- [ ] Implement Rate Limiting (`golang.org/x/time/rate`)
- [ ] Implement Circuit Breaker (`github.com/sony/gobreaker`)
- [ ] Implement Model Cost Calculation logic
## Verification
- [ ] Unit tests for feature-specific mapping (CoT, Tools, Images)
- [ ] Integration tests with live LLM APIs

View File

@@ -1 +0,0 @@
too-many-arguments-threshold = 8

39
cmd/llm-proxy/main.go Normal file
View File

@@ -0,0 +1,39 @@
package main
import (
"log"
"llm-proxy/internal/config"
"llm-proxy/internal/db"
"llm-proxy/internal/server"
"github.com/joho/godotenv"
)
func main() {
// Load environment variables
if err := godotenv.Load(); err != nil {
log.Println("No .env file found")
}
// Load configuration
cfg, err := config.Load()
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
// Initialize database
database, err := db.Init(cfg.Database.Path)
if err != nil {
log.Fatalf("Failed to initialize database: %v", err)
}
// Initialize server
s := server.NewServer(cfg, database)
// Run server
log.Printf("Starting LLM Proxy on %s:%d", cfg.Server.Host, cfg.Server.Port)
if err := s.Run(); err != nil {
log.Fatalf("Server failed: %v", err)
}
}

View File

@@ -1,322 +1,52 @@
# LLM Proxy Gateway - Deployment Guide
# Deployment Guide (Go)
## Overview
A unified LLM proxy gateway supporting OpenAI, Google Gemini, DeepSeek, and xAI Grok with token tracking, cost calculation, and admin dashboard.
This guide covers deploying the Go-based LLM Proxy Gateway.
## System Requirements
- **CPU**: 2 cores minimum
- **RAM**: 512MB minimum (1GB recommended)
- **Storage**: 10GB minimum
- **OS**: Linux (tested on Arch Linux, Ubuntu, Debian)
- **Runtime**: Rust 1.70+ with Cargo
## Environment Setup
## Deployment Options
### Option 1: Docker (Recommended)
```dockerfile
FROM rust:1.70-alpine as builder
WORKDIR /app
COPY . .
RUN cargo build --release
FROM alpine:latest
RUN apk add --no-cache libgcc
COPY --from=builder /app/target/release/llm-proxy /usr/local/bin/
COPY --from=builder /app/static /app/static
WORKDIR /app
EXPOSE 8080
CMD ["llm-proxy"]
```
### Option 2: Systemd Service (Bare Metal/LXC)
```ini
# /etc/systemd/system/llm-proxy.service
[Unit]
Description=LLM Proxy Gateway
After=network.target
[Service]
Type=simple
User=llmproxy
Group=llmproxy
WorkingDirectory=/opt/llm-proxy
ExecStart=/opt/llm-proxy/llm-proxy
Restart=always
RestartSec=10
Environment="RUST_LOG=info"
Environment="LLM_PROXY__SERVER__PORT=8080"
Environment="LLM_PROXY__SERVER__AUTH_TOKENS=sk-test-123,sk-test-456"
[Install]
WantedBy=multi-user.target
```
### Option 3: LXC Container (Proxmox)
1. Create Alpine Linux LXC container
2. Install Rust: `apk add rust cargo`
3. Copy application files
4. Build: `cargo build --release`
5. Run: `./target/release/llm-proxy`
## Configuration
### Environment Variables
```bash
# Required API Keys
OPENAI_API_KEY=sk-...
GEMINI_API_KEY=AIza...
DEEPSEEK_API_KEY=sk-...
GROK_API_KEY=gk-... # Optional
# Server Configuration (with LLM_PROXY__ prefix)
LLM_PROXY__SERVER__PORT=8080
LLM_PROXY__SERVER__HOST=0.0.0.0
LLM_PROXY__SERVER__AUTH_TOKENS=sk-test-123,sk-test-456
# Database Configuration
LLM_PROXY__DATABASE__PATH=./data/llm_proxy.db
LLM_PROXY__DATABASE__MAX_CONNECTIONS=10
# Provider Configuration
LLM_PROXY__PROVIDERS__OPENAI__ENABLED=true
LLM_PROXY__PROVIDERS__GEMINI__ENABLED=true
LLM_PROXY__PROVIDERS__DEEPSEEK__ENABLED=true
LLM_PROXY__PROVIDERS__GROK__ENABLED=false
```
### Configuration File (config.toml)
Create `config.toml` in the application directory:
```toml
[server]
port = 8080
host = "0.0.0.0"
auth_tokens = ["sk-test-123", "sk-test-456"]
[database]
path = "./data/llm_proxy.db"
max_connections = 10
[providers.openai]
enabled = true
base_url = "https://api.openai.com/v1"
default_model = "gpt-4o"
[providers.gemini]
enabled = true
base_url = "https://generativelanguage.googleapis.com/v1"
default_model = "gemini-2.0-flash"
[providers.deepseek]
enabled = true
base_url = "https://api.deepseek.com"
default_model = "deepseek-reasoner"
[providers.grok]
enabled = false
base_url = "https://api.x.ai/v1"
default_model = "grok-beta"
```
## Nginx Reverse Proxy Configuration
**Important for SSE/Streaming:** Disable buffering and configure timeouts for proper SSE support.
```nginx
server {
listen 80;
server_name llm-proxy.yourdomain.com;
location / {
proxy_pass http://localhost:8080;
proxy_http_version 1.1;
# SSE/Streaming support
proxy_buffering off;
chunked_transfer_encoding on;
proxy_set_header Connection '';
# Timeouts for long-running streams
proxy_connect_timeout 7200s;
proxy_read_timeout 7200s;
proxy_send_timeout 7200s;
# Disable gzip for streaming
gzip off;
# Headers
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# SSL configuration (recommended)
listen 443 ssl http2;
ssl_certificate /etc/letsencrypt/live/llm-proxy.yourdomain.com/fullchain.pem;
ssl_certificate_key /etc/letsencrypt/live/llm-proxy.yourdomain.com/privkey.pem;
}
```
### NGINX Proxy Manager
If using NGINX Proxy Manager, add this to **Advanced Settings**:
```nginx
proxy_buffering off;
proxy_http_version 1.1;
proxy_set_header Connection '';
chunked_transfer_encoding on;
proxy_connect_timeout 7200s;
proxy_read_timeout 7200s;
proxy_send_timeout 7200s;
gzip off;
```
## Security Considerations
### 1. Authentication
- Use strong Bearer tokens
- Rotate tokens regularly
- Consider implementing JWT for production
### 2. Rate Limiting
- Implement per-client rate limiting
- Consider using `governor` crate for advanced rate limiting
### 3. Network Security
- Run behind reverse proxy (nginx)
- Enable HTTPS
- Restrict access by IP if needed
- Use firewall rules
### 4. Data Security
- Database encryption (SQLCipher for SQLite)
- Secure API key storage
- Regular backups
## Monitoring & Maintenance
### Logging
- Application logs: `RUST_LOG=info` (or `debug` for troubleshooting)
- Access logs via nginx
- Database logs for audit trail
### Health Checks
```bash
# Health endpoint
curl http://localhost:8080/health
# Database check
sqlite3 ./data/llm_proxy.db "SELECT COUNT(*) FROM llm_requests;"
```
### Backup Strategy
```bash
#!/bin/bash
# backup.sh
BACKUP_DIR="/backups/llm-proxy"
DATE=$(date +%Y%m%d_%H%M%S)
# Backup database
sqlite3 ./data/llm_proxy.db ".backup $BACKUP_DIR/llm_proxy_$DATE.db"
# Backup configuration
cp config.toml $BACKUP_DIR/config_$DATE.toml
# Rotate old backups (keep 30 days)
find $BACKUP_DIR -name "*.db" -mtime +30 -delete
find $BACKUP_DIR -name "*.toml" -mtime +30 -delete
```
## Performance Tuning
### Database Optimization
```sql
-- Run these SQL commands periodically
VACUUM;
ANALYZE;
```
### Memory Management
- Monitor memory usage with `htop` or `ps aux`
- Adjust `max_connections` based on load
- Consider connection pooling for high traffic
### Scaling
1. **Vertical Scaling**: Increase container resources
2. **Horizontal Scaling**: Deploy multiple instances behind load balancer
3. **Database**: Migrate to PostgreSQL for high-volume usage
## Troubleshooting
### Common Issues
1. **Port already in use**
1. **Mandatory Configuration:**
Create a `.env` file from the example:
```bash
netstat -tulpn | grep :8080
kill <PID> # or change port in config
cp .env.example .env
```
Ensure `LLM_PROXY__ENCRYPTION_KEY` is set to a secure 32-byte string.
2. **Database permissions**
```bash
chown -R llmproxy:llmproxy /opt/llm-proxy/data
chmod 600 /opt/llm-proxy/data/llm_proxy.db
```
2. **Data Directory:**
The proxy stores its database in `./data/llm_proxy.db` by default. Ensure this directory exists and is writable.
3. **API key errors**
- Verify environment variables are set
- Check provider status (dashboard)
- Test connectivity: `curl https://api.openai.com/v1/models`
## Binary Deployment
4. **High memory usage**
- Check for memory leaks
- Reduce `max_connections`
- Implement connection timeouts
### Debug Mode
### 1. Build
```bash
# Run with debug logging
RUST_LOG=debug ./llm-proxy
# Check system logs
journalctl -u llm-proxy -f
go build -o llm-proxy ./cmd/llm-proxy
```
## Integration
### Open-WebUI Compatibility
The proxy provides OpenAI-compatible API, so configure Open-WebUI:
```
API Base URL: http://your-proxy-address:8080
API Key: sk-test-123 (or your configured token)
### 2. Run
```bash
./llm-proxy
```
### Custom Clients
```python
import openai
## Docker Deployment
client = openai.OpenAI(
base_url="http://localhost:8080/v1",
api_key="sk-test-123"
)
The project includes a multi-stage `Dockerfile` for minimal image size.
response = client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": "Hello"}]
)
### 1. Build Image
```bash
docker build -t llm-proxy .
```
## Updates & Upgrades
### 2. Run Container
```bash
docker run -d \
--name llm-proxy \
-p 8080:8080 \
-v $(pwd)/data:/app/data \
--env-file .env \
llm-proxy
```
1. **Backup** current configuration and database
2. **Stop** the service: `systemctl stop llm-proxy`
3. **Update** code: `git pull` or copy new binaries
4. **Migrate** database if needed (check migrations/)
5. **Restart**: `systemctl start llm-proxy`
6. **Verify**: Check logs and test endpoints
## Production Considerations
## Support
- Check logs in `/var/log/llm-proxy/`
- Monitor dashboard at `http://your-server:8080`
- Review database metrics in dashboard
- Enable debug logging for troubleshooting
- **SSL/TLS:** It is recommended to run the proxy behind a reverse proxy like Nginx or Caddy for SSL termination.
- **Backups:** Regularly backup the `data/llm_proxy.db` file.
- **Monitoring:** Monitor the `/health` endpoint for system status.

62
go.mod Normal file
View File

@@ -0,0 +1,62 @@
module llm-proxy
go 1.26.1
require (
github.com/gin-gonic/gin v1.12.0
github.com/go-resty/resty/v2 v2.17.2
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/jmoiron/sqlx v1.4.0
github.com/joho/godotenv v1.5.1
github.com/spf13/viper v1.21.0
golang.org/x/crypto v0.48.0
modernc.org/sqlite v1.47.0
)
require (
github.com/bytedance/gopkg v0.1.3 // indirect
github.com/bytedance/sonic v1.15.0 // indirect
github.com/bytedance/sonic/loader v0.5.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.30.1 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/goccy/go-yaml v1.19.2 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.59.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/sagikazarmark/locafero v0.11.0 // indirect
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.3.1 // indirect
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/arch v0.22.0 // indirect
golang.org/x/net v0.51.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.34.0 // indirect
golang.org/x/time v0.15.0 // indirect
google.golang.org/protobuf v1.36.10 // indirect
modernc.org/libc v1.70.0 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
)

183
go.sum Normal file
View File

@@ -0,0 +1,183 @@
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8=
github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w=
github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM=
github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk=
github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY=
github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI=
golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.32.0 h1:hjG66bI/kqIPX1b2yT6fr/jt+QedtP2fqojG2VrFuVw=
modernc.org/ccgo/v4 v4.32.0/go.mod h1:6F08EBCx5uQc38kMGl+0Nm0oWczoo1c7cgpzEry7Uc0=
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw=
modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.47.0 h1:R1XyaNpoW4Et9yly+I2EeX7pBza/w+pmYee/0HJDyKk=
modernc.org/sqlite v1.47.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=

174
internal/config/config.go Normal file
View File

@@ -0,0 +1,174 @@
package config
import (
"encoding/base64"
"encoding/hex"
"fmt"
"os"
"strings"
"github.com/spf13/viper"
)
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Providers ProviderConfig `mapstructure:"providers"`
EncryptionKey string `mapstructure:"encryption_key"`
KeyBytes []byte
}
type ServerConfig struct {
Port int `mapstructure:"port"`
Host string `mapstructure:"host"`
AuthTokens []string `mapstructure:"auth_tokens"`
}
type DatabaseConfig struct {
Path string `mapstructure:"path"`
MaxConnections int `mapstructure:"max_connections"`
}
type ProviderConfig struct {
OpenAI OpenAIConfig `mapstructure:"openai"`
Gemini GeminiConfig `mapstructure:"gemini"`
DeepSeek DeepSeekConfig `mapstructure:"deepseek"`
Grok GrokConfig `mapstructure:"grok"`
Ollama OllamaConfig `mapstructure:"ollama"`
}
type OpenAIConfig struct {
APIKeyEnv string `mapstructure:"api_key_env"`
BaseURL string `mapstructure:"base_url"`
DefaultModel string `mapstructure:"default_model"`
Enabled bool `mapstructure:"enabled"`
}
type GeminiConfig struct {
APIKeyEnv string `mapstructure:"api_key_env"`
BaseURL string `mapstructure:"base_url"`
DefaultModel string `mapstructure:"default_model"`
Enabled bool `mapstructure:"enabled"`
}
type DeepSeekConfig struct {
APIKeyEnv string `mapstructure:"api_key_env"`
BaseURL string `mapstructure:"base_url"`
DefaultModel string `mapstructure:"default_model"`
Enabled bool `mapstructure:"enabled"`
}
type GrokConfig struct {
APIKeyEnv string `mapstructure:"api_key_env"`
BaseURL string `mapstructure:"base_url"`
DefaultModel string `mapstructure:"default_model"`
Enabled bool `mapstructure:"enabled"`
}
type OllamaConfig struct {
BaseURL string `mapstructure:"base_url"`
Enabled bool `mapstructure:"enabled"`
DefaultModel string `mapstructure:"default_model"`
Models []string `mapstructure:"models"`
}
func Load() (*Config, error) {
v := viper.New()
// Defaults
v.SetDefault("server.port", 8080)
v.SetDefault("server.host", "0.0.0.0")
v.SetDefault("server.auth_tokens", []string{})
v.SetDefault("database.path", "./data/llm_proxy.db")
v.SetDefault("database.max_connections", 10)
v.SetDefault("providers.openai.api_key_env", "OPENAI_API_KEY")
v.SetDefault("providers.openai.base_url", "https://api.openai.com/v1")
v.SetDefault("providers.openai.default_model", "gpt-4o")
v.SetDefault("providers.openai.enabled", true)
v.SetDefault("providers.gemini.api_key_env", "GEMINI_API_KEY")
v.SetDefault("providers.gemini.base_url", "https://generativelanguage.googleapis.com/v1")
v.SetDefault("providers.gemini.default_model", "gemini-2.0-flash")
v.SetDefault("providers.gemini.enabled", true)
v.SetDefault("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")
v.SetDefault("providers.deepseek.base_url", "https://api.deepseek.com")
v.SetDefault("providers.deepseek.default_model", "deepseek-reasoner")
v.SetDefault("providers.deepseek.enabled", true)
v.SetDefault("providers.grok.api_key_env", "GROK_API_KEY")
v.SetDefault("providers.grok.base_url", "https://api.x.ai/v1")
v.SetDefault("providers.grok.default_model", "grok-beta")
v.SetDefault("providers.grok.enabled", true)
v.SetDefault("providers.ollama.base_url", "http://localhost:11434/v1")
v.SetDefault("providers.ollama.enabled", false)
v.SetDefault("providers.ollama.models", []string{})
// Environment variables
v.SetEnvPrefix("LLM_PROXY")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "__"))
v.AutomaticEnv()
// Config file
v.SetConfigName("config")
v.SetConfigType("toml")
v.AddConfigPath(".")
if envPath := os.Getenv("LLM_PROXY__CONFIG_PATH"); envPath != "" {
v.SetConfigFile(envPath)
}
if err := v.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
}
var cfg Config
if err := v.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
// Validate encryption key
if cfg.EncryptionKey == "" {
return nil, fmt.Errorf("encryption key is required (LLM_PROXY__ENCRYPTION_KEY)")
}
keyBytes, err := hex.DecodeString(cfg.EncryptionKey)
if err != nil {
keyBytes, err = base64.StdEncoding.DecodeString(cfg.EncryptionKey)
if err != nil {
return nil, fmt.Errorf("encryption key must be hex or base64 encoded")
}
}
if len(keyBytes) != 32 {
return nil, fmt.Errorf("encryption key must be 32 bytes, got %d", len(keyBytes))
}
cfg.KeyBytes = keyBytes
return &cfg, nil
}
func (c *Config) GetAPIKey(provider string) (string, error) {
var envVar string
switch provider {
case "openai":
envVar = c.Providers.OpenAI.APIKeyEnv
case "gemini":
envVar = c.Providers.Gemini.APIKeyEnv
case "deepseek":
envVar = c.Providers.DeepSeek.APIKeyEnv
case "grok":
envVar = c.Providers.Grok.APIKeyEnv
default:
return "", fmt.Errorf("unknown provider: %s", provider)
}
val := os.Getenv(envVar)
if val == "" {
return "", fmt.Errorf("environment variable %s not set for %s", envVar, provider)
}
return val, nil
}

264
internal/db/db.go Normal file
View File

@@ -0,0 +1,264 @@
package db
import (
"fmt"
"log"
"os"
"path/filepath"
"time"
"github.com/jmoiron/sqlx"
_ "modernc.org/sqlite"
"golang.org/x/crypto/bcrypt"
)
type DB struct {
*sqlx.DB
}
func Init(path string) (*DB, error) {
// Ensure directory exists
dir := filepath.Dir(path)
if dir != "." && dir != "/" {
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create database directory: %w", err)
}
}
// Connect to SQLite
dsn := fmt.Sprintf("file:%s?_pragma=foreign_keys(1)", path)
db, err := sqlx.Connect("sqlite", dsn)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
instance := &DB{db}
// Run migrations
if err := instance.RunMigrations(); err != nil {
return nil, fmt.Errorf("failed to run migrations: %w", err)
}
return instance, nil
}
func (db *DB) RunMigrations() error {
// Tables creation
queries := []string{
`CREATE TABLE IF NOT EXISTS clients (
id INTEGER PRIMARY KEY AUTOINCREMENT,
client_id TEXT UNIQUE NOT NULL,
name TEXT,
description TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
is_active BOOLEAN DEFAULT TRUE,
rate_limit_per_minute INTEGER DEFAULT 60,
total_requests INTEGER DEFAULT 0,
total_tokens INTEGER DEFAULT 0,
total_cost REAL DEFAULT 0.0
)`,
`CREATE TABLE IF NOT EXISTS llm_requests (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
client_id TEXT,
provider TEXT,
model TEXT,
prompt_tokens INTEGER,
completion_tokens INTEGER,
reasoning_tokens INTEGER DEFAULT 0,
total_tokens INTEGER,
cost REAL,
has_images BOOLEAN DEFAULT FALSE,
status TEXT DEFAULT 'success',
error_message TEXT,
duration_ms INTEGER,
request_body TEXT,
response_body TEXT,
cache_read_tokens INTEGER DEFAULT 0,
cache_write_tokens INTEGER DEFAULT 0,
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE SET NULL
)`,
`CREATE TABLE IF NOT EXISTS provider_configs (
id TEXT PRIMARY KEY,
display_name TEXT NOT NULL,
enabled BOOLEAN DEFAULT TRUE,
base_url TEXT,
api_key TEXT,
credit_balance REAL DEFAULT 0.0,
low_credit_threshold REAL DEFAULT 5.0,
billing_mode TEXT,
api_key_encrypted BOOLEAN DEFAULT FALSE,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
`CREATE TABLE IF NOT EXISTS model_configs (
id TEXT PRIMARY KEY,
provider_id TEXT NOT NULL,
display_name TEXT,
enabled BOOLEAN DEFAULT TRUE,
prompt_cost_per_m REAL,
completion_cost_per_m REAL,
cache_read_cost_per_m REAL,
cache_write_cost_per_m REAL,
mapping TEXT,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE
)`,
`CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
display_name TEXT,
role TEXT DEFAULT 'admin',
must_change_password BOOLEAN DEFAULT FALSE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`,
`CREATE TABLE IF NOT EXISTS client_tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
client_id TEXT NOT NULL,
token TEXT NOT NULL UNIQUE,
name TEXT DEFAULT 'default',
is_active BOOLEAN DEFAULT TRUE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_used_at DATETIME,
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE CASCADE
)`,
}
for _, q := range queries {
if _, err := db.Exec(q); err != nil {
return fmt.Errorf("migration failed for query [%s]: %w", q, err)
}
}
// Add indices
indices := []string{
"CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)",
"CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)",
"CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)",
"CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)",
"CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)",
"CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)",
"CREATE UNIQUE INDEX IF NOT EXISTS idx_client_tokens_token ON client_tokens(token)",
"CREATE INDEX IF NOT EXISTS idx_client_tokens_client_id ON client_tokens(client_id)",
"CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp)",
"CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp)",
"CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id)",
}
for _, idx := range indices {
if _, err := db.Exec(idx); err != nil {
return fmt.Errorf("failed to create index [%s]: %w", idx, err)
}
}
// Default admin user
var count int
if err := db.Get(&count, "SELECT COUNT(*) FROM users"); err != nil {
return fmt.Errorf("failed to count users: %w", err)
}
if count == 0 {
hash, err := bcrypt.GenerateFromPassword([]byte("admin"), 12)
if err != nil {
return fmt.Errorf("failed to hash default password: %w", err)
}
_, err = db.Exec("INSERT INTO users (username, password_hash, role, must_change_password) VALUES ('admin', ?, 'admin', 1)", string(hash))
if err != nil {
return fmt.Errorf("failed to insert default admin: %w", err)
}
log.Println("Created default admin user with password 'admin' (must change on first login)")
}
// Default client
_, err := db.Exec(`INSERT OR IGNORE INTO clients (client_id, name, description)
VALUES ('default', 'Default Client', 'Default client for anonymous requests')`)
if err != nil {
return fmt.Errorf("failed to insert default client: %w", err)
}
return nil
}
// Data models for DB tables
type Client struct {
ID int `db:"id"`
ClientID string `db:"client_id"`
Name *string `db:"name"`
Description *string `db:"description"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
IsActive bool `db:"is_active"`
RateLimitPerMinute int `db:"rate_limit_per_minute"`
TotalRequests int `db:"total_requests"`
TotalTokens int `db:"total_tokens"`
TotalCost float64 `db:"total_cost"`
}
type LLMRequest struct {
ID int `db:"id"`
Timestamp time.Time `db:"timestamp"`
ClientID *string `db:"client_id"`
Provider *string `db:"provider"`
Model *string `db:"model"`
PromptTokens *int `db:"prompt_tokens"`
CompletionTokens *int `db:"completion_tokens"`
ReasoningTokens int `db:"reasoning_tokens"`
TotalTokens *int `db:"total_tokens"`
Cost *float64 `db:"cost"`
HasImages bool `db:"has_images"`
Status string `db:"status"`
ErrorMessage *string `db:"error_message"`
DurationMS *int `db:"duration_ms"`
RequestBody *string `db:"request_body"`
ResponseBody *string `db:"response_body"`
CacheReadTokens int `db:"cache_read_tokens"`
CacheWriteTokens int `db:"cache_write_tokens"`
}
type ProviderConfig struct {
ID string `db:"id"`
DisplayName string `db:"display_name"`
Enabled bool `db:"enabled"`
BaseURL *string `db:"base_url"`
APIKey *string `db:"api_key"`
CreditBalance float64 `db:"credit_balance"`
LowCreditThreshold float64 `db:"low_credit_threshold"`
BillingMode *string `db:"billing_mode"`
APIKeyEncrypted bool `db:"api_key_encrypted"`
UpdatedAt time.Time `db:"updated_at"`
}
type ModelConfig struct {
ID string `db:"id"`
ProviderID string `db:"provider_id"`
DisplayName *string `db:"display_name"`
Enabled bool `db:"enabled"`
PromptCostPerM *float64 `db:"prompt_cost_per_m"`
CompletionCostPerM *float64 `db:"completion_cost_per_m"`
CacheReadCostPerM *float64 `db:"cache_read_cost_per_m"`
CacheWriteCostPerM *float64 `db:"cache_write_cost_per_m"`
Mapping *string `db:"mapping"`
UpdatedAt time.Time `db:"updated_at"`
}
type User struct {
ID int `db:"id"`
Username string `db:"username"`
PasswordHash string `db:"password_hash"`
DisplayName *string `db:"display_name"`
Role string `db:"role"`
MustChangePassword bool `db:"must_change_password"`
CreatedAt time.Time `db:"created_at"`
}
type ClientToken struct {
ID int `db:"id"`
ClientID string `db:"client_id"`
Token string `db:"token"`
Name string `db:"name"`
IsActive bool `db:"is_active"`
CreatedAt time.Time `db:"created_at"`
LastUsedAt *time.Time `db:"last_used_at"`
}

View File

@@ -0,0 +1,52 @@
package middleware
import (
"log"
"strings"
"llm-proxy/internal/db"
"llm-proxy/internal/models"
"github.com/gin-gonic/gin"
)
func AuthMiddleware(database *db.DB) gin.HandlerFunc {
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.Next()
return
}
token := strings.TrimPrefix(authHeader, "Bearer ")
if token == authHeader { // No "Bearer " prefix
c.Next()
return
}
// Try to resolve client from database
var clientID string
err := database.Get(&clientID, "UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = 1 RETURNING client_id", token)
if err == nil {
c.Set("auth", models.AuthInfo{
Token: token,
ClientID: clientID,
})
} else {
// Fallback to token-prefix derivation (matches Rust behavior)
prefixLen := len(token)
if prefixLen > 8 {
prefixLen = 8
}
clientID = "client_" + token[:prefixLen]
c.Set("auth", models.AuthInfo{
Token: token,
ClientID: clientID,
})
log.Printf("Token not found in DB, using fallback client ID: %s", clientID)
}
c.Next()
}
}

216
internal/models/models.go Normal file
View File

@@ -0,0 +1,216 @@
package models
import (
"encoding/base64"
"encoding/json"
"fmt"
"github.com/go-resty/resty/v2"
)
// OpenAI-compatible Request/Response Structs
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK *uint32 `json:"top_k,omitempty"`
N *uint32 `json:"n,omitempty"`
Stop json.RawMessage `json:"stop,omitempty"` // Can be string or array of strings
MaxTokens *uint32 `json:"max_tokens,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
Stream *bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
}
type ChatMessage struct {
Role string `json:"role"` // "system", "user", "assistant", "tool"
Content interface{} `json:"content"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Name *string `json:"name,omitempty"`
ToolCallID *string `json:"tool_call_id,omitempty"`
}
type ContentPart struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ImageUrl *ImageUrl `json:"image_url,omitempty"`
}
type ImageUrl struct {
URL string `json:"url"`
Detail *string `json:"detail,omitempty"`
}
// Tool-Calling Types
type Tool struct {
Type string `json:"type"`
Function FunctionDef `json:"function"`
}
type FunctionDef struct {
Name string `json:"name"`
Description *string `json:"description,omitempty"`
Parameters json.RawMessage `json:"parameters,omitempty"`
}
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function FunctionCall `json:"function"`
}
type FunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type ToolCallDelta struct {
Index uint32 `json:"index"`
ID *string `json:"id,omitempty"`
Type *string `json:"type,omitempty"`
Function *FunctionCallDelta `json:"function,omitempty"`
}
type FunctionCallDelta struct {
Name *string `json:"name,omitempty"`
Arguments *string `json:"arguments,omitempty"`
}
// OpenAI-compatible Response Structs
type ChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatChoice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
}
type ChatChoice struct {
Index uint32 `json:"index"`
Message ChatMessage `json:"message"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type Usage struct {
PromptTokens uint32 `json:"prompt_tokens"`
CompletionTokens uint32 `json:"completion_tokens"`
TotalTokens uint32 `json:"total_tokens"`
ReasoningTokens *uint32 `json:"reasoning_tokens,omitempty"`
CacheReadTokens *uint32 `json:"cache_read_tokens,omitempty"`
CacheWriteTokens *uint32 `json:"cache_write_tokens,omitempty"`
}
// Streaming Response Structs
type ChatCompletionStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatStreamChoice `json:"choices"`
Usage *Usage `json:"usage,omitempty"`
}
type ChatStreamChoice struct {
Index uint32 `json:"index"`
Delta ChatStreamDelta `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type ChatStreamDelta struct {
Role *string `json:"role,omitempty"`
Content *string `json:"content,omitempty"`
ReasoningContent *string `json:"reasoning_content,omitempty"`
ToolCalls []ToolCallDelta `json:"tool_calls,omitempty"`
}
type StreamUsage struct {
PromptTokens uint32 `json:"prompt_tokens"`
CompletionTokens uint32 `json:"completion_tokens"`
TotalTokens uint32 `json:"total_tokens"`
ReasoningTokens uint32 `json:"reasoning_tokens"`
CacheReadTokens uint32 `json:"cache_read_tokens"`
CacheWriteTokens uint32 `json:"cache_write_tokens"`
}
// Unified Request Format (for internal use)
type UnifiedRequest struct {
ClientID string
Model string
Messages []UnifiedMessage
Temperature *float64
TopP *float64
TopK *uint32
N *uint32
Stop []string
MaxTokens *uint32
PresencePenalty *float64
FrequencyPenalty *float64
Stream bool
HasImages bool
Tools []Tool
ToolChoice json.RawMessage
}
type UnifiedMessage struct {
Role string
Content []UnifiedContentPart
ReasoningContent *string
ToolCalls []ToolCall
Name *string
ToolCallID *string
}
type UnifiedContentPart struct {
Type string
Text string
Image *ImageInput
}
type ImageInput struct {
Base64 string `json:"base64,omitempty"`
URL string `json:"url,omitempty"`
MimeType string `json:"mime_type,omitempty"`
}
func (i *ImageInput) ToBase64() (string, string, error) {
if i.Base64 != "" {
return i.Base64, i.MimeType, nil
}
if i.URL != "" {
client := resty.New()
resp, err := client.R().Get(i.URL)
if err != nil {
return "", "", fmt.Errorf("failed to fetch image: %w", err)
}
if !resp.IsSuccess() {
return "", "", fmt.Errorf("failed to fetch image: HTTP %d", resp.StatusCode())
}
mimeType := resp.Header().Get("Content-Type")
if mimeType == "" {
mimeType = "image/jpeg"
}
encoded := base64.StdEncoding.EncodeToString(resp.Body())
return encoded, mimeType, nil
}
return "", "", fmt.Errorf("empty image input")
}
// AuthInfo for context
type AuthInfo struct {
Token string
ClientID string
}

View File

@@ -0,0 +1,143 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"llm-proxy/internal/config"
"llm-proxy/internal/models"
"github.com/go-resty/resty/v2"
)
type DeepSeekProvider struct {
client *resty.Client
config config.DeepSeekConfig
apiKey string
}
func NewDeepSeekProvider(cfg config.DeepSeekConfig, apiKey string) *DeepSeekProvider {
return &DeepSeekProvider{
client: resty.New(),
config: cfg,
apiKey: apiKey,
}
}
func (p *DeepSeekProvider) Name() string {
return "deepseek"
}
func (p *DeepSeekProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
body := BuildOpenAIBody(req, messagesJSON, false)
// Sanitize for deepseek-reasoner
if req.Model == "deepseek-reasoner" {
delete(body, "temperature")
delete(body, "top_p")
delete(body, "presence_penalty")
delete(body, "frequency_penalty")
// Ensure assistant messages have content and reasoning_content
if msgs, ok := body["messages"].([]interface{}); ok {
for _, m := range msgs {
if msg, ok := m.(map[string]interface{}); ok {
if msg["role"] == "assistant" {
if msg["reasoning_content"] == nil {
msg["reasoning_content"] = " "
}
if msg["content"] == nil || msg["content"] == "" {
msg["content"] = ""
}
}
}
}
}
}
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetBody(body).
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
}
var respJSON map[string]interface{}
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return ParseOpenAIResponse(respJSON, req.Model)
}
func (p *DeepSeekProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
body := BuildOpenAIBody(req, messagesJSON, true)
// Sanitize for deepseek-reasoner
if req.Model == "deepseek-reasoner" {
delete(body, "temperature")
delete(body, "top_p")
delete(body, "presence_penalty")
delete(body, "frequency_penalty")
// Ensure assistant messages have content and reasoning_content
if msgs, ok := body["messages"].([]interface{}); ok {
for _, m := range msgs {
if msg, ok := m.(map[string]interface{}); ok {
if msg["role"] == "assistant" {
if msg["reasoning_content"] == nil {
msg["reasoning_content"] = " "
}
if msg["content"] == nil || msg["content"] == "" {
msg["content"] = ""
}
}
}
}
}
}
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetBody(body).
SetDoNotParseResponse(true).
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("DeepSeek API error (%d): %s", resp.StatusCode(), resp.String())
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
err := StreamOpenAI(resp.RawBody(), ch)
if err != nil {
fmt.Printf("DeepSeek Stream error: %v\n", err)
}
}()
return ch, nil
}

View File

@@ -0,0 +1,254 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"llm-proxy/internal/config"
"llm-proxy/internal/models"
"github.com/go-resty/resty/v2"
)
type GeminiProvider struct {
client *resty.Client
config config.GeminiConfig
apiKey string
}
func NewGeminiProvider(cfg config.GeminiConfig, apiKey string) *GeminiProvider {
return &GeminiProvider{
client: resty.New(),
config: cfg,
apiKey: apiKey,
}
}
func (p *GeminiProvider) Name() string {
return "gemini"
}
type GeminiRequest struct {
Contents []GeminiContent `json:"contents"`
}
type GeminiContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *GeminiFunctionCall `json:"functionCall,omitempty"`
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
}
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type GeminiFunctionCall struct {
Name string `json:"name"`
Args json.RawMessage `json:"args"`
}
type GeminiFunctionResponse struct {
Name string `json:"name"`
Response json.RawMessage `json:"response"`
}
func (p *GeminiProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
// Gemini mapping
var contents []GeminiContent
for _, msg := range req.Messages {
role := "user"
if msg.Role == "assistant" {
role = "model"
} else if msg.Role == "tool" {
role = "user" // Tool results are user-side in Gemini
}
var parts []GeminiPart
// Handle tool responses
if msg.Role == "tool" {
text := ""
if len(msg.Content) > 0 {
text = msg.Content[0].Text
}
// Gemini expects functionResponse to be an object
name := "unknown_function"
if msg.Name != nil {
name = *msg.Name
}
parts = append(parts, GeminiPart{
FunctionResponse: &GeminiFunctionResponse{
Name: name,
Response: json.RawMessage(text),
},
})
} else {
for _, cp := range msg.Content {
if cp.Type == "text" {
parts = append(parts, GeminiPart{Text: cp.Text})
} else if cp.Image != nil {
base64Data, mimeType, _ := cp.Image.ToBase64()
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
},
})
}
}
// Handle assistant tool calls
if msg.Role == "assistant" && len(msg.ToolCalls) > 0 {
for _, tc := range msg.ToolCalls {
parts = append(parts, GeminiPart{
FunctionCall: &GeminiFunctionCall{
Name: tc.Function.Name,
Args: json.RawMessage(tc.Function.Arguments),
},
})
}
}
}
contents = append(contents, GeminiContent{
Role: role,
Parts: parts,
})
}
body := GeminiRequest{
Contents: contents,
}
url := fmt.Sprintf("%s/models/%s:generateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
resp, err := p.client.R().
SetContext(ctx).
SetBody(body).
Post(url)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
}
// Parse Gemini response and convert to OpenAI format
var geminiResp struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"content"`
FinishReason string `json:"finishReason"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount uint32 `json:"promptTokenCount"`
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
TotalTokenCount uint32 `json:"totalTokenCount"`
} `json:"usageMetadata"`
}
if err := json.Unmarshal(resp.Body(), &geminiResp); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
if len(geminiResp.Candidates) == 0 {
return nil, fmt.Errorf("no candidates in Gemini response")
}
content := ""
for _, p := range geminiResp.Candidates[0].Content.Parts {
content += p.Text
}
openAIResp := &models.ChatCompletionResponse{
ID: "gemini-" + req.Model,
Object: "chat.completion",
Created: 0, // Should be current timestamp
Model: req.Model,
Choices: []models.ChatChoice{
{
Index: 0,
Message: models.ChatMessage{
Role: "assistant",
Content: content,
},
FinishReason: &geminiResp.Candidates[0].FinishReason,
},
},
Usage: &models.Usage{
PromptTokens: geminiResp.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiResp.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiResp.UsageMetadata.TotalTokenCount,
},
}
return openAIResp, nil
}
func (p *GeminiProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
// Simplified Gemini mapping
var contents []GeminiContent
for _, msg := range req.Messages {
role := "user"
if msg.Role == "assistant" {
role = "model"
}
var parts []GeminiPart
for _, p := range msg.Content {
parts = append(parts, GeminiPart{Text: p.Text})
}
contents = append(contents, GeminiContent{
Role: role,
Parts: parts,
})
}
body := GeminiRequest{
Contents: contents,
}
// Use streamGenerateContent for streaming
url := fmt.Sprintf("%s/models/%s:streamGenerateContent?key=%s", p.config.BaseURL, req.Model, p.apiKey)
resp, err := p.client.R().
SetContext(ctx).
SetBody(body).
SetDoNotParseResponse(true).
Post(url)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Gemini API error (%d): %s", resp.StatusCode(), resp.String())
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
err := StreamGemini(resp.RawBody(), ch, req.Model)
if err != nil {
fmt.Printf("Gemini Stream error: %v\n", err)
}
}()
return ch, nil
}

View File

@@ -0,0 +1,95 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"llm-proxy/internal/config"
"llm-proxy/internal/models"
"github.com/go-resty/resty/v2"
)
type GrokProvider struct {
client *resty.Client
config config.GrokConfig
apiKey string
}
func NewGrokProvider(cfg config.GrokConfig, apiKey string) *GrokProvider {
return &GrokProvider{
client: resty.New(),
config: cfg,
apiKey: apiKey,
}
}
func (p *GrokProvider) Name() string {
return "grok"
}
func (p *GrokProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
body := BuildOpenAIBody(req, messagesJSON, false)
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetBody(body).
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
}
var respJSON map[string]interface{}
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return ParseOpenAIResponse(respJSON, req.Model)
}
func (p *GrokProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
body := BuildOpenAIBody(req, messagesJSON, true)
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetBody(body).
SetDoNotParseResponse(true).
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("Grok API error (%d): %s", resp.StatusCode(), resp.String())
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
err := StreamOpenAI(resp.RawBody(), ch)
if err != nil {
fmt.Printf("Grok Stream error: %v\n", err)
}
}()
return ch, nil
}

View File

@@ -0,0 +1,259 @@
package providers
import (
"bufio"
"encoding/json"
"fmt"
"io"
"strings"
"llm-proxy/internal/models"
)
// MessagesToOpenAIJSON converts unified messages to OpenAI-compatible JSON, including tools and images.
func MessagesToOpenAIJSON(messages []models.UnifiedMessage) ([]interface{}, error) {
var result []interface{}
for _, m := range messages {
if m.Role == "tool" {
text := ""
if len(m.Content) > 0 {
text = m.Content[0].Text
}
msg := map[string]interface{}{
"role": "tool",
"content": text,
}
if m.ToolCallID != nil {
id := *m.ToolCallID
if len(id) > 40 {
id = id[:40]
}
msg["tool_call_id"] = id
}
if m.Name != nil {
msg["name"] = *m.Name
}
result = append(result, msg)
continue
}
var parts []interface{}
for _, p := range m.Content {
if p.Type == "text" {
parts = append(parts, map[string]interface{}{
"type": "text",
"text": p.Text,
})
} else if p.Image != nil {
base64Data, mimeType, err := p.Image.ToBase64()
if err != nil {
return nil, fmt.Errorf("failed to convert image to base64: %w", err)
}
parts = append(parts, map[string]interface{}{
"type": "image_url",
"image_url": map[string]interface{}{
"url": fmt.Sprintf("data:%s;base64,%s", mimeType, base64Data),
},
})
}
}
msg := map[string]interface{}{
"role": m.Role,
"content": parts,
}
if m.ReasoningContent != nil {
msg["reasoning_content"] = *m.ReasoningContent
}
if len(m.ToolCalls) > 0 {
sanitizedCalls := make([]models.ToolCall, len(m.ToolCalls))
copy(sanitizedCalls, m.ToolCalls)
for i := range sanitizedCalls {
if len(sanitizedCalls[i].ID) > 40 {
sanitizedCalls[i].ID = sanitizedCalls[i].ID[:40]
}
}
msg["tool_calls"] = sanitizedCalls
if len(parts) == 0 {
msg["content"] = ""
}
}
if m.Name != nil {
msg["name"] = *m.Name
}
result = append(result, msg)
}
return result, nil
}
func BuildOpenAIBody(request *models.UnifiedRequest, messagesJSON []interface{}, stream bool) map[string]interface{} {
body := map[string]interface{}{
"model": request.Model,
"messages": messagesJSON,
"stream": stream,
}
if stream {
body["stream_options"] = map[string]interface{}{
"include_usage": true,
}
}
if request.Temperature != nil {
body["temperature"] = *request.Temperature
}
if request.MaxTokens != nil {
body["max_tokens"] = *request.MaxTokens
}
if len(request.Tools) > 0 {
body["tools"] = request.Tools
}
if request.ToolChoice != nil {
var toolChoice interface{}
if err := json.Unmarshal(request.ToolChoice, &toolChoice); err == nil {
body["tool_choice"] = toolChoice
}
}
return body
}
func ParseOpenAIResponse(respJSON map[string]interface{}, model string) (*models.ChatCompletionResponse, error) {
data, err := json.Marshal(respJSON)
if err != nil {
return nil, err
}
var resp models.ChatCompletionResponse
if err := json.Unmarshal(data, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// Streaming support
func ParseOpenAIStreamChunk(line string) (*models.ChatCompletionStreamResponse, bool, error) {
if line == "" {
return nil, false, nil
}
if !strings.HasPrefix(line, "data: ") {
return nil, false, nil
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
return nil, true, nil
}
var chunk models.ChatCompletionStreamResponse
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
return nil, false, fmt.Errorf("failed to unmarshal stream chunk: %w", err)
}
return &chunk, false, nil
}
func StreamOpenAI(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse) error {
defer ctx.Close()
scanner := bufio.NewScanner(ctx)
for scanner.Scan() {
line := scanner.Text()
chunk, done, err := ParseOpenAIStreamChunk(line)
if err != nil {
return err
}
if done {
break
}
if chunk != nil {
ch <- chunk
}
}
return scanner.Err()
}
func StreamGemini(ctx io.ReadCloser, ch chan<- *models.ChatCompletionStreamResponse, model string) error {
defer ctx.Close()
dec := json.NewDecoder(ctx)
t, err := dec.Token()
if err != nil {
return err
}
if delim, ok := t.(json.Delim); ok && delim == '[' {
for dec.More() {
var geminiChunk struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text,omitempty"`
Thought string `json:"thought,omitempty"`
} `json:"parts"`
} `json:"content"`
FinishReason string `json:"finishReason"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount uint32 `json:"promptTokenCount"`
CandidatesTokenCount uint32 `json:"candidatesTokenCount"`
TotalTokenCount uint32 `json:"totalTokenCount"`
} `json:"usageMetadata"`
}
if err := dec.Decode(&geminiChunk); err != nil {
return err
}
if len(geminiChunk.Candidates) > 0 {
content := ""
var reasoning *string
for _, p := range geminiChunk.Candidates[0].Content.Parts {
if p.Text != "" {
content += p.Text
}
if p.Thought != "" {
if reasoning == nil {
reasoning = new(string)
}
*reasoning += p.Thought
}
}
finishReason := strings.ToLower(geminiChunk.Candidates[0].FinishReason)
if finishReason == "stop" {
finishReason = "stop"
}
ch <- &models.ChatCompletionStreamResponse{
ID: "gemini-stream",
Object: "chat.completion.chunk",
Created: 0,
Model: model,
Choices: []models.ChatStreamChoice{
{
Index: 0,
Delta: models.ChatStreamDelta{
Content: &content,
ReasoningContent: reasoning,
},
FinishReason: &finishReason,
},
},
Usage: &models.Usage{
PromptTokens: geminiChunk.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiChunk.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiChunk.UsageMetadata.TotalTokenCount,
},
}
}
}
}
return nil
}

View File

@@ -0,0 +1,113 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"strings"
"llm-proxy/internal/config"
"llm-proxy/internal/models"
"github.com/go-resty/resty/v2"
)
type OpenAIProvider struct {
client *resty.Client
config config.OpenAIConfig
apiKey string
}
func NewOpenAIProvider(cfg config.OpenAIConfig, apiKey string) *OpenAIProvider {
return &OpenAIProvider{
client: resty.New(),
config: cfg,
apiKey: apiKey,
}
}
func (p *OpenAIProvider) Name() string {
return "openai"
}
func (p *OpenAIProvider) ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
body := BuildOpenAIBody(req, messagesJSON, false)
// Transition: Newer models require max_completion_tokens
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
if maxTokens, ok := body["max_tokens"]; ok {
delete(body, "max_tokens")
body["max_completion_tokens"] = maxTokens
}
}
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetBody(body).
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
}
var respJSON map[string]interface{}
if err := json.Unmarshal(resp.Body(), &respJSON); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return ParseOpenAIResponse(respJSON, req.Model)
}
func (p *OpenAIProvider) ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error) {
messagesJSON, err := MessagesToOpenAIJSON(req.Messages)
if err != nil {
return nil, fmt.Errorf("failed to convert messages: %w", err)
}
body := BuildOpenAIBody(req, messagesJSON, true)
// Transition: Newer models require max_completion_tokens
if strings.HasPrefix(req.Model, "o1-") || strings.HasPrefix(req.Model, "o3-") || strings.Contains(req.Model, "gpt-5") {
if maxTokens, ok := body["max_tokens"]; ok {
delete(body, "max_tokens")
body["max_completion_tokens"] = maxTokens
}
}
resp, err := p.client.R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+p.apiKey).
SetBody(body).
SetDoNotParseResponse(true).
Post(fmt.Sprintf("%s/chat/completions", p.config.BaseURL))
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
if !resp.IsSuccess() {
return nil, fmt.Errorf("OpenAI API error (%d): %s", resp.StatusCode(), resp.String())
}
ch := make(chan *models.ChatCompletionStreamResponse)
go func() {
defer close(ch)
err := StreamOpenAI(resp.RawBody(), ch)
if err != nil {
// In a real app, you might want to send an error chunk or log it
fmt.Printf("Stream error: %v\n", err)
}
}()
return ch, nil
}

View File

@@ -0,0 +1,13 @@
package providers
import (
"context"
"llm-proxy/internal/models"
)
type Provider interface {
Name() string
ChatCompletion(ctx context.Context, req *models.UnifiedRequest) (*models.ChatCompletionResponse, error)
ChatCompletionStream(ctx context.Context, req *models.UnifiedRequest) (<-chan *models.ChatCompletionStreamResponse, error)
}

View File

@@ -0,0 +1,675 @@
package server
import (
"fmt"
"net/http"
"strings"
"time"
"llm-proxy/internal/db"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
type ApiResponse struct {
Success bool `json:"success"`
Data interface{} `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
func SuccessResponse(data interface{}) ApiResponse {
return ApiResponse{Success: true, Data: data}
}
func ErrorResponse(err string) ApiResponse {
return ApiResponse{Success: false, Error: err}
}
func (s *Server) adminAuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
if token == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse("Not authenticated"))
return
}
session, _, err := s.sessions.ValidateSession(token)
if err != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, ErrorResponse("Session expired or invalid"))
return
}
if session.Role != "admin" {
c.AbortWithStatusJSON(http.StatusForbidden, ErrorResponse("Admin access required"))
return
}
c.Set("session", session)
c.Next()
}
}
type LoginRequest struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}
func (s *Server) handleLogin(c *gin.Context) {
var req LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
return
}
var user db.User
err := s.database.Get(&user, "SELECT username, password_hash, display_name, role, must_change_password FROM users WHERE username = ?", req.Username)
if err != nil {
c.JSON(http.StatusUnauthorized, ErrorResponse("Invalid username or password"))
return
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password)); err != nil {
c.JSON(http.StatusUnauthorized, ErrorResponse("Invalid username or password"))
return
}
token, err := s.sessions.CreateSession(user.Username, user.Role)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse("Failed to create session"))
return
}
displayName := user.Username
if user.DisplayName != nil {
displayName = *user.DisplayName
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"token": token,
"must_change_password": user.MustChangePassword,
"user": gin.H{
"username": user.Username,
"name": displayName,
"role": user.Role,
},
}))
}
func (s *Server) handleAuthStatus(c *gin.Context) {
token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
session, _, err := s.sessions.ValidateSession(token)
if err != nil {
c.JSON(http.StatusUnauthorized, ErrorResponse("Not authenticated"))
return
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"authenticated": true,
"user": gin.H{
"username": session.Username,
"role": session.Role,
},
}))
}
func (s *Server) handleLogout(c *gin.Context) {
token := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
s.sessions.RevokeSession(token)
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Logged out"}))
}
type UsagePeriodFilter struct {
Period string `form:"period"`
From string `form:"from"`
To string `form:"to"`
}
func (f *UsagePeriodFilter) ToSQL() (string, []interface{}) {
period := f.Period
if period == "" {
period = "all"
}
if period == "custom" {
var clauses []string
var binds []interface{}
if f.From != "" {
clauses = append(clauses, "timestamp >= ?")
binds = append(binds, f.From)
}
if f.To != "" {
clauses = append(clauses, "timestamp <= ?")
binds = append(binds, f.To)
}
if len(clauses) > 0 {
return " AND " + strings.Join(clauses, " AND "), binds
}
return "", nil
}
now := time.Now().UTC()
var cutoff time.Time
switch period {
case "today":
cutoff = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
case "24h":
cutoff = now.Add(-24 * time.Hour)
case "7d":
cutoff = now.Add(-7 * 24 * time.Hour)
case "30d":
cutoff = now.Add(-30 * 24 * time.Hour)
default:
return "", nil
}
return " AND timestamp >= ?", []interface{}{cutoff.Format(time.RFC3339)}
}
func (s *Server) handleUsageSummary(c *gin.Context) {
var filter UsagePeriodFilter
if err := c.ShouldBindQuery(&filter); err != nil {
// ignore
}
clause, binds := filter.ToSQL()
query := fmt.Sprintf(`
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(total_tokens), 0) as total_tokens,
COALESCE(SUM(cost), 0.0) as total_cost,
COUNT(DISTINCT client_id) as active_clients
FROM llm_requests
WHERE 1=1 %s
`, clause)
var stats struct {
TotalRequests int `db:"total_requests"`
TotalTokens int `db:"total_tokens"`
TotalCost float64 `db:"total_cost"`
ActiveClients int `db:"active_clients"`
}
err := s.database.Get(&stats, query, binds...)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(stats))
}
func (s *Server) handleTimeSeries(c *gin.Context) {
var filter UsagePeriodFilter
if err := c.ShouldBindQuery(&filter); err != nil {
// ignore
}
clause, binds := filter.ToSQL()
if clause == "" {
cutoff := time.Now().UTC().Add(-30 * 24 * time.Hour)
clause = " AND timestamp >= ?"
binds = []interface{}{cutoff.Format(time.RFC3339)}
}
query := fmt.Sprintf(`
SELECT
strftime('%%Y-%%m-%%d', timestamp) as bucket,
COUNT(*) as requests,
COALESCE(SUM(total_tokens), 0) as tokens,
COALESCE(SUM(cost), 0.0) as cost
FROM llm_requests
WHERE 1=1 %s
GROUP BY bucket
ORDER BY bucket
`, clause)
var rows []struct {
Bucket string `db:"bucket"`
Requests int `db:"requests"`
Tokens int `db:"tokens"`
Cost float64 `db:"cost"`
}
err := s.database.Select(&rows, query, binds...)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
series := make([]gin.H, len(rows))
for i, r := range rows {
series[i] = gin.H{
"time": r.Bucket,
"requests": r.Requests,
"tokens": r.Tokens,
"cost": r.Cost,
}
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"series": series,
}))
}
func (s *Server) handleAnalyticsBreakdown(c *gin.Context) {
var filter UsagePeriodFilter
if err := c.ShouldBindQuery(&filter); err != nil {
// ignore
}
clause, binds := filter.ToSQL()
var models []struct {
Label string `db:"label"`
Value int `db:"value"`
}
err := s.database.Select(&models, fmt.Sprintf("SELECT model as label, COUNT(*) as value FROM llm_requests WHERE 1=1 %s GROUP BY model ORDER BY value DESC", clause), binds...)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
var clients []struct {
Label string `db:"label"`
Value int `db:"value"`
}
err = s.database.Select(&clients, fmt.Sprintf("SELECT client_id as label, COUNT(*) as value FROM llm_requests WHERE 1=1 %s GROUP BY client_id ORDER BY value DESC", clause), binds...)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"models": models,
"clients": clients,
}))
}
func (s *Server) handleGetClients(c *gin.Context) {
var clients []db.Client
err := s.database.Select(&clients, "SELECT * FROM clients ORDER BY created_at DESC")
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(clients))
}
type CreateClientRequest struct {
Name string `json:"name" binding:"required"`
ClientID *string `json:"client_id"`
}
func (s *Server) handleCreateClient(c *gin.Context) {
var req CreateClientRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
return
}
clientID := ""
if req.ClientID != nil {
clientID = *req.ClientID
} else {
clientID = "client-" + uuid.New().String()[:8]
}
_, err := s.database.Exec("INSERT INTO clients (client_id, name, is_active) VALUES (?, ?, 1)", clientID, req.Name)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
token := "sk-" + uuid.New().String() + uuid.New().String()
token = token[:51]
_, err = s.database.Exec("INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, 'default')", clientID, token)
if err != nil {
// Log error
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"id": clientID,
"name": req.Name,
"status": "active",
"token": token,
"created_at": time.Now(),
}))
}
func (s *Server) handleDeleteClient(c *gin.Context) {
id := c.Param("id")
if id == "default" {
c.JSON(http.StatusBadRequest, ErrorResponse("Cannot delete default client"))
return
}
_, err := s.database.Exec("DELETE FROM clients WHERE client_id = ?", id)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Client deleted"}))
}
func (s *Server) handleGetClientTokens(c *gin.Context) {
id := c.Param("id")
var tokens []db.ClientToken
err := s.database.Select(&tokens, "SELECT * FROM client_tokens WHERE client_id = ? ORDER BY created_at DESC", id)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
type MaskedToken struct {
ID int `json:"id"`
TokenMasked string `json:"token_masked"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
CreatedAt time.Time `json:"created_at"`
LastUsedAt *time.Time `json:"last_used_at"`
}
masked := make([]MaskedToken, len(tokens))
for i, t := range tokens {
maskedToken := "••••"
if len(t.Token) > 8 {
maskedToken = t.Token[:3] + "••••" + t.Token[len(t.Token)-8:]
}
masked[i] = MaskedToken{
ID: t.ID,
TokenMasked: maskedToken,
Name: t.Name,
IsActive: t.IsActive,
CreatedAt: t.CreatedAt,
LastUsedAt: t.LastUsedAt,
}
}
c.JSON(http.StatusOK, SuccessResponse(masked))
}
type CreateTokenRequest struct {
Name string `json:"name"`
}
func (s *Server) handleCreateClientToken(c *gin.Context) {
clientID := c.Param("id")
var req CreateTokenRequest
if err := c.ShouldBindJSON(&req); err != nil {
// optional name
}
name := "default"
if req.Name != "" {
name = req.Name
}
token := "sk-" + uuid.New().String() + uuid.New().String()
token = token[:51]
_, err := s.database.Exec("INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, ?)", clientID, token, name)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"token": token,
"name": name,
"created_at": time.Now(),
}))
}
func (s *Server) handleDeleteClientToken(c *gin.Context) {
tokenID := c.Param("token_id")
_, err := s.database.Exec("DELETE FROM client_tokens WHERE id = ?", tokenID)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Token revoked"}))
}
func (s *Server) handleGetProviders(c *gin.Context) {
var dbConfigs []db.ProviderConfig
err := s.database.Select(&dbConfigs, "SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs")
if err != nil {
// Log error
}
dbMap := make(map[string]db.ProviderConfig)
for _, cfg := range dbConfigs {
dbMap[cfg.ID] = cfg
}
providerIDs := []string{"openai", "gemini", "deepseek", "grok", "ollama"}
var result []gin.H
for _, id := range providerIDs {
var name string
var enabled bool
var baseURL string
switch id {
case "openai":
name = "OpenAI"
enabled = s.cfg.Providers.OpenAI.Enabled
baseURL = s.cfg.Providers.OpenAI.BaseURL
case "gemini":
name = "Google Gemini"
enabled = s.cfg.Providers.Gemini.Enabled
baseURL = s.cfg.Providers.Gemini.BaseURL
case "deepseek":
name = "DeepSeek"
enabled = s.cfg.Providers.DeepSeek.Enabled
baseURL = s.cfg.Providers.DeepSeek.BaseURL
case "grok":
name = "xAI Grok"
enabled = s.cfg.Providers.Grok.Enabled
baseURL = s.cfg.Providers.Grok.BaseURL
case "ollama":
name = "Ollama"
enabled = s.cfg.Providers.Ollama.Enabled
baseURL = s.cfg.Providers.Ollama.BaseURL
}
var balance float64
var threshold float64 = 5.0
var billingMode string
if dbCfg, ok := dbMap[id]; ok {
enabled = dbCfg.Enabled
if dbCfg.BaseURL != nil {
baseURL = *dbCfg.BaseURL
}
balance = dbCfg.CreditBalance
threshold = dbCfg.LowCreditThreshold
if dbCfg.BillingMode != nil {
billingMode = *dbCfg.BillingMode
}
}
status := "disabled"
if enabled {
if _, ok := s.providers[id]; ok {
status = "online"
} else {
status = "error"
}
}
result = append(result, gin.H{
"id": id,
"name": name,
"enabled": enabled,
"status": status,
"base_url": baseURL,
"credit_balance": balance,
"low_credit_threshold": threshold,
"billing_mode": billingMode,
})
}
c.JSON(http.StatusOK, SuccessResponse(result))
}
type UpdateProviderRequest struct {
Enabled bool `json:"enabled"`
BaseURL *string `json:"base_url"`
APIKey *string `json:"api_key"`
CreditBalance *float64 `json:"credit_balance"`
LowCreditThreshold *float64 `json:"low_credit_threshold"`
BillingMode *string `json:"billing_mode"`
}
func (s *Server) handleUpdateProvider(c *gin.Context) {
name := c.Param("name")
var req UpdateProviderRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
return
}
_, err := s.database.Exec(`
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, credit_balance, low_credit_threshold, billing_mode)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
enabled = excluded.enabled,
base_url = COALESCE(excluded.base_url, provider_configs.base_url),
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
updated_at = CURRENT_TIMESTAMP
`, name, strings.ToUpper(name), req.Enabled, req.BaseURL, req.APIKey, req.CreditBalance, req.LowCreditThreshold, req.BillingMode)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "Provider updated"}))
}
func (s *Server) handleGetModels(c *gin.Context) {
var models []db.ModelConfig
err := s.database.Select(&models, "SELECT * FROM model_configs")
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(models))
}
func (s *Server) handleGetUsers(c *gin.Context) {
var users []db.User
err := s.database.Select(&users, "SELECT id, username, display_name, role, must_change_password, created_at FROM users")
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(users))
}
type CreateUserRequest struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
DisplayName *string `json:"display_name"`
Role *string `json:"role"`
}
func (s *Server) handleCreateUser(c *gin.Context) {
var req CreateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
return
}
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), 12)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse("Failed to hash password"))
return
}
role := "viewer"
if req.Role != nil {
role = *req.Role
}
_, err = s.database.Exec("INSERT INTO users (username, password_hash, display_name, role, must_change_password) VALUES (?, ?, ?, ?, 1)",
req.Username, string(hash), req.DisplayName, role)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User created"}))
}
type UpdateUserRequest struct {
DisplayName *string `json:"display_name"`
Role *string `json:"role"`
Password *string `json:"password"`
MustChangePassword *bool `json:"must_change_password"`
}
func (s *Server) handleUpdateUser(c *gin.Context) {
id := c.Param("id")
var req UpdateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse("Invalid request"))
return
}
if req.DisplayName != nil {
s.database.Exec("UPDATE users SET display_name = ? WHERE id = ?", req.DisplayName, id)
}
if req.Role != nil {
s.database.Exec("UPDATE users SET role = ? WHERE id = ?", req.Role, id)
}
if req.MustChangePassword != nil {
s.database.Exec("UPDATE users SET must_change_password = ? WHERE id = ?", req.MustChangePassword, id)
}
if req.Password != nil {
hash, _ := bcrypt.GenerateFromPassword([]byte(*req.Password), 12)
s.database.Exec("UPDATE users SET password_hash = ? WHERE id = ?", string(hash), id)
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User updated"}))
}
func (s *Server) handleDeleteUser(c *gin.Context) {
id := c.Param("id")
session, _ := c.Get("session")
if sess, ok := session.(*Session); ok {
var username string
s.database.Get(&username, "SELECT username FROM users WHERE id = ?", id)
if username == sess.Username {
c.JSON(http.StatusBadRequest, ErrorResponse("Cannot delete your own account"))
return
}
}
_, err := s.database.Exec("DELETE FROM users WHERE id = ?", id)
if err != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse(err.Error()))
return
}
c.JSON(http.StatusOK, SuccessResponse(gin.H{"message": "User deleted"}))
}
func (s *Server) handleSystemHealth(c *gin.Context) {
c.JSON(http.StatusOK, SuccessResponse(gin.H{
"status": "ok",
"db": "connected",
}))
}

113
internal/server/logging.go Normal file
View File

@@ -0,0 +1,113 @@
package server
import (
"log"
"time"
"llm-proxy/internal/db"
)
type RequestLog struct {
Timestamp time.Time `json:"timestamp"`
ClientID string `json:"client_id"`
Provider string `json:"provider"`
Model string `json:"model"`
PromptTokens uint32 `json:"prompt_tokens"`
CompletionTokens uint32 `json:"completion_tokens"`
ReasoningTokens uint32 `json:"reasoning_tokens"`
TotalTokens uint32 `json:"total_tokens"`
CacheReadTokens uint32 `json:"cache_read_tokens"`
CacheWriteTokens uint32 `json:"cache_write_tokens"`
Cost float64 `json:"cost"`
HasImages bool `json:"has_images"`
Status string `json:"status"`
ErrorMessage string `json:"error_message,omitempty"`
DurationMS int64 `json:"duration_ms"`
}
type RequestLogger struct {
database *db.DB
hub *Hub
logChan chan RequestLog
}
func NewRequestLogger(database *db.DB, hub *Hub) *RequestLogger {
return &RequestLogger{
database: database,
hub: hub,
logChan: make(chan RequestLog, 100),
}
}
func (l *RequestLogger) Start() {
go func() {
for entry := range l.logChan {
l.processLog(entry)
}
}()
}
func (l *RequestLogger) LogRequest(entry RequestLog) {
select {
case l.logChan <- entry:
default:
log.Println("Request log channel full, dropping log entry")
}
}
func (l *RequestLogger) processLog(entry RequestLog) {
// Broadcast to dashboard
l.hub.broadcast <- map[string]interface{}{
"type": "request",
"channel": "requests",
"payload": entry,
}
// Insert into DB
tx, err := l.database.Begin()
if err != nil {
log.Printf("Failed to begin transaction for logging: %v", err)
return
}
defer tx.Rollback()
// Ensure client exists
_, _ = tx.Exec("INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')",
entry.ClientID, entry.ClientID)
// Insert log
_, err = tx.Exec(`
INSERT INTO llm_requests
(timestamp, client_id, provider, model, prompt_tokens, completion_tokens, reasoning_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status, error_message, duration_ms)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`, entry.Timestamp, entry.ClientID, entry.Provider, entry.Model,
entry.PromptTokens, entry.CompletionTokens, entry.ReasoningTokens, entry.TotalTokens,
entry.CacheReadTokens, entry.CacheWriteTokens, entry.Cost, entry.HasImages,
entry.Status, entry.ErrorMessage, entry.DurationMS)
if err != nil {
log.Printf("Failed to insert request log: %v", err)
return
}
// Update client stats
_, _ = tx.Exec(`
UPDATE clients SET
total_requests = total_requests + 1,
total_tokens = total_tokens + ?,
total_cost = total_cost + ?,
updated_at = CURRENT_TIMESTAMP
WHERE client_id = ?
`, entry.TotalTokens, entry.Cost, entry.ClientID)
// Update provider balance
if entry.Cost > 0 {
_, _ = tx.Exec("UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')",
entry.Cost, entry.Provider)
}
err = tx.Commit()
if err != nil {
log.Printf("Failed to commit logging transaction: %v", err)
}
}

326
internal/server/server.go Normal file
View File

@@ -0,0 +1,326 @@
package server
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"llm-proxy/internal/config"
"llm-proxy/internal/db"
"llm-proxy/internal/middleware"
"llm-proxy/internal/models"
"llm-proxy/internal/providers"
"llm-proxy/internal/utils"
"github.com/gin-gonic/gin"
)
type Server struct {
router *gin.Engine
cfg *config.Config
database *db.DB
providers map[string]providers.Provider
sessions *SessionManager
hub *Hub
logger *RequestLogger
}
func NewServer(cfg *config.Config, database *db.DB) *Server {
router := gin.Default()
hub := NewHub()
s := &Server{
router: router,
cfg: cfg,
database: database,
providers: make(map[string]providers.Provider),
sessions: NewSessionManager(cfg.KeyBytes, 24*time.Hour),
hub: hub,
logger: NewRequestLogger(database, hub),
}
// Initialize providers
if cfg.Providers.OpenAI.Enabled {
apiKey, _ := cfg.GetAPIKey("openai")
s.providers["openai"] = providers.NewOpenAIProvider(cfg.Providers.OpenAI, apiKey)
}
if cfg.Providers.Gemini.Enabled {
apiKey, _ := cfg.GetAPIKey("gemini")
s.providers["gemini"] = providers.NewGeminiProvider(cfg.Providers.Gemini, apiKey)
}
if cfg.Providers.DeepSeek.Enabled {
apiKey, _ := cfg.GetAPIKey("deepseek")
s.providers["deepseek"] = providers.NewDeepSeekProvider(cfg.Providers.DeepSeek, apiKey)
}
if cfg.Providers.Grok.Enabled {
apiKey, _ := cfg.GetAPIKey("grok")
s.providers["grok"] = providers.NewGrokProvider(cfg.Providers.Grok, apiKey)
}
s.setupRoutes()
return s
}
func (s *Server) setupRoutes() {
s.router.Use(middleware.AuthMiddleware(s.database))
// Static files
s.router.Static("/static", "./static")
s.router.StaticFile("/", "./static/index.html")
s.router.StaticFile("/favicon.ico", "./static/favicon.ico")
// WebSocket
s.router.GET("/ws", s.handleWebSocket)
v1 := s.router.Group("/v1")
{
v1.POST("/chat/completions", s.handleChatCompletions)
}
// Dashboard API Group
api := s.router.Group("/api")
{
api.POST("/auth/login", s.handleLogin)
api.GET("/auth/status", s.handleAuthStatus)
api.POST("/auth/logout", s.handleLogout)
// Protected dashboard routes (need admin session)
admin := api.Group("/")
admin.Use(s.adminAuthMiddleware())
{
admin.GET("/usage/summary", s.handleUsageSummary)
admin.GET("/usage/time-series", s.handleTimeSeries)
admin.GET("/analytics/breakdown", s.handleAnalyticsBreakdown)
admin.GET("/clients", s.handleGetClients)
admin.POST("/clients", s.handleCreateClient)
admin.DELETE("/clients/:id", s.handleDeleteClient)
admin.GET("/clients/:id/tokens", s.handleGetClientTokens)
admin.POST("/clients/:id/tokens", s.handleCreateClientToken)
admin.DELETE("/clients/:id/tokens/:token_id", s.handleDeleteClientToken)
admin.GET("/providers", s.handleGetProviders)
admin.PUT("/providers/:name", s.handleUpdateProvider)
admin.GET("/models", s.handleGetModels)
admin.GET("/users", s.handleGetUsers)
admin.POST("/users", s.handleCreateUser)
admin.PUT("/users/:id", s.handleUpdateUser)
admin.DELETE("/users/:id", s.handleDeleteUser)
admin.GET("/system/health", s.handleSystemHealth)
}
}
s.router.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
}
func (s *Server) handleChatCompletions(c *gin.Context) {
startTime := time.Now()
var req models.ChatCompletionRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Select provider based on model name
providerName := "openai" // default
if strings.Contains(req.Model, "gemini") {
providerName = "gemini"
} else if strings.Contains(req.Model, "deepseek") {
providerName = "deepseek"
} else if strings.Contains(req.Model, "grok") {
providerName = "grok"
}
provider, ok := s.providers[providerName]
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Provider %s not enabled or supported", providerName)})
return
}
// Convert ChatCompletionRequest to UnifiedRequest
unifiedReq := &models.UnifiedRequest{
Model: req.Model,
Messages: []models.UnifiedMessage{},
Temperature: req.Temperature,
TopP: req.TopP,
TopK: req.TopK,
N: req.N,
MaxTokens: req.MaxTokens,
PresencePenalty: req.PresencePenalty,
FrequencyPenalty: req.FrequencyPenalty,
Stream: req.Stream != nil && *req.Stream,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
}
// Handle Stop sequences
if req.Stop != nil {
var stop []string
if err := json.Unmarshal(req.Stop, &stop); err == nil {
unifiedReq.Stop = stop
} else {
var singleStop string
if err := json.Unmarshal(req.Stop, &singleStop); err == nil {
unifiedReq.Stop = []string{singleStop}
}
}
}
// Convert messages
for _, msg := range req.Messages {
unifiedMsg := models.UnifiedMessage{
Role: msg.Role,
Content: []models.UnifiedContentPart{},
ReasoningContent: msg.ReasoningContent,
ToolCalls: msg.ToolCalls,
Name: msg.Name,
ToolCallID: msg.ToolCallID,
}
// Handle multimodal content
if strContent, ok := msg.Content.(string); ok {
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
Type: "text",
Text: strContent,
})
} else if parts, ok := msg.Content.([]interface{}); ok {
for _, part := range parts {
if partMap, ok := part.(map[string]interface{}); ok {
partType, _ := partMap["type"].(string)
if partType == "text" {
text, _ := partMap["text"].(string)
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
Type: "text",
Text: text,
})
} else if partType == "image_url" {
if imgURLMap, ok := partMap["image_url"].(map[string]interface{}); ok {
url, _ := imgURLMap["url"].(string)
imageInput := &models.ImageInput{}
if strings.HasPrefix(url, "data:") {
mime, data, err := utils.ParseDataURL(url)
if err == nil {
imageInput.Base64 = data
imageInput.MimeType = mime
}
} else {
imageInput.URL = url
}
unifiedMsg.Content = append(unifiedMsg.Content, models.UnifiedContentPart{
Type: "image",
Image: imageInput,
})
unifiedReq.HasImages = true
}
}
}
}
}
unifiedReq.Messages = append(unifiedReq.Messages, unifiedMsg)
}
clientID := "default"
if auth, ok := c.Get("auth"); ok {
if authInfo, ok := auth.(models.AuthInfo); ok {
unifiedReq.ClientID = authInfo.ClientID
clientID = authInfo.ClientID
}
} else {
unifiedReq.ClientID = clientID
}
if unifiedReq.Stream {
ch, err := provider.ChatCompletionStream(c.Request.Context(), unifiedReq)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
var lastUsage *models.Usage
c.Stream(func(w io.Writer) bool {
chunk, ok := <-ch
if !ok {
fmt.Fprintf(w, "data: [DONE]\n\n")
s.logRequest(startTime, clientID, providerName, req.Model, lastUsage, nil, unifiedReq.HasImages)
return false
}
if chunk.Usage != nil {
lastUsage = chunk.Usage
}
data, err := json.Marshal(chunk)
if err != nil {
return false
}
fmt.Fprintf(w, "data: %s\n\n", data)
return true
})
return
}
resp, err := provider.ChatCompletion(c.Request.Context(), unifiedReq)
if err != nil {
s.logRequest(startTime, clientID, providerName, req.Model, nil, err, unifiedReq.HasImages)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
s.logRequest(startTime, clientID, providerName, req.Model, resp.Usage, nil, unifiedReq.HasImages)
c.JSON(http.StatusOK, resp)
}
func (s *Server) logRequest(start time.Time, clientID, provider, model string, usage *models.Usage, err error, hasImages bool) {
entry := RequestLog{
Timestamp: start,
ClientID: clientID,
Provider: provider,
Model: model,
Status: "success",
DurationMS: time.Since(start).Milliseconds(),
HasImages: hasImages,
}
if err != nil {
entry.Status = "error"
entry.ErrorMessage = err.Error()
}
if usage != nil {
entry.PromptTokens = usage.PromptTokens
entry.CompletionTokens = usage.CompletionTokens
entry.TotalTokens = usage.TotalTokens
if usage.ReasoningTokens != nil {
entry.ReasoningTokens = *usage.ReasoningTokens
}
if usage.CacheReadTokens != nil {
entry.CacheReadTokens = *usage.CacheReadTokens
}
if usage.CacheWriteTokens != nil {
entry.CacheWriteTokens = *usage.CacheWriteTokens
}
// TODO: Calculate cost properly based on pricing
entry.Cost = 0.0
}
s.logger.LogRequest(entry)
}
func (s *Server) Run() error {
go s.hub.Run()
s.logger.Start()
addr := fmt.Sprintf("%s:%d", s.cfg.Server.Host, s.cfg.Server.Port)
return s.router.Run(addr)
}

151
internal/server/sessions.go Normal file
View File

@@ -0,0 +1,151 @@
package server
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"sync"
"time"
"github.com/google/uuid"
)
type Session struct {
Username string `json:"username"`
Role string `json:"role"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
SessionID string `json:"session_id"`
}
type SessionManager struct {
sessions map[string]Session
mu sync.RWMutex
secret []byte
ttl time.Duration
}
type sessionPayload struct {
SessionID string `json:"session_id"`
Username string `json:"username"`
Role string `json:"role"`
Exp int64 `json:"exp"`
}
func NewSessionManager(secret []byte, ttl time.Duration) *SessionManager {
return &SessionManager{
sessions: make(map[string]Session),
secret: secret,
ttl: ttl,
}
}
func (m *SessionManager) CreateSession(username, role string) (string, error) {
sessionID := uuid.New().String()
now := time.Now()
expiresAt := now.Add(m.ttl)
m.mu.Lock()
m.sessions[sessionID] = Session{
Username: username,
Role: role,
CreatedAt: now,
ExpiresAt: expiresAt,
SessionID: sessionID,
}
m.mu.Unlock()
return m.createSignedToken(sessionID, username, role, expiresAt.Unix())
}
func (m *SessionManager) createSignedToken(sessionID, username, role string, exp int64) (string, error) {
payload := sessionPayload{
SessionID: sessionID,
Username: username,
Role: role,
Exp: exp,
}
payloadJSON, err := json.Marshal(payload)
if err != nil {
return "", err
}
payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON)
h := hmac.New(sha256.New, m.secret)
h.Write(payloadJSON)
signature := h.Sum(nil)
signatureB64 := base64.RawURLEncoding.EncodeToString(signature)
return fmt.Sprintf("%s.%s", payloadB64, signatureB64), nil
}
func (m *SessionManager) ValidateSession(token string) (*Session, string, error) {
parts := strings.Split(token, ".")
if len(parts) != 2 {
return nil, "", fmt.Errorf("invalid token format")
}
payloadB64 := parts[0]
signatureB64 := parts[1]
payloadJSON, err := base64.RawURLEncoding.DecodeString(payloadB64)
if err != nil {
return nil, "", err
}
signature, err := base64.RawURLEncoding.DecodeString(signatureB64)
if err != nil {
return nil, "", err
}
h := hmac.New(sha256.New, m.secret)
h.Write(payloadJSON)
if !hmac.Equal(signature, h.Sum(nil)) {
return nil, "", fmt.Errorf("invalid signature")
}
var payload sessionPayload
if err := json.Unmarshal(payloadJSON, &payload); err != nil {
return nil, "", err
}
if time.Now().Unix() > payload.Exp {
return nil, "", fmt.Errorf("token expired")
}
m.mu.RLock()
session, ok := m.sessions[payload.SessionID]
m.mu.RUnlock()
if !ok {
return nil, "", fmt.Errorf("session not found")
}
return &session, "", nil
}
func (m *SessionManager) RevokeSession(token string) {
parts := strings.Split(token, ".")
if len(parts) != 2 {
return
}
payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return
}
var payload sessionPayload
if err := json.Unmarshal(payloadJSON, &payload); err != nil {
return
}
m.mu.Lock()
delete(m.sessions, payload.SessionID)
m.mu.Unlock()
}

View File

@@ -0,0 +1,98 @@
package server
import (
"log"
"net/http"
"sync"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // In production, refine this
},
}
type Hub struct {
clients map[*websocket.Conn]bool
broadcast chan interface{}
register chan *websocket.Conn
unregister chan *websocket.Conn
mu sync.Mutex
}
func NewHub() *Hub {
return &Hub{
clients: make(map[*websocket.Conn]bool),
broadcast: make(chan interface{}),
register: make(chan *websocket.Conn),
unregister: make(chan *websocket.Conn),
}
}
func (h *Hub) Run() {
for {
select {
case client := <-h.register:
h.mu.Lock()
h.clients[client] = true
h.mu.Unlock()
log.Println("WebSocket client registered")
case client := <-h.unregister:
h.mu.Lock()
if _, ok := h.clients[client]; ok {
delete(h.clients, client)
client.Close()
}
h.mu.Unlock()
log.Println("WebSocket client unregistered")
case message := <-h.broadcast:
h.mu.Lock()
for client := range h.clients {
err := client.WriteJSON(message)
if err != nil {
log.Printf("WebSocket error: %v", err)
client.Close()
delete(h.clients, client)
}
}
h.mu.Unlock()
}
}
}
func (s *Server) handleWebSocket(c *gin.Context) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("Failed to set websocket upgrade: %v", err)
return
}
s.hub.register <- conn
defer func() {
s.hub.unregister <- conn
}()
// Initial message
conn.WriteJSON(gin.H{
"type": "connected",
"message": "Connected to LLM Proxy Dashboard",
})
for {
var msg map[string]interface{}
err := conn.ReadJSON(&msg)
if err != nil {
break
}
if msg["type"] == "ping" {
conn.WriteJSON(gin.H{"type": "pong", "payload": gin.H{}})
}
}
}

19
internal/utils/utils.go Normal file
View File

@@ -0,0 +1,19 @@
package utils
import (
"fmt"
"strings"
)
func ParseDataURL(dataURL string) (string, string, error) {
if !strings.HasPrefix(dataURL, "data:") {
return "", "", fmt.Errorf("not a data URL")
}
parts := strings.Split(dataURL[5:], ";base64,")
if len(parts) != 2 {
return "", "", fmt.Errorf("invalid data URL format")
}
return parts[0], parts[1], nil
}

View File

@@ -1,2 +0,0 @@
max_width = 120
use_field_init_shorthand = true

View File

@@ -1,45 +0,0 @@
use axum::{extract::FromRequestParts, http::request::Parts};
use crate::errors::AppError;
#[derive(Debug, Clone)]
pub struct AuthInfo {
pub token: String,
pub client_id: String,
}
pub struct AuthenticatedClient {
pub info: AuthInfo,
}
impl<S> FromRequestParts<S> for AuthenticatedClient
where
S: Send + Sync,
{
type Rejection = AppError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
// Retrieve AuthInfo from request extensions, where it was placed by rate_limit_middleware
let info = parts
.extensions
.get::<AuthInfo>()
.cloned()
.ok_or_else(|| AppError::AuthError("Authentication info not found in request".to_string()))?;
Ok(AuthenticatedClient { info })
}
}
impl std::ops::Deref for AuthenticatedClient {
type Target = AuthInfo;
fn deref(&self) -> &Self::Target {
&self.info
}
}
pub fn validate_token(token: &str, valid_tokens: &[String]) -> bool {
// Simple validation against list of tokens
// In production, use proper token validation (JWT, database lookup, etc.)
valid_tokens.contains(&token.to_string())
}

View File

@@ -1,304 +0,0 @@
//! Client management for LLM proxy
//!
//! This module handles:
//! 1. Client registration and management
//! 2. Client usage tracking
//! 3. Client rate limit configuration
use anyhow::Result;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{Row, SqlitePool};
use tracing::{info, warn};
/// Client information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Client {
pub id: i64,
pub client_id: String,
pub name: String,
pub description: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub is_active: bool,
pub rate_limit_per_minute: i64,
pub total_requests: i64,
pub total_tokens: i64,
pub total_cost: f64,
}
/// Client creation request
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateClientRequest {
pub client_id: String,
pub name: String,
pub description: String,
pub rate_limit_per_minute: Option<i64>,
}
/// Client update request
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateClientRequest {
pub name: Option<String>,
pub description: Option<String>,
pub is_active: Option<bool>,
pub rate_limit_per_minute: Option<i64>,
}
/// Client manager for database operations
pub struct ClientManager {
db_pool: SqlitePool,
}
impl ClientManager {
pub fn new(db_pool: SqlitePool) -> Self {
Self { db_pool }
}
/// Create a new client
pub async fn create_client(&self, request: CreateClientRequest) -> Result<Client> {
let rate_limit = request.rate_limit_per_minute.unwrap_or(60);
// First insert the client
sqlx::query(
r#"
INSERT INTO clients (client_id, name, description, rate_limit_per_minute)
VALUES (?, ?, ?, ?)
"#,
)
.bind(&request.client_id)
.bind(&request.name)
.bind(&request.description)
.bind(rate_limit)
.execute(&self.db_pool)
.await?;
// Then fetch the created client
let client = self
.get_client(&request.client_id)
.await?
.ok_or_else(|| anyhow::anyhow!("Failed to retrieve created client"))?;
info!("Created client: {} ({})", client.name, client.client_id);
Ok(client)
}
/// Get a client by ID
pub async fn get_client(&self, client_id: &str) -> Result<Option<Client>> {
let row = sqlx::query(
r#"
SELECT
id, client_id, name, description,
created_at, updated_at, is_active,
rate_limit_per_minute, total_requests, total_tokens, total_cost
FROM clients
WHERE client_id = ?
"#,
)
.bind(client_id)
.fetch_optional(&self.db_pool)
.await?;
if let Some(row) = row {
let client = Client {
id: row.get("id"),
client_id: row.get("client_id"),
name: row.get("name"),
description: row.get("description"),
created_at: row.get("created_at"),
updated_at: row.get("updated_at"),
is_active: row.get("is_active"),
rate_limit_per_minute: row.get("rate_limit_per_minute"),
total_requests: row.get("total_requests"),
total_tokens: row.get("total_tokens"),
total_cost: row.get("total_cost"),
};
Ok(Some(client))
} else {
Ok(None)
}
}
/// Update a client
pub async fn update_client(&self, client_id: &str, request: UpdateClientRequest) -> Result<Option<Client>> {
// First, get the current client to check if it exists
let current_client = self.get_client(client_id).await?;
if current_client.is_none() {
return Ok(None);
}
// Build update query dynamically based on provided fields
let mut query_builder = sqlx::QueryBuilder::new("UPDATE clients SET ");
let mut has_updates = false;
if let Some(name) = &request.name {
query_builder.push("name = ");
query_builder.push_bind(name);
has_updates = true;
}
if let Some(description) = &request.description {
if has_updates {
query_builder.push(", ");
}
query_builder.push("description = ");
query_builder.push_bind(description);
has_updates = true;
}
if let Some(is_active) = request.is_active {
if has_updates {
query_builder.push(", ");
}
query_builder.push("is_active = ");
query_builder.push_bind(is_active);
has_updates = true;
}
if let Some(rate_limit) = request.rate_limit_per_minute {
if has_updates {
query_builder.push(", ");
}
query_builder.push("rate_limit_per_minute = ");
query_builder.push_bind(rate_limit);
has_updates = true;
}
// Always update the updated_at timestamp
if has_updates {
query_builder.push(", ");
}
query_builder.push("updated_at = CURRENT_TIMESTAMP");
if !has_updates {
// No updates to make
return self.get_client(client_id).await;
}
query_builder.push(" WHERE client_id = ");
query_builder.push_bind(client_id);
let query = query_builder.build();
query.execute(&self.db_pool).await?;
// Fetch the updated client
let updated_client = self.get_client(client_id).await?;
if updated_client.is_some() {
info!("Updated client: {}", client_id);
}
Ok(updated_client)
}
/// List all clients
pub async fn list_clients(&self, limit: Option<i64>, offset: Option<i64>) -> Result<Vec<Client>> {
let limit = limit.unwrap_or(100);
let offset = offset.unwrap_or(0);
let rows = sqlx::query(
r#"
SELECT
id, client_id, name, description,
created_at, updated_at, is_active,
rate_limit_per_minute, total_requests, total_tokens, total_cost
FROM clients
ORDER BY created_at DESC
LIMIT ? OFFSET ?
"#,
)
.bind(limit)
.bind(offset)
.fetch_all(&self.db_pool)
.await?;
let mut clients = Vec::new();
for row in rows {
let client = Client {
id: row.get("id"),
client_id: row.get("client_id"),
name: row.get("name"),
description: row.get("description"),
created_at: row.get("created_at"),
updated_at: row.get("updated_at"),
is_active: row.get("is_active"),
rate_limit_per_minute: row.get("rate_limit_per_minute"),
total_requests: row.get("total_requests"),
total_tokens: row.get("total_tokens"),
total_cost: row.get("total_cost"),
};
clients.push(client);
}
Ok(clients)
}
/// Delete a client
pub async fn delete_client(&self, client_id: &str) -> Result<bool> {
let result = sqlx::query("DELETE FROM clients WHERE client_id = ?")
.bind(client_id)
.execute(&self.db_pool)
.await?;
let deleted = result.rows_affected() > 0;
if deleted {
info!("Deleted client: {}", client_id);
} else {
warn!("Client not found for deletion: {}", client_id);
}
Ok(deleted)
}
/// Update client usage statistics after a request
pub async fn update_client_usage(&self, client_id: &str, tokens: i64, cost: f64) -> Result<()> {
sqlx::query(
r#"
UPDATE clients
SET
total_requests = total_requests + 1,
total_tokens = total_tokens + ?,
total_cost = total_cost + ?,
updated_at = CURRENT_TIMESTAMP
WHERE client_id = ?
"#,
)
.bind(tokens)
.bind(cost)
.bind(client_id)
.execute(&self.db_pool)
.await?;
Ok(())
}
/// Get client usage statistics
pub async fn get_client_usage(&self, client_id: &str) -> Result<Option<(i64, i64, f64)>> {
let row = sqlx::query(
r#"
SELECT total_requests, total_tokens, total_cost
FROM clients
WHERE client_id = ?
"#,
)
.bind(client_id)
.fetch_optional(&self.db_pool)
.await?;
if let Some(row) = row {
let total_requests: i64 = row.get("total_requests");
let total_tokens: i64 = row.get("total_tokens");
let total_cost: f64 = row.get("total_cost");
Ok(Some((total_requests, total_tokens, total_cost)))
} else {
Ok(None)
}
}
/// Check if a client exists and is active
pub async fn validate_client(&self, client_id: &str) -> Result<bool> {
let client = self.get_client(client_id).await?;
Ok(client.map(|c| c.is_active).unwrap_or(false))
}
}

View File

@@ -1,260 +0,0 @@
use anyhow::Result;
use base64::{Engine as _};
use config::{Config, File, FileFormat};
use hex;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
pub port: u16,
pub host: String,
#[serde(deserialize_with = "deserialize_vec_or_string")]
pub auth_tokens: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
pub path: PathBuf,
pub max_connections: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub openai: OpenAIConfig,
pub gemini: GeminiConfig,
pub deepseek: DeepSeekConfig,
pub grok: GrokConfig,
pub ollama: OllamaConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIConfig {
pub api_key_env: String,
pub base_url: String,
pub default_model: String,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeminiConfig {
pub api_key_env: String,
pub base_url: String,
pub default_model: String,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeepSeekConfig {
pub api_key_env: String,
pub base_url: String,
pub default_model: String,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrokConfig {
pub api_key_env: String,
pub base_url: String,
pub default_model: String,
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaConfig {
pub base_url: String,
pub enabled: bool,
#[serde(deserialize_with = "deserialize_vec_or_string")]
pub models: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMappingConfig {
pub patterns: Vec<(String, String)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PricingConfig {
pub openai: Vec<ModelPricing>,
pub gemini: Vec<ModelPricing>,
pub deepseek: Vec<ModelPricing>,
pub grok: Vec<ModelPricing>,
pub ollama: Vec<ModelPricing>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelPricing {
pub model: String,
pub prompt_tokens_per_million: f64,
pub completion_tokens_per_million: f64,
}
#[derive(Debug, Clone)]
pub struct AppConfig {
pub server: ServerConfig,
pub database: DatabaseConfig,
pub providers: ProviderConfig,
pub model_mapping: ModelMappingConfig,
pub pricing: PricingConfig,
pub config_path: Option<PathBuf>,
pub encryption_key: String,
}
impl AppConfig {
pub async fn load() -> Result<Arc<Self>> {
Self::load_from_path(None).await
}
/// Load configuration from a specific path (for testing)
pub async fn load_from_path(config_path: Option<PathBuf>) -> Result<Arc<Self>> {
// Load configuration from multiple sources
let mut config_builder = Config::builder();
// Default configuration
config_builder = config_builder
.set_default("server.port", 8080)?
.set_default("server.host", "0.0.0.0")?
.set_default("server.auth_tokens", Vec::<String>::new())?
.set_default("database.path", "./data/llm_proxy.db")?
.set_default("database.max_connections", 10)?
.set_default("providers.openai.api_key_env", "OPENAI_API_KEY")?
.set_default("providers.openai.base_url", "https://api.openai.com/v1")?
.set_default("providers.openai.default_model", "gpt-4o")?
.set_default("providers.openai.enabled", true)?
.set_default("providers.gemini.api_key_env", "GEMINI_API_KEY")?
.set_default(
"providers.gemini.base_url",
"https://generativelanguage.googleapis.com/v1",
)?
.set_default("providers.gemini.default_model", "gemini-2.0-flash")?
.set_default("providers.gemini.enabled", true)?
.set_default("providers.deepseek.api_key_env", "DEEPSEEK_API_KEY")?
.set_default("providers.deepseek.base_url", "https://api.deepseek.com")?
.set_default("providers.deepseek.default_model", "deepseek-reasoner")?
.set_default("providers.deepseek.enabled", true)?
.set_default("providers.grok.api_key_env", "GROK_API_KEY")?
.set_default("providers.grok.base_url", "https://api.x.ai/v1")?
.set_default("providers.grok.default_model", "grok-beta")?
.set_default("providers.grok.enabled", true)?
.set_default("providers.ollama.base_url", "http://localhost:11434/v1")?
.set_default("providers.ollama.enabled", false)?
.set_default("providers.ollama.models", Vec::<String>::new())?
.set_default("encryption_key", "")?;
// Load from config file if exists
// Priority: explicit path arg > LLM_PROXY__CONFIG_PATH env var > ./config.toml
let config_path = config_path
.or_else(|| std::env::var("LLM_PROXY__CONFIG_PATH").ok().map(PathBuf::from))
.unwrap_or_else(|| {
std::env::current_dir()
.unwrap_or_else(|_| PathBuf::from("."))
.join("config.toml")
});
if config_path.exists() {
config_builder = config_builder.add_source(File::from(config_path.clone()).format(FileFormat::Toml));
}
// Load from .env file
dotenvy::dotenv().ok();
// Load from environment variables (with prefix "LLM_PROXY_")
config_builder = config_builder.add_source(
config::Environment::with_prefix("LLM_PROXY")
.separator("__")
.try_parsing(true),
);
let config = config_builder.build()?;
// Deserialize configuration
let server: ServerConfig = config.get("server")?;
let database: DatabaseConfig = config.get("database")?;
let providers: ProviderConfig = config.get("providers")?;
let encryption_key: String = config.get("encryption_key")?;
// Validate encryption key length (must be 32 bytes after hex or base64 decoding)
if encryption_key.is_empty() {
anyhow::bail!("Encryption key is required (LLM_PROXY__ENCRYPTION_KEY environment variable)");
}
// Try hex decode first, then base64
let key_bytes = hex::decode(&encryption_key)
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&encryption_key))
.map_err(|e| anyhow::anyhow!("Encryption key must be hex or base64 encoded: {}", e))?;
if key_bytes.len() != 32 {
anyhow::bail!("Encryption key must be 32 bytes (256 bits), got {} bytes", key_bytes.len());
}
// For now, use empty model mapping and pricing (will be populated later)
let model_mapping = ModelMappingConfig { patterns: vec![] };
let pricing = PricingConfig {
openai: vec![],
gemini: vec![],
deepseek: vec![],
grok: vec![],
ollama: vec![],
};
Ok(Arc::new(AppConfig {
server,
database,
providers,
model_mapping,
pricing,
config_path: Some(config_path),
encryption_key,
}))
}
pub fn get_api_key(&self, provider: &str) -> Result<String> {
let env_var = match provider {
"openai" => &self.providers.openai.api_key_env,
"gemini" => &self.providers.gemini.api_key_env,
"deepseek" => &self.providers.deepseek.api_key_env,
"grok" => &self.providers.grok.api_key_env,
_ => return Err(anyhow::anyhow!("Unknown provider: {}", provider)),
};
std::env::var(env_var).map_err(|_| anyhow::anyhow!("Environment variable {} not set for {}", env_var, provider))
}
}
/// Helper function to deserialize a Vec<String> from either a sequence or a comma-separated string
fn deserialize_vec_or_string<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
struct VecOrString;
impl<'de> serde::de::Visitor<'de> for VecOrString {
type Value = Vec<String>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a sequence or a comma-separated string")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(value
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect())
}
fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
where
S: serde::de::SeqAccess<'de>,
{
let mut vec = Vec::new();
while let Some(element) = seq.next_element()? {
vec.push(element);
}
Ok(vec)
}
}
deserializer.deserialize_any(VecOrString)
}

View File

@@ -1,229 +0,0 @@
use axum::{extract::State, http::{HeaderMap, HeaderValue}, response::{Json, IntoResponse}};
use bcrypt;
use serde::Deserialize;
use sqlx::Row;
use tracing::warn;
use super::{ApiResponse, DashboardState};
// Authentication handlers
#[derive(Deserialize)]
pub(super) struct LoginRequest {
pub(super) username: String,
pub(super) password: String,
}
pub(super) async fn handle_login(
State(state): State<DashboardState>,
Json(payload): Json<LoginRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let user_result = sqlx::query(
"SELECT username, password_hash, display_name, role, must_change_password FROM users WHERE username = ?",
)
.bind(&payload.username)
.fetch_optional(pool)
.await;
match user_result {
Ok(Some(row)) => {
let hash = row.get::<String, _>("password_hash");
if bcrypt::verify(&payload.password, &hash).unwrap_or(false) {
let username = row.get::<String, _>("username");
let role = row.get::<String, _>("role");
let display_name = row
.get::<Option<String>, _>("display_name")
.unwrap_or_else(|| username.clone());
let must_change_password = row.get::<bool, _>("must_change_password");
let token = state
.session_manager
.create_session(username.clone(), role.clone())
.await;
Json(ApiResponse::success(serde_json::json!({
"token": token,
"must_change_password": must_change_password,
"user": {
"username": username,
"name": display_name,
"role": role
}
})))
} else {
Json(ApiResponse::error("Invalid username or password".to_string()))
}
}
Ok(None) => Json(ApiResponse::error("Invalid username or password".to_string())),
Err(e) => {
warn!("Database error during login: {}", e);
Json(ApiResponse::error("Login failed due to system error".to_string()))
}
}
}
pub(super) async fn handle_auth_status(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> impl IntoResponse {
let token = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
if let Some(token) = token
&& let Some((session, new_token)) = state.session_manager.validate_session_with_refresh(token).await
{
// Look up display_name from DB
let display_name = sqlx::query_scalar::<_, Option<String>>(
"SELECT display_name FROM users WHERE username = ?",
)
.bind(&session.username)
.fetch_optional(&state.app_state.db_pool)
.await
.ok()
.flatten()
.flatten()
.unwrap_or_else(|| session.username.clone());
let mut headers = HeaderMap::new();
if let Some(refreshed_token) = new_token {
if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) {
headers.insert("X-Refreshed-Token", header_value);
}
}
return (headers, Json(ApiResponse::success(serde_json::json!({
"authenticated": true,
"user": {
"username": session.username,
"name": display_name,
"role": session.role
}
}))));
}
(HeaderMap::new(), Json(ApiResponse::error("Not authenticated".to_string())))
}
#[derive(Deserialize)]
pub(super) struct ChangePasswordRequest {
pub(super) current_password: String,
pub(super) new_password: String,
}
pub(super) async fn handle_change_password(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Json(payload): Json<ChangePasswordRequest>,
) -> impl IntoResponse {
let pool = &state.app_state.db_pool;
// Extract the authenticated user from the session token
let token = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
let (session, new_token) = match token {
Some(t) => match state.session_manager.validate_session_with_refresh(t).await {
Some((session, new_token)) => (Some(session), new_token),
None => (None, None),
},
None => (None, None),
};
let mut response_headers = HeaderMap::new();
if let Some(refreshed_token) = new_token {
if let Ok(header_value) = HeaderValue::from_str(&refreshed_token) {
response_headers.insert("X-Refreshed-Token", header_value);
}
}
let username = match session {
Some(s) => s.username,
None => return (response_headers, Json(ApiResponse::error("Not authenticated".to_string()))),
};
let user_result = sqlx::query("SELECT password_hash FROM users WHERE username = ?")
.bind(&username)
.fetch_one(pool)
.await;
match user_result {
Ok(row) => {
let hash = row.get::<String, _>("password_hash");
if bcrypt::verify(&payload.current_password, &hash).unwrap_or(false) {
let new_hash = match bcrypt::hash(&payload.new_password, 12) {
Ok(h) => h,
Err(_) => return (response_headers, Json(ApiResponse::error("Failed to hash new password".to_string()))),
};
let update_result = sqlx::query(
"UPDATE users SET password_hash = ?, must_change_password = FALSE WHERE username = ?",
)
.bind(new_hash)
.bind(&username)
.execute(pool)
.await;
match update_result {
Ok(_) => (response_headers, Json(ApiResponse::success(
serde_json::json!({ "message": "Password updated successfully" }),
))),
Err(e) => (response_headers, Json(ApiResponse::error(format!("Failed to update database: {}", e)))),
}
} else {
(response_headers, Json(ApiResponse::error("Current password incorrect".to_string())))
}
}
Err(e) => (response_headers, Json(ApiResponse::error(format!("User not found: {}", e)))),
}
}
pub(super) async fn handle_logout(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let token = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
if let Some(token) = token {
state.session_manager.revoke_session(token).await;
}
Json(ApiResponse::success(serde_json::json!({ "message": "Logged out" })))
}
/// Helper: Extract and validate a session from the Authorization header.
/// Returns the Session and optional new token if refreshed, or an error response.
pub(super) async fn extract_session(
state: &DashboardState,
headers: &axum::http::HeaderMap,
) -> Result<(super::sessions::Session, Option<String>), Json<ApiResponse<serde_json::Value>>> {
let token = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
match token {
Some(t) => match state.session_manager.validate_session_with_refresh(t).await {
Some((session, new_token)) => Ok((session, new_token)),
None => Err(Json(ApiResponse::error("Session expired or invalid".to_string()))),
},
None => Err(Json(ApiResponse::error("Not authenticated".to_string()))),
}
}
/// Helper: Extract session and require admin role.
/// Returns session and optional new token if refreshed.
pub(super) async fn require_admin(
state: &DashboardState,
headers: &axum::http::HeaderMap,
) -> Result<(super::sessions::Session, Option<String>), Json<ApiResponse<serde_json::Value>>> {
let (session, new_token) = extract_session(state, headers).await?;
if session.role != "admin" {
return Err(Json(ApiResponse::error("Admin access required".to_string())));
}
Ok((session, new_token))
}

View File

@@ -1,538 +0,0 @@
use axum::{
extract::{Path, State},
response::Json,
};
use chrono;
use rand::Rng;
use serde::Deserialize;
use serde_json;
use sqlx::Row;
use tracing::warn;
use uuid;
use super::{ApiResponse, DashboardState};
/// Generate a random API token: sk-{48 hex chars}
fn generate_token() -> String {
let mut rng = rand::rng();
let bytes: Vec<u8> = (0..24).map(|_| rng.random::<u8>()).collect();
format!("sk-{}", hex::encode(bytes))
}
#[derive(Deserialize)]
pub(super) struct CreateClientRequest {
pub(super) name: String,
pub(super) client_id: Option<String>,
}
#[derive(Deserialize)]
pub(super) struct UpdateClientPayload {
pub(super) name: Option<String>,
pub(super) description: Option<String>,
pub(super) is_active: Option<bool>,
pub(super) rate_limit_per_minute: Option<i64>,
}
pub(super) async fn handle_get_clients(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
client_id as id,
name,
description,
created_at,
total_requests,
total_tokens,
total_cost,
is_active,
rate_limit_per_minute
FROM clients
ORDER BY created_at DESC
"#,
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let clients: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
serde_json::json!({
"id": row.get::<String, _>("id"),
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
"description": row.get::<Option<String>, _>("description"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
"requests_count": row.get::<i64, _>("total_requests"),
"total_tokens": row.get::<i64, _>("total_tokens"),
"total_cost": row.get::<f64, _>("total_cost"),
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
"rate_limit_per_minute": row.get::<Option<i64>, _>("rate_limit_per_minute"),
})
})
.collect();
Json(ApiResponse::success(serde_json::json!(clients)))
}
Err(e) => {
warn!("Failed to fetch clients: {}", e);
Json(ApiResponse::error("Failed to fetch clients".to_string()))
}
}
}
pub(super) async fn handle_create_client(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Json(payload): Json<CreateClientRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let client_id = payload
.client_id
.unwrap_or_else(|| format!("client-{}", &uuid::Uuid::new_v4().to_string()[..8]));
let result = sqlx::query(
r#"
INSERT INTO clients (client_id, name, is_active)
VALUES (?, ?, TRUE)
RETURNING *
"#,
)
.bind(&client_id)
.bind(&payload.name)
.fetch_one(pool)
.await;
match result {
Ok(row) => {
// Auto-generate a token for the new client
let token = generate_token();
let token_result = sqlx::query(
"INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, 'default')",
)
.bind(&client_id)
.bind(&token)
.execute(pool)
.await;
if let Err(e) = token_result {
warn!("Client created but failed to generate token: {}", e);
}
Json(ApiResponse::success(serde_json::json!({
"id": row.get::<String, _>("client_id"),
"name": row.get::<Option<String>, _>("name"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
"status": "active",
"token": token,
})))
}
Err(e) => {
warn!("Failed to create client: {}", e);
Json(ApiResponse::error(format!("Failed to create client: {}", e)))
}
}
}
pub(super) async fn handle_get_client(
State(state): State<DashboardState>,
Path(id): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
c.client_id as id,
c.name,
c.description,
c.is_active,
c.rate_limit_per_minute,
c.created_at,
COALESCE(c.total_tokens, 0) as total_tokens,
COALESCE(c.total_cost, 0.0) as total_cost,
COUNT(r.id) as total_requests,
MAX(r.timestamp) as last_request
FROM clients c
LEFT JOIN llm_requests r ON c.client_id = r.client_id
WHERE c.client_id = ?
GROUP BY c.client_id
"#,
)
.bind(&id)
.fetch_optional(pool)
.await;
match result {
Ok(Some(row)) => Json(ApiResponse::success(serde_json::json!({
"id": row.get::<String, _>("id"),
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
"description": row.get::<Option<String>, _>("description"),
"is_active": row.get::<bool, _>("is_active"),
"rate_limit_per_minute": row.get::<Option<i64>, _>("rate_limit_per_minute"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
"total_tokens": row.get::<i64, _>("total_tokens"),
"total_cost": row.get::<f64, _>("total_cost"),
"total_requests": row.get::<i64, _>("total_requests"),
"last_request": row.get::<Option<chrono::DateTime<chrono::Utc>>, _>("last_request"),
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
}))),
Ok(None) => Json(ApiResponse::error(format!("Client '{}' not found", id))),
Err(e) => {
warn!("Failed to fetch client: {}", e);
Json(ApiResponse::error(format!("Failed to fetch client: {}", e)))
}
}
}
pub(super) async fn handle_update_client(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(id): Path<String>,
Json(payload): Json<UpdateClientPayload>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Build dynamic UPDATE query from provided fields
let mut sets = Vec::new();
let mut binds: Vec<String> = Vec::new();
if let Some(ref name) = payload.name {
sets.push("name = ?");
binds.push(name.clone());
}
if let Some(ref desc) = payload.description {
sets.push("description = ?");
binds.push(desc.clone());
}
if payload.is_active.is_some() {
sets.push("is_active = ?");
}
if payload.rate_limit_per_minute.is_some() {
sets.push("rate_limit_per_minute = ?");
}
if sets.is_empty() {
return Json(ApiResponse::error("No fields to update".to_string()));
}
// Always update updated_at
sets.push("updated_at = CURRENT_TIMESTAMP");
let sql = format!("UPDATE clients SET {} WHERE client_id = ?", sets.join(", "));
let mut query = sqlx::query(&sql);
// Bind in the same order as sets
for b in &binds {
query = query.bind(b);
}
if let Some(active) = payload.is_active {
query = query.bind(active);
}
if let Some(rate) = payload.rate_limit_per_minute {
query = query.bind(rate);
}
query = query.bind(&id);
match query.execute(pool).await {
Ok(result) => {
if result.rows_affected() == 0 {
return Json(ApiResponse::error(format!("Client '{}' not found", id)));
}
// Return the updated client
let row = sqlx::query(
r#"
SELECT client_id as id, name, description, is_active, rate_limit_per_minute,
created_at, total_requests, total_tokens, total_cost
FROM clients WHERE client_id = ?
"#,
)
.bind(&id)
.fetch_one(pool)
.await;
match row {
Ok(row) => Json(ApiResponse::success(serde_json::json!({
"id": row.get::<String, _>("id"),
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "Unnamed".to_string()),
"description": row.get::<Option<String>, _>("description"),
"is_active": row.get::<bool, _>("is_active"),
"rate_limit_per_minute": row.get::<Option<i64>, _>("rate_limit_per_minute"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
"total_requests": row.get::<i64, _>("total_requests"),
"total_tokens": row.get::<i64, _>("total_tokens"),
"total_cost": row.get::<f64, _>("total_cost"),
"status": if row.get::<bool, _>("is_active") { "active" } else { "inactive" },
}))),
Err(e) => {
warn!("Failed to fetch updated client: {}", e);
// Update succeeded but fetch failed — still report success
Json(ApiResponse::success(serde_json::json!({ "message": "Client updated" })))
}
}
}
Err(e) => {
warn!("Failed to update client: {}", e);
Json(ApiResponse::error(format!("Failed to update client: {}", e)))
}
}
}
pub(super) async fn handle_delete_client(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(id): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Don't allow deleting the default client
if id == "default" {
return Json(ApiResponse::error("Cannot delete default client".to_string()));
}
let result = sqlx::query("DELETE FROM clients WHERE client_id = ?")
.bind(id)
.execute(pool)
.await;
match result {
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "Client deleted" }))),
Err(e) => Json(ApiResponse::error(format!("Failed to delete client: {}", e))),
}
}
pub(super) async fn handle_client_usage(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(id): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Get per-model breakdown for this client
let result = sqlx::query(
r#"
SELECT
model,
provider,
COUNT(*) as request_count,
SUM(prompt_tokens) as prompt_tokens,
SUM(completion_tokens) as completion_tokens,
SUM(total_tokens) as total_tokens,
SUM(cost) as total_cost,
AVG(duration_ms) as avg_duration_ms
FROM llm_requests
WHERE client_id = ?
GROUP BY model, provider
ORDER BY total_cost DESC
"#,
)
.bind(&id)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let breakdown: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
serde_json::json!({
"model": row.get::<String, _>("model"),
"provider": row.get::<String, _>("provider"),
"request_count": row.get::<i64, _>("request_count"),
"prompt_tokens": row.get::<i64, _>("prompt_tokens"),
"completion_tokens": row.get::<i64, _>("completion_tokens"),
"total_tokens": row.get::<i64, _>("total_tokens"),
"total_cost": row.get::<f64, _>("total_cost"),
"avg_duration_ms": row.get::<f64, _>("avg_duration_ms"),
})
})
.collect();
Json(ApiResponse::success(serde_json::json!({
"client_id": id,
"breakdown": breakdown,
})))
}
Err(e) => {
warn!("Failed to fetch client usage: {}", e);
Json(ApiResponse::error(format!("Failed to fetch client usage: {}", e)))
}
}
}
// ── Token management endpoints ──────────────────────────────────────
pub(super) async fn handle_get_client_tokens(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(id): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT id, token, name, is_active, created_at, last_used_at
FROM client_tokens
WHERE client_id = ?
ORDER BY created_at DESC
"#,
)
.bind(&id)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let tokens: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
let token: String = row.get("token");
// Mask all but last 8 chars: sk-••••abcd1234
let masked = if token.len() > 8 {
format!("{}••••{}", &token[..3], &token[token.len() - 8..])
} else {
"••••".to_string()
};
serde_json::json!({
"id": row.get::<i64, _>("id"),
"token_masked": masked,
"name": row.get::<Option<String>, _>("name").unwrap_or_else(|| "default".to_string()),
"is_active": row.get::<bool, _>("is_active"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
"last_used_at": row.get::<Option<chrono::DateTime<chrono::Utc>>, _>("last_used_at"),
})
})
.collect();
Json(ApiResponse::success(serde_json::json!(tokens)))
}
Err(e) => {
warn!("Failed to fetch client tokens: {}", e);
Json(ApiResponse::error(format!("Failed to fetch client tokens: {}", e)))
}
}
}
#[derive(Deserialize)]
pub(super) struct CreateTokenRequest {
pub(super) name: Option<String>,
}
pub(super) async fn handle_create_client_token(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(id): Path<String>,
Json(payload): Json<CreateTokenRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Verify client exists
let exists: Option<(i64,)> = sqlx::query_as("SELECT 1 as x FROM clients WHERE client_id = ?")
.bind(&id)
.fetch_optional(pool)
.await
.unwrap_or(None);
if exists.is_none() {
return Json(ApiResponse::error(format!("Client '{}' not found", id)));
}
let token = generate_token();
let token_name = payload.name.unwrap_or_else(|| "default".to_string());
let result = sqlx::query(
"INSERT INTO client_tokens (client_id, token, name) VALUES (?, ?, ?) RETURNING id, created_at",
)
.bind(&id)
.bind(&token)
.bind(&token_name)
.fetch_one(pool)
.await;
match result {
Ok(row) => Json(ApiResponse::success(serde_json::json!({
"id": row.get::<i64, _>("id"),
"token": token,
"name": token_name,
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
}))),
Err(e) => {
warn!("Failed to create client token: {}", e);
Json(ApiResponse::error(format!("Failed to create token: {}", e)))
}
}
}
pub(super) async fn handle_delete_client_token(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path((client_id, token_id)): Path<(String, i64)>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let result = sqlx::query("DELETE FROM client_tokens WHERE id = ? AND client_id = ?")
.bind(token_id)
.bind(&client_id)
.execute(pool)
.await;
match result {
Ok(r) => {
if r.rows_affected() == 0 {
Json(ApiResponse::error("Token not found".to_string()))
} else {
Json(ApiResponse::success(serde_json::json!({ "message": "Token revoked" })))
}
}
Err(e) => {
warn!("Failed to delete client token: {}", e);
Json(ApiResponse::error(format!("Failed to revoke token: {}", e)))
}
}
}

View File

@@ -1,174 +0,0 @@
// Dashboard module for LLM Proxy Gateway
mod auth;
mod clients;
mod models;
mod providers;
pub mod sessions;
mod system;
mod usage;
mod users;
mod websocket;
use axum::{
extract::{Request, State},
middleware::Next,
response::Response,
Router,
routing::{delete, get, post, put},
};
use axum::http::{header, HeaderValue};
use serde::Serialize;
use tower_http::{
limit::RequestBodyLimitLayer,
set_header::SetResponseHeaderLayer,
};
use crate::state::AppState;
use sessions::SessionManager;
// Dashboard state
#[derive(Clone)]
struct DashboardState {
app_state: AppState,
session_manager: SessionManager,
}
// API Response types
#[derive(Serialize)]
struct ApiResponse<T> {
success: bool,
data: Option<T>,
error: Option<String>,
}
impl<T> ApiResponse<T> {
fn success(data: T) -> Self {
Self {
success: true,
data: Some(data),
error: None,
}
}
fn error(error: String) -> Self {
Self {
success: false,
data: None,
error: Some(error),
}
}
}
/// Rate limiting middleware for dashboard routes
async fn dashboard_rate_limit_middleware(
State(_dashboard_state): State<DashboardState>,
request: Request,
next: Next,
) -> Result<Response, crate::errors::AppError> {
// Bypass rate limiting for dashboard routes to prevent "Failed to load statistics"
// when the UI makes many concurrent requests on load.
// Dashboard endpoints are already secured via auth::require_admin.
Ok(next.run(request).await)
}
// Dashboard routes
pub fn router(state: AppState) -> Router {
let session_manager = SessionManager::new(24); // 24-hour session TTL
let dashboard_state = DashboardState {
app_state: state,
session_manager,
};
// Security headers
let csp_header: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::CONTENT_SECURITY_POLICY,
"default-src 'self'; script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; style-src 'self' 'unsafe-inline' https://cdnjs.cloudflare.com https://fonts.googleapis.com; font-src 'self' https://cdnjs.cloudflare.com https://fonts.gstatic.com; img-src 'self' data:; connect-src 'self' ws:;"
.parse()
.unwrap(),
);
let x_frame_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_FRAME_OPTIONS,
"DENY".parse().unwrap(),
);
let x_content_type_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_CONTENT_TYPE_OPTIONS,
"nosniff".parse().unwrap(),
);
let strict_transport_security: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::STRICT_TRANSPORT_SECURITY,
"max-age=31536000; includeSubDomains".parse().unwrap(),
);
Router::new()
// Static file serving
.fallback_service(tower_http::services::ServeDir::new("static"))
// WebSocket endpoint
.route("/ws", get(websocket::handle_websocket))
// API endpoints
.route("/api/auth/login", post(auth::handle_login))
.route("/api/auth/status", get(auth::handle_auth_status))
.route("/api/auth/logout", post(auth::handle_logout))
.route("/api/auth/change-password", post(auth::handle_change_password))
.route(
"/api/users",
get(users::handle_get_users).post(users::handle_create_user),
)
.route(
"/api/users/{id}",
put(users::handle_update_user).delete(users::handle_delete_user),
)
.route("/api/usage/summary", get(usage::handle_usage_summary))
.route("/api/usage/time-series", get(usage::handle_time_series))
.route("/api/usage/clients", get(usage::handle_clients_usage))
.route("/api/usage/providers", get(usage::handle_providers_usage))
.route("/api/usage/detailed", get(usage::handle_detailed_usage))
.route("/api/analytics/breakdown", get(usage::handle_analytics_breakdown))
.route("/api/models", get(models::handle_get_models))
.route("/api/models/{id}", put(models::handle_update_model))
.route(
"/api/clients",
get(clients::handle_get_clients).post(clients::handle_create_client),
)
.route(
"/api/clients/{id}",
get(clients::handle_get_client)
.put(clients::handle_update_client)
.delete(clients::handle_delete_client),
)
.route("/api/clients/{id}/usage", get(clients::handle_client_usage))
.route(
"/api/clients/{id}/tokens",
get(clients::handle_get_client_tokens).post(clients::handle_create_client_token),
)
.route(
"/api/clients/{id}/tokens/{token_id}",
delete(clients::handle_delete_client_token),
)
.route("/api/providers", get(providers::handle_get_providers))
.route(
"/api/providers/{name}",
get(providers::handle_get_provider).put(providers::handle_update_provider),
)
.route("/api/providers/{name}/test", post(providers::handle_test_provider))
.route("/api/system/health", get(system::handle_system_health))
.route("/api/system/metrics", get(system::handle_system_metrics))
.route("/api/system/logs", get(system::handle_system_logs))
.route("/api/system/backup", post(system::handle_system_backup))
.route(
"/api/system/settings",
get(system::handle_get_settings).post(system::handle_update_settings),
)
// Security layers
.layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit
.layer(csp_header)
.layer(x_frame_options)
.layer(x_content_type_options)
.layer(strict_transport_security)
// Rate limiting middleware
.layer(axum::middleware::from_fn_with_state(
dashboard_state.clone(),
dashboard_rate_limit_middleware,
))
.with_state(dashboard_state)
}

View File

@@ -1,254 +0,0 @@
use axum::{
extract::{Path, Query, State},
response::Json,
};
use serde::Deserialize;
use serde_json;
use sqlx::Row;
use std::collections::HashMap;
use super::{ApiResponse, DashboardState};
use crate::models::registry::{ModelFilter, ModelSortBy, SortOrder};
#[derive(Deserialize)]
pub(super) struct UpdateModelRequest {
pub(super) enabled: bool,
pub(super) prompt_cost: Option<f64>,
pub(super) completion_cost: Option<f64>,
pub(super) mapping: Option<String>,
}
/// Query parameters for `GET /api/models`.
#[derive(Debug, Deserialize, Default)]
pub(super) struct ModelListParams {
/// Filter by provider ID.
pub provider: Option<String>,
/// Text search on model ID or name.
pub search: Option<String>,
/// Filter by input modality (e.g. "image").
pub modality: Option<String>,
/// Only models that support tool calling.
pub tool_call: Option<bool>,
/// Only models that support reasoning.
pub reasoning: Option<bool>,
/// Only models that have pricing data.
pub has_cost: Option<bool>,
/// Only models that have been used in requests.
pub used_only: Option<bool>,
/// Sort field (name, id, provider, context_limit, input_cost, output_cost).
pub sort_by: Option<ModelSortBy>,
/// Sort direction (asc, desc).
pub sort_order: Option<SortOrder>,
}
pub(super) async fn handle_get_models(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Query(params): Query<ModelListParams>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let registry = &state.app_state.model_registry;
let pool = &state.app_state.db_pool;
// Load overrides from database
let db_models_result =
sqlx::query("SELECT id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping FROM model_configs")
.fetch_all(pool)
.await;
let mut db_models = HashMap::new();
if let Ok(rows) = db_models_result {
for row in rows {
let id: String = row.get("id");
db_models.insert(id, row);
}
}
let mut models_json = Vec::new();
if params.used_only.unwrap_or(false) {
// EXACT USED MODELS LOGIC
let used_pairs_result = sqlx::query(
"SELECT DISTINCT provider, model FROM llm_requests",
)
.fetch_all(pool)
.await;
if let Ok(rows) = used_pairs_result {
for row in rows {
let provider: String = row.get("provider");
let m_key: String = row.get("model");
let provider_name = match provider.as_str() {
"openai" => "OpenAI",
"gemini" => "Google Gemini",
"deepseek" => "DeepSeek",
"grok" => "xAI Grok",
"ollama" => "Ollama",
_ => provider.as_str(),
}.to_string();
let m_meta = registry.find_model(&m_key);
let mut enabled = true;
let mut prompt_cost = m_meta.and_then(|m| m.cost.as_ref().map(|c| c.input)).unwrap_or(0.0);
let mut completion_cost = m_meta.and_then(|m| m.cost.as_ref().map(|c| c.output)).unwrap_or(0.0);
let cache_read_cost = m_meta.and_then(|m| m.cost.as_ref().and_then(|c| c.cache_read));
let cache_write_cost = m_meta.and_then(|m| m.cost.as_ref().and_then(|c| c.cache_write));
let mut mapping = None::<String>;
if let Some(db_row) = db_models.get(&m_key) {
enabled = db_row.get("enabled");
if let Some(p) = db_row.get::<Option<f64>, _>("prompt_cost_per_m") {
prompt_cost = p;
}
if let Some(c) = db_row.get::<Option<f64>, _>("completion_cost_per_m") {
completion_cost = c;
}
mapping = db_row.get("mapping");
}
models_json.push(serde_json::json!({
"id": m_key,
"provider": provider,
"provider_name": provider_name,
"name": m_meta.map(|m| m.name.clone()).unwrap_or_else(|| m_key.clone()),
"enabled": enabled,
"prompt_cost": prompt_cost,
"completion_cost": completion_cost,
"cache_read_cost": cache_read_cost,
"cache_write_cost": cache_write_cost,
"mapping": mapping,
"context_limit": m_meta.and_then(|m| m.limit.as_ref().map(|l| l.context)).unwrap_or(0),
"output_limit": m_meta.and_then(|m| m.limit.as_ref().map(|l| l.output)).unwrap_or(0),
"modalities": m_meta.and_then(|m| m.modalities.as_ref().map(|mo| serde_json::json!({
"input": mo.input,
"output": mo.output,
}))),
"tool_call": m_meta.and_then(|m| m.tool_call),
"reasoning": m_meta.and_then(|m| m.reasoning),
}));
}
}
} else {
// REGISTRY LISTING LOGIC
// Build filter from query params
let filter = ModelFilter {
provider: params.provider,
search: params.search,
modality: params.modality,
tool_call: params.tool_call,
reasoning: params.reasoning,
has_cost: params.has_cost,
};
let sort_by = params.sort_by.unwrap_or_default();
let sort_order = params.sort_order.unwrap_or_default();
// Get filtered and sorted model entries
let entries = registry.list_models(&filter, &sort_by, &sort_order);
for entry in &entries {
let m_key = entry.model_key;
let m_meta = entry.metadata;
let mut enabled = true;
let mut prompt_cost = m_meta.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
let mut completion_cost = m_meta.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
let cache_read_cost = m_meta.cost.as_ref().and_then(|c| c.cache_read);
let cache_write_cost = m_meta.cost.as_ref().and_then(|c| c.cache_write);
let mut mapping = None::<String>;
if let Some(row) = db_models.get(m_key) {
enabled = row.get("enabled");
if let Some(p) = row.get::<Option<f64>, _>("prompt_cost_per_m") {
prompt_cost = p;
}
if let Some(c) = row.get::<Option<f64>, _>("completion_cost_per_m") {
completion_cost = c;
}
mapping = row.get("mapping");
}
models_json.push(serde_json::json!({
"id": m_key,
"provider": entry.provider_id,
"provider_name": entry.provider_name,
"name": m_meta.name,
"enabled": enabled,
"prompt_cost": prompt_cost,
"completion_cost": completion_cost,
"cache_read_cost": cache_read_cost,
"cache_write_cost": cache_write_cost,
"mapping": mapping,
"context_limit": m_meta.limit.as_ref().map(|l| l.context).unwrap_or(0),
"output_limit": m_meta.limit.as_ref().map(|l| l.output).unwrap_or(0),
"modalities": m_meta.modalities.as_ref().map(|m| serde_json::json!({
"input": m.input,
"output": m.output,
})),
"tool_call": m_meta.tool_call,
"reasoning": m_meta.reasoning,
}));
}
}
Json(ApiResponse::success(serde_json::json!(models_json)))
}
pub(super) async fn handle_update_model(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(id): Path<String>,
Json(payload): Json<UpdateModelRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Find provider_id for this model in registry
let provider_id = state
.app_state
.model_registry
.providers
.iter()
.find(|(_, p)| p.models.contains_key(&id))
.map(|(id, _)| id.clone())
.unwrap_or_else(|| "unknown".to_string());
let result = sqlx::query(
r#"
INSERT INTO model_configs (id, provider_id, enabled, prompt_cost_per_m, completion_cost_per_m, mapping)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
enabled = excluded.enabled,
prompt_cost_per_m = excluded.prompt_cost_per_m,
completion_cost_per_m = excluded.completion_cost_per_m,
mapping = excluded.mapping,
updated_at = CURRENT_TIMESTAMP
"#,
)
.bind(&id)
.bind(provider_id)
.bind(payload.enabled)
.bind(payload.prompt_cost)
.bind(payload.completion_cost)
.bind(payload.mapping)
.execute(pool)
.await;
match result {
Ok(_) => {
// Invalidate the in-memory cache so the proxy picks up the change immediately
state.app_state.model_config_cache.invalidate().await;
Json(ApiResponse::success(serde_json::json!({ "message": "Model updated" })))
}
Err(e) => Json(ApiResponse::error(format!("Failed to update model: {}", e))),
}
}

View File

@@ -1,440 +0,0 @@
use axum::{
extract::{Path, State},
response::Json,
};
use serde::Deserialize;
use serde_json;
use sqlx::Row;
use std::collections::HashMap;
use tracing::warn;
use super::{ApiResponse, DashboardState};
use crate::utils::crypto;
#[derive(Deserialize)]
pub(super) struct UpdateProviderRequest {
pub(super) enabled: bool,
pub(super) base_url: Option<String>,
pub(super) api_key: Option<String>,
pub(super) credit_balance: Option<f64>,
pub(super) low_credit_threshold: Option<f64>,
pub(super) billing_mode: Option<String>,
}
pub(super) async fn handle_get_providers(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let registry = &state.app_state.model_registry;
let config = &state.app_state.config;
let pool = &state.app_state.db_pool;
// Load all overrides from database (including billing_mode)
let db_configs_result = sqlx::query(
"SELECT id, enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs",
)
.fetch_all(pool)
.await;
let mut db_configs = HashMap::new();
if let Ok(rows) = db_configs_result {
for row in rows {
let id: String = row.get("id");
let enabled: bool = row.get("enabled");
let base_url: Option<String> = row.get("base_url");
let balance: f64 = row.get("credit_balance");
let threshold: f64 = row.get("low_credit_threshold");
let billing_mode: Option<String> = row.get("billing_mode");
db_configs.insert(id, (enabled, base_url, balance, threshold, billing_mode));
}
}
let mut providers_json = Vec::new();
// Define the list of providers we support
let provider_ids = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
for id in provider_ids {
// Get base config
let (mut enabled, mut base_url, display_name) = match id {
"openai" => (
config.providers.openai.enabled,
config.providers.openai.base_url.clone(),
"OpenAI",
),
"gemini" => (
config.providers.gemini.enabled,
config.providers.gemini.base_url.clone(),
"Google Gemini",
),
"deepseek" => (
config.providers.deepseek.enabled,
config.providers.deepseek.base_url.clone(),
"DeepSeek",
),
"grok" => (
config.providers.grok.enabled,
config.providers.grok.base_url.clone(),
"xAI Grok",
),
"ollama" => (
config.providers.ollama.enabled,
config.providers.ollama.base_url.clone(),
"Ollama",
),
_ => (false, "".to_string(), "Unknown"),
};
let mut balance = 0.0;
let mut threshold = 5.0;
let mut billing_mode: Option<String> = None;
// Apply database overrides
if let Some((db_enabled, db_url, db_balance, db_threshold, db_billing)) = db_configs.get(id) {
enabled = *db_enabled;
if let Some(url) = db_url {
base_url = url.clone();
}
balance = *db_balance;
threshold = *db_threshold;
billing_mode = db_billing.clone();
}
// Find models for this provider in registry
// NOTE: registry provider IDs differ from internal IDs for some providers.
let registry_key = match id {
"gemini" => "google",
"grok" => "xai",
_ => id,
};
let mut models = Vec::new();
if let Some(p_info) = registry.providers.get(registry_key) {
models = p_info.models.keys().cloned().collect();
} else if id == "ollama" {
models = config.providers.ollama.models.clone();
}
// Determine status
let status = if !enabled {
"disabled"
} else {
// Check if it's actually initialized in the provider manager
if state.app_state.provider_manager.get_provider(id).await.is_some() {
// Check circuit breaker
if state
.app_state
.rate_limit_manager
.check_provider_request(id)
.await
.unwrap_or(true)
{
"online"
} else {
"degraded"
}
} else {
"error" // Enabled but failed to initialize (e.g. missing API key)
}
};
providers_json.push(serde_json::json!({
"id": id,
"name": display_name,
"enabled": enabled,
"status": status,
"models": models,
"base_url": base_url,
"credit_balance": balance,
"low_credit_threshold": threshold,
"billing_mode": billing_mode,
"last_used": None::<String>,
}));
}
Json(ApiResponse::success(serde_json::json!(providers_json)))
}
pub(super) async fn handle_get_provider(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(name): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let registry = &state.app_state.model_registry;
let config = &state.app_state.config;
let pool = &state.app_state.db_pool;
// Validate provider name
let (mut enabled, mut base_url, display_name) = match name.as_str() {
"openai" => (
config.providers.openai.enabled,
config.providers.openai.base_url.clone(),
"OpenAI",
),
"gemini" => (
config.providers.gemini.enabled,
config.providers.gemini.base_url.clone(),
"Google Gemini",
),
"deepseek" => (
config.providers.deepseek.enabled,
config.providers.deepseek.base_url.clone(),
"DeepSeek",
),
"grok" => (
config.providers.grok.enabled,
config.providers.grok.base_url.clone(),
"xAI Grok",
),
"ollama" => (
config.providers.ollama.enabled,
config.providers.ollama.base_url.clone(),
"Ollama",
),
_ => return Json(ApiResponse::error(format!("Unknown provider '{}'", name))),
};
let mut balance = 0.0;
let mut threshold = 5.0;
let mut billing_mode: Option<String> = None;
// Apply database overrides
let db_config = sqlx::query(
"SELECT enabled, base_url, credit_balance, low_credit_threshold, billing_mode FROM provider_configs WHERE id = ?",
)
.bind(&name)
.fetch_optional(pool)
.await;
if let Ok(Some(row)) = db_config {
enabled = row.get::<bool, _>("enabled");
if let Some(url) = row.get::<Option<String>, _>("base_url") {
base_url = url;
}
balance = row.get::<f64, _>("credit_balance");
threshold = row.get::<f64, _>("low_credit_threshold");
billing_mode = row.get::<Option<String>, _>("billing_mode");
}
// Find models for this provider
// NOTE: registry provider IDs differ from internal IDs for some providers.
let registry_key = match name.as_str() {
"gemini" => "google",
"grok" => "xai",
_ => name.as_str(),
};
let mut models = Vec::new();
if let Some(p_info) = registry.providers.get(registry_key) {
models = p_info.models.keys().cloned().collect();
} else if name == "ollama" {
models = config.providers.ollama.models.clone();
}
// Determine status
let status = if !enabled {
"disabled"
} else if state.app_state.provider_manager.get_provider(&name).await.is_some() {
if state
.app_state
.rate_limit_manager
.check_provider_request(&name)
.await
.unwrap_or(true)
{
"online"
} else {
"degraded"
}
} else {
"error"
};
Json(ApiResponse::success(serde_json::json!({
"id": name,
"name": display_name,
"enabled": enabled,
"status": status,
"models": models,
"base_url": base_url,
"credit_balance": balance,
"low_credit_threshold": threshold,
"billing_mode": billing_mode,
"last_used": None::<String>,
})))
}
pub(super) async fn handle_update_provider(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(name): Path<String>,
Json(payload): Json<UpdateProviderRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Prepare API key encryption if provided
let (api_key_to_store, api_key_encrypted_flag) = match &payload.api_key {
Some(key) if !key.is_empty() => {
match crypto::encrypt(key) {
Ok(encrypted) => (Some(encrypted), Some(true)),
Err(e) => {
warn!("Failed to encrypt API key for provider {}: {}", name, e);
return Json(ApiResponse::error(format!("Failed to encrypt API key: {}", e)));
}
}
}
Some(_) => {
// Empty string means clear the key
(None, Some(false))
}
None => {
// Keep existing key, we'll rely on COALESCE in SQL
(None, None)
}
};
// Update or insert into database (include billing_mode and api_key_encrypted)
let result = sqlx::query(
r#"
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold, billing_mode)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
enabled = excluded.enabled,
base_url = excluded.base_url,
api_key = COALESCE(excluded.api_key, provider_configs.api_key),
api_key_encrypted = COALESCE(excluded.api_key_encrypted, provider_configs.api_key_encrypted),
credit_balance = COALESCE(excluded.credit_balance, provider_configs.credit_balance),
low_credit_threshold = COALESCE(excluded.low_credit_threshold, provider_configs.low_credit_threshold),
billing_mode = COALESCE(excluded.billing_mode, provider_configs.billing_mode),
updated_at = CURRENT_TIMESTAMP
"#,
)
.bind(&name)
.bind(name.to_uppercase())
.bind(payload.enabled)
.bind(&payload.base_url)
.bind(&api_key_to_store)
.bind(api_key_encrypted_flag)
.bind(payload.credit_balance)
.bind(payload.low_credit_threshold)
.bind(payload.billing_mode)
.execute(pool)
.await;
match result {
Ok(_) => {
// Re-initialize provider in manager
if let Err(e) = state
.app_state
.provider_manager
.initialize_provider(&name, &state.app_state.config, &state.app_state.db_pool)
.await
{
warn!("Failed to re-initialize provider {}: {}", name, e);
return Json(ApiResponse::error(format!(
"Provider settings saved but initialization failed: {}",
e
)));
}
Json(ApiResponse::success(
serde_json::json!({ "message": "Provider updated and re-initialized" }),
))
}
Err(e) => {
warn!("Failed to update provider config: {}", e);
Json(ApiResponse::error(format!("Failed to update provider: {}", e)))
}
}
}
pub(super) async fn handle_test_provider(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(name): Path<String>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let start = std::time::Instant::now();
let provider = match state.app_state.provider_manager.get_provider(&name).await {
Some(p) => p,
None => {
return Json(ApiResponse::error(format!(
"Provider '{}' not found or not enabled",
name
)));
}
};
// Pick a real model for this provider from the registry
// NOTE: registry provider IDs differ from internal IDs for some providers.
let registry_key = match name.as_str() {
"gemini" => "google",
"grok" => "xai",
_ => name.as_str(),
};
let test_model = state
.app_state
.model_registry
.providers
.get(registry_key)
.and_then(|p| p.models.keys().next().cloned())
.unwrap_or_else(|| name.clone());
let test_request = crate::models::UnifiedRequest {
client_id: "system-test".to_string(),
model: test_model,
messages: vec![crate::models::UnifiedMessage {
role: "user".to_string(),
content: vec![crate::models::ContentPart::Text { text: "Hi".to_string() }],
reasoning_content: None,
tool_calls: None,
name: None,
tool_call_id: None,
}],
temperature: None,
top_p: None,
top_k: None,
n: None,
stop: None,
max_tokens: Some(5),
presence_penalty: None,
frequency_penalty: None,
stream: false,
has_images: false,
tools: None,
tool_choice: None,
};
match provider.chat_completion(test_request).await {
Ok(_) => {
let latency = start.elapsed().as_millis();
Json(ApiResponse::success(serde_json::json!({
"success": true,
"latency": latency,
"message": "Connection test successful"
})))
}
Err(e) => Json(ApiResponse::error(format!("Provider test failed: {}", e))),
}
}

View File

@@ -1,311 +0,0 @@
use chrono::{DateTime, Duration, Utc};
use hmac::{Hmac, Mac};
use serde::{Deserialize, Serialize};
use sha2::{Sha256, digest::generic_array::GenericArray};
use std::collections::HashMap;
use std::env;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
use base64::{engine::general_purpose::URL_SAFE, Engine as _};
const TOKEN_VERSION: &str = "v2";
const REFRESH_WINDOW_MINUTES: i64 = 15; // refresh if token expires within 15 minutes
#[derive(Clone, Debug)]
pub struct Session {
pub username: String,
pub role: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub session_id: String, // unique identifier for the session (UUID)
}
#[derive(Clone)]
pub struct SessionManager {
sessions: Arc<RwLock<HashMap<String, Session>>>, // key = session_id
ttl_hours: i64,
secret: Vec<u8>,
}
#[derive(Debug, Serialize, Deserialize)]
struct SessionPayload {
session_id: String,
username: String,
role: String,
iat: i64, // issued at (Unix timestamp)
exp: i64, // expiry (Unix timestamp)
version: String,
}
impl SessionManager {
pub fn new(ttl_hours: i64) -> Self {
let secret = load_session_secret();
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
ttl_hours,
secret,
}
}
/// Create a new session and return a signed session token.
pub async fn create_session(&self, username: String, role: String) -> String {
let session_id = Uuid::new_v4().to_string();
let now = Utc::now();
let expires_at = now + Duration::hours(self.ttl_hours);
let session = Session {
username: username.clone(),
role: role.clone(),
created_at: now,
expires_at,
session_id: session_id.clone(),
};
// Store session by session_id
self.sessions.write().await.insert(session_id.clone(), session);
// Create signed token
self.create_signed_token(&session_id, &username, &role, now.timestamp(), expires_at.timestamp())
}
/// Validate a session token and return the session if valid and not expired.
/// If the token is within the refresh window, returns a new token as well.
pub async fn validate_session(&self, token: &str) -> Option<Session> {
self.validate_session_with_refresh(token).await.map(|(session, _)| session)
}
/// Validate a session token and return (session, optional new token if refreshed).
pub async fn validate_session_with_refresh(&self, token: &str) -> Option<(Session, Option<String>)> {
// Legacy token format (UUID)
if token.starts_with("session-") {
let sessions = self.sessions.read().await;
return sessions.get(token).and_then(|s| {
if s.expires_at > Utc::now() {
Some((s.clone(), None))
} else {
None
}
});
}
// Signed token format
let payload = match verify_signed_token(token, &self.secret) {
Ok(p) => p,
Err(_) => return None,
};
// Check expiry
let now = Utc::now().timestamp();
if payload.exp <= now {
return None;
}
// Look up session by session_id
let sessions = self.sessions.read().await;
let session = match sessions.get(&payload.session_id) {
Some(s) => s.clone(),
None => return None, // session revoked or not found
};
// Ensure session username/role matches (should always match)
if session.username != payload.username || session.role != payload.role {
return None;
}
// Check if token is within refresh window (last REFRESH_WINDOW_MINUTES of validity)
let refresh_threshold = payload.exp - REFRESH_WINDOW_MINUTES * 60;
let new_token = if now >= refresh_threshold {
// Generate a new token with same session data but updated iat/exp?
// According to activity-based refresh, we should extend the session expiry.
// We'll extend from now by ttl_hours (or keep original expiry?).
// Let's extend from now by ttl_hours (sliding window).
let new_exp = Utc::now() + Duration::hours(self.ttl_hours);
// Update session expiry in store
drop(sessions); // release read lock before acquiring write lock
self.update_session_expiry(&payload.session_id, new_exp).await;
// Create new token with updated iat/exp
let new_token = self.create_signed_token(
&payload.session_id,
&payload.username,
&payload.role,
now,
new_exp.timestamp(),
);
Some(new_token)
} else {
None
};
Some((session, new_token))
}
/// Revoke (delete) a session by token.
/// Supports both legacy tokens (token is key) and signed tokens (extract session_id).
pub async fn revoke_session(&self, token: &str) {
if token.starts_with("session-") {
self.sessions.write().await.remove(token);
return;
}
// For signed token, try to extract session_id
if let Ok(payload) = verify_signed_token(token, &self.secret) {
self.sessions.write().await.remove(&payload.session_id);
}
}
/// Remove all expired sessions from the store.
pub async fn cleanup_expired(&self) {
let now = Utc::now();
self.sessions.write().await.retain(|_, s| s.expires_at > now);
}
// --- Private helpers ---
fn create_signed_token(&self, session_id: &str, username: &str, role: &str, iat: i64, exp: i64) -> String {
let payload = SessionPayload {
session_id: session_id.to_string(),
username: username.to_string(),
role: role.to_string(),
iat,
exp,
version: TOKEN_VERSION.to_string(),
};
sign_token(&payload, &self.secret)
}
async fn update_session_expiry(&self, session_id: &str, new_expires_at: DateTime<Utc>) {
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(session_id) {
session.expires_at = new_expires_at;
}
}
}
/// Load session secret from environment variable SESSION_SECRET (hex or base64 encoded).
/// If not set, generates a random 32-byte secret and logs a warning.
fn load_session_secret() -> Vec<u8> {
let secret_str = env::var("SESSION_SECRET").unwrap_or_else(|_| {
// Also check LLM_PROXY__SESSION_SECRET for consistency with config prefix
env::var("LLM_PROXY__SESSION_SECRET").unwrap_or_else(|_| {
// Generate a random secret (32 bytes) and encode as hex
use rand::RngCore;
let mut bytes = [0u8; 32];
rand::rng().fill_bytes(&mut bytes);
let hex_secret = hex::encode(bytes);
tracing::warn!(
"SESSION_SECRET environment variable not set. Using a randomly generated secret. \
This will invalidate all sessions on restart. Set SESSION_SECRET to a fixed hex or base64 encoded 32-byte value."
);
hex_secret
})
});
// Decode hex or base64
hex::decode(&secret_str)
.or_else(|_| URL_SAFE.decode(&secret_str))
.or_else(|_| base64::engine::general_purpose::STANDARD.decode(&secret_str))
.unwrap_or_else(|_| {
panic!("SESSION_SECRET must be hex or base64 encoded (32 bytes)");
})
}
/// Sign a session payload and return a token string in format base64_url(payload).base64_url(signature).
fn sign_token(payload: &SessionPayload, secret: &[u8]) -> String {
let json = serde_json::to_vec(payload).expect("Failed to serialize payload");
let payload_b64 = URL_SAFE.encode(&json);
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
mac.update(&json);
let signature = mac.finalize().into_bytes();
let signature_b64 = URL_SAFE.encode(signature);
format!("{}.{}", payload_b64, signature_b64)
}
/// Verify a signed token and return the decoded payload if valid.
fn verify_signed_token(token: &str, secret: &[u8]) -> Result<SessionPayload, TokenError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 2 {
return Err(TokenError::InvalidFormat);
}
let payload_b64 = parts[0];
let signature_b64 = parts[1];
let json = URL_SAFE.decode(payload_b64).map_err(|_| TokenError::InvalidFormat)?;
let signature = URL_SAFE.decode(signature_b64).map_err(|_| TokenError::InvalidFormat)?;
// Verify HMAC
let mut mac = Hmac::<Sha256>::new_from_slice(secret).expect("HMAC can take key of any size");
mac.update(&json);
// Convert signature slice to GenericArray
let tag = GenericArray::from_slice(&signature);
mac.verify(tag).map_err(|_| TokenError::InvalidSignature)?;
// Deserialize payload
let payload: SessionPayload = serde_json::from_slice(&json).map_err(|_| TokenError::InvalidPayload)?;
Ok(payload)
}
#[derive(Debug)]
enum TokenError {
InvalidFormat,
InvalidSignature,
InvalidPayload,
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[test]
fn test_sign_and_verify_token() {
let secret = b"test-secret-must-be-32-bytes-long!";
let payload = SessionPayload {
session_id: "test-session".to_string(),
username: "testuser".to_string(),
role: "user".to_string(),
iat: 1000,
exp: 2000,
version: TOKEN_VERSION.to_string(),
};
let token = sign_token(&payload, secret);
let verified = verify_signed_token(&token, secret).unwrap();
assert_eq!(verified.session_id, payload.session_id);
assert_eq!(verified.username, payload.username);
assert_eq!(verified.role, payload.role);
assert_eq!(verified.iat, payload.iat);
assert_eq!(verified.exp, payload.exp);
assert_eq!(verified.version, payload.version);
}
#[test]
fn test_tampered_token() {
let secret = b"test-secret-must-be-32-bytes-long!";
let payload = SessionPayload {
session_id: "test-session".to_string(),
username: "testuser".to_string(),
role: "user".to_string(),
iat: 1000,
exp: 2000,
version: TOKEN_VERSION.to_string(),
};
let mut token = sign_token(&payload, secret);
// Tamper with payload part
let mut parts: Vec<&str> = token.split('.').collect();
let mut payload_bytes = URL_SAFE.decode(parts[0]).unwrap();
payload_bytes[0] ^= 0xFF; // flip some bits
let tampered_payload = URL_SAFE.encode(payload_bytes);
parts[0] = &tampered_payload;
token = parts.join(".");
assert!(verify_signed_token(&token, secret).is_err());
}
#[test]
fn test_load_session_secret_from_env() {
unsafe {
env::set_var("SESSION_SECRET", hex::encode([0xAA; 32]));
}
let secret = load_session_secret();
assert_eq!(secret, vec![0xAA; 32]);
unsafe {
env::remove_var("SESSION_SECRET");
}
}
}

View File

@@ -1,405 +0,0 @@
use axum::{extract::State, response::Json};
use chrono;
use serde_json;
use sqlx::Row;
use std::collections::HashMap;
use tracing::warn;
use super::{ApiResponse, DashboardState};
/// Read a value from /proc files, returning None on any failure.
fn read_proc_file(path: &str) -> Option<String> {
std::fs::read_to_string(path).ok()
}
pub(super) async fn handle_system_health(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let mut components = HashMap::new();
components.insert("api_server".to_string(), "online".to_string());
components.insert("database".to_string(), "online".to_string());
// Check provider health via circuit breakers
let provider_ids: Vec<String> = state
.app_state
.provider_manager
.get_all_providers()
.await
.iter()
.map(|p| p.name().to_string())
.collect();
for p_id in provider_ids {
if state
.app_state
.rate_limit_manager
.check_provider_request(&p_id)
.await
.unwrap_or(true)
{
components.insert(p_id, "online".to_string());
} else {
components.insert(p_id, "degraded".to_string());
}
}
// Read real memory usage from /proc/self/status
let memory_mb = read_proc_file("/proc/self/status")
.and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string()))
.and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::<f64>().ok()))
.map(|kb| kb / 1024.0)
.unwrap_or(0.0);
// Get real database pool stats
let db_pool_size = state.app_state.db_pool.size();
let db_pool_idle = state.app_state.db_pool.num_idle();
Json(ApiResponse::success(serde_json::json!({
"status": "healthy",
"timestamp": chrono::Utc::now().to_rfc3339(),
"components": components,
"metrics": {
"memory_usage_mb": (memory_mb * 10.0).round() / 10.0,
"db_connections_active": db_pool_size - db_pool_idle as u32,
"db_connections_idle": db_pool_idle,
}
})))
}
/// Real system metrics from /proc (Linux only).
pub(super) async fn handle_system_metrics(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
// --- CPU usage (aggregate across all cores) ---
// /proc/stat first line: cpu user nice system idle iowait irq softirq steal guest guest_nice
let cpu_percent = read_proc_file("/proc/stat")
.and_then(|s| {
let line = s.lines().find(|l| l.starts_with("cpu "))?.to_string();
let fields: Vec<u64> = line
.split_whitespace()
.skip(1)
.filter_map(|v| v.parse().ok())
.collect();
if fields.len() >= 4 {
let idle = fields[3];
let total: u64 = fields.iter().sum();
if total > 0 {
Some(((total - idle) as f64 / total as f64 * 100.0 * 10.0).round() / 10.0)
} else {
None
}
} else {
None
}
})
.unwrap_or(0.0);
// --- Memory (system-wide from /proc/meminfo) ---
let meminfo = read_proc_file("/proc/meminfo").unwrap_or_default();
let parse_meminfo = |key: &str| -> u64 {
meminfo
.lines()
.find(|l| l.starts_with(key))
.and_then(|l| l.split_whitespace().nth(1))
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(0)
};
let mem_total_kb = parse_meminfo("MemTotal:");
let mem_available_kb = parse_meminfo("MemAvailable:");
let mem_used_kb = mem_total_kb.saturating_sub(mem_available_kb);
let mem_total_mb = mem_total_kb as f64 / 1024.0;
let mem_used_mb = mem_used_kb as f64 / 1024.0;
let mem_percent = if mem_total_kb > 0 {
(mem_used_kb as f64 / mem_total_kb as f64 * 100.0 * 10.0).round() / 10.0
} else {
0.0
};
// --- Process-specific memory (VmRSS) ---
let process_rss_mb = read_proc_file("/proc/self/status")
.and_then(|s| s.lines().find(|l| l.starts_with("VmRSS:")).map(|l| l.to_string()))
.and_then(|l| l.split_whitespace().nth(1).and_then(|v| v.parse::<f64>().ok()))
.map(|kb| (kb / 1024.0 * 10.0).round() / 10.0)
.unwrap_or(0.0);
// --- Disk usage of the data directory ---
let (disk_total_gb, disk_used_gb, disk_percent) = {
// statvfs via libc would be ideal; use df as a simple fallback
std::process::Command::new("df")
.args(["-BM", "--output=size,used,pcent", "."])
.output()
.ok()
.and_then(|o| {
let out = String::from_utf8_lossy(&o.stdout);
let line = out.lines().nth(1)?.to_string();
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 3 {
let total = parts[0].trim_end_matches('M').parse::<f64>().unwrap_or(0.0) / 1024.0;
let used = parts[1].trim_end_matches('M').parse::<f64>().unwrap_or(0.0) / 1024.0;
let pct = parts[2].trim_end_matches('%').parse::<f64>().unwrap_or(0.0);
Some(((total * 10.0).round() / 10.0, (used * 10.0).round() / 10.0, pct))
} else {
None
}
})
.unwrap_or((0.0, 0.0, 0.0))
};
// --- Uptime ---
let uptime_seconds = read_proc_file("/proc/uptime")
.and_then(|s| s.split_whitespace().next().and_then(|v| v.parse::<f64>().ok()))
.unwrap_or(0.0) as u64;
// --- Load average ---
let (load_1, load_5, load_15) = read_proc_file("/proc/loadavg")
.and_then(|s| {
let parts: Vec<&str> = s.split_whitespace().collect();
if parts.len() >= 3 {
Some((
parts[0].parse::<f64>().unwrap_or(0.0),
parts[1].parse::<f64>().unwrap_or(0.0),
parts[2].parse::<f64>().unwrap_or(0.0),
))
} else {
None
}
})
.unwrap_or((0.0, 0.0, 0.0));
// --- Network (from /proc/net/dev, aggregate non-lo interfaces) ---
let (net_rx_bytes, net_tx_bytes) = read_proc_file("/proc/net/dev")
.map(|s| {
s.lines()
.skip(2) // skip header lines
.filter(|l| !l.trim().starts_with("lo:"))
.fold((0u64, 0u64), |(rx, tx), line| {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 10 {
let r = parts[1].parse::<u64>().unwrap_or(0);
let t = parts[9].parse::<u64>().unwrap_or(0);
(rx + r, tx + t)
} else {
(rx, tx)
}
})
})
.unwrap_or((0, 0));
// --- Database pool ---
let db_pool_size = state.app_state.db_pool.size();
let db_pool_idle = state.app_state.db_pool.num_idle();
// --- Active WebSocket listeners ---
let ws_listeners = state.app_state.dashboard_tx.receiver_count();
Json(ApiResponse::success(serde_json::json!({
"cpu": {
"usage_percent": cpu_percent,
"load_average": [load_1, load_5, load_15],
},
"memory": {
"total_mb": (mem_total_mb * 10.0).round() / 10.0,
"used_mb": (mem_used_mb * 10.0).round() / 10.0,
"usage_percent": mem_percent,
"process_rss_mb": process_rss_mb,
},
"disk": {
"total_gb": disk_total_gb,
"used_gb": disk_used_gb,
"usage_percent": disk_percent,
},
"network": {
"rx_bytes": net_rx_bytes,
"tx_bytes": net_tx_bytes,
},
"uptime_seconds": uptime_seconds,
"connections": {
"db_active": db_pool_size - db_pool_idle as u32,
"db_idle": db_pool_idle,
"websocket_listeners": ws_listeners,
},
"timestamp": chrono::Utc::now().to_rfc3339(),
})))
}
pub(super) async fn handle_system_logs(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let result = sqlx::query(
r#"
SELECT
id,
timestamp,
client_id,
provider,
model,
prompt_tokens,
completion_tokens,
reasoning_tokens,
total_tokens,
cache_read_tokens,
cache_write_tokens,
cost,
status,
error_message,
duration_ms
FROM llm_requests
ORDER BY timestamp DESC
LIMIT 100
"#,
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let logs: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
serde_json::json!({
"id": row.get::<i64, _>("id"),
"timestamp": row.get::<chrono::DateTime<chrono::Utc>, _>("timestamp"),
"client_id": row.get::<String, _>("client_id"),
"provider": row.get::<String, _>("provider"),
"model": row.get::<String, _>("model"),
"prompt_tokens": row.get::<i64, _>("prompt_tokens"),
"completion_tokens": row.get::<i64, _>("completion_tokens"),
"reasoning_tokens": row.get::<i64, _>("reasoning_tokens"),
"cache_read_tokens": row.get::<i64, _>("cache_read_tokens"),
"cache_write_tokens": row.get::<i64, _>("cache_write_tokens"),
"tokens": row.get::<i64, _>("total_tokens"),
"cost": row.get::<f64, _>("cost"),
"status": row.get::<String, _>("status"),
"error": row.get::<Option<String>, _>("error_message"),
"duration": row.get::<i64, _>("duration_ms"),
})
})
.collect();
Json(ApiResponse::success(serde_json::json!(logs)))
}
Err(e) => {
warn!("Failed to fetch system logs: {}", e);
Json(ApiResponse::error("Failed to fetch system logs".to_string()))
}
}
}
pub(super) async fn handle_system_backup(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let backup_id = format!("backup-{}", chrono::Utc::now().timestamp());
let backup_path = format!("data/{}.db", backup_id);
// Ensure the data directory exists
if let Err(e) = std::fs::create_dir_all("data") {
return Json(ApiResponse::error(format!("Failed to create backup directory: {}", e)));
}
// Use SQLite VACUUM INTO for a consistent backup
let result = sqlx::query(&format!("VACUUM INTO '{}'", backup_path))
.execute(pool)
.await;
match result {
Ok(_) => {
// Get backup file size
let size_bytes = std::fs::metadata(&backup_path).map(|m| m.len()).unwrap_or(0);
Json(ApiResponse::success(serde_json::json!({
"success": true,
"message": "Backup completed successfully",
"backup_id": backup_id,
"backup_path": backup_path,
"size_bytes": size_bytes,
})))
}
Err(e) => {
warn!("Database backup failed: {}", e);
Json(ApiResponse::error(format!("Backup failed: {}", e)))
}
}
}
pub(super) async fn handle_get_settings(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let registry = &state.app_state.model_registry;
let provider_count = registry.providers.len();
let model_count: usize = registry.providers.values().map(|p| p.models.len()).sum();
Json(ApiResponse::success(serde_json::json!({
"server": {
"auth_tokens": state.app_state.auth_tokens.iter().map(|t| mask_token(t)).collect::<Vec<_>>(),
"version": env!("CARGO_PKG_VERSION"),
},
"registry": {
"provider_count": provider_count,
"model_count": model_count,
},
"database": {
"type": "SQLite",
}
})))
}
pub(super) async fn handle_update_settings(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
Json(ApiResponse::error(
"Changing settings at runtime is not yet supported. Please update your config file and restart the server."
.to_string(),
))
}
// Helper functions
fn mask_token(token: &str) -> String {
if token.len() <= 8 {
return "*****".to_string();
}
let masked_len = token.len().min(12);
let visible_len = 4;
let mask_len = masked_len - visible_len;
format!("{}{}", "*".repeat(mask_len), &token[token.len() - visible_len..])
}

View File

@@ -1,526 +0,0 @@
use axum::{
extract::{Query, State},
response::Json,
};
use chrono;
use serde::Deserialize;
use serde_json;
use sqlx::Row;
use tracing::warn;
use super::{ApiResponse, DashboardState};
/// Query parameters for time-based filtering on usage endpoints.
#[derive(Debug, Deserialize, Default)]
pub(super) struct UsagePeriodFilter {
/// Preset period: "today", "24h", "7d", "30d", "all" (default: "all")
pub period: Option<String>,
/// Custom range start (ISO 8601, e.g. "2025-06-01T00:00:00Z")
pub from: Option<String>,
/// Custom range end (ISO 8601)
pub to: Option<String>,
}
impl UsagePeriodFilter {
/// Returns `(sql_fragment, bind_values)` for a WHERE clause.
/// The fragment is either empty (no filter) or " AND timestamp >= ? [AND timestamp <= ?]".
fn to_sql(&self) -> (String, Vec<String>) {
let period = self.period.as_deref().unwrap_or("all");
if period == "custom" {
let mut clause = String::new();
let mut binds = Vec::new();
if let Some(ref from) = self.from {
clause.push_str(" AND timestamp >= ?");
binds.push(from.clone());
}
if let Some(ref to) = self.to {
clause.push_str(" AND timestamp <= ?");
binds.push(to.clone());
}
return (clause, binds);
}
let now = chrono::Utc::now();
let cutoff = match period {
"today" => {
// Start of today (UTC)
let today = now.format("%Y-%m-%dT00:00:00Z").to_string();
Some(today)
}
"24h" => Some((now - chrono::Duration::hours(24)).to_rfc3339()),
"7d" => Some((now - chrono::Duration::days(7)).to_rfc3339()),
"30d" => Some((now - chrono::Duration::days(30)).to_rfc3339()),
_ => None, // "all" or unrecognized
};
match cutoff {
Some(ts) => (" AND timestamp >= ?".to_string(), vec![ts]),
None => (String::new(), vec![]),
}
}
/// Determine the time-series granularity label for grouping.
fn granularity(&self) -> &'static str {
match self.period.as_deref().unwrap_or("all") {
"today" | "24h" => "hour",
_ => "day",
}
}
}
pub(super) async fn handle_usage_summary(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Query(filter): Query<UsagePeriodFilter>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let (period_clause, period_binds) = filter.to_sql();
// Total stats (filtered by period)
let period_sql = format!(
r#"
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(total_tokens), 0) as total_tokens,
COALESCE(SUM(cost), 0.0) as total_cost,
COUNT(DISTINCT client_id) as active_clients,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read,
COALESCE(SUM(cache_write_tokens), 0) as total_cache_write
FROM llm_requests
WHERE 1=1 {}
"#,
period_clause
);
let mut q = sqlx::query(&period_sql);
for b in &period_binds {
q = q.bind(b);
}
let total_stats = q.fetch_one(pool);
// Today's stats
let today = chrono::Utc::now().format("%Y-%m-%d").to_string();
let today_stats = sqlx::query(
r#"
SELECT
COUNT(*) as today_requests,
COALESCE(SUM(total_tokens), 0) as today_tokens,
COALESCE(SUM(cost), 0.0) as today_cost
FROM llm_requests
WHERE strftime('%Y-%m-%d', timestamp) = ?
"#,
)
.bind(today)
.fetch_one(pool);
// Error stats
let error_stats = sqlx::query(
r#"
SELECT
COUNT(*) as total,
SUM(CASE WHEN status = 'error' THEN 1 ELSE 0 END) as errors
FROM llm_requests
"#,
)
.fetch_one(pool);
// Average response time
let avg_response = sqlx::query(
r#"
SELECT COALESCE(AVG(duration_ms), 0.0) as avg_duration
FROM llm_requests
WHERE status = 'success'
"#,
)
.fetch_one(pool);
let results = tokio::join!(total_stats, today_stats, error_stats, avg_response);
match results {
(Ok(t), Ok(d), Ok(e), Ok(a)) => {
let total_requests: i64 = t.get("total_requests");
let total_tokens: i64 = t.get("total_tokens");
let total_cost: f64 = t.get("total_cost");
let active_clients: i64 = t.get("active_clients");
let total_cache_read: i64 = t.get("total_cache_read");
let total_cache_write: i64 = t.get("total_cache_write");
let today_requests: i64 = d.get("today_requests");
let today_cost: f64 = d.get("today_cost");
let total_count: i64 = e.get("total");
let error_count: i64 = e.get("errors");
let error_rate = if total_count > 0 {
(error_count as f64 / total_count as f64) * 100.0
} else {
0.0
};
let avg_response_time: f64 = a.get("avg_duration");
Json(ApiResponse::success(serde_json::json!({
"total_requests": total_requests,
"total_tokens": total_tokens,
"total_cost": total_cost,
"active_clients": active_clients,
"today_requests": today_requests,
"today_cost": today_cost,
"error_rate": error_rate,
"avg_response_time": avg_response_time,
"total_cache_read_tokens": total_cache_read,
"total_cache_write_tokens": total_cache_write,
})))
}
(t_res, d_res, e_res, a_res) => {
if let Err(e) = t_res { warn!("Total stats query failed: {}", e); }
if let Err(e) = d_res { warn!("Today stats query failed: {}", e); }
if let Err(e) = e_res { warn!("Error stats query failed: {}", e); }
if let Err(e) = a_res { warn!("Avg response query failed: {}", e); }
Json(ApiResponse::error("Failed to fetch usage statistics".to_string()))
}
}
}
pub(super) async fn handle_time_series(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Query(filter): Query<UsagePeriodFilter>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let (period_clause, period_binds) = filter.to_sql();
let granularity = filter.granularity();
// Determine the strftime format and default lookback
let (strftime_fmt, _label_key, default_lookback) = match granularity {
"hour" => ("%H:00", "hour", chrono::Duration::hours(24)),
_ => ("%Y-%m-%d", "day", chrono::Duration::days(30)),
};
// If no period filter, apply a sensible default lookback
let (clause, binds) = if period_clause.is_empty() {
let cutoff = (chrono::Utc::now() - default_lookback).to_rfc3339();
(" AND timestamp >= ?".to_string(), vec![cutoff])
} else {
(period_clause, period_binds)
};
let sql = format!(
r#"
SELECT
strftime('{strftime_fmt}', timestamp) as bucket,
COUNT(*) as requests,
COALESCE(SUM(total_tokens), 0) as tokens,
COALESCE(SUM(cost), 0.0) as cost
FROM llm_requests
WHERE 1=1 {clause}
GROUP BY bucket
ORDER BY bucket
"#,
);
let mut q = sqlx::query(&sql);
for b in &binds {
q = q.bind(b);
}
let result = q.fetch_all(pool).await;
match result {
Ok(rows) => {
let mut series = Vec::new();
for row in rows {
let bucket: String = row.get("bucket");
let requests: i64 = row.get("requests");
let tokens: i64 = row.get("tokens");
let cost: f64 = row.get("cost");
series.push(serde_json::json!({
"time": bucket,
"requests": requests,
"tokens": tokens,
"cost": cost,
}));
}
Json(ApiResponse::success(serde_json::json!({
"series": series,
"period": filter.period.as_deref().unwrap_or("all"),
"granularity": granularity,
})))
}
Err(e) => {
warn!("Failed to fetch time series data: {}", e);
Json(ApiResponse::error("Failed to fetch time series data".to_string()))
}
}
}
pub(super) async fn handle_clients_usage(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Query(filter): Query<UsagePeriodFilter>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let (period_clause, period_binds) = filter.to_sql();
let sql = format!(
r#"
SELECT
client_id,
COUNT(*) as requests,
COALESCE(SUM(total_tokens), 0) as tokens,
COALESCE(SUM(cost), 0.0) as cost,
MAX(timestamp) as last_request
FROM llm_requests
WHERE 1=1 {}
GROUP BY client_id
ORDER BY requests DESC
"#,
period_clause
);
let mut q = sqlx::query(&sql);
for b in &period_binds {
q = q.bind(b);
}
let result = q.fetch_all(pool).await;
match result {
Ok(rows) => {
let mut client_usage = Vec::new();
for row in rows {
let client_id: String = row.get("client_id");
let requests: i64 = row.get("requests");
let tokens: i64 = row.get("tokens");
let cost: f64 = row.get("cost");
let last_request: Option<chrono::DateTime<chrono::Utc>> = row.get("last_request");
client_usage.push(serde_json::json!({
"client_id": client_id,
"client_name": client_id,
"requests": requests,
"tokens": tokens,
"cost": cost,
"last_request": last_request,
}));
}
Json(ApiResponse::success(serde_json::json!(client_usage)))
}
Err(e) => {
warn!("Failed to fetch client usage data: {}", e);
Json(ApiResponse::error("Failed to fetch client usage data".to_string()))
}
}
}
pub(super) async fn handle_providers_usage(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Query(filter): Query<UsagePeriodFilter>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let (period_clause, period_binds) = filter.to_sql();
let sql = format!(
r#"
SELECT
provider,
COUNT(*) as requests,
COALESCE(SUM(total_tokens), 0) as tokens,
COALESCE(SUM(cost), 0.0) as cost,
COALESCE(SUM(cache_read_tokens), 0) as cache_read,
COALESCE(SUM(cache_write_tokens), 0) as cache_write
FROM llm_requests
WHERE 1=1 {}
GROUP BY provider
ORDER BY requests DESC
"#,
period_clause
);
let mut q = sqlx::query(&sql);
for b in &period_binds {
q = q.bind(b);
}
let result = q.fetch_all(pool).await;
match result {
Ok(rows) => {
let mut provider_usage = Vec::new();
for row in rows {
let provider: String = row.get("provider");
let requests: i64 = row.get("requests");
let tokens: i64 = row.get("tokens");
let cost: f64 = row.get("cost");
let cache_read: i64 = row.get("cache_read");
let cache_write: i64 = row.get("cache_write");
provider_usage.push(serde_json::json!({
"provider": provider,
"requests": requests,
"tokens": tokens,
"cost": cost,
"cache_read_tokens": cache_read,
"cache_write_tokens": cache_write,
}));
}
Json(ApiResponse::success(serde_json::json!(provider_usage)))
}
Err(e) => {
warn!("Failed to fetch provider usage data: {}", e);
Json(ApiResponse::error("Failed to fetch provider usage data".to_string()))
}
}
}
pub(super) async fn handle_detailed_usage(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Query(filter): Query<UsagePeriodFilter>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let (period_clause, period_binds) = filter.to_sql();
let sql = format!(
r#"
SELECT
strftime('%Y-%m-%d', timestamp) as date,
client_id,
provider,
model,
COUNT(*) as requests,
COALESCE(SUM(total_tokens), 0) as tokens,
COALESCE(SUM(cost), 0.0) as cost,
COALESCE(SUM(cache_read_tokens), 0) as cache_read,
COALESCE(SUM(cache_write_tokens), 0) as cache_write
FROM llm_requests
WHERE 1=1 {}
GROUP BY date, client_id, provider, model
ORDER BY date DESC
LIMIT 200
"#,
period_clause
);
let mut q = sqlx::query(&sql);
for b in &period_binds {
q = q.bind(b);
}
let result = q.fetch_all(pool).await;
match result {
Ok(rows) => {
let usage: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
serde_json::json!({
"date": row.get::<String, _>("date"),
"client": row.get::<String, _>("client_id"),
"provider": row.get::<String, _>("provider"),
"model": row.get::<String, _>("model"),
"requests": row.get::<i64, _>("requests"),
"tokens": row.get::<i64, _>("tokens"),
"cost": row.get::<f64, _>("cost"),
"cache_read_tokens": row.get::<i64, _>("cache_read"),
"cache_write_tokens": row.get::<i64, _>("cache_write"),
})
})
.collect();
Json(ApiResponse::success(serde_json::json!(usage)))
}
Err(e) => {
warn!("Failed to fetch detailed usage: {}", e);
Json(ApiResponse::error("Failed to fetch detailed usage".to_string()))
}
}
}
pub(super) async fn handle_analytics_breakdown(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Query(filter): Query<UsagePeriodFilter>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match super::auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let (period_clause, period_binds) = filter.to_sql();
// Model breakdown
let model_sql = format!(
"SELECT model as label, COUNT(*) as value FROM llm_requests WHERE 1=1 {} GROUP BY model ORDER BY value DESC",
period_clause
);
let mut mq = sqlx::query(&model_sql);
for b in &period_binds {
mq = mq.bind(b);
}
let models = mq.fetch_all(pool);
// Client breakdown
let client_sql = format!(
"SELECT client_id as label, COUNT(*) as value FROM llm_requests WHERE 1=1 {} GROUP BY client_id ORDER BY value DESC",
period_clause
);
let mut cq = sqlx::query(&client_sql);
for b in &period_binds {
cq = cq.bind(b);
}
let clients = cq.fetch_all(pool);
match tokio::join!(models, clients) {
(Ok(m_rows), Ok(c_rows)) => {
let model_breakdown: Vec<serde_json::Value> = m_rows
.into_iter()
.map(|r| serde_json::json!({ "label": r.get::<String, _>("label"), "value": r.get::<i64, _>("value") }))
.collect();
let client_breakdown: Vec<serde_json::Value> = c_rows
.into_iter()
.map(|r| serde_json::json!({ "label": r.get::<String, _>("label"), "value": r.get::<i64, _>("value") }))
.collect();
Json(ApiResponse::success(serde_json::json!({
"models": model_breakdown,
"clients": client_breakdown
})))
}
_ => Json(ApiResponse::error("Failed to fetch analytics breakdown".to_string())),
}
}

View File

@@ -1,290 +0,0 @@
use axum::{
extract::{Path, State},
response::Json,
};
use serde::Deserialize;
use sqlx::Row;
use tracing::warn;
use super::{ApiResponse, DashboardState, auth};
// ── User management endpoints (admin-only) ──────────────────────────
pub(super) async fn handle_get_users(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
let result = sqlx::query(
"SELECT id, username, display_name, role, must_change_password, created_at FROM users ORDER BY created_at ASC",
)
.fetch_all(pool)
.await;
match result {
Ok(rows) => {
let users: Vec<serde_json::Value> = rows
.into_iter()
.map(|row| {
let username: String = row.get("username");
let display_name: Option<String> = row.get("display_name");
serde_json::json!({
"id": row.get::<i64, _>("id"),
"username": &username,
"display_name": display_name.as_deref().unwrap_or(&username),
"role": row.get::<String, _>("role"),
"must_change_password": row.get::<bool, _>("must_change_password"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
})
})
.collect();
Json(ApiResponse::success(serde_json::json!(users)))
}
Err(e) => {
warn!("Failed to fetch users: {}", e);
Json(ApiResponse::error("Failed to fetch users".to_string()))
}
}
}
#[derive(Deserialize)]
pub(super) struct CreateUserRequest {
pub(super) username: String,
pub(super) password: String,
pub(super) display_name: Option<String>,
pub(super) role: Option<String>,
}
pub(super) async fn handle_create_user(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Json(payload): Json<CreateUserRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Validate role
let role = payload.role.as_deref().unwrap_or("viewer");
if role != "admin" && role != "viewer" {
return Json(ApiResponse::error("Role must be 'admin' or 'viewer'".to_string()));
}
// Validate username
let username = payload.username.trim();
if username.is_empty() || username.len() > 64 {
return Json(ApiResponse::error("Username must be 1-64 characters".to_string()));
}
// Validate password
if payload.password.len() < 4 {
return Json(ApiResponse::error("Password must be at least 4 characters".to_string()));
}
let password_hash = match bcrypt::hash(&payload.password, 12) {
Ok(h) => h,
Err(_) => return Json(ApiResponse::error("Failed to hash password".to_string())),
};
let result = sqlx::query(
r#"
INSERT INTO users (username, password_hash, display_name, role, must_change_password)
VALUES (?, ?, ?, ?, TRUE)
RETURNING id, username, display_name, role, must_change_password, created_at
"#,
)
.bind(username)
.bind(&password_hash)
.bind(&payload.display_name)
.bind(role)
.fetch_one(pool)
.await;
match result {
Ok(row) => {
let uname: String = row.get("username");
let display_name: Option<String> = row.get("display_name");
Json(ApiResponse::success(serde_json::json!({
"id": row.get::<i64, _>("id"),
"username": &uname,
"display_name": display_name.as_deref().unwrap_or(&uname),
"role": row.get::<String, _>("role"),
"must_change_password": row.get::<bool, _>("must_change_password"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
})))
}
Err(e) => {
let msg = if e.to_string().contains("UNIQUE") {
format!("Username '{}' already exists", username)
} else {
format!("Failed to create user: {}", e)
};
warn!("Failed to create user: {}", e);
Json(ApiResponse::error(msg))
}
}
}
#[derive(Deserialize)]
pub(super) struct UpdateUserRequest {
pub(super) display_name: Option<String>,
pub(super) role: Option<String>,
pub(super) password: Option<String>,
pub(super) must_change_password: Option<bool>,
}
pub(super) async fn handle_update_user(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(id): Path<i64>,
Json(payload): Json<UpdateUserRequest>,
) -> Json<ApiResponse<serde_json::Value>> {
let (_session, _) = match auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Validate role if provided
if let Some(ref role) = payload.role {
if role != "admin" && role != "viewer" {
return Json(ApiResponse::error("Role must be 'admin' or 'viewer'".to_string()));
}
}
// Build dynamic update
let mut sets = Vec::new();
let mut string_binds: Vec<String> = Vec::new();
let mut has_password = false;
let mut has_must_change = false;
if let Some(ref display_name) = payload.display_name {
sets.push("display_name = ?");
string_binds.push(display_name.clone());
}
if let Some(ref role) = payload.role {
sets.push("role = ?");
string_binds.push(role.clone());
}
if let Some(ref password) = payload.password {
if password.len() < 4 {
return Json(ApiResponse::error("Password must be at least 4 characters".to_string()));
}
let hash = match bcrypt::hash(password, 12) {
Ok(h) => h,
Err(_) => return Json(ApiResponse::error("Failed to hash password".to_string())),
};
sets.push("password_hash = ?");
string_binds.push(hash);
has_password = true;
}
if let Some(mcp) = payload.must_change_password {
sets.push("must_change_password = ?");
has_must_change = true;
let _ = mcp; // used below in bind
}
if sets.is_empty() {
return Json(ApiResponse::error("No fields to update".to_string()));
}
let sql = format!("UPDATE users SET {} WHERE id = ?", sets.join(", "));
let mut query = sqlx::query(&sql);
for b in &string_binds {
query = query.bind(b);
}
if has_must_change {
query = query.bind(payload.must_change_password.unwrap());
}
let _ = has_password; // consumed above via string_binds
query = query.bind(id);
match query.execute(pool).await {
Ok(result) => {
if result.rows_affected() == 0 {
return Json(ApiResponse::error("User not found".to_string()));
}
// Fetch updated user
let row = sqlx::query(
"SELECT id, username, display_name, role, must_change_password, created_at FROM users WHERE id = ?",
)
.bind(id)
.fetch_one(pool)
.await;
match row {
Ok(row) => {
let uname: String = row.get("username");
let display_name: Option<String> = row.get("display_name");
Json(ApiResponse::success(serde_json::json!({
"id": row.get::<i64, _>("id"),
"username": &uname,
"display_name": display_name.as_deref().unwrap_or(&uname),
"role": row.get::<String, _>("role"),
"must_change_password": row.get::<bool, _>("must_change_password"),
"created_at": row.get::<chrono::DateTime<chrono::Utc>, _>("created_at"),
})))
}
Err(_) => Json(ApiResponse::success(serde_json::json!({ "message": "User updated" }))),
}
}
Err(e) => {
warn!("Failed to update user: {}", e);
Json(ApiResponse::error(format!("Failed to update user: {}", e)))
}
}
}
pub(super) async fn handle_delete_user(
State(state): State<DashboardState>,
headers: axum::http::HeaderMap,
Path(id): Path<i64>,
) -> Json<ApiResponse<serde_json::Value>> {
let (session, _) = match auth::require_admin(&state, &headers).await {
Ok((session, new_token)) => (session, new_token),
Err(e) => return e,
};
let pool = &state.app_state.db_pool;
// Don't allow deleting yourself
let target_username: Option<String> =
sqlx::query_scalar::<_, String>("SELECT username FROM users WHERE id = ?")
.bind(id)
.fetch_optional(pool)
.await
.unwrap_or(None);
match target_username {
None => return Json(ApiResponse::error("User not found".to_string())),
Some(ref uname) if uname == &session.username => {
return Json(ApiResponse::error("Cannot delete your own account".to_string()));
}
_ => {}
}
let result = sqlx::query("DELETE FROM users WHERE id = ?")
.bind(id)
.execute(pool)
.await;
match result {
Ok(_) => Json(ApiResponse::success(serde_json::json!({ "message": "User deleted" }))),
Err(e) => {
warn!("Failed to delete user: {}", e);
Json(ApiResponse::error(format!("Failed to delete user: {}", e)))
}
}
}

View File

@@ -1,75 +0,0 @@
use axum::{
extract::{
State,
ws::{Message, WebSocket, WebSocketUpgrade},
},
response::IntoResponse,
};
use serde_json;
use tracing::info;
use super::DashboardState;
// WebSocket handler
pub(super) async fn handle_websocket(ws: WebSocketUpgrade, State(state): State<DashboardState>) -> impl IntoResponse {
ws.on_upgrade(|socket| handle_websocket_connection(socket, state))
}
pub(super) async fn handle_websocket_connection(mut socket: WebSocket, state: DashboardState) {
info!("WebSocket connection established");
// Subscribe to events from the global bus
let mut rx = state.app_state.dashboard_tx.subscribe();
// Send initial connection message
let _ = socket
.send(Message::Text(
serde_json::json!({
"type": "connected",
"message": "Connected to LLM Proxy Dashboard"
})
.to_string()
.into(),
))
.await;
// Handle incoming messages and broadcast events
loop {
tokio::select! {
// Receive broadcast events
Ok(event) = rx.recv() => {
let Ok(json_str) = serde_json::to_string(&event) else {
continue;
};
let message = Message::Text(json_str.into());
if socket.send(message).await.is_err() {
break;
}
}
// Receive WebSocket messages
result = socket.recv() => {
match result {
Some(Ok(Message::Text(text))) => {
handle_websocket_message(&text, &state).await;
}
_ => break,
}
}
}
}
info!("WebSocket connection closed");
}
pub(super) async fn handle_websocket_message(text: &str, state: &DashboardState) {
// Parse and handle WebSocket messages
if let Ok(data) = serde_json::from_str::<serde_json::Value>(text)
&& data.get("type").and_then(|v| v.as_str()) == Some("ping")
{
let _ = state.app_state.dashboard_tx.send(serde_json::json!({
"type": "pong",
"payload": {}
}));
}
}

View File

@@ -1,275 +0,0 @@
use anyhow::Result;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePool};
use std::str::FromStr;
use tracing::info;
use crate::config::DatabaseConfig;
pub type DbPool = SqlitePool;
pub async fn init(config: &DatabaseConfig) -> Result<DbPool> {
// Ensure the database directory exists
if let Some(parent) = config.path.parent()
&& !parent.as_os_str().is_empty()
{
tokio::fs::create_dir_all(parent).await?;
}
let database_path = config.path.to_string_lossy().to_string();
info!("Connecting to database at {}", database_path);
let options = SqliteConnectOptions::from_str(&format!("sqlite:{}", database_path))?
.create_if_missing(true)
.pragma("foreign_keys", "ON");
let pool = SqlitePool::connect_with(options).await?;
// Run migrations
run_migrations(&pool).await?;
info!("Database migrations completed");
Ok(pool)
}
pub async fn run_migrations(pool: &DbPool) -> Result<()> {
// Create clients table if it doesn't exist
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS clients (
id INTEGER PRIMARY KEY AUTOINCREMENT,
client_id TEXT UNIQUE NOT NULL,
name TEXT,
description TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
is_active BOOLEAN DEFAULT TRUE,
rate_limit_per_minute INTEGER DEFAULT 60,
total_requests INTEGER DEFAULT 0,
total_tokens INTEGER DEFAULT 0,
total_cost REAL DEFAULT 0.0
)
"#,
)
.execute(pool)
.await?;
// Create llm_requests table if it doesn't exist
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS llm_requests (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
client_id TEXT,
provider TEXT,
model TEXT,
prompt_tokens INTEGER,
completion_tokens INTEGER,
reasoning_tokens INTEGER DEFAULT 0,
total_tokens INTEGER,
cost REAL,
has_images BOOLEAN DEFAULT FALSE,
status TEXT DEFAULT 'success',
error_message TEXT,
duration_ms INTEGER,
request_body TEXT,
response_body TEXT,
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE SET NULL
)
"#,
)
.execute(pool)
.await?;
// Create provider_configs table
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS provider_configs (
id TEXT PRIMARY KEY,
display_name TEXT NOT NULL,
enabled BOOLEAN DEFAULT TRUE,
base_url TEXT,
api_key TEXT,
credit_balance REAL DEFAULT 0.0,
low_credit_threshold REAL DEFAULT 5.0,
billing_mode TEXT,
api_key_encrypted BOOLEAN DEFAULT FALSE,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
"#,
)
.execute(pool)
.await?;
// Create model_configs table
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS model_configs (
id TEXT PRIMARY KEY,
provider_id TEXT NOT NULL,
display_name TEXT,
enabled BOOLEAN DEFAULT TRUE,
prompt_cost_per_m REAL,
completion_cost_per_m REAL,
cache_read_cost_per_m REAL,
cache_write_cost_per_m REAL,
mapping TEXT,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (provider_id) REFERENCES provider_configs(id) ON DELETE CASCADE
)
"#,
)
.execute(pool)
.await?;
// Create users table for dashboard access
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
display_name TEXT,
role TEXT DEFAULT 'admin',
must_change_password BOOLEAN DEFAULT FALSE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
"#,
)
.execute(pool)
.await?;
// Create client_tokens table for DB-based token auth
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS client_tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
client_id TEXT NOT NULL,
token TEXT NOT NULL UNIQUE,
name TEXT DEFAULT 'default',
is_active BOOLEAN DEFAULT TRUE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_used_at DATETIME,
FOREIGN KEY (client_id) REFERENCES clients(client_id) ON DELETE CASCADE
)
"#,
)
.execute(pool)
.await?;
// Add must_change_password column if it doesn't exist (migration for existing DBs)
let _ = sqlx::query("ALTER TABLE users ADD COLUMN must_change_password BOOLEAN DEFAULT FALSE")
.execute(pool)
.await;
// Add display_name column if it doesn't exist (migration for existing DBs)
let _ = sqlx::query("ALTER TABLE users ADD COLUMN display_name TEXT")
.execute(pool)
.await;
// Add cache token columns if they don't exist (migration for existing DBs)
let _ = sqlx::query("ALTER TABLE llm_requests ADD COLUMN cache_read_tokens INTEGER DEFAULT 0")
.execute(pool)
.await;
let _ = sqlx::query("ALTER TABLE llm_requests ADD COLUMN cache_write_tokens INTEGER DEFAULT 0")
.execute(pool)
.await;
let _ = sqlx::query("ALTER TABLE llm_requests ADD COLUMN reasoning_tokens INTEGER DEFAULT 0")
.execute(pool)
.await;
// Add billing_mode column if it doesn't exist (migration for existing DBs)
let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN billing_mode TEXT")
.execute(pool)
.await;
// Add api_key_encrypted column if it doesn't exist (migration for existing DBs)
let _ = sqlx::query("ALTER TABLE provider_configs ADD COLUMN api_key_encrypted BOOLEAN DEFAULT FALSE")
.execute(pool)
.await;
// Add manual cache cost columns to model_configs if they don't exist
let _ = sqlx::query("ALTER TABLE model_configs ADD COLUMN cache_read_cost_per_m REAL")
.execute(pool)
.await;
let _ = sqlx::query("ALTER TABLE model_configs ADD COLUMN cache_write_cost_per_m REAL")
.execute(pool)
.await;
// Insert default admin user if none exists (default password: admin)
let user_count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users").fetch_one(pool).await?;
if user_count.0 == 0 {
// 'admin' hashed with default cost (12)
let default_admin_hash =
bcrypt::hash("admin", 12).map_err(|e| anyhow::anyhow!("Failed to hash default password: {}", e))?;
sqlx::query(
"INSERT INTO users (username, password_hash, role, must_change_password) VALUES ('admin', ?, 'admin', TRUE)"
)
.bind(default_admin_hash)
.execute(pool)
.await?;
info!("Created default admin user with password 'admin' (must change on first login)");
}
// Create indices
sqlx::query("CREATE INDEX IF NOT EXISTS idx_clients_client_id ON clients(client_id)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_clients_created_at ON clients(created_at)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_timestamp ON llm_requests(timestamp)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_id ON llm_requests(client_id)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider ON llm_requests(provider)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_status ON llm_requests(status)")
.execute(pool)
.await?;
sqlx::query("CREATE UNIQUE INDEX IF NOT EXISTS idx_client_tokens_token ON client_tokens(token)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_client_tokens_client_id ON client_tokens(client_id)")
.execute(pool)
.await?;
// Composite indexes for performance
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_client_timestamp ON llm_requests(client_id, timestamp)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_llm_requests_provider_timestamp ON llm_requests(provider, timestamp)")
.execute(pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_model_configs_provider_id ON model_configs(provider_id)")
.execute(pool)
.await?;
// Insert default client if none exists
sqlx::query(
r#"
INSERT OR IGNORE INTO clients (client_id, name, description)
VALUES ('default', 'Default Client', 'Default client for anonymous requests')
"#,
)
.execute(pool)
.await?;
Ok(())
}
pub async fn test_connection(pool: &DbPool) -> Result<()> {
sqlx::query("SELECT 1").execute(pool).await?;
Ok(())
}

View File

@@ -1,58 +0,0 @@
use thiserror::Error;
#[derive(Error, Debug, Clone)]
pub enum AppError {
#[error("Authentication failed: {0}")]
AuthError(String),
#[error("Configuration error: {0}")]
ConfigError(String),
#[error("Database error: {0}")]
DatabaseError(String),
#[error("Provider error: {0}")]
ProviderError(String),
#[error("Validation error: {0}")]
ValidationError(String),
#[error("Multimodal processing error: {0}")]
MultimodalError(String),
#[error("Rate limit exceeded: {0}")]
RateLimitError(String),
#[error("Internal server error: {0}")]
InternalError(String),
}
impl From<sqlx::Error> for AppError {
fn from(err: sqlx::Error) -> Self {
AppError::DatabaseError(err.to_string())
}
}
impl From<anyhow::Error> for AppError {
fn from(err: anyhow::Error) -> Self {
AppError::InternalError(err.to_string())
}
}
impl axum::response::IntoResponse for AppError {
fn into_response(self) -> axum::response::Response {
let status = match self {
AppError::AuthError(_) => axum::http::StatusCode::UNAUTHORIZED,
AppError::RateLimitError(_) => axum::http::StatusCode::TOO_MANY_REQUESTS,
AppError::ValidationError(_) => axum::http::StatusCode::BAD_REQUEST,
_ => axum::http::StatusCode::INTERNAL_SERVER_ERROR,
};
let body = axum::Json(serde_json::json!({
"error": self.to_string(),
"type": format!("{:?}", self)
}));
(status, body).into_response()
}
}

View File

@@ -1,328 +0,0 @@
//! LLM Proxy Library
//!
//! This library provides the core functionality for the LLM proxy gateway,
//! including provider integration, token tracking, and API endpoints.
pub mod auth;
pub mod client;
pub mod config;
pub mod dashboard;
pub mod database;
pub mod errors;
pub mod logging;
pub mod models;
pub mod multimodal;
pub mod providers;
pub mod rate_limiting;
pub mod server;
pub mod state;
pub mod utils;
// Re-exports for convenience
pub use auth::{AuthenticatedClient, validate_token};
pub use config::{
AppConfig, DatabaseConfig, DeepSeekConfig, GeminiConfig, GrokConfig, ModelMappingConfig, ModelPricing,
OllamaConfig, OpenAIConfig, PricingConfig, ProviderConfig, ServerConfig,
};
pub use database::{DbPool, init as init_db, test_connection};
pub use errors::AppError;
pub use logging::{LoggingContext, RequestLog, RequestLogger};
pub use models::{
ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage,
ChatStreamChoice, ChatStreamDelta, ContentPart, ContentPartValue, FromOpenAI, ImageUrl, MessageContent,
OpenAIContentPart, OpenAIMessage, OpenAIRequest, ToOpenAI, UnifiedMessage, UnifiedRequest, Usage,
};
pub use providers::{Provider, ProviderManager, ProviderResponse, ProviderStreamChunk};
pub use server::router;
pub use state::AppState;
/// Test utilities for integration testing
#[cfg(test)]
pub mod test_utils {
use std::sync::Arc;
use crate::{client::ClientManager, providers::ProviderManager, rate_limiting::RateLimitManager, state::AppState, utils::crypto, database::run_migrations};
use sqlx::sqlite::SqlitePool;
/// Create a test application state
pub async fn create_test_state() -> AppState {
// Create in-memory database
let pool = SqlitePool::connect("sqlite::memory:")
.await
.expect("Failed to create test database");
// Run migrations on the pool
run_migrations(&pool).await.expect("Failed to run migrations");
let rate_limit_manager = RateLimitManager::new(
crate::rate_limiting::RateLimiterConfig::default(),
crate::rate_limiting::CircuitBreakerConfig::default(),
);
let client_manager = Arc::new(ClientManager::new(pool.clone()));
// Create provider manager
let provider_manager = ProviderManager::new();
let model_registry = crate::models::registry::ModelRegistry {
providers: std::collections::HashMap::new(),
};
let (dashboard_tx, _) = tokio::sync::broadcast::channel::<serde_json::Value>(100);
let config = Arc::new(crate::config::AppConfig {
server: crate::config::ServerConfig {
port: 8080,
host: "127.0.0.1".to_string(),
auth_tokens: vec![],
},
database: crate::config::DatabaseConfig {
path: std::path::PathBuf::from(":memory:"),
max_connections: 5,
},
providers: crate::config::ProviderConfig {
openai: crate::config::OpenAIConfig {
api_key_env: "OPENAI_API_KEY".to_string(),
base_url: "".to_string(),
default_model: "".to_string(),
enabled: true,
},
gemini: crate::config::GeminiConfig {
api_key_env: "GEMINI_API_KEY".to_string(),
base_url: "".to_string(),
default_model: "".to_string(),
enabled: true,
},
deepseek: crate::config::DeepSeekConfig {
api_key_env: "DEEPSEEK_API_KEY".to_string(),
base_url: "".to_string(),
default_model: "".to_string(),
enabled: true,
},
grok: crate::config::GrokConfig {
api_key_env: "GROK_API_KEY".to_string(),
base_url: "".to_string(),
default_model: "".to_string(),
enabled: true,
},
ollama: crate::config::OllamaConfig {
base_url: "".to_string(),
enabled: true,
models: vec![],
},
},
model_mapping: crate::config::ModelMappingConfig { patterns: vec![] },
pricing: crate::config::PricingConfig {
openai: vec![],
gemini: vec![],
deepseek: vec![],
grok: vec![],
ollama: vec![],
},
config_path: None,
encryption_key: "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f".to_string(),
});
// Initialize encryption with the test key
crypto::init_with_key(&config.encryption_key).expect("failed to initialize crypto");
AppState::new(
config,
provider_manager,
pool,
rate_limit_manager,
model_registry,
vec![], // auth_tokens
)
}
/// Create a test HTTP client
pub fn create_test_client() -> reqwest::Client {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("Failed to create test HTTP client")
}
}
#[cfg(test)]
mod integration_tests {
use super::test_utils::*;
use crate::{
models::{ChatCompletionRequest, ChatMessage},
server::router,
utils::crypto,
};
use axum::{
body::Body,
http::{Request, StatusCode},
};
use mockito::Server;
use serde_json::json;
use sqlx::Row;
use tower::util::ServiceExt;
#[tokio::test]
async fn test_encrypted_provider_key_integration() {
// Step 1: Setup test database and state
let state = create_test_state().await;
let pool = state.db_pool.clone();
// Step 2: Insert provider with encrypted API key
let test_api_key = "test-openai-key-12345";
let encrypted_key = crypto::encrypt(test_api_key).expect("Failed to encrypt test key");
sqlx::query(
r#"
INSERT INTO provider_configs (id, display_name, enabled, base_url, api_key, api_key_encrypted, credit_balance, low_credit_threshold)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind("openai")
.bind("OpenAI")
.bind(true)
.bind("http://localhost:1234") // Mock server URL
.bind(&encrypted_key)
.bind(true) // api_key_encrypted flag
.bind(100.0)
.bind(5.0)
.execute(&pool)
.await
.expect("Failed to update provider URL");
// Re-initialize provider with new URL
state
.provider_manager
.initialize_provider("openai", &state.config, &pool)
.await
.expect("Failed to re-initialize provider");
// Step 4: Mock OpenAI API server
let mut server = Server::new_async().await;
let mock = server
.mock("POST", "/chat/completions")
.match_header("authorization", format!("Bearer {}", test_api_key).as_str())
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
json!({
"id": "chatcmpl-test",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-3.5-turbo",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello, world!"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
}
})
.to_string(),
)
.create_async()
.await;
// Update provider base URL to use mock server
sqlx::query("UPDATE provider_configs SET base_url = ? WHERE id = 'openai'")
.bind(&server.url())
.execute(&pool)
.await
.expect("Failed to update provider URL");
// Re-initialize provider with new URL
state
.provider_manager
.initialize_provider("openai", &state.config, &pool)
.await
.expect("Failed to re-initialize provider");
// Step 5: Create test router and make request
let app = router(state);
let request_body = ChatCompletionRequest {
model: "gpt-3.5-turbo".to_string(),
messages: vec![ChatMessage {
role: "user".to_string(),
content: crate::models::MessageContent::Text {
content: "Hello".to_string(),
},
reasoning_content: None,
tool_calls: None,
name: None,
tool_call_id: None,
}],
temperature: None,
top_p: None,
top_k: None,
n: None,
stop: None,
max_tokens: Some(100),
presence_penalty: None,
frequency_penalty: None,
stream: Some(false),
tools: None,
tool_choice: None,
};
let request = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.header("authorization", "Bearer test-token")
.body(Body::from(serde_json::to_string(&request_body).unwrap()))
.unwrap();
// Step 6: Execute request through proxy
let response = app
.oneshot(request)
.await
.expect("Failed to execute request");
let status = response.status();
println!("Response status: {}", status);
if status != StatusCode::OK {
let body_bytes = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap();
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
println!("Response body: {}", body_str);
panic!("Response status is not OK: {}", status);
}
assert_eq!(status, StatusCode::OK);
// Verify the mock was called
mock.assert_async().await;
// Give the async logging task time to complete
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Step 7: Verify usage was logged in database
let log_row = sqlx::query("SELECT * FROM llm_requests WHERE client_id = 'client_test-tok' ORDER BY id DESC LIMIT 1")
.fetch_one(&pool)
.await
.expect("Request log not found");
assert_eq!(log_row.get::<String, _>("provider"), "openai");
assert_eq!(log_row.get::<String, _>("model"), "gpt-3.5-turbo");
assert_eq!(log_row.get::<i64, _>("prompt_tokens"), 10);
assert_eq!(log_row.get::<i64, _>("completion_tokens"), 5);
assert_eq!(log_row.get::<i64, _>("total_tokens"), 15);
assert_eq!(log_row.get::<String, _>("status"), "success");
// Verify client usage was updated
let client_row = sqlx::query("SELECT total_requests, total_tokens, total_cost FROM clients WHERE client_id = 'client_test-tok'")
.fetch_one(&pool)
.await
.expect("Client not found");
assert_eq!(client_row.get::<i64, _>("total_requests"), 1);
assert_eq!(client_row.get::<i64, _>("total_tokens"), 15);
}
}

View File

@@ -1,242 +0,0 @@
use chrono::{DateTime, Utc};
use serde::Serialize;
use sqlx::SqlitePool;
use tokio::sync::broadcast;
use tracing::{info, warn};
use crate::errors::AppError;
/// Request log entry for database storage
#[derive(Debug, Clone, Serialize)]
pub struct RequestLog {
pub timestamp: DateTime<Utc>,
pub client_id: String,
pub provider: String,
pub model: String,
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub reasoning_tokens: u32,
pub total_tokens: u32,
pub cache_read_tokens: u32,
pub cache_write_tokens: u32,
pub cost: f64,
pub has_images: bool,
pub status: String, // "success", "error"
pub error_message: Option<String>,
pub duration_ms: u64,
}
/// Database operations for request logging
pub struct RequestLogger {
db_pool: SqlitePool,
dashboard_tx: broadcast::Sender<serde_json::Value>,
}
impl RequestLogger {
pub fn new(db_pool: SqlitePool, dashboard_tx: broadcast::Sender<serde_json::Value>) -> Self {
Self { db_pool, dashboard_tx }
}
/// Log a request to the database (async, spawns a task)
pub fn log_request(&self, log: RequestLog) {
let pool = self.db_pool.clone();
let tx = self.dashboard_tx.clone();
// Spawn async task to log without blocking response
tokio::spawn(async move {
// Broadcast to dashboard
let broadcast_result = tx.send(serde_json::json!({
"type": "request",
"channel": "requests",
"payload": log
}));
match broadcast_result {
Ok(receivers) => info!("Broadcast request log to {} dashboard listeners", receivers),
Err(_) => {} // No active WebSocket clients — expected when dashboard isn't open
}
match Self::insert_log(&pool, log).await {
Ok(()) => info!("Request logged to database successfully"),
Err(e) => warn!("Failed to log request to database: {}", e),
}
});
}
/// Insert a log entry into the database
async fn insert_log(pool: &SqlitePool, log: RequestLog) -> Result<(), sqlx::Error> {
let mut tx = pool.begin().await?;
// Ensure the client row exists (FK constraint requires it)
sqlx::query(
"INSERT OR IGNORE INTO clients (client_id, name, description) VALUES (?, ?, 'Auto-created from request')",
)
.bind(&log.client_id)
.bind(&log.client_id)
.execute(&mut *tx)
.await?;
sqlx::query(
r#"
INSERT INTO llm_requests
(timestamp, client_id, provider, model, prompt_tokens, completion_tokens, reasoning_tokens, total_tokens, cache_read_tokens, cache_write_tokens, cost, has_images, status, error_message, duration_ms, request_body, response_body)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.bind(log.timestamp)
.bind(&log.client_id)
.bind(&log.provider)
.bind(&log.model)
.bind(log.prompt_tokens as i64)
.bind(log.completion_tokens as i64)
.bind(log.reasoning_tokens as i64)
.bind(log.total_tokens as i64)
.bind(log.cache_read_tokens as i64)
.bind(log.cache_write_tokens as i64)
.bind(log.cost)
.bind(log.has_images)
.bind(&log.status)
.bind(log.error_message)
.bind(log.duration_ms as i64)
.bind(None::<String>) // request_body - optional, not stored to save disk space
.bind(None::<String>) // response_body - optional, not stored to save disk space
.execute(&mut *tx)
.await?;
// Update client usage statistics
sqlx::query(
r#"
UPDATE clients SET
total_requests = total_requests + 1,
total_tokens = total_tokens + ?,
total_cost = total_cost + ?,
updated_at = CURRENT_TIMESTAMP
WHERE client_id = ?
"#,
)
.bind(log.total_tokens as i64)
.bind(log.cost)
.bind(&log.client_id)
.execute(&mut *tx)
.await?;
// Deduct from provider balance if successful.
// Providers configured with billing_mode = 'postpaid' will not have their
// credit_balance decremented. Use a conditional UPDATE so we don't need
// a prior SELECT and avoid extra round-trips.
if log.cost > 0.0 {
sqlx::query(
"UPDATE provider_configs SET credit_balance = credit_balance - ? WHERE id = ? AND (billing_mode IS NULL OR billing_mode != 'postpaid')",
)
.bind(log.cost)
.bind(&log.provider)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(())
}
}
// /// Middleware to log LLM API requests
// /// TODO: Implement proper middleware that can extract response body details
// pub async fn request_logging_middleware(
// // Extract the authenticated client (if available)
// auth_result: Result<AuthenticatedClient, AppError>,
// request: Request,
// next: Next,
// ) -> Response {
// let start_time = std::time::Instant::now();
//
// // Extract client_id from auth or use "unknown"
// let client_id = match auth_result {
// Ok(auth) => auth.client_id,
// Err(_) => "unknown".to_string(),
// };
//
// // Try to extract request details
// let (request_parts, request_body) = request.into_parts();
//
// // Clone request parts for logging
// let path = request_parts.uri.path().to_string();
//
// // Check if this is a chat completion request
// let is_chat_completion = path == "/v1/chat/completions";
//
// // Reconstruct request for downstream handlers
// let request = Request::from_parts(request_parts, request_body);
//
// // Process request and get response
// let response = next.run(request).await;
//
// // Calculate duration
// let duration = start_time.elapsed();
// let duration_ms = duration.as_millis() as u64;
//
// // Log basic request info
// info!(
// "Request from {} to {} - Status: {} - Duration: {}ms",
// client_id,
// path,
// response.status().as_u16(),
// duration_ms
// );
//
// // TODO: Extract more details from request/response for logging
// // For now, we'll need to modify the server handler to pass additional context
//
// response
// }
/// Context for request logging that can be passed through extensions
#[derive(Clone)]
pub struct LoggingContext {
pub client_id: String,
pub provider_name: String,
pub model: String,
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub cost: f64,
pub has_images: bool,
pub error: Option<AppError>,
}
impl LoggingContext {
pub fn new(client_id: String, provider_name: String, model: String) -> Self {
Self {
client_id,
provider_name,
model,
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
cost: 0.0,
has_images: false,
error: None,
}
}
pub fn with_token_counts(mut self, prompt_tokens: u32, completion_tokens: u32) -> Self {
self.prompt_tokens = prompt_tokens;
self.completion_tokens = completion_tokens;
self.total_tokens = prompt_tokens + completion_tokens;
self
}
pub fn with_cost(mut self, cost: f64) -> Self {
self.cost = cost;
self
}
pub fn with_images(mut self, has_images: bool) -> Self {
self.has_images = has_images;
self
}
pub fn with_error(mut self, error: AppError) -> Self {
self.error = Some(error);
self
}
}

View File

@@ -1,96 +0,0 @@
use anyhow::Result;
use axum::{Router, routing::get};
use std::net::SocketAddr;
use tracing::{error, info};
use llm_proxy::{
config::AppConfig,
dashboard, database,
providers::ProviderManager,
rate_limiting::{CircuitBreakerConfig, RateLimitManager, RateLimiterConfig},
server,
state::AppState,
utils::crypto,
};
#[tokio::main]
async fn main() -> Result<()> {
// Initialize tracing (logging)
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.with_target(false)
.init();
info!("Starting LLM Proxy Gateway v{}", env!("CARGO_PKG_VERSION"));
// Load configuration
let config = AppConfig::load().await?;
info!("Configuration loaded from {:?}", config.config_path);
// Initialize encryption
crypto::init_with_key(&config.encryption_key)?;
info!("Encryption initialized");
// Initialize database connection pool
let db_pool = database::init(&config.database).await?;
info!("Database initialized at {:?}", config.database.path);
// Initialize provider manager with configured providers
let provider_manager = ProviderManager::new();
// Initialize all supported providers (they handle their own enabled check)
let supported_providers = vec!["openai", "gemini", "deepseek", "grok", "ollama"];
for name in supported_providers {
if let Err(e) = provider_manager.initialize_provider(name, &config, &db_pool).await {
error!("Failed to initialize provider {}: {}", name, e);
}
}
// Create rate limit manager
let rate_limit_manager = RateLimitManager::new(RateLimiterConfig::default(), CircuitBreakerConfig::default());
// Fetch model registry from models.dev
let model_registry = match llm_proxy::utils::registry::fetch_registry().await {
Ok(registry) => registry,
Err(e) => {
error!("Failed to fetch model registry: {}. Using empty registry.", e);
llm_proxy::models::registry::ModelRegistry {
providers: std::collections::HashMap::new(),
}
}
};
// Create application state
let state = AppState::new(
config.clone(),
provider_manager,
db_pool,
rate_limit_manager,
model_registry,
config.server.auth_tokens.clone(),
);
// Initialize model config cache and start background refresh (every 30s)
state.model_config_cache.refresh().await;
state.model_config_cache.clone().start_refresh_task(30);
info!("Model config cache initialized");
// Create application router
let app = Router::new()
.route("/health", get(health_check))
.merge(server::router(state.clone()))
.merge(dashboard::router(state.clone()));
// Start server
let addr = SocketAddr::from(([0, 0, 0, 0], config.server.port));
info!("Server listening on http://{}", addr);
let listener = tokio::net::TcpListener::bind(&addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
async fn health_check() -> &'static str {
"OK"
}

View File

@@ -1,381 +0,0 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
pub mod registry;
// ========== OpenAI-compatible Request/Response Structs ==========
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(default)]
pub temperature: Option<f64>,
#[serde(default)]
pub top_p: Option<f64>,
#[serde(default)]
pub top_k: Option<u32>,
#[serde(default)]
pub n: Option<u32>,
#[serde(default)]
pub stop: Option<Value>, // Can be string or array of strings
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub presence_penalty: Option<f64>,
#[serde(default)]
pub frequency_penalty: Option<f64>,
#[serde(default)]
pub stream: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String, // "system", "user", "assistant", "tool"
#[serde(flatten)]
pub content: MessageContent,
#[serde(alias = "reasoning", alias = "thought", skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MessageContent {
Text { content: String },
Parts { content: Vec<ContentPartValue> },
None, // Handle cases where content might be null but reasoning is present
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPartValue {
Text { text: String },
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>,
}
// ========== Tool-Calling Types ==========
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionDef,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDef {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parameters: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
Mode(String), // "auto", "none", "required"
Specific(ToolChoiceSpecific),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolChoiceSpecific {
#[serde(rename = "type")]
pub choice_type: String,
pub function: ToolChoiceFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolChoiceFunction {
pub name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallDelta {
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
pub call_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<FunctionCallDelta>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCallDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
// ========== OpenAI-compatible Response Structs ==========
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatChoice>,
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_read_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_write_tokens: Option<u32>,
}
// ========== Streaming Response Structs ==========
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionStreamResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatStreamChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatStreamChoice {
pub index: u32,
pub delta: ChatStreamDelta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatStreamDelta {
pub role: Option<String>,
pub content: Option<String>,
#[serde(alias = "reasoning", alias = "thought", skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallDelta>>,
}
// ========== Unified Request Format (for internal use) ==========
#[derive(Debug, Clone)]
pub struct UnifiedRequest {
pub client_id: String,
pub model: String,
pub messages: Vec<UnifiedMessage>,
pub temperature: Option<f64>,
pub top_p: Option<f64>,
pub top_k: Option<u32>,
pub n: Option<u32>,
pub stop: Option<Vec<String>>,
pub max_tokens: Option<u32>,
pub presence_penalty: Option<f64>,
pub frequency_penalty: Option<f64>,
pub stream: bool,
pub has_images: bool,
pub tools: Option<Vec<Tool>>,
pub tool_choice: Option<ToolChoice>,
}
#[derive(Debug, Clone)]
pub struct UnifiedMessage {
pub role: String,
pub content: Vec<ContentPart>,
pub reasoning_content: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
pub name: Option<String>,
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone)]
pub enum ContentPart {
Text { text: String },
Image(crate::multimodal::ImageInput),
}
// ========== Provider-specific Structs ==========
#[derive(Debug, Clone, Serialize)]
pub struct OpenAIRequest {
pub model: String,
pub messages: Vec<OpenAIMessage>,
pub temperature: Option<f64>,
pub max_tokens: Option<u32>,
pub stream: Option<bool>,
}
#[derive(Debug, Clone, Serialize)]
pub struct OpenAIMessage {
pub role: String,
pub content: Vec<OpenAIContentPart>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum OpenAIContentPart {
Text { text: String },
ImageUrl { image_url: ImageUrl },
}
// Note: ImageUrl struct is defined earlier in the file
// ========== Conversion Traits ==========
pub trait ToOpenAI {
fn to_openai(&self) -> Result<OpenAIRequest, anyhow::Error>;
}
pub trait FromOpenAI {
fn from_openai(request: &OpenAIRequest) -> Result<Self, anyhow::Error>
where
Self: Sized;
}
impl UnifiedRequest {
/// Hydrate all image content by fetching URLs and converting to base64/bytes
pub async fn hydrate_images(&mut self) -> anyhow::Result<()> {
if !self.has_images {
return Ok(());
}
for msg in &mut self.messages {
for part in &mut msg.content {
if let ContentPart::Image(image_input) = part {
// Pre-fetch and validate if it's a URL
if let crate::multimodal::ImageInput::Url(_url) = image_input {
let (base64_data, mime_type) = image_input.to_base64().await?;
*image_input = crate::multimodal::ImageInput::Base64 {
data: base64_data,
mime_type,
};
}
}
}
}
Ok(())
}
}
impl TryFrom<ChatCompletionRequest> for UnifiedRequest {
type Error = anyhow::Error;
fn try_from(req: ChatCompletionRequest) -> Result<Self, Self::Error> {
let mut has_images = false;
// Convert OpenAI-compatible request to unified format
let messages = req
.messages
.into_iter()
.map(|msg| {
let (content, _images_in_message) = match msg.content {
MessageContent::Text { content } => (vec![ContentPart::Text { text: content }], false),
MessageContent::Parts { content } => {
let mut unified_content = Vec::new();
let mut has_images_in_msg = false;
for part in content {
match part {
ContentPartValue::Text { text } => {
unified_content.push(ContentPart::Text { text });
}
ContentPartValue::ImageUrl { image_url } => {
has_images_in_msg = true;
has_images = true;
unified_content.push(ContentPart::Image(crate::multimodal::ImageInput::from_url(
image_url.url,
)));
}
}
}
(unified_content, has_images_in_msg)
}
MessageContent::None => (vec![], false),
};
UnifiedMessage {
role: msg.role,
content,
reasoning_content: msg.reasoning_content,
tool_calls: msg.tool_calls,
name: msg.name,
tool_call_id: msg.tool_call_id,
}
})
.collect();
let stop = match req.stop {
Some(Value::String(s)) => Some(vec![s]),
Some(Value::Array(a)) => Some(
a.into_iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect(),
),
_ => None,
};
Ok(UnifiedRequest {
client_id: String::new(), // Will be populated by auth middleware
model: req.model,
messages,
temperature: req.temperature,
top_p: req.top_p,
top_k: req.top_k,
n: req.n,
stop,
max_tokens: req.max_tokens,
presence_penalty: req.presence_penalty,
frequency_penalty: req.frequency_penalty,
stream: req.stream.unwrap_or(false),
has_images,
tools: req.tools,
tool_choice: req.tool_choice,
})
}
}

View File

@@ -1,219 +0,0 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelRegistry {
#[serde(flatten)]
pub providers: HashMap<String, ProviderInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderInfo {
pub id: String,
pub name: String,
pub models: HashMap<String, ModelMetadata>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetadata {
pub id: String,
pub name: String,
pub cost: Option<ModelCost>,
pub limit: Option<ModelLimit>,
pub modalities: Option<ModelModalities>,
pub tool_call: Option<bool>,
pub reasoning: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCost {
pub input: f64,
pub output: f64,
pub cache_read: Option<f64>,
pub cache_write: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelLimit {
pub context: u32,
pub output: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelModalities {
pub input: Vec<String>,
pub output: Vec<String>,
}
/// A model entry paired with its provider ID, returned by listing/filtering methods.
#[derive(Debug, Clone)]
pub struct ModelEntry<'a> {
pub model_key: &'a str,
pub provider_id: &'a str,
pub provider_name: &'a str,
pub metadata: &'a ModelMetadata,
}
/// Filter criteria for listing models. All fields are optional; `None` means no filter.
#[derive(Debug, Default, Clone, Deserialize)]
pub struct ModelFilter {
/// Filter by provider ID (exact match).
pub provider: Option<String>,
/// Text search on model ID or name (case-insensitive substring).
pub search: Option<String>,
/// Filter by input modality (e.g. "image", "text").
pub modality: Option<String>,
/// Only models that support tool calling.
pub tool_call: Option<bool>,
/// Only models that support reasoning.
pub reasoning: Option<bool>,
/// Only models that have pricing data.
pub has_cost: Option<bool>,
}
/// Sort field for model listings.
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ModelSortBy {
#[default]
Name,
Id,
Provider,
ContextLimit,
InputCost,
OutputCost,
}
/// Sort direction.
#[derive(Debug, Clone, Deserialize, Default, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum SortOrder {
#[default]
Asc,
Desc,
}
impl ModelRegistry {
/// Find a model by its ID (searching across all providers)
pub fn find_model(&self, model_id: &str) -> Option<&ModelMetadata> {
// First try exact match if the key in models map matches the ID
for provider in self.providers.values() {
if let Some(model) = provider.models.get(model_id) {
return Some(model);
}
}
// Try searching for the model ID inside the metadata if the key was different
for provider in self.providers.values() {
for model in provider.models.values() {
if model.id == model_id {
return Some(model);
}
}
}
None
}
/// List all models with optional filtering and sorting.
pub fn list_models(
&self,
filter: &ModelFilter,
sort_by: &ModelSortBy,
sort_order: &SortOrder,
) -> Vec<ModelEntry<'_>> {
let mut entries: Vec<ModelEntry<'_>> = Vec::new();
for (p_id, p_info) in &self.providers {
// Provider filter
if let Some(ref prov) = filter.provider {
if p_id != prov {
continue;
}
}
for (m_key, m_meta) in &p_info.models {
// Text search filter
if let Some(ref search) = filter.search {
let search_lower = search.to_lowercase();
if !m_meta.id.to_lowercase().contains(&search_lower)
&& !m_meta.name.to_lowercase().contains(&search_lower)
&& !m_key.to_lowercase().contains(&search_lower)
{
continue;
}
}
// Modality filter
if let Some(ref modality) = filter.modality {
let has_modality = m_meta
.modalities
.as_ref()
.is_some_and(|m| m.input.iter().any(|i| i.eq_ignore_ascii_case(modality)));
if !has_modality {
continue;
}
}
// Tool call filter
if let Some(tc) = filter.tool_call {
if m_meta.tool_call.unwrap_or(false) != tc {
continue;
}
}
// Reasoning filter
if let Some(r) = filter.reasoning {
if m_meta.reasoning.unwrap_or(false) != r {
continue;
}
}
// Has cost filter
if let Some(hc) = filter.has_cost {
if hc != m_meta.cost.is_some() {
continue;
}
}
entries.push(ModelEntry {
model_key: m_key,
provider_id: p_id,
provider_name: &p_info.name,
metadata: m_meta,
});
}
}
// Sort
entries.sort_by(|a, b| {
let cmp = match sort_by {
ModelSortBy::Name => a.metadata.name.to_lowercase().cmp(&b.metadata.name.to_lowercase()),
ModelSortBy::Id => a.model_key.cmp(b.model_key),
ModelSortBy::Provider => a.provider_id.cmp(b.provider_id),
ModelSortBy::ContextLimit => {
let a_ctx = a.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0);
let b_ctx = b.metadata.limit.as_ref().map(|l| l.context).unwrap_or(0);
a_ctx.cmp(&b_ctx)
}
ModelSortBy::InputCost => {
let a_cost = a.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
let b_cost = b.metadata.cost.as_ref().map(|c| c.input).unwrap_or(0.0);
a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
}
ModelSortBy::OutputCost => {
let a_cost = a.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
let b_cost = b.metadata.cost.as_ref().map(|c| c.output).unwrap_or(0.0);
a_cost.partial_cmp(&b_cost).unwrap_or(std::cmp::Ordering::Equal)
}
};
match sort_order {
SortOrder::Asc => cmp,
SortOrder::Desc => cmp.reverse(),
}
});
entries
}
}

View File

@@ -1,299 +0,0 @@
//! Multimodal support for image processing and conversion
//!
//! This module handles:
//! 1. Image format detection and conversion
//! 2. Base64 encoding/decoding
//! 3. URL fetching for images
//! 4. Provider-specific image format conversion
use anyhow::{Context, Result};
use base64::{Engine as _, engine::general_purpose};
use std::sync::LazyLock;
use tracing::{info, warn};
/// Shared HTTP client for image fetching — avoids creating a new TCP+TLS
/// connection for every image URL.
static IMAGE_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(30))
.pool_idle_timeout(std::time::Duration::from_secs(60))
.build()
.expect("Failed to build image HTTP client")
});
/// Supported image formats for multimodal input
#[derive(Debug, Clone)]
pub enum ImageInput {
/// Base64-encoded image data with MIME type
Base64 { data: String, mime_type: String },
/// URL to fetch image from
Url(String),
/// Raw bytes with MIME type
Bytes { data: Vec<u8>, mime_type: String },
}
impl ImageInput {
/// Create ImageInput from base64 string
pub fn from_base64(data: String, mime_type: String) -> Self {
Self::Base64 { data, mime_type }
}
/// Create ImageInput from URL
pub fn from_url(url: String) -> Self {
Self::Url(url)
}
/// Create ImageInput from raw bytes
pub fn from_bytes(data: Vec<u8>, mime_type: String) -> Self {
Self::Bytes { data, mime_type }
}
/// Get MIME type if available
pub fn mime_type(&self) -> Option<&str> {
match self {
Self::Base64 { mime_type, .. } => Some(mime_type),
Self::Bytes { mime_type, .. } => Some(mime_type),
Self::Url(_) => None,
}
}
/// Convert to base64 if not already
pub async fn to_base64(&self) -> Result<(String, String)> {
match self {
Self::Base64 { data, mime_type } => Ok((data.clone(), mime_type.clone())),
Self::Bytes { data, mime_type } => {
let base64_data = general_purpose::STANDARD.encode(data);
Ok((base64_data, mime_type.clone()))
}
Self::Url(url) => {
// Fetch image from URL using shared client
info!("Fetching image from URL: {}", url);
let response = IMAGE_CLIENT
.get(url)
.send()
.await
.context("Failed to fetch image from URL")?;
if !response.status().is_success() {
anyhow::bail!("Failed to fetch image: HTTP {}", response.status());
}
let mime_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|h| h.to_str().ok())
.unwrap_or("image/jpeg")
.to_string();
let bytes = response.bytes().await.context("Failed to read image bytes")?;
let base64_data = general_purpose::STANDARD.encode(&bytes);
Ok((base64_data, mime_type))
}
}
}
/// Get image dimensions (width, height)
pub async fn get_dimensions(&self) -> Result<(u32, u32)> {
let bytes = match self {
Self::Base64 { data, .. } => general_purpose::STANDARD
.decode(data)
.context("Failed to decode base64")?,
Self::Bytes { data, .. } => data.clone(),
Self::Url(_) => {
let (base64_data, _) = self.to_base64().await?;
general_purpose::STANDARD
.decode(&base64_data)
.context("Failed to decode base64")?
}
};
let img = image::load_from_memory(&bytes).context("Failed to load image from bytes")?;
Ok((img.width(), img.height()))
}
/// Validate image size and format
pub async fn validate(&self, max_size_mb: f64) -> Result<()> {
let (width, height) = self.get_dimensions().await?;
// Check dimensions
if width > 4096 || height > 4096 {
warn!("Image dimensions too large: {}x{}", width, height);
// Continue anyway, but log warning
}
// Check file size
let size_bytes = match self {
Self::Base64 { data, .. } => {
// Base64 size is ~4/3 of original
(data.len() as f64 * 0.75) as usize
}
Self::Bytes { data, .. } => data.len(),
Self::Url(_) => {
// For URLs, we'd need to fetch to check size
// Skip size check for URLs for now
return Ok(());
}
};
let size_mb = size_bytes as f64 / (1024.0 * 1024.0);
if size_mb > max_size_mb {
anyhow::bail!("Image too large: {:.2}MB > {:.2}MB limit", size_mb, max_size_mb);
}
Ok(())
}
}
/// Provider-specific image format conversion
pub struct ImageConverter;
impl ImageConverter {
/// Convert image to OpenAI-compatible format
pub async fn to_openai_format(image: &ImageInput) -> Result<serde_json::Value> {
let (base64_data, mime_type) = image.to_base64().await?;
// OpenAI expects data URL format: "data:image/jpeg;base64,{data}"
let data_url = format!("data:{};base64,{}", mime_type, base64_data);
Ok(serde_json::json!({
"type": "image_url",
"image_url": {
"url": data_url,
"detail": "auto" // Can be "low", "high", or "auto"
}
}))
}
/// Convert image to Gemini-compatible format
pub async fn to_gemini_format(image: &ImageInput) -> Result<serde_json::Value> {
let (base64_data, mime_type) = image.to_base64().await?;
// Gemini expects inline data format
Ok(serde_json::json!({
"inline_data": {
"mime_type": mime_type,
"data": base64_data
}
}))
}
/// Convert image to DeepSeek-compatible format
pub async fn to_deepseek_format(image: &ImageInput) -> Result<serde_json::Value> {
// DeepSeek uses OpenAI-compatible format for vision models
Self::to_openai_format(image).await
}
/// Detect if a model supports multimodal input
pub fn model_supports_multimodal(model: &str) -> bool {
// OpenAI vision models
if (model.starts_with("gpt-4") && (model.contains("vision") || model.contains("-v") || model.contains("4o")))
|| model.starts_with("o1-")
|| model.starts_with("o3-")
{
return true;
}
// Gemini vision models
if model.starts_with("gemini") {
// Most Gemini models support vision
return true;
}
// DeepSeek vision models
if model.starts_with("deepseek-vl") {
return true;
}
false
}
}
/// Parse OpenAI-compatible multimodal message content
pub fn parse_openai_content(content: &serde_json::Value) -> Result<Vec<(String, Option<ImageInput>)>> {
let mut parts = Vec::new();
if let Some(content_str) = content.as_str() {
// Simple text content
parts.push((content_str.to_string(), None));
} else if let Some(content_array) = content.as_array() {
// Array of content parts (text and/or images)
for part in content_array {
if let Some(part_obj) = part.as_object()
&& let Some(part_type) = part_obj.get("type").and_then(|t| t.as_str())
{
match part_type {
"text" => {
if let Some(text) = part_obj.get("text").and_then(|t| t.as_str()) {
parts.push((text.to_string(), None));
}
}
"image_url" => {
if let Some(image_url_obj) = part_obj.get("image_url").and_then(|o| o.as_object())
&& let Some(url) = image_url_obj.get("url").and_then(|u| u.as_str())
{
if url.starts_with("data:") {
// Parse data URL
if let Some((mime_type, data)) = parse_data_url(url) {
let image_input = ImageInput::from_base64(data, mime_type);
parts.push(("".to_string(), Some(image_input)));
}
} else {
// Regular URL
let image_input = ImageInput::from_url(url.to_string());
parts.push(("".to_string(), Some(image_input)));
}
}
}
_ => {
warn!("Unknown content part type: {}", part_type);
}
}
}
}
}
Ok(parts)
}
/// Parse data URL (data:image/jpeg;base64,{data})
fn parse_data_url(data_url: &str) -> Option<(String, String)> {
if !data_url.starts_with("data:") {
return None;
}
let parts: Vec<&str> = data_url[5..].split(";base64,").collect();
if parts.len() != 2 {
return None;
}
let mime_type = parts[0].to_string();
let data = parts[1].to_string();
Some((mime_type, data))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_data_url() {
let test_url = "data:image/jpeg;base64,SGVsbG8gV29ybGQ="; // "Hello World" in base64
let (mime_type, data) = parse_data_url(test_url).unwrap();
assert_eq!(mime_type, "image/jpeg");
assert_eq!(data, "SGVsbG8gV29ybGQ=");
}
#[tokio::test]
async fn test_model_supports_multimodal() {
assert!(ImageConverter::model_supports_multimodal("gpt-4-vision-preview"));
assert!(ImageConverter::model_supports_multimodal("gpt-4o"));
assert!(ImageConverter::model_supports_multimodal("gemini-pro-vision"));
assert!(ImageConverter::model_supports_multimodal("gemini-pro"));
assert!(!ImageConverter::model_supports_multimodal("gpt-3.5-turbo"));
assert!(!ImageConverter::model_supports_multimodal("claude-3-opus"));
}
}

View File

@@ -1,268 +0,0 @@
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use futures::StreamExt;
use super::helpers;
use super::{ProviderResponse, ProviderStreamChunk};
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
pub struct DeepSeekProvider {
client: reqwest::Client,
config: crate::config::DeepSeekConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>,
}
impl DeepSeekProvider {
pub fn new(config: &crate::config::DeepSeekConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("deepseek")?;
Self::new_with_key(config, app_config, api_key)
}
pub fn new_with_key(
config: &crate::config::DeepSeekConfig,
app_config: &AppConfig,
api_key: String,
) -> Result<Self> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(300))
.pool_idle_timeout(std::time::Duration::from_secs(90))
.pool_max_idle_per_host(4)
.tcp_keepalive(std::time::Duration::from_secs(30))
.build()?;
Ok(Self {
client,
config: config.clone(),
api_key,
pricing: app_config.pricing.deepseek.clone(),
})
}
}
#[async_trait]
impl super::Provider for DeepSeekProvider {
fn name(&self) -> &str {
"deepseek"
}
fn supports_model(&self, model: &str) -> bool {
model.starts_with("deepseek-") || model.contains("deepseek")
}
fn supports_multimodal(&self) -> bool {
false
}
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let mut body = helpers::build_openai_body(&request, messages_json, false);
// Sanitize and fix for deepseek-reasoner (R1)
if request.model == "deepseek-reasoner" {
if let Some(obj) = body.as_object_mut() {
// Remove unsupported parameters
obj.remove("temperature");
obj.remove("top_p");
obj.remove("presence_penalty");
obj.remove("frequency_penalty");
obj.remove("logit_bias");
obj.remove("logprobs");
obj.remove("top_logprobs");
// ENSURE: EVERY assistant message must have reasoning_content and valid content
if let Some(messages) = obj.get_mut("messages").and_then(|m| m.as_array_mut()) {
for m in messages {
if m["role"].as_str() == Some("assistant") {
// DeepSeek R1 requires reasoning_content for consistency in history.
if m.get("reasoning_content").is_none() || m["reasoning_content"].is_null() {
m["reasoning_content"] = serde_json::json!(" ");
}
// DeepSeek R1 often requires content to be a string, not null/array
if m.get("content").is_none() || m["content"].is_null() || m["content"].is_array() {
m["content"] = serde_json::json!("");
}
}
}
}
}
}
let response = self
.client
.post(format!("{}/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
tracing::error!("DeepSeek API error ({}): {}", status, error_text);
tracing::error!("Offending DeepSeek Request Body: {}", serde_json::to_string(&body).unwrap_or_default());
return Err(AppError::ProviderError(format!("DeepSeek API error ({}): {}", status, error_text)));
}
let resp_json: serde_json::Value = response
.json()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
helpers::parse_openai_response(&resp_json, request.model)
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(
&self,
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
cache_read_tokens: u32,
cache_write_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
) -> f64 {
if let Some(metadata) = registry.find_model(model) {
if metadata.cost.is_some() {
return helpers::calculate_cost_with_registry(
model,
prompt_tokens,
completion_tokens,
cache_read_tokens,
cache_write_tokens,
registry,
&self.pricing,
0.28,
0.42,
);
}
}
// Custom DeepSeek fallback that correctly handles cache hits
let (prompt_rate, completion_rate) = self
.pricing
.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((0.28, 0.42)); // Default to DeepSeek's current API pricing
let cache_hit_rate = prompt_rate / 10.0;
let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens);
(non_cached_prompt as f64 * prompt_rate / 1_000_000.0)
+ (cache_read_tokens as f64 * cache_hit_rate / 1_000_000.0)
+ (completion_tokens as f64 * completion_rate / 1_000_000.0)
}
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
// DeepSeek doesn't support images in streaming, use text-only
let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?;
let mut body = helpers::build_openai_body(&request, messages_json, true);
// Sanitize and fix for deepseek-reasoner (R1)
if request.model == "deepseek-reasoner" {
if let Some(obj) = body.as_object_mut() {
// Keep stream_options if present (DeepSeek supports include_usage)
// Remove unsupported parameters
obj.remove("temperature");
obj.remove("top_p");
obj.remove("presence_penalty");
obj.remove("frequency_penalty");
obj.remove("logit_bias");
obj.remove("logprobs");
obj.remove("top_logprobs");
// ENSURE: EVERY assistant message must have reasoning_content and valid content
if let Some(messages) = obj.get_mut("messages").and_then(|m| m.as_array_mut()) {
for m in messages {
if m["role"].as_str() == Some("assistant") {
// DeepSeek R1 requires reasoning_content for consistency in history.
if m.get("reasoning_content").is_none() || m["reasoning_content"].is_null() {
m["reasoning_content"] = serde_json::json!(" ");
}
// DeepSeek R1 often requires content to be a string, not null/array
if m.get("content").is_none() || m["content"].is_null() || m["content"].is_array() {
m["content"] = serde_json::json!("");
}
}
}
}
}
}
let url = format!("{}/chat/completions", self.config.base_url);
let api_key = self.api_key.clone();
let probe_client = self.client.clone();
let probe_body = body.clone();
let model = request.model.clone();
let es = reqwest_eventsource::EventSource::new(
self.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body),
)
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(reqwest_eventsource::Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: serde_json::Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(p_chunk) = helpers::parse_openai_stream_chunk(&chunk, &model, None) {
yield p_chunk?;
}
}
Ok(_) => continue,
Err(e) => {
// Attempt to probe for the actual error body
let probe_resp = probe_client
.post(&url)
.header("Authorization", format!("Bearer {}", api_key))
.json(&probe_body)
.send()
.await;
match probe_resp {
Ok(r) if !r.status().is_success() => {
let status = r.status();
let error_body = r.text().await.unwrap_or_default();
tracing::error!("DeepSeek Stream Error Probe ({}): {}", status, error_body);
// Log the offending request body at ERROR level so it shows up in standard logs
tracing::error!("Offending DeepSeek Request Body: {}", serde_json::to_string(&probe_body).unwrap_or_default());
Err(AppError::ProviderError(format!("DeepSeek API error ({}): {}", status, error_body)))?;
}
Ok(_) => {
Err(AppError::ProviderError(format!("Stream error (probe returned 200): {}", e)))?;
}
Err(probe_err) => {
tracing::error!("DeepSeek Stream Error Probe failed: {}", probe_err);
Err(AppError::ProviderError(format!("Stream error (probe failed: {}): {}", probe_err, e)))?;
}
}
}
}
}
};
Ok(Box::pin(stream))
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,123 +0,0 @@
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use super::helpers;
use super::{ProviderResponse, ProviderStreamChunk};
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
pub struct GrokProvider {
client: reqwest::Client,
config: crate::config::GrokConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>,
}
impl GrokProvider {
pub fn new(config: &crate::config::GrokConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("grok")?;
Self::new_with_key(config, app_config, api_key)
}
pub fn new_with_key(config: &crate::config::GrokConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(300))
.pool_idle_timeout(std::time::Duration::from_secs(90))
.pool_max_idle_per_host(4)
.tcp_keepalive(std::time::Duration::from_secs(30))
.build()?;
Ok(Self {
client,
config: config.clone(),
api_key,
pricing: app_config.pricing.grok.clone(),
})
}
}
#[async_trait]
impl super::Provider for GrokProvider {
fn name(&self) -> &str {
"grok"
}
fn supports_model(&self, model: &str) -> bool {
model.starts_with("grok-")
}
fn supports_multimodal(&self) -> bool {
true
}
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let body = helpers::build_openai_body(&request, messages_json, false);
let response = self
.client
.post(format!("{}/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("Grok API error: {}", error_text)));
}
let resp_json: serde_json::Value = response
.json()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
helpers::parse_openai_response(&resp_json, request.model)
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(
&self,
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
cache_read_tokens: u32,
cache_write_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
) -> f64 {
helpers::calculate_cost_with_registry(
model,
prompt_tokens,
completion_tokens,
cache_read_tokens,
cache_write_tokens,
registry,
&self.pricing,
5.0,
15.0,
)
}
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let body = helpers::build_openai_body(&request, messages_json, true);
let es = reqwest_eventsource::EventSource::new(
self.client
.post(format!("{}/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body),
)
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
Ok(helpers::create_openai_stream(es, request.model, None))
}
}

View File

@@ -1,454 +0,0 @@
use super::{ProviderResponse, ProviderStreamChunk, StreamUsage};
use crate::errors::AppError;
use crate::models::{ContentPart, ToolCall, ToolCallDelta, UnifiedMessage, UnifiedRequest};
use futures::stream::{BoxStream, StreamExt};
use serde_json::Value;
/// Convert messages to OpenAI-compatible JSON, resolving images asynchronously.
///
/// This avoids the deadlock caused by `futures::executor::block_on` inside a
/// Tokio async context. All image base64 conversions are awaited properly.
/// Handles tool-calling messages: assistant messages with tool_calls, and
/// tool-role messages with tool_call_id/name.
pub async fn messages_to_openai_json(messages: &[UnifiedMessage]) -> Result<Vec<serde_json::Value>, AppError> {
let mut result = Vec::new();
for m in messages {
// Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." }
if m.role == "tool" {
let text_content = m
.content
.first()
.map(|p| match p {
ContentPart::Text { text } => text.clone(),
ContentPart::Image(_) => "[Image]".to_string(),
})
.unwrap_or_default();
let mut msg = serde_json::json!({
"role": "tool",
"content": text_content
});
if let Some(tool_call_id) = &m.tool_call_id {
// OpenAI and others have a 40-char limit for tool_call_id.
// Gemini signatures (56 chars) must be shortened for compatibility.
let id = if tool_call_id.len() > 40 {
&tool_call_id[..40]
} else {
tool_call_id
};
msg["tool_call_id"] = serde_json::json!(id);
}
if let Some(name) = &m.name {
msg["name"] = serde_json::json!(name);
}
result.push(msg);
continue;
}
// Build content parts for non-tool messages
let mut parts = Vec::new();
for p in &m.content {
match p {
ContentPart::Text { text } => {
parts.push(serde_json::json!({ "type": "text", "text": text }));
}
ContentPart::Image(image_input) => {
let (base64_data, mime_type) = image_input
.to_base64()
.await
.map_err(|e| AppError::MultimodalError(e.to_string()))?;
parts.push(serde_json::json!({
"type": "image_url",
"image_url": { "url": format!("data:{};base64,{}", mime_type, base64_data) }
}));
}
}
}
let mut msg = serde_json::json!({ "role": m.role });
// Include reasoning_content if present (DeepSeek R1/reasoner requires this in history)
if let Some(reasoning) = &m.reasoning_content {
msg["reasoning_content"] = serde_json::json!(reasoning);
}
// For assistant messages with tool_calls, content can be empty string
if let Some(tool_calls) = &m.tool_calls {
// Sanitize tool call IDs for OpenAI compatibility (max 40 chars)
let sanitized_calls: Vec<_> = tool_calls.iter().map(|tc| {
let mut sanitized = tc.clone();
if sanitized.id.len() > 40 {
sanitized.id = sanitized.id[..40].to_string();
}
sanitized
}).collect();
if parts.is_empty() {
msg["content"] = serde_json::json!("");
} else {
msg["content"] = serde_json::json!(parts);
}
msg["tool_calls"] = serde_json::json!(sanitized_calls);
} else {
msg["content"] = serde_json::json!(parts);
}
if let Some(name) = &m.name {
msg["name"] = serde_json::json!(name);
}
result.push(msg);
}
Ok(result)
}
/// Convert messages to OpenAI-compatible JSON, but replace images with a
/// text placeholder "[Image]". Useful for providers that don't support
/// multimodal in streaming mode or at all.
///
/// Handles tool-calling messages identically to `messages_to_openai_json`:
/// assistant messages with `tool_calls`, and tool-role messages with
/// `tool_call_id`/`name`.
pub async fn messages_to_openai_json_text_only(
messages: &[UnifiedMessage],
) -> Result<Vec<serde_json::Value>, AppError> {
let mut result = Vec::new();
for m in messages {
// Tool-role messages: { role: "tool", content: "...", tool_call_id: "...", name: "..." }
if m.role == "tool" {
let text_content = m
.content
.first()
.map(|p| match p {
ContentPart::Text { text } => text.clone(),
ContentPart::Image(_) => "[Image]".to_string(),
})
.unwrap_or_default();
let mut msg = serde_json::json!({
"role": "tool",
"content": text_content
});
if let Some(tool_call_id) = &m.tool_call_id {
// OpenAI and others have a 40-char limit for tool_call_id.
let id = if tool_call_id.len() > 40 {
&tool_call_id[..40]
} else {
tool_call_id
};
msg["tool_call_id"] = serde_json::json!(id);
}
if let Some(name) = &m.name {
msg["name"] = serde_json::json!(name);
}
result.push(msg);
continue;
}
// Build content parts for non-tool messages (images become "[Image]" text)
let mut parts = Vec::new();
for p in &m.content {
match p {
ContentPart::Text { text } => {
parts.push(serde_json::json!({ "type": "text", "text": text }));
}
ContentPart::Image(_) => {
parts.push(serde_json::json!({ "type": "text", "text": "[Image]" }));
}
}
}
let mut msg = serde_json::json!({ "role": m.role });
// Include reasoning_content if present (DeepSeek R1/reasoner requires this in history)
if let Some(reasoning) = &m.reasoning_content {
msg["reasoning_content"] = serde_json::json!(reasoning);
}
// For assistant messages with tool_calls, content can be empty string
if let Some(tool_calls) = &m.tool_calls {
// Sanitize tool call IDs for OpenAI compatibility (max 40 chars)
let sanitized_calls: Vec<_> = tool_calls.iter().map(|tc| {
let mut sanitized = tc.clone();
if sanitized.id.len() > 40 {
sanitized.id = sanitized.id[..40].to_string();
}
sanitized
}).collect();
if parts.is_empty() {
msg["content"] = serde_json::json!("");
} else {
msg["content"] = serde_json::json!(parts);
}
msg["tool_calls"] = serde_json::json!(sanitized_calls);
} else {
msg["content"] = serde_json::json!(parts);
}
if let Some(name) = &m.name {
msg["name"] = serde_json::json!(name);
}
result.push(msg);
}
Ok(result)
}
/// Build an OpenAI-compatible request body from a UnifiedRequest and pre-converted messages.
/// Includes tools and tool_choice when present.
/// When streaming, adds `stream_options.include_usage: true` so providers report
/// token counts in the final SSE chunk.
pub fn build_openai_body(
request: &UnifiedRequest,
messages_json: Vec<serde_json::Value>,
stream: bool,
) -> serde_json::Value {
let mut body = serde_json::json!({
"model": request.model,
"messages": messages_json,
"stream": stream,
});
if stream {
body["stream_options"] = serde_json::json!({ "include_usage": true });
}
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
if let Some(max_tokens) = request.max_tokens {
body["max_tokens"] = serde_json::json!(max_tokens);
}
if let Some(tools) = &request.tools {
body["tools"] = serde_json::json!(tools);
}
if let Some(tool_choice) = &request.tool_choice {
body["tool_choice"] = serde_json::json!(tool_choice);
}
body
}
/// Parse an OpenAI-compatible chat completion response JSON into a ProviderResponse.
/// Extracts tool_calls from the message when present.
/// Extracts cache token counts from:
/// - OpenAI/Grok: `usage.prompt_tokens_details.cached_tokens`
/// - DeepSeek: `usage.prompt_cache_hit_tokens` / `usage.prompt_cache_miss_tokens`
pub fn parse_openai_response(resp_json: &Value, model: String) -> Result<ProviderResponse, AppError> {
let choice = resp_json["choices"]
.get(0)
.ok_or_else(|| AppError::ProviderError("No choices in response".to_string()))?;
let message = &choice["message"];
let content = message["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = message["reasoning_content"].as_str().map(|s| s.to_string());
// Parse tool_calls from the response message
let tool_calls: Option<Vec<ToolCall>> = message
.get("tool_calls")
.and_then(|tc| serde_json::from_value(tc.clone()).ok());
let usage = &resp_json["usage"];
let prompt_tokens = usage["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = usage["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = usage["total_tokens"].as_u64().unwrap_or(0) as u32;
// Extract reasoning tokens
let reasoning_tokens = usage["completion_tokens_details"]["reasoning_tokens"]
.as_u64()
.unwrap_or(0) as u32;
// Extract cache tokens — try OpenAI/Grok format first, then DeepSeek format
let cache_read_tokens = usage["prompt_tokens_details"]["cached_tokens"]
.as_u64()
// DeepSeek uses a different field name
.or_else(|| usage["prompt_cache_hit_tokens"].as_u64())
.unwrap_or(0) as u32;
// DeepSeek reports prompt_cache_miss_tokens which are just regular non-cached tokens.
// They do not incur a separate cache_write fee, so we don't map them here to avoid double-charging.
let cache_write_tokens = 0;
Ok(ProviderResponse {
content,
reasoning_content,
tool_calls,
prompt_tokens,
completion_tokens,
reasoning_tokens,
total_tokens,
cache_read_tokens,
cache_write_tokens,
model,
})
}
/// Parse a single OpenAI-compatible stream chunk into a ProviderStreamChunk.
/// Returns None if the chunk should be skipped (e.g. promptFeedback).
pub fn parse_openai_stream_chunk(
chunk: &Value,
model: &str,
reasoning_field: Option<&'static str>,
) -> Option<Result<ProviderStreamChunk, AppError>> {
// Parse usage from the final chunk (sent when stream_options.include_usage is true).
// This chunk may have an empty `choices` array.
let stream_usage = chunk.get("usage").and_then(|u| {
if u.is_null() {
return None;
}
let prompt_tokens = u["prompt_tokens"].as_u64().unwrap_or(0) as u32;
let completion_tokens = u["completion_tokens"].as_u64().unwrap_or(0) as u32;
let total_tokens = u["total_tokens"].as_u64().unwrap_or(0) as u32;
let reasoning_tokens = u["completion_tokens_details"]["reasoning_tokens"]
.as_u64()
.unwrap_or(0) as u32;
let cache_read_tokens = u["prompt_tokens_details"]["cached_tokens"]
.as_u64()
.or_else(|| u["prompt_cache_hit_tokens"].as_u64())
.unwrap_or(0) as u32;
let cache_write_tokens = 0;
Some(StreamUsage {
prompt_tokens,
completion_tokens,
reasoning_tokens,
total_tokens,
cache_read_tokens,
cache_write_tokens,
})
});
if let Some(choice) = chunk["choices"].get(0) {
let delta = &choice["delta"];
let content = delta["content"].as_str().unwrap_or_default().to_string();
let reasoning_content = delta["reasoning_content"]
.as_str()
.or_else(|| reasoning_field.and_then(|f| delta[f].as_str()))
.map(|s| s.to_string());
let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
// Parse tool_calls deltas from the stream chunk
let tool_calls: Option<Vec<ToolCallDelta>> = delta
.get("tool_calls")
.and_then(|tc| serde_json::from_value(tc.clone()).ok());
Some(Ok(ProviderStreamChunk {
content,
reasoning_content,
finish_reason,
tool_calls,
model: model.to_string(),
usage: stream_usage,
}))
} else if stream_usage.is_some() {
// Final usage-only chunk (empty choices array) — yield it so
// AggregatingStream can capture the real token counts.
Some(Ok(ProviderStreamChunk {
content: String::new(),
reasoning_content: None,
finish_reason: None,
tool_calls: None,
model: model.to_string(),
usage: stream_usage,
}))
} else {
None
}
}
/// Create an SSE stream that parses OpenAI-compatible streaming chunks.
///
/// The optional `reasoning_field` allows overriding the field name for
/// reasoning content (e.g., "thought" for Ollama).
/// Parses tool_calls deltas from streaming chunks when present.
/// When `stream_options.include_usage: true` was sent, the provider sends a
/// final chunk with `usage` data — this is parsed into `StreamUsage` and
/// attached to the yielded `ProviderStreamChunk`.
pub fn create_openai_stream(
es: reqwest_eventsource::EventSource,
model: String,
reasoning_field: Option<&'static str>,
) -> BoxStream<'static, Result<ProviderStreamChunk, AppError>> {
use reqwest_eventsource::Event;
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(p_chunk) = parse_openai_stream_chunk(&chunk, &model, reasoning_field) {
yield p_chunk?;
}
}
Ok(_) => continue,
Err(e) => {
Err(AppError::ProviderError(format!("Stream error: {}", e)))?;
}
}
}
};
Box::pin(stream)
}
/// Calculate cost using the model registry first, then falling back to provider pricing config.
///
/// When the registry provides `cache_read` / `cache_write` rates, the formula is:
/// (prompt_tokens - cache_read_tokens) * input_rate
/// + cache_read_tokens * cache_read_rate
/// + cache_write_tokens * cache_write_rate (if applicable)
/// + completion_tokens * output_rate
///
/// All rates are per-token (the registry stores per-million-token rates).
pub fn calculate_cost_with_registry(
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
cache_read_tokens: u32,
cache_write_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
pricing: &[crate::config::ModelPricing],
default_prompt_rate: f64,
default_completion_rate: f64,
) -> f64 {
if let Some(metadata) = registry.find_model(model)
&& let Some(cost) = &metadata.cost
{
let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens);
let mut total = (non_cached_prompt as f64 * cost.input / 1_000_000.0)
+ (completion_tokens as f64 * cost.output / 1_000_000.0);
if let Some(cache_read_rate) = cost.cache_read {
total += cache_read_tokens as f64 * cache_read_rate / 1_000_000.0;
} else {
// No cache_read rate — charge cached tokens at full input rate
total += cache_read_tokens as f64 * cost.input / 1_000_000.0;
}
if let Some(cache_write_rate) = cost.cache_write {
total += cache_write_tokens as f64 * cache_write_rate / 1_000_000.0;
}
return total;
}
// Fallback: no registry entry — use provider pricing config (no cache awareness)
let (prompt_rate, completion_rate) = pricing
.iter()
.find(|p| model.contains(&p.model))
.map(|p| (p.prompt_tokens_per_million, p.completion_tokens_per_million))
.unwrap_or((default_prompt_rate, default_completion_rate));
(prompt_tokens as f64 * prompt_rate / 1_000_000.0) + (completion_tokens as f64 * completion_rate / 1_000_000.0)
}

View File

@@ -1,365 +0,0 @@
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use sqlx::Row;
use std::sync::Arc;
use crate::errors::AppError;
use crate::models::UnifiedRequest;
pub mod deepseek;
pub mod gemini;
pub mod grok;
pub mod helpers;
pub mod ollama;
pub mod openai;
#[async_trait]
pub trait Provider: Send + Sync {
/// Get provider name (e.g., "openai", "gemini")
fn name(&self) -> &str;
/// Check if provider supports a specific model
fn supports_model(&self, model: &str) -> bool;
/// Check if provider supports multimodal (images, etc.)
fn supports_multimodal(&self) -> bool;
/// Process a chat completion request
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError>;
/// Process a chat request using provider-specific "responses" style endpoint
/// Default implementation falls back to `chat_completion` for providers
/// that do not implement a dedicated responses endpoint.
async fn chat_responses(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
self.chat_completion(request).await
}
/// Process a streaming chat completion request
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError>;
/// Process a streaming chat request using provider-specific "responses" style endpoint
/// Default implementation falls back to `chat_completion_stream` for providers
/// that do not implement a dedicated responses endpoint.
async fn chat_responses_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
self.chat_completion_stream(request).await
}
/// Estimate token count for a request (for cost calculation)
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32>;
/// Calculate cost based on token usage and model using the registry.
/// `cache_read_tokens` / `cache_write_tokens` allow cache-aware pricing
/// when the registry provides `cache_read` / `cache_write` rates.
fn calculate_cost(
&self,
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
cache_read_tokens: u32,
cache_write_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
) -> f64;
}
pub struct ProviderResponse {
pub content: String,
pub reasoning_content: Option<String>,
pub tool_calls: Option<Vec<crate::models::ToolCall>>,
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub reasoning_tokens: u32,
pub total_tokens: u32,
pub cache_read_tokens: u32,
pub cache_write_tokens: u32,
pub model: String,
}
/// Usage data from the final streaming chunk (when providers report real token counts).
#[derive(Debug, Clone, Default)]
pub struct StreamUsage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub reasoning_tokens: u32,
pub total_tokens: u32,
pub cache_read_tokens: u32,
pub cache_write_tokens: u32,
}
#[derive(Debug, Clone)]
pub struct ProviderStreamChunk {
pub content: String,
pub reasoning_content: Option<String>,
pub finish_reason: Option<String>,
pub tool_calls: Option<Vec<crate::models::ToolCallDelta>>,
pub model: String,
/// Populated only on the final chunk when providers report usage (e.g. stream_options.include_usage).
pub usage: Option<StreamUsage>,
}
use tokio::sync::RwLock;
use crate::config::AppConfig;
use crate::providers::{
deepseek::DeepSeekProvider, gemini::GeminiProvider, grok::GrokProvider, ollama::OllamaProvider,
openai::OpenAIProvider,
};
#[derive(Clone)]
pub struct ProviderManager {
providers: Arc<RwLock<Vec<Arc<dyn Provider>>>>,
}
impl Default for ProviderManager {
fn default() -> Self {
Self::new()
}
}
impl ProviderManager {
pub fn new() -> Self {
Self {
providers: Arc::new(RwLock::new(Vec::new())),
}
}
/// Initialize a provider by name using config and database overrides
pub async fn initialize_provider(
&self,
name: &str,
app_config: &AppConfig,
db_pool: &crate::database::DbPool,
) -> Result<()> {
// Load override from database
let db_config = sqlx::query("SELECT enabled, base_url, api_key, api_key_encrypted FROM provider_configs WHERE id = ?")
.bind(name)
.fetch_optional(db_pool)
.await?;
let (enabled, base_url, api_key) = if let Some(row) = db_config {
let enabled = row.get::<bool, _>("enabled");
let base_url = row.get::<Option<String>, _>("base_url");
let api_key_encrypted = row.get::<bool, _>("api_key_encrypted");
let api_key = row.get::<Option<String>, _>("api_key");
// Decrypt API key if encrypted
let api_key = match (api_key, api_key_encrypted) {
(Some(key), true) => {
match crate::utils::crypto::decrypt(&key) {
Ok(decrypted) => Some(decrypted),
Err(e) => {
tracing::error!("Failed to decrypt API key for provider {}: {}", name, e);
None
}
}
}
(Some(key), false) => {
// Plaintext key - optionally encrypt and update database (lazy migration)
// For now, just use plaintext
Some(key)
}
(None, _) => None,
};
(enabled, base_url, api_key)
} else {
// No database override, use defaults from AppConfig
match name {
"openai" => (
app_config.providers.openai.enabled,
Some(app_config.providers.openai.base_url.clone()),
None,
),
"gemini" => (
app_config.providers.gemini.enabled,
Some(app_config.providers.gemini.base_url.clone()),
None,
),
"deepseek" => (
app_config.providers.deepseek.enabled,
Some(app_config.providers.deepseek.base_url.clone()),
None,
),
"grok" => (
app_config.providers.grok.enabled,
Some(app_config.providers.grok.base_url.clone()),
None,
),
"ollama" => (
app_config.providers.ollama.enabled,
Some(app_config.providers.ollama.base_url.clone()),
None,
),
_ => (false, None, None),
}
};
if !enabled {
self.remove_provider(name).await;
return Ok(());
}
// Create provider instance with merged config
let provider: Arc<dyn Provider> = match name {
"openai" => {
let mut cfg = app_config.providers.openai.clone();
if let Some(url) = base_url {
cfg.base_url = url;
}
// Handle API key override if present
let p = if let Some(key) = api_key {
// We need a way to create a provider with an explicit key
// Let's modify the providers to allow this
OpenAIProvider::new_with_key(&cfg, app_config, key)?
} else {
OpenAIProvider::new(&cfg, app_config)?
};
Arc::new(p)
}
"ollama" => {
let mut cfg = app_config.providers.ollama.clone();
if let Some(url) = base_url {
cfg.base_url = url;
}
Arc::new(OllamaProvider::new(&cfg, app_config)?)
}
"gemini" => {
let mut cfg = app_config.providers.gemini.clone();
if let Some(url) = base_url {
cfg.base_url = url;
}
let p = if let Some(key) = api_key {
GeminiProvider::new_with_key(&cfg, app_config, key)?
} else {
GeminiProvider::new(&cfg, app_config)?
};
Arc::new(p)
}
"deepseek" => {
let mut cfg = app_config.providers.deepseek.clone();
if let Some(url) = base_url {
cfg.base_url = url;
}
let p = if let Some(key) = api_key {
DeepSeekProvider::new_with_key(&cfg, app_config, key)?
} else {
DeepSeekProvider::new(&cfg, app_config)?
};
Arc::new(p)
}
"grok" => {
let mut cfg = app_config.providers.grok.clone();
if let Some(url) = base_url {
cfg.base_url = url;
}
let p = if let Some(key) = api_key {
GrokProvider::new_with_key(&cfg, app_config, key)?
} else {
GrokProvider::new(&cfg, app_config)?
};
Arc::new(p)
}
_ => return Err(anyhow::anyhow!("Unknown provider: {}", name)),
};
self.add_provider(provider).await;
Ok(())
}
pub async fn add_provider(&self, provider: Arc<dyn Provider>) {
let mut providers = self.providers.write().await;
// If provider with same name exists, replace it
if let Some(index) = providers.iter().position(|p| p.name() == provider.name()) {
providers[index] = provider;
} else {
providers.push(provider);
}
}
pub async fn remove_provider(&self, name: &str) {
let mut providers = self.providers.write().await;
providers.retain(|p| p.name() != name);
}
pub async fn get_provider_for_model(&self, model: &str) -> Option<Arc<dyn Provider>> {
let providers = self.providers.read().await;
providers.iter().find(|p| p.supports_model(model)).map(Arc::clone)
}
pub async fn get_provider(&self, name: &str) -> Option<Arc<dyn Provider>> {
let providers = self.providers.read().await;
providers.iter().find(|p| p.name() == name).map(Arc::clone)
}
pub async fn get_all_providers(&self) -> Vec<Arc<dyn Provider>> {
let providers = self.providers.read().await;
providers.clone()
}
}
// Create placeholder provider implementations
pub mod placeholder {
use super::*;
pub struct PlaceholderProvider {
name: String,
}
impl PlaceholderProvider {
pub fn new(name: &str) -> Self {
Self { name: name.to_string() }
}
}
#[async_trait]
impl Provider for PlaceholderProvider {
fn name(&self) -> &str {
&self.name
}
fn supports_model(&self, _model: &str) -> bool {
false
}
fn supports_multimodal(&self) -> bool {
false
}
async fn chat_completion_stream(
&self,
_request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
Err(AppError::ProviderError(
"Streaming not supported for placeholder provider".to_string(),
))
}
async fn chat_completion(&self, _request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
Err(AppError::ProviderError(format!(
"Provider {} not implemented",
self.name
)))
}
fn estimate_tokens(&self, _request: &UnifiedRequest) -> Result<u32> {
Ok(0)
}
fn calculate_cost(
&self,
_model: &str,
_prompt_tokens: u32,
_completion_tokens: u32,
_cache_read_tokens: u32,
_cache_write_tokens: u32,
_registry: &crate::models::registry::ModelRegistry,
) -> f64 {
0.0
}
}
}

View File

@@ -1,140 +0,0 @@
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use super::helpers;
use super::{ProviderResponse, ProviderStreamChunk};
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
pub struct OllamaProvider {
client: reqwest::Client,
config: crate::config::OllamaConfig,
pricing: Vec<crate::config::ModelPricing>,
}
impl OllamaProvider {
pub fn new(config: &crate::config::OllamaConfig, app_config: &AppConfig) -> Result<Self> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(300))
.pool_idle_timeout(std::time::Duration::from_secs(90))
.pool_max_idle_per_host(4)
.tcp_keepalive(std::time::Duration::from_secs(30))
.build()?;
Ok(Self {
client,
config: config.clone(),
pricing: app_config.pricing.ollama.clone(),
})
}
}
#[async_trait]
impl super::Provider for OllamaProvider {
fn name(&self) -> &str {
"ollama"
}
fn supports_model(&self, model: &str) -> bool {
self.config.models.iter().any(|m| m == model) || model.starts_with("ollama/")
}
fn supports_multimodal(&self) -> bool {
true
}
async fn chat_completion(&self, mut request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
// Strip "ollama/" prefix if present for the API call
let api_model = request
.model
.strip_prefix("ollama/")
.unwrap_or(&request.model)
.to_string();
let original_model = request.model.clone();
request.model = api_model;
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let body = helpers::build_openai_body(&request, messages_json, false);
let response = self
.client
.post(format!("{}/chat/completions", self.config.base_url))
.json(&body)
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("Ollama API error: {}", error_text)));
}
let resp_json: serde_json::Value = response
.json()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Ollama also supports "thought" as an alias for reasoning_content
let mut result = helpers::parse_openai_response(&resp_json, original_model)?;
if result.reasoning_content.is_none() {
result.reasoning_content = resp_json["choices"]
.get(0)
.and_then(|c| c["message"]["thought"].as_str())
.map(|s| s.to_string());
}
Ok(result)
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(
&self,
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
cache_read_tokens: u32,
cache_write_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
) -> f64 {
helpers::calculate_cost_with_registry(
model,
prompt_tokens,
completion_tokens,
cache_read_tokens,
cache_write_tokens,
registry,
&self.pricing,
0.0,
0.0,
)
}
async fn chat_completion_stream(
&self,
mut request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
let api_model = request
.model
.strip_prefix("ollama/")
.unwrap_or(&request.model)
.to_string();
let original_model = request.model.clone();
request.model = api_model;
let messages_json = helpers::messages_to_openai_json_text_only(&request.messages).await?;
let body = helpers::build_openai_body(&request, messages_json, true);
let es = reqwest_eventsource::EventSource::new(
self.client
.post(format!("{}/chat/completions", self.config.base_url))
.json(&body),
)
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
// Ollama uses "thought" as an alternative field for reasoning content
Ok(helpers::create_openai_stream(es, original_model, Some("thought")))
}
}

View File

@@ -1,564 +0,0 @@
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use futures::StreamExt;
use super::helpers;
use super::{ProviderResponse, ProviderStreamChunk};
use crate::{config::AppConfig, errors::AppError, models::UnifiedRequest};
pub struct OpenAIProvider {
client: reqwest::Client,
config: crate::config::OpenAIConfig,
api_key: String,
pricing: Vec<crate::config::ModelPricing>,
}
impl OpenAIProvider {
pub fn new(config: &crate::config::OpenAIConfig, app_config: &AppConfig) -> Result<Self> {
let api_key = app_config.get_api_key("openai")?;
Self::new_with_key(config, app_config, api_key)
}
pub fn new_with_key(config: &crate::config::OpenAIConfig, app_config: &AppConfig, api_key: String) -> Result<Self> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(5))
.timeout(std::time::Duration::from_secs(300))
.pool_idle_timeout(std::time::Duration::from_secs(90))
.pool_max_idle_per_host(4)
.tcp_keepalive(std::time::Duration::from_secs(15))
.build()?;
Ok(Self {
client,
config: config.clone(),
api_key,
pricing: app_config.pricing.openai.clone(),
})
}
}
#[async_trait]
impl super::Provider for OpenAIProvider {
fn name(&self) -> &str {
"openai"
}
fn supports_model(&self, model: &str) -> bool {
model.starts_with("gpt-") ||
model.starts_with("o1-") ||
model.starts_with("o2-") ||
model.starts_with("o3-") ||
model.starts_with("o4-") ||
model.starts_with("o5-") ||
model.contains("gpt-5")
}
fn supports_multimodal(&self) -> bool {
true
}
async fn chat_completion(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
// Allow proactive routing to Responses API based on heuristic
let model_lc = request.model.to_lowercase();
if model_lc.contains("gpt-5") || model_lc.contains("codex") {
return self.chat_responses(request).await;
}
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let mut body = helpers::build_openai_body(&request, messages_json, false);
// Transition: Newer OpenAI models (o1, o3, gpt-5) require max_completion_tokens
// instead of the legacy max_tokens parameter.
if request.model.starts_with("o1-") || request.model.starts_with("o3-") || request.model.contains("gpt-5") {
if let Some(max_tokens) = body.as_object_mut().and_then(|obj| obj.remove("max_tokens")) {
body["max_completion_tokens"] = max_tokens;
}
}
let response = self
.client
.post(format!("{}/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
// Read error body to diagnose. If the model requires the Responses
// API (v1/responses), retry against that endpoint.
if error_text.to_lowercase().contains("v1/responses") || error_text.to_lowercase().contains("only supported in v1/responses") {
return self.chat_responses(request).await;
}
tracing::error!("OpenAI API error ({}): {}", status, error_text);
return Err(AppError::ProviderError(format!("OpenAI API error ({}): {}", status, error_text)));
}
let resp_json: serde_json::Value = response
.json()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
helpers::parse_openai_response(&resp_json, request.model)
}
async fn chat_responses(&self, request: UnifiedRequest) -> Result<ProviderResponse, AppError> {
// Build a structured input for the Responses API.
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let mut input_parts = Vec::new();
for m in &messages_json {
let mut role = m["role"].as_str().unwrap_or("user").to_string();
// Newer models (gpt-5, o1) prefer "developer" over "system"
if role == "system" {
role = "developer".to_string();
}
let mut content = m.get("content").cloned().unwrap_or(serde_json::json!([]));
// Map content types based on role for Responses API
if let Some(content_array) = content.as_array_mut() {
for part in content_array {
if let Some(part_obj) = part.as_object_mut() {
if let Some(t) = part_obj.get("type").and_then(|v| v.as_str()) {
match t {
"text" => {
let new_type = if role == "assistant" { "output_text" } else { "input_text" };
part_obj.insert("type".to_string(), serde_json::json!(new_type));
}
"image_url" => {
// Assistant typically doesn't have image_url in history this way, but for safety:
let new_type = if role == "assistant" { "output_image" } else { "input_image" };
part_obj.insert("type".to_string(), serde_json::json!(new_type));
if let Some(img_url) = part_obj.remove("image_url") {
part_obj.insert("image".to_string(), img_url);
}
}
_ => {}
}
}
}
}
} else if let Some(text) = content.as_str() {
let new_type = if role == "assistant" { "output_text" } else { "input_text" };
content = serde_json::json!([{ "type": new_type, "text": text }]);
}
input_parts.push(serde_json::json!({
"role": role,
"content": content
}));
}
let mut body = serde_json::json!({
"model": request.model,
"input": input_parts,
});
// Add standard parameters
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
// Newer models (gpt-5, o1) in Responses API use max_output_tokens
if let Some(max_tokens) = request.max_tokens {
if request.model.contains("gpt-5") || request.model.starts_with("o1-") || request.model.starts_with("o3-") {
body["max_output_tokens"] = serde_json::json!(max_tokens);
} else {
body["max_tokens"] = serde_json::json!(max_tokens);
}
}
if let Some(tools) = &request.tools {
body["tools"] = serde_json::json!(tools);
}
let resp = self
.client
.post(format!("{}/responses", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await
.map_err(|e| AppError::ProviderError(e.to_string()))?;
if !resp.status().is_success() {
let err = resp.text().await.unwrap_or_default();
return Err(AppError::ProviderError(format!("OpenAI Responses API error: {}", err)));
}
let resp_json: serde_json::Value = resp.json().await.map_err(|e| AppError::ProviderError(e.to_string()))?;
// Try to normalize: if it's chat-style, use existing parser
if resp_json.get("choices").is_some() {
return helpers::parse_openai_response(&resp_json, request.model);
}
// Normalize Responses API output into ProviderResponse
let mut content_text = String::new();
if let Some(output) = resp_json.get("output").and_then(|o| o.as_array()) {
for out in output {
if let Some(contents) = out.get("content").and_then(|c| c.as_array()) {
for item in contents {
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
if !content_text.is_empty() { content_text.push_str("\n"); }
content_text.push_str(text);
} else if let Some(parts) = item.get("parts").and_then(|p| p.as_array()) {
for p in parts {
if let Some(t) = p.as_str() {
if !content_text.is_empty() { content_text.push_str("\n"); }
content_text.push_str(t);
}
}
}
}
}
}
}
if content_text.is_empty() {
if let Some(cands) = resp_json.get("candidates").and_then(|c| c.as_array()) {
if let Some(c0) = cands.get(0) {
if let Some(content) = c0.get("content") {
if let Some(parts) = content.get("parts").and_then(|p| p.as_array()) {
for p in parts {
if let Some(t) = p.get("text").and_then(|v| v.as_str()) {
if !content_text.is_empty() { content_text.push_str("\n"); }
content_text.push_str(t);
}
}
}
}
}
}
}
let prompt_tokens = resp_json.get("usage").and_then(|u| u.get("prompt_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
let completion_tokens = resp_json.get("usage").and_then(|u| u.get("completion_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
let total_tokens = resp_json.get("usage").and_then(|u| u.get("total_tokens")).and_then(|v| v.as_u64()).unwrap_or(0) as u32;
Ok(ProviderResponse {
content: content_text,
reasoning_content: None,
tool_calls: None,
prompt_tokens,
completion_tokens,
reasoning_tokens: 0,
total_tokens,
cache_read_tokens: 0,
cache_write_tokens: 0,
model: request.model,
})
}
fn estimate_tokens(&self, request: &UnifiedRequest) -> Result<u32> {
Ok(crate::utils::tokens::estimate_request_tokens(&request.model, request))
}
fn calculate_cost(
&self,
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
cache_read_tokens: u32,
cache_write_tokens: u32,
registry: &crate::models::registry::ModelRegistry,
) -> f64 {
helpers::calculate_cost_with_registry(
model,
prompt_tokens,
completion_tokens,
cache_read_tokens,
cache_write_tokens,
registry,
&self.pricing,
0.15,
0.60,
)
}
async fn chat_completion_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
// Allow proactive routing to Responses API based on heuristic
let model_lc = request.model.to_lowercase();
if model_lc.contains("gpt-5") || model_lc.contains("codex") {
return self.chat_responses_stream(request).await;
}
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let mut body = helpers::build_openai_body(&request, messages_json, true);
// Standard OpenAI cleanup
if let Some(obj) = body.as_object_mut() {
// stream_options.include_usage is supported by OpenAI for token usage in streaming
// Transition: Newer OpenAI models (o1, o3, gpt-5) require max_completion_tokens
if request.model.starts_with("o1-") || request.model.starts_with("o3-") || request.model.contains("gpt-5") {
if let Some(max_tokens) = obj.remove("max_tokens") {
obj.insert("max_completion_tokens".to_string(), max_tokens);
}
}
}
let url = format!("{}/chat/completions", self.config.base_url);
let api_key = self.api_key.clone();
let probe_client = self.client.clone();
let probe_body = body.clone();
let model = request.model.clone();
let es = reqwest_eventsource::EventSource::new(
self.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body),
)
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource: {}", e)))?;
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(reqwest_eventsource::Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: serde_json::Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse stream chunk: {}", e)))?;
if let Some(p_chunk) = helpers::parse_openai_stream_chunk(&chunk, &model, None) {
yield p_chunk?;
}
}
Ok(_) => continue,
Err(e) => {
// Attempt to probe for the actual error body
let probe_resp = probe_client
.post(&url)
.header("Authorization", format!("Bearer {}", api_key))
.json(&probe_body)
.send()
.await;
match probe_resp {
Ok(r) if !r.status().is_success() => {
let status = r.status();
let error_body = r.text().await.unwrap_or_default();
tracing::error!("OpenAI Stream Error Probe ({}): {}", status, error_body);
tracing::debug!("Offending OpenAI Request Body: {}", serde_json::to_string(&probe_body).unwrap_or_default());
Err(AppError::ProviderError(format!("OpenAI API error ({}): {}", status, error_body)))?;
}
Ok(_) => {
// Probe returned success? This is unexpected if the original stream failed.
Err(AppError::ProviderError(format!("Stream error (probe returned 200): {}", e)))?;
}
Err(probe_err) => {
// Probe itself failed
tracing::error!("OpenAI Stream Error Probe failed: {}", probe_err);
Err(AppError::ProviderError(format!("Stream error (probe failed: {}): {}", probe_err, e)))?;
}
}
}
}
}
};
Ok(Box::pin(stream))
}
async fn chat_responses_stream(
&self,
request: UnifiedRequest,
) -> Result<BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
// Build a structured input for the Responses API.
let messages_json = helpers::messages_to_openai_json(&request.messages).await?;
let mut input_parts = Vec::new();
for m in &messages_json {
let mut role = m["role"].as_str().unwrap_or("user").to_string();
// Newer models (gpt-5, o1) prefer "developer" over "system"
if role == "system" {
role = "developer".to_string();
}
let mut content = m.get("content").cloned().unwrap_or(serde_json::json!([]));
// Map content types based on role for Responses API
if let Some(content_array) = content.as_array_mut() {
for part in content_array {
if let Some(part_obj) = part.as_object_mut() {
if let Some(t) = part_obj.get("type").and_then(|v| v.as_str()) {
match t {
"text" => {
let new_type = if role == "assistant" { "output_text" } else { "input_text" };
part_obj.insert("type".to_string(), serde_json::json!(new_type));
}
"image_url" => {
// Assistant typically doesn't have image_url in history this way, but for safety:
let new_type = if role == "assistant" { "output_image" } else { "input_image" };
part_obj.insert("type".to_string(), serde_json::json!(new_type));
if let Some(img_url) = part_obj.remove("image_url") {
part_obj.insert("image".to_string(), img_url);
}
}
_ => {}
}
}
}
}
} else if let Some(text) = content.as_str() {
let new_type = if role == "assistant" { "output_text" } else { "input_text" };
content = serde_json::json!([{ "type": new_type, "text": text }]);
}
input_parts.push(serde_json::json!({
"role": role,
"content": content
}));
}
let mut body = serde_json::json!({
"model": request.model,
"input": input_parts,
"stream": true,
});
// Add standard parameters
if let Some(temp) = request.temperature {
body["temperature"] = serde_json::json!(temp);
}
// Newer models (gpt-5, o1) in Responses API use max_output_tokens
if let Some(max_tokens) = request.max_tokens {
if request.model.contains("gpt-5") || request.model.starts_with("o1-") || request.model.starts_with("o3-") {
body["max_output_tokens"] = serde_json::json!(max_tokens);
} else {
body["max_tokens"] = serde_json::json!(max_tokens);
}
}
let url = format!("{}/responses", self.config.base_url);
let api_key = self.api_key.clone();
let model = request.model.clone();
let probe_client = self.client.clone();
let probe_body = body.clone();
let es = reqwest_eventsource::EventSource::new(
self.client
.post(&url)
.header("Authorization", format!("Bearer {}", api_key))
.header("Accept", "text/event-stream")
.json(&body),
)
.map_err(|e| AppError::ProviderError(format!("Failed to create EventSource for Responses API: {}", e)))?;
let stream = async_stream::try_stream! {
let mut es = es;
while let Some(event) = es.next().await {
match event {
Ok(reqwest_eventsource::Event::Message(msg)) => {
if msg.data == "[DONE]" {
break;
}
let chunk: serde_json::Value = serde_json::from_str(&msg.data)
.map_err(|e| AppError::ProviderError(format!("Failed to parse Responses stream chunk: {}", e)))?;
// Try standard OpenAI parsing first (choices/usage)
if let Some(p_chunk) = helpers::parse_openai_stream_chunk(&chunk, &model, None) {
yield p_chunk?;
} else {
// Responses API specific parsing for streaming
let mut content = String::new();
let mut finish_reason = None;
let event_type = chunk.get("type").and_then(|v| v.as_str()).unwrap_or("");
match event_type {
"response.output_text.delta" => {
if let Some(delta) = chunk.get("delta").and_then(|v| v.as_str()) {
content.push_str(delta);
}
}
"response.output_text.done" => {
if let Some(text) = chunk.get("text").and_then(|v| v.as_str()) {
// Some implementations send the full text at the end
// We usually prefer deltas, but if we haven't seen them, this is the fallback.
// However, if we're already yielding deltas, we might not want this.
// For now, let's just use it as a signal that we're done.
finish_reason = Some("stop".to_string());
}
}
"response.done" => {
finish_reason = Some("stop".to_string());
}
_ => {
// Fallback to older nested structure if present
if let Some(output) = chunk.get("output").and_then(|o| o.as_array()) {
for out in output {
if let Some(contents) = out.get("content").and_then(|c| c.as_array()) {
for item in contents {
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
content.push_str(text);
} else if let Some(delta) = item.get("delta").and_then(|d| d.get("text")).and_then(|t| t.as_str()) {
content.push_str(delta);
}
}
}
}
}
}
}
if !content.is_empty() || finish_reason.is_some() {
yield ProviderStreamChunk {
content,
reasoning_content: None,
finish_reason,
tool_calls: None,
model: model.clone(),
usage: None,
};
}
}
}
Ok(_) => continue,
Err(e) => {
// Attempt to probe for the actual error body
let probe_resp = probe_client
.post(&url)
.header("Authorization", format!("Bearer {}", api_key))
.header("Accept", "application/json") // Ask for JSON during probe
.json(&probe_body)
.send()
.await;
match probe_resp {
Ok(r) => {
let status = r.status();
let body = r.text().await.unwrap_or_default();
if status.is_success() {
tracing::warn!("Responses stream ended prematurely but probe returned 200 OK. Body: {}", body);
Err(AppError::ProviderError(format!("Responses stream ended (server sent 200 OK with body: {})", body)))?;
} else {
tracing::error!("OpenAI Responses Stream Error Probe ({}): {}", status, body);
Err(AppError::ProviderError(format!("OpenAI Responses API error ({}): {}", status, body)))?;
}
}
Err(probe_err) => {
tracing::error!("OpenAI Responses Stream Error Probe failed: {}", probe_err);
Err(AppError::ProviderError(format!("Responses stream error (probe failed: {}): {}", probe_err, e)))?;
}
}
}
}
}
};
Ok(Box::pin(stream))
}
}

View File

@@ -1,353 +0,0 @@
//! Rate limiting and circuit breaking for LLM proxy
//!
//! This module provides:
//! 1. Per-client rate limiting using governor crate
//! 2. Provider circuit breaking to handle API failures
//! 3. Global rate limiting for overall system protection
use anyhow::Result;
use governor::{Quota, RateLimiter, DefaultDirectRateLimiter};
use std::collections::HashMap;
use std::num::NonZeroU32;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{info, warn};
type GovRateLimiter = DefaultDirectRateLimiter;
/// Rate limiter configuration
#[derive(Debug, Clone)]
pub struct RateLimiterConfig {
/// Requests per minute per client
pub requests_per_minute: u32,
/// Burst size (maximum burst capacity)
pub burst_size: u32,
/// Global requests per minute (across all clients)
pub global_requests_per_minute: u32,
}
impl Default for RateLimiterConfig {
fn default() -> Self {
Self {
requests_per_minute: 60, // 1 request per second per client
burst_size: 10, // Allow bursts of up to 10 requests
global_requests_per_minute: 600, // 10 requests per second globally
}
}
}
/// Circuit breaker state
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum CircuitState {
Closed, // Normal operation
Open, // Circuit is open, requests fail fast
HalfOpen, // Testing if service has recovered
}
/// Circuit breaker configuration
#[derive(Debug, Clone)]
pub struct CircuitBreakerConfig {
/// Failure threshold to open circuit
pub failure_threshold: u32,
/// Time window for failure counting (seconds)
pub failure_window_secs: u64,
/// Time to wait before trying half-open state (seconds)
pub reset_timeout_secs: u64,
/// Success threshold to close circuit
pub success_threshold: u32,
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
failure_threshold: 5, // 5 failures
failure_window_secs: 60, // within 60 seconds
reset_timeout_secs: 30, // wait 30 seconds before half-open
success_threshold: 3, // 3 successes to close circuit
}
}
}
/// Circuit breaker for a provider
#[derive(Debug)]
pub struct ProviderCircuitBreaker {
state: CircuitState,
failure_count: u32,
success_count: u32,
last_failure_time: Option<std::time::Instant>,
last_state_change: std::time::Instant,
config: CircuitBreakerConfig,
}
impl ProviderCircuitBreaker {
pub fn new(config: CircuitBreakerConfig) -> Self {
Self {
state: CircuitState::Closed,
failure_count: 0,
success_count: 0,
last_failure_time: None,
last_state_change: std::time::Instant::now(),
config,
}
}
/// Check if request is allowed
pub fn allow_request(&mut self) -> bool {
match self.state {
CircuitState::Closed => true,
CircuitState::Open => {
// Check if reset timeout has passed
let elapsed = self.last_state_change.elapsed();
if elapsed.as_secs() >= self.config.reset_timeout_secs {
self.state = CircuitState::HalfOpen;
self.last_state_change = std::time::Instant::now();
info!("Circuit breaker transitioning to half-open state");
true
} else {
false
}
}
CircuitState::HalfOpen => true,
}
}
/// Record a successful request
pub fn record_success(&mut self) {
match self.state {
CircuitState::Closed => {
// Reset failure count on success
self.failure_count = 0;
self.last_failure_time = None;
}
CircuitState::HalfOpen => {
self.success_count += 1;
if self.success_count >= self.config.success_threshold {
self.state = CircuitState::Closed;
self.success_count = 0;
self.failure_count = 0;
self.last_state_change = std::time::Instant::now();
info!("Circuit breaker closed after successful requests");
}
}
CircuitState::Open => {
// Should not happen, but handle gracefully
}
}
}
/// Record a failed request
pub fn record_failure(&mut self) {
let now = std::time::Instant::now();
// Check if failure window has expired
if let Some(last_failure) = self.last_failure_time
&& now.duration_since(last_failure).as_secs() > self.config.failure_window_secs
{
// Reset failure count if window expired
self.failure_count = 0;
}
self.failure_count += 1;
self.last_failure_time = Some(now);
if self.failure_count >= self.config.failure_threshold && self.state == CircuitState::Closed {
self.state = CircuitState::Open;
self.last_state_change = now;
warn!("Circuit breaker opened due to {} failures", self.failure_count);
} else if self.state == CircuitState::HalfOpen {
// Failure in half-open state, go back to open
self.state = CircuitState::Open;
self.success_count = 0;
self.last_state_change = now;
warn!("Circuit breaker re-opened after failure in half-open state");
}
}
/// Get current state
pub fn state(&self) -> CircuitState {
self.state
}
}
/// Rate limiting and circuit breaking manager
#[derive(Debug)]
pub struct RateLimitManager {
client_buckets: Arc<RwLock<HashMap<String, GovRateLimiter>>>,
global_bucket: Arc<GovRateLimiter>,
circuit_breakers: Arc<RwLock<HashMap<String, ProviderCircuitBreaker>>>,
config: RateLimiterConfig,
circuit_config: CircuitBreakerConfig,
}
impl RateLimitManager {
pub fn new(config: RateLimiterConfig, circuit_config: CircuitBreakerConfig) -> Self {
// Create global rate limiter quota
// Use a much larger burst size for the global bucket to handle concurrent dashboard load
let global_burst = config.global_requests_per_minute / 6; // e.g., 100 for 600 req/min
let global_quota = Quota::per_minute(
NonZeroU32::new(config.global_requests_per_minute).expect("global_requests_per_minute must be positive")
)
.allow_burst(NonZeroU32::new(global_burst).expect("global_burst must be positive"));
let global_bucket = RateLimiter::direct(global_quota);
Self {
client_buckets: Arc::new(RwLock::new(HashMap::new())),
global_bucket: Arc::new(global_bucket),
circuit_breakers: Arc::new(RwLock::new(HashMap::new())),
config,
circuit_config,
}
}
/// Check if a client request is allowed
pub async fn check_client_request(&self, client_id: &str) -> Result<bool> {
// Check global rate limit first (1 token per request)
if self.global_bucket.check().is_err() {
warn!("Global rate limit exceeded");
return Ok(false);
}
// Check client-specific rate limit
let mut buckets = self.client_buckets.write().await;
let bucket = buckets.entry(client_id.to_string()).or_insert_with(|| {
let quota = Quota::per_minute(
NonZeroU32::new(self.config.requests_per_minute).expect("requests_per_minute must be positive")
)
.allow_burst(NonZeroU32::new(self.config.burst_size).expect("burst_size must be positive"));
RateLimiter::direct(quota)
});
Ok(bucket.check().is_ok())
}
/// Check if provider requests are allowed (circuit breaker)
pub async fn check_provider_request(&self, provider_name: &str) -> Result<bool> {
let mut breakers = self.circuit_breakers.write().await;
let breaker = breakers
.entry(provider_name.to_string())
.or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone()));
Ok(breaker.allow_request())
}
/// Record provider success
pub async fn record_provider_success(&self, provider_name: &str) {
let mut breakers = self.circuit_breakers.write().await;
if let Some(breaker) = breakers.get_mut(provider_name) {
breaker.record_success();
}
}
/// Record provider failure
pub async fn record_provider_failure(&self, provider_name: &str) {
let mut breakers = self.circuit_breakers.write().await;
let breaker = breakers
.entry(provider_name.to_string())
.or_insert_with(|| ProviderCircuitBreaker::new(self.circuit_config.clone()));
breaker.record_failure();
}
/// Get provider circuit state
pub async fn get_provider_state(&self, provider_name: &str) -> CircuitState {
let breakers = self.circuit_breakers.read().await;
breakers
.get(provider_name)
.map(|b| b.state())
.unwrap_or(CircuitState::Closed)
}
}
/// Axum middleware for rate limiting
pub mod middleware {
use super::*;
use crate::errors::AppError;
use crate::state::AppState;
use crate::auth::AuthInfo;
use axum::{
extract::{Request, State},
middleware::Next,
response::Response,
};
use sqlx;
/// Rate limiting middleware
pub async fn rate_limit_middleware(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Result<Response, AppError> {
// Extract token synchronously from headers (avoids holding &Request across await)
let token = extract_bearer_token(&request);
// Resolve client_id and populate AuthInfo: DB token lookup, then prefix fallback
let auth_info = resolve_auth_info(token, &state).await;
let client_id = auth_info.client_id.clone();
// Check rate limits
if !state.rate_limit_manager.check_client_request(&client_id).await? {
return Err(AppError::RateLimitError("Rate limit exceeded".to_string()));
}
// Store AuthInfo in request extensions for extractors and downstream handlers
request.extensions_mut().insert(auth_info);
Ok(next.run(request).await)
}
/// Synchronously extract bearer token from request headers
fn extract_bearer_token(request: &Request) -> Option<String> {
request.headers().get("Authorization")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "))
.map(|t| t.to_string())
}
/// Resolve auth info: try DB token first, then fall back to token-prefix derivation
async fn resolve_auth_info(token: Option<String>, state: &AppState) -> AuthInfo {
if let Some(token) = token {
// Try DB token lookup first
match sqlx::query_scalar::<_, String>(
"UPDATE client_tokens SET last_used_at = CURRENT_TIMESTAMP WHERE token = ? AND is_active = TRUE RETURNING client_id",
)
.bind(&token)
.fetch_optional(&state.db_pool)
.await
{
Ok(Some(cid)) => {
return AuthInfo {
token,
client_id: cid,
};
}
Err(e) => {
warn!("DB error during token lookup: {}", e);
}
_ => {}
}
// Fallback to token-prefix derivation (env tokens / permissive mode)
let client_id = format!("client_{}", &token[..8.min(token.len())]);
return AuthInfo { token, client_id };
}
// No token — anonymous
AuthInfo {
token: String::new(),
client_id: "anonymous".to_string(),
}
}
/// Circuit breaker middleware for provider requests
pub async fn circuit_breaker_middleware(provider_name: &str, state: &AppState) -> Result<(), AppError> {
if !state.rate_limit_manager.check_provider_request(provider_name).await? {
return Err(AppError::ProviderError(format!(
"Provider {} is currently unavailable (circuit breaker open)",
provider_name
)));
}
Ok(())
}
}

View File

@@ -1,482 +0,0 @@
use axum::{
Json, Router,
extract::State,
response::IntoResponse,
response::sse::{Event, Sse},
routing::{get, post},
};
use axum::http::{header, HeaderValue};
use tower_http::{
limit::RequestBodyLimitLayer,
set_header::SetResponseHeaderLayer,
};
use futures::StreamExt;
use std::sync::Arc;
use uuid::Uuid;
use tracing::{info, warn};
use crate::{
auth::AuthenticatedClient,
errors::AppError,
models::{
ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse, ChatMessage,
ChatStreamChoice, ChatStreamDelta, Usage,
},
rate_limiting,
state::AppState,
};
pub fn router(state: AppState) -> Router {
// Security headers
let csp_header: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::CONTENT_SECURITY_POLICY,
"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; connect-src 'self' ws:;"
.parse()
.unwrap(),
);
let x_frame_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_FRAME_OPTIONS,
"DENY".parse().unwrap(),
);
let x_content_type_options: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::X_CONTENT_TYPE_OPTIONS,
"nosniff".parse().unwrap(),
);
let strict_transport_security: SetResponseHeaderLayer<HeaderValue> = SetResponseHeaderLayer::overriding(
header::STRICT_TRANSPORT_SECURITY,
"max-age=31536000; includeSubDomains".parse().unwrap(),
);
Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/models", get(list_models))
.layer(RequestBodyLimitLayer::new(10 * 1024 * 1024)) // 10 MB limit
.layer(csp_header)
.layer(x_frame_options)
.layer(x_content_type_options)
.layer(strict_transport_security)
.layer(axum::middleware::from_fn_with_state(
state.clone(),
rate_limiting::middleware::rate_limit_middleware,
))
.with_state(state)
}
/// GET /v1/models — OpenAI-compatible model listing.
/// Returns all models from enabled providers so clients like Open WebUI can
/// discover which models are available through the proxy.
async fn list_models(
State(state): State<AppState>,
_auth: AuthenticatedClient,
) -> Result<Json<serde_json::Value>, AppError> {
let registry = &state.model_registry;
let providers = state.provider_manager.get_all_providers().await;
let mut models = Vec::new();
for provider in &providers {
let provider_name = provider.name();
// Map internal provider names to registry provider IDs
let registry_key = match provider_name {
"gemini" => "google",
"grok" => "xai",
_ => provider_name,
};
// Find this provider's models in the registry
if let Some(provider_info) = registry.providers.get(registry_key) {
for (model_id, meta) in &provider_info.models {
// Skip disabled models via the config cache
if let Some(cfg) = state.model_config_cache.get(model_id).await {
if !cfg.enabled {
continue;
}
}
models.push(serde_json::json!({
"id": model_id,
"object": "model",
"created": 0,
"owned_by": provider_name,
"name": meta.name,
}));
}
}
// For Ollama, models are configured in the TOML, not the registry
if provider_name == "ollama" {
for model_id in &state.config.providers.ollama.models {
models.push(serde_json::json!({
"id": model_id,
"object": "model",
"created": 0,
"owned_by": "ollama",
}));
}
}
}
Ok(Json(serde_json::json!({
"object": "list",
"data": models
})))
}
async fn get_model_cost(
model: &str,
prompt_tokens: u32,
completion_tokens: u32,
cache_read_tokens: u32,
cache_write_tokens: u32,
provider: &Arc<dyn crate::providers::Provider>,
state: &AppState,
) -> f64 {
// Check in-memory cache for cost overrides (no SQLite hit)
if let Some(cached) = state.model_config_cache.get(model).await {
if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) {
// Manual overrides logic: if cache rates are provided, use cache-aware formula.
// Formula: (non_cached_prompt * input_rate) + (cache_read * read_rate) + (cache_write * write_rate) + (completion * output_rate)
let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens);
let mut total = (non_cached_prompt as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0);
if let Some(cr) = cached.cache_read_cost_per_m {
total += cache_read_tokens as f64 * cr / 1_000_000.0;
} else {
// No manual cache_read rate — charge cached tokens at full input rate (backwards compatibility)
total += cache_read_tokens as f64 * p / 1_000_000.0;
}
if let Some(cw) = cached.cache_write_cost_per_m {
total += cache_write_tokens as f64 * cw / 1_000_000.0;
}
return total;
}
}
// Fallback to provider's registry-based calculation (cache-aware)
provider.calculate_cost(model, prompt_tokens, completion_tokens, cache_read_tokens, cache_write_tokens, &state.model_registry)
}
async fn chat_completions(
State(state): State<AppState>,
auth: AuthenticatedClient,
Json(mut request): Json<ChatCompletionRequest>,
) -> Result<axum::response::Response, AppError> {
let client_id = auth.client_id.clone();
let token = auth.token.clone();
// Verify token if env tokens are configured
if !state.auth_tokens.is_empty() && !state.auth_tokens.contains(&token) {
// If not in env tokens, check if it was a DB token (client_id wouldn't be client_XXXX prefix)
if client_id.starts_with("client_") {
return Err(AppError::AuthError("Invalid authentication token".to_string()));
}
}
let start_time = std::time::Instant::now();
let model = request.model.clone();
info!("Chat completion request from client {} for model {}", client_id, model);
// Check if model is enabled via in-memory cache (no SQLite hit)
let cached_config = state.model_config_cache.get(&model).await;
let (model_enabled, model_mapping) = match cached_config {
Some(cfg) => (cfg.enabled, cfg.mapping),
None => (true, None),
};
if !model_enabled {
return Err(AppError::ValidationError(format!(
"Model {} is currently disabled",
model
)));
}
// Apply mapping if present
if let Some(target_model) = model_mapping {
info!("Mapping model {} to {}", model, target_model);
request.model = target_model;
}
// Find appropriate provider for the model
let provider = state
.provider_manager
.get_provider_for_model(&request.model)
.await
.ok_or_else(|| AppError::ProviderError(format!("No provider found for model: {}", request.model)))?;
let provider_name = provider.name().to_string();
// Check circuit breaker for this provider
rate_limiting::middleware::circuit_breaker_middleware(&provider_name, &state).await?;
// Convert to unified request format
let mut unified_request =
crate::models::UnifiedRequest::try_from(request).map_err(|e| AppError::ValidationError(e.to_string()))?;
// Set client_id from authentication
unified_request.client_id = client_id.clone();
// Hydrate images if present
if unified_request.has_images {
unified_request
.hydrate_images()
.await
.map_err(|e| AppError::ValidationError(format!("Failed to process images: {}", e)))?;
}
let has_images = unified_request.has_images;
// Measure proxy overhead (time spent before sending to upstream provider)
let proxy_overhead = start_time.elapsed();
// Check if streaming is requested
if unified_request.stream {
// Estimate prompt tokens for logging later
let prompt_tokens = crate::utils::tokens::estimate_request_tokens(&model, &unified_request);
// Handle streaming response
// Allow provider-specific routing for streaming too
let use_responses = provider.name() == "openai"
&& crate::utils::registry::model_prefers_responses(&state.model_registry, &unified_request.model);
let stream_result = if use_responses {
provider.chat_responses_stream(unified_request).await
} else {
provider.chat_completion_stream(unified_request).await
};
match stream_result {
Ok(stream) => {
// Record provider success
state.rate_limit_manager.record_provider_success(&provider_name).await;
info!(
"Streaming started for {} (proxy overhead: {}ms)",
model,
proxy_overhead.as_millis()
);
// Wrap with AggregatingStream for token counting and database logging
let aggregating_stream = crate::utils::streaming::AggregatingStream::new(
stream,
crate::utils::streaming::StreamConfig {
client_id: client_id.clone(),
provider: provider.clone(),
model: model.clone(),
prompt_tokens,
has_images,
logger: state.request_logger.clone(),
model_registry: state.model_registry.clone(),
model_config_cache: state.model_config_cache.clone(),
},
);
// Create SSE stream - simpler approach that works
let stream_id = format!("chatcmpl-{}", Uuid::new_v4());
let stream_created = chrono::Utc::now().timestamp() as u64;
let stream_id_sse = stream_id.clone();
// Build stream that yields events wrapped in Result
let stream = async_stream::stream! {
let mut aggregator = Box::pin(aggregating_stream);
let mut first_chunk = true;
while let Some(chunk_result) = aggregator.next().await {
match chunk_result {
Ok(chunk) => {
let role = if first_chunk {
first_chunk = false;
Some("assistant".to_string())
} else {
None
};
let response = ChatCompletionStreamResponse {
id: stream_id_sse.clone(),
object: "chat.completion.chunk".to_string(),
created: stream_created,
model: chunk.model.clone(),
choices: vec![ChatStreamChoice {
index: 0,
delta: ChatStreamDelta {
role,
content: Some(chunk.content),
reasoning_content: chunk.reasoning_content,
tool_calls: chunk.tool_calls,
},
finish_reason: chunk.finish_reason,
}],
usage: chunk.usage.as_ref().map(|u| crate::models::Usage {
prompt_tokens: u.prompt_tokens,
completion_tokens: u.completion_tokens,
total_tokens: u.total_tokens,
reasoning_tokens: if u.reasoning_tokens > 0 { Some(u.reasoning_tokens) } else { None },
cache_read_tokens: if u.cache_read_tokens > 0 { Some(u.cache_read_tokens) } else { None },
cache_write_tokens: if u.cache_write_tokens > 0 { Some(u.cache_write_tokens) } else { None },
}),
};
// Use axum's Event directly, wrap in Ok
match Event::default().json_data(response) {
Ok(event) => yield Ok::<_, crate::errors::AppError>(event),
Err(e) => {
warn!("Failed to serialize SSE: {}", e);
}
}
}
Err(e) => {
warn!("Stream error: {}", e);
}
}
}
// Yield [DONE] at the end
yield Ok::<_, crate::errors::AppError>(Event::default().data("[DONE]"));
};
Ok(Sse::new(stream).into_response())
}
Err(e) => {
// Record provider failure
state.rate_limit_manager.record_provider_failure(&provider_name).await;
// Log failed request
let duration = start_time.elapsed();
warn!("Streaming request failed after {:?}: {}", duration, e);
Err(e)
}
}
} else {
// Handle non-streaming response
// Allow provider-specific routing: for OpenAI, some models prefer the
// Responses API (/v1/responses). Use the model registry heuristic to
// choose chat_responses vs chat_completion automatically.
let use_responses = provider.name() == "openai"
&& crate::utils::registry::model_prefers_responses(&state.model_registry, &unified_request.model);
let result = if use_responses {
provider.chat_responses(unified_request).await
} else {
provider.chat_completion(unified_request).await
};
match result {
Ok(response) => {
// Record provider success
state.rate_limit_manager.record_provider_success(&provider_name).await;
let duration = start_time.elapsed();
let cost = get_model_cost(
&response.model,
response.prompt_tokens,
response.completion_tokens,
response.cache_read_tokens,
response.cache_write_tokens,
&provider,
&state,
)
.await;
// Log request to database
state.request_logger.log_request(crate::logging::RequestLog {
timestamp: chrono::Utc::now(),
client_id: client_id.clone(),
provider: provider_name.clone(),
model: response.model.clone(),
prompt_tokens: response.prompt_tokens,
completion_tokens: response.completion_tokens,
reasoning_tokens: response.reasoning_tokens,
total_tokens: response.total_tokens,
cache_read_tokens: response.cache_read_tokens,
cache_write_tokens: response.cache_write_tokens,
cost,
has_images,
status: "success".to_string(),
error_message: None,
duration_ms: duration.as_millis() as u64,
});
// Convert ProviderResponse to ChatCompletionResponse
let finish_reason = if response.tool_calls.is_some() {
"tool_calls".to_string()
} else {
"stop".to_string()
};
let chat_response = ChatCompletionResponse {
id: format!("chatcmpl-{}", Uuid::new_v4()),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: response.model,
choices: vec![ChatChoice {
index: 0,
message: ChatMessage {
role: "assistant".to_string(),
content: crate::models::MessageContent::Text {
content: response.content,
},
reasoning_content: response.reasoning_content,
tool_calls: response.tool_calls,
name: None,
tool_call_id: None,
},
finish_reason: Some(finish_reason),
}],
usage: Some(Usage {
prompt_tokens: response.prompt_tokens,
completion_tokens: response.completion_tokens,
total_tokens: response.total_tokens,
reasoning_tokens: if response.reasoning_tokens > 0 { Some(response.reasoning_tokens) } else { None },
cache_read_tokens: if response.cache_read_tokens > 0 { Some(response.cache_read_tokens) } else { None },
cache_write_tokens: if response.cache_write_tokens > 0 { Some(response.cache_write_tokens) } else { None },
}),
};
// Log successful request with proxy overhead breakdown
let upstream_ms = duration.as_millis() as u64 - proxy_overhead.as_millis() as u64;
info!(
"Request completed in {:?} (proxy: {}ms, upstream: {}ms)",
duration,
proxy_overhead.as_millis(),
upstream_ms
);
Ok(Json(chat_response).into_response())
}
Err(e) => {
// Record provider failure
state.rate_limit_manager.record_provider_failure(&provider_name).await;
// Log failed request to database
let duration = start_time.elapsed();
state.request_logger.log_request(crate::logging::RequestLog {
timestamp: chrono::Utc::now(),
client_id: client_id.clone(),
provider: provider_name.clone(),
model: model.clone(),
prompt_tokens: 0,
completion_tokens: 0,
reasoning_tokens: 0,
total_tokens: 0,
cache_read_tokens: 0,
cache_write_tokens: 0,
cost: 0.0,
has_images: false,
status: "error".to_string(),
error_message: Some(e.to_string()),
duration_ms: duration.as_millis() as u64,
});
warn!("Request failed after {:?}: {}", duration, e);
Err(e)
}
}
}
}

View File

@@ -1,133 +0,0 @@
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
use tracing::warn;
use crate::{
client::ClientManager, config::AppConfig, database::DbPool, logging::RequestLogger,
models::registry::ModelRegistry, providers::ProviderManager, rate_limiting::RateLimitManager,
};
/// Cached model configuration entry
#[derive(Debug, Clone)]
pub struct CachedModelConfig {
pub enabled: bool,
pub mapping: Option<String>,
pub prompt_cost_per_m: Option<f64>,
pub completion_cost_per_m: Option<f64>,
pub cache_read_cost_per_m: Option<f64>,
pub cache_write_cost_per_m: Option<f64>,
}
/// In-memory cache for model_configs table.
/// Refreshes periodically to avoid hitting SQLite on every request.
#[derive(Clone)]
pub struct ModelConfigCache {
cache: Arc<RwLock<HashMap<String, CachedModelConfig>>>,
db_pool: DbPool,
}
impl ModelConfigCache {
pub fn new(db_pool: DbPool) -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
db_pool,
}
}
/// Load all model configs from the database into cache
pub async fn refresh(&self) {
match sqlx::query_as::<_, (String, bool, Option<String>, Option<f64>, Option<f64>, Option<f64>, Option<f64>)>(
"SELECT id, enabled, mapping, prompt_cost_per_m, completion_cost_per_m, cache_read_cost_per_m, cache_write_cost_per_m FROM model_configs",
)
.fetch_all(&self.db_pool)
.await
{
Ok(rows) => {
let mut map = HashMap::with_capacity(rows.len());
for (id, enabled, mapping, prompt_cost, completion_cost, cache_read_cost, cache_write_cost) in rows {
map.insert(
id,
CachedModelConfig {
enabled,
mapping,
prompt_cost_per_m: prompt_cost,
completion_cost_per_m: completion_cost,
cache_read_cost_per_m: cache_read_cost,
cache_write_cost_per_m: cache_write_cost,
},
);
}
*self.cache.write().await = map;
}
Err(e) => {
warn!("Failed to refresh model config cache: {}", e);
}
}
}
/// Get a cached model config. Returns None if not in cache (model is unconfigured).
pub async fn get(&self, model: &str) -> Option<CachedModelConfig> {
self.cache.read().await.get(model).cloned()
}
/// Invalidate cache — call this after dashboard writes to model_configs
pub async fn invalidate(&self) {
self.refresh().await;
}
/// Start a background task that refreshes the cache every `interval` seconds
pub fn start_refresh_task(self, interval_secs: u64) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
loop {
interval.tick().await;
self.refresh().await;
}
});
}
}
/// Shared application state
#[derive(Clone)]
pub struct AppState {
pub config: Arc<AppConfig>,
pub provider_manager: ProviderManager,
pub db_pool: DbPool,
pub rate_limit_manager: Arc<RateLimitManager>,
pub client_manager: Arc<ClientManager>,
pub request_logger: Arc<RequestLogger>,
pub model_registry: Arc<ModelRegistry>,
pub model_config_cache: ModelConfigCache,
pub dashboard_tx: broadcast::Sender<serde_json::Value>,
pub auth_tokens: Vec<String>,
}
impl AppState {
pub fn new(
config: Arc<AppConfig>,
provider_manager: ProviderManager,
db_pool: DbPool,
rate_limit_manager: RateLimitManager,
model_registry: ModelRegistry,
auth_tokens: Vec<String>,
) -> Self {
let client_manager = Arc::new(ClientManager::new(db_pool.clone()));
let (dashboard_tx, _) = broadcast::channel(100);
let request_logger = Arc::new(RequestLogger::new(db_pool.clone(), dashboard_tx.clone()));
let model_config_cache = ModelConfigCache::new(db_pool.clone());
Self {
config,
provider_manager,
db_pool,
rate_limit_manager: Arc::new(rate_limit_manager),
client_manager,
request_logger,
model_registry: Arc::new(model_registry),
model_config_cache,
dashboard_tx,
auth_tokens,
}
}
}

View File

@@ -1,171 +0,0 @@
use aes_gcm::{
aead::{Aead, AeadCore, KeyInit, OsRng},
Aes256Gcm, Key, Nonce,
};
use anyhow::{anyhow, Context, Result};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use std::env;
use std::sync::OnceLock;
static RAW_KEY: OnceLock<[u8; 32]> = OnceLock::new();
/// Initialize the encryption key from a hex or base64 encoded string.
/// Must be called before any encryption/decryption operations.
/// Returns error if the key is invalid or already initialized with a different key.
pub fn init_with_key(key_str: &str) -> Result<()> {
let key_bytes = hex::decode(key_str)
.or_else(|_| BASE64.decode(key_str))
.context("Encryption key must be hex or base64 encoded")?;
if key_bytes.len() != 32 {
anyhow::bail!(
"Encryption key must be 32 bytes (256 bits), got {} bytes",
key_bytes.len()
);
}
let key_array: [u8; 32] = key_bytes.try_into().unwrap(); // safe due to length check
// Check if already initialized with same key
if let Some(existing) = RAW_KEY.get() {
if existing == &key_array {
// Same key already initialized, okay
return Ok(());
} else {
anyhow::bail!("Encryption key already initialized with a different key");
}
}
// Store raw key bytes
RAW_KEY
.set(key_array)
.map_err(|_| anyhow::anyhow!("Encryption key already initialized"))?;
Ok(())
}
/// Initialize the encryption key from the environment variable `LLM_PROXY__ENCRYPTION_KEY`.
/// Must be called before any encryption/decryption operations.
/// Panics if the environment variable is missing or invalid.
pub fn init_from_env() -> Result<()> {
let key_str =
env::var("LLM_PROXY__ENCRYPTION_KEY").context("LLM_PROXY__ENCRYPTION_KEY environment variable not set")?;
init_with_key(&key_str)
}
/// Get the encryption key bytes, panicking if not initialized.
fn get_key() -> &'static [u8; 32] {
RAW_KEY
.get()
.expect("Encryption key not initialized. Call crypto::init_with_key() or crypto::init_from_env() first.")
}
/// Encrypt a plaintext string and return a base64-encoded ciphertext (nonce || ciphertext || tag).
pub fn encrypt(plaintext: &str) -> Result<String> {
let key = Key::<Aes256Gcm>::from_slice(get_key());
let cipher = Aes256Gcm::new(key);
let nonce = Aes256Gcm::generate_nonce(&mut OsRng); // 12 bytes
let ciphertext = cipher
.encrypt(&nonce, plaintext.as_bytes())
.map_err(|e| anyhow!("Encryption failed: {}", e))?;
// Combine nonce and ciphertext (ciphertext already includes tag)
let mut combined = Vec::with_capacity(nonce.len() + ciphertext.len());
combined.extend_from_slice(&nonce);
combined.extend_from_slice(&ciphertext);
Ok(BASE64.encode(combined))
}
/// Decrypt a base64-encoded ciphertext (nonce || ciphertext || tag) to a plaintext string.
pub fn decrypt(ciphertext_b64: &str) -> Result<String> {
let key = Key::<Aes256Gcm>::from_slice(get_key());
let cipher = Aes256Gcm::new(key);
let combined = BASE64.decode(ciphertext_b64).context("Invalid base64 ciphertext")?;
if combined.len() < 12 {
anyhow::bail!("Ciphertext too short");
}
let (nonce_bytes, ciphertext_and_tag) = combined.split_at(12);
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext_bytes = cipher
.decrypt(nonce, ciphertext_and_tag)
.map_err(|e| anyhow!("Decryption failed (invalid key or corrupted ciphertext): {}", e))?;
String::from_utf8(plaintext_bytes).context("Decrypted bytes are not valid UTF-8")
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_KEY_HEX: &str = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f";
#[test]
fn test_encrypt_decrypt() {
init_with_key(TEST_KEY_HEX).unwrap();
let plaintext = "super secret api key";
let ciphertext = encrypt(plaintext).unwrap();
assert_ne!(ciphertext, plaintext);
let decrypted = decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_different_inputs_produce_different_ciphertexts() {
init_with_key(TEST_KEY_HEX).unwrap();
let plaintext = "same";
let cipher1 = encrypt(plaintext).unwrap();
let cipher2 = encrypt(plaintext).unwrap();
assert_ne!(cipher1, cipher2, "Nonce should make ciphertexts differ");
assert_eq!(decrypt(&cipher1).unwrap(), plaintext);
assert_eq!(decrypt(&cipher2).unwrap(), plaintext);
}
#[test]
fn test_invalid_key_length() {
let result = init_with_key("tooshort");
assert!(result.is_err());
}
#[test]
fn test_init_from_env() {
unsafe { std::env::set_var("LLM_PROXY__ENCRYPTION_KEY", TEST_KEY_HEX) };
let result = init_from_env();
assert!(result.is_ok());
// Ensure encryption works
let ciphertext = encrypt("test").unwrap();
let decrypted = decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, "test");
}
#[test]
fn test_missing_env_key() {
unsafe { std::env::remove_var("LLM_PROXY__ENCRYPTION_KEY") };
let result = init_from_env();
assert!(result.is_err());
}
#[test]
fn test_key_hex_and_base64() {
// Hex key works
init_with_key(TEST_KEY_HEX).unwrap();
// Base64 key (same bytes encoded as base64)
let base64_key = BASE64.encode(hex::decode(TEST_KEY_HEX).unwrap());
// Re-initialization with same key (different encoding) is allowed
let result = init_with_key(&base64_key);
assert!(result.is_ok());
// Encryption should still work
let ciphertext = encrypt("test").unwrap();
let decrypted = decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, "test");
}
#[test]
#[ignore] // conflicts with global state from other tests
fn test_already_initialized() {
init_with_key(TEST_KEY_HEX).unwrap();
let result = init_with_key(TEST_KEY_HEX);
assert!(result.is_ok()); // same key allowed
}
#[test]
#[ignore] // conflicts with global state from other tests
fn test_already_initialized_different_key() {
init_with_key(TEST_KEY_HEX).unwrap();
let different_key = "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e20";
let result = init_with_key(different_key);
assert!(result.is_err());
}
}

View File

@@ -1,4 +0,0 @@
pub mod crypto;
pub mod registry;
pub mod streaming;
pub mod tokens;

View File

@@ -1,49 +0,0 @@
use crate::models::registry::ModelRegistry;
use anyhow::Result;
use tracing::info;
const MODELS_DEV_URL: &str = "https://models.dev/api.json";
pub async fn fetch_registry() -> Result<ModelRegistry> {
info!("Fetching model registry from {}", MODELS_DEV_URL);
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()?;
let response = client.get(MODELS_DEV_URL).send().await?;
if !response.status().is_success() {
return Err(anyhow::anyhow!("Failed to fetch registry: HTTP {}", response.status()));
}
let registry: ModelRegistry = response.json().await?;
info!("Successfully loaded model registry");
Ok(registry)
}
/// Heuristic: decide whether a model should be routed to OpenAI Responses API
/// instead of the legacy chat/completions endpoint.
///
/// Currently this uses simple patterns (codex, gpt-5 series) and also checks
/// the loaded registry metadata name for the substring "codex" as a hint.
pub fn model_prefers_responses(registry: &ModelRegistry, model: &str) -> bool {
let model_lc = model.to_lowercase();
if model_lc.contains("codex") {
return true;
}
if model_lc.starts_with("gpt-5") || model_lc.contains("gpt-5.") {
return true;
}
if let Some(meta) = registry.find_model(model) {
if meta.name.to_lowercase().contains("codex") {
return true;
}
}
false
}

View File

@@ -1,340 +0,0 @@
use crate::errors::AppError;
use crate::logging::{RequestLog, RequestLogger};
use crate::models::ToolCall;
use crate::providers::{Provider, ProviderStreamChunk, StreamUsage};
use crate::state::ModelConfigCache;
use crate::utils::tokens::estimate_completion_tokens;
use futures::stream::Stream;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
/// Configuration for creating an AggregatingStream.
pub struct StreamConfig {
pub client_id: String,
pub provider: Arc<dyn Provider>,
pub model: String,
pub prompt_tokens: u32,
pub has_images: bool,
pub logger: Arc<RequestLogger>,
pub model_registry: Arc<crate::models::registry::ModelRegistry>,
pub model_config_cache: ModelConfigCache,
}
pub struct AggregatingStream<S> {
inner: S,
client_id: String,
provider: Arc<dyn Provider>,
model: String,
prompt_tokens: u32,
has_images: bool,
accumulated_content: String,
accumulated_reasoning: String,
accumulated_tool_calls: Vec<ToolCall>,
/// Real usage data from the provider's final stream chunk (when available).
real_usage: Option<StreamUsage>,
logger: Arc<RequestLogger>,
model_registry: Arc<crate::models::registry::ModelRegistry>,
model_config_cache: ModelConfigCache,
start_time: std::time::Instant,
has_logged: bool,
}
impl<S> AggregatingStream<S>
where
S: Stream<Item = Result<ProviderStreamChunk, AppError>> + Unpin,
{
pub fn new(inner: S, config: StreamConfig) -> Self {
Self {
inner,
client_id: config.client_id,
provider: config.provider,
model: config.model,
prompt_tokens: config.prompt_tokens,
has_images: config.has_images,
accumulated_content: String::new(),
accumulated_reasoning: String::new(),
accumulated_tool_calls: Vec::new(),
real_usage: None,
logger: config.logger,
model_registry: config.model_registry,
model_config_cache: config.model_config_cache,
start_time: std::time::Instant::now(),
has_logged: false,
}
}
fn finalize(&mut self) {
if self.has_logged {
return;
}
self.has_logged = true;
let duration = self.start_time.elapsed();
let client_id = self.client_id.clone();
let provider_name = self.provider.name().to_string();
let model = self.model.clone();
let logger = self.logger.clone();
let provider = self.provider.clone();
let estimated_prompt_tokens = self.prompt_tokens;
let has_images = self.has_images;
let registry = self.model_registry.clone();
let config_cache = self.model_config_cache.clone();
let real_usage = self.real_usage.take();
// Estimate completion tokens (including reasoning if present)
let estimated_content_tokens = estimate_completion_tokens(&self.accumulated_content, &model);
let estimated_reasoning_tokens = if !self.accumulated_reasoning.is_empty() {
estimate_completion_tokens(&self.accumulated_reasoning, &model)
} else {
0
};
let estimated_completion = estimated_content_tokens + estimated_reasoning_tokens;
// Spawn a background task to log the completion
tokio::spawn(async move {
// Use real usage from the provider when available, otherwise fall back to estimates
let (prompt_tokens, completion_tokens, reasoning_tokens, total_tokens, cache_read_tokens, cache_write_tokens) =
if let Some(usage) = &real_usage {
(
usage.prompt_tokens,
usage.completion_tokens,
usage.reasoning_tokens,
usage.total_tokens,
usage.cache_read_tokens,
usage.cache_write_tokens,
)
} else {
(
estimated_prompt_tokens,
estimated_completion,
estimated_reasoning_tokens,
estimated_prompt_tokens + estimated_completion,
0u32,
0u32,
)
};
// Check in-memory cache for cost overrides (no SQLite hit)
let cost = if let Some(cached) = config_cache.get(&model).await {
if let (Some(p), Some(c)) = (cached.prompt_cost_per_m, cached.completion_cost_per_m) {
// Manual overrides logic: if cache rates are provided, use cache-aware formula.
let non_cached_prompt = prompt_tokens.saturating_sub(cache_read_tokens);
let mut total = (non_cached_prompt as f64 * p / 1_000_000.0) + (completion_tokens as f64 * c / 1_000_000.0);
if let Some(cr) = cached.cache_read_cost_per_m {
total += cache_read_tokens as f64 * cr / 1_000_000.0;
} else {
// Charge cached tokens at full input rate if no specific rate provided
total += cache_read_tokens as f64 * p / 1_000_000.0;
}
if let Some(cw) = cached.cache_write_cost_per_m {
total += cache_write_tokens as f64 * cw / 1_000_000.0;
}
total
} else {
provider.calculate_cost(
&model,
prompt_tokens,
completion_tokens,
cache_read_tokens,
cache_write_tokens,
&registry,
)
}
} else {
provider.calculate_cost(
&model,
prompt_tokens,
completion_tokens,
cache_read_tokens,
cache_write_tokens,
&registry,
)
};
// Log to database
logger.log_request(RequestLog {
timestamp: chrono::Utc::now(),
client_id: client_id.clone(),
provider: provider_name,
model,
prompt_tokens,
completion_tokens,
reasoning_tokens,
total_tokens,
cache_read_tokens,
cache_write_tokens,
cost,
has_images,
status: "success".to_string(),
error_message: None,
duration_ms: duration.as_millis() as u64,
});
});
}
}
impl<S> Stream for AggregatingStream<S>
where
S: Stream<Item = Result<ProviderStreamChunk, AppError>> + Unpin,
{
type Item = Result<ProviderStreamChunk, AppError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let result = Pin::new(&mut self.inner).poll_next(cx);
match &result {
Poll::Ready(Some(Ok(chunk))) => {
self.accumulated_content.push_str(&chunk.content);
if let Some(reasoning) = &chunk.reasoning_content {
self.accumulated_reasoning.push_str(reasoning);
}
// Capture real usage from the provider when present (typically on the final chunk)
if let Some(usage) = &chunk.usage {
self.real_usage = Some(usage.clone());
}
// Accumulate tool call deltas into complete tool calls
if let Some(deltas) = &chunk.tool_calls {
for delta in deltas {
let idx = delta.index as usize;
// Grow the accumulated_tool_calls vec if needed
while self.accumulated_tool_calls.len() <= idx {
self.accumulated_tool_calls.push(ToolCall {
id: String::new(),
call_type: "function".to_string(),
function: crate::models::FunctionCall {
name: String::new(),
arguments: String::new(),
},
});
}
let tc = &mut self.accumulated_tool_calls[idx];
if let Some(id) = &delta.id {
tc.id.clone_from(id);
}
if let Some(ct) = &delta.call_type {
tc.call_type.clone_from(ct);
}
if let Some(f) = &delta.function {
if let Some(name) = &f.name {
tc.function.name.push_str(name);
}
if let Some(args) = &f.arguments {
tc.function.arguments.push_str(args);
}
}
}
}
}
Poll::Ready(Some(Err(_))) => {
// If there's an error, we might still want to log what we got so far?
// For now, just finalize if we have content
if !self.accumulated_content.is_empty() {
self.finalize();
}
}
Poll::Ready(None) => {
self.finalize();
}
Poll::Pending => {}
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
use futures::stream::{self, StreamExt};
// Simple mock provider for testing
struct MockProvider;
#[async_trait::async_trait]
impl Provider for MockProvider {
fn name(&self) -> &str {
"mock"
}
fn supports_model(&self, _model: &str) -> bool {
true
}
fn supports_multimodal(&self) -> bool {
false
}
async fn chat_completion(
&self,
_req: crate::models::UnifiedRequest,
) -> Result<crate::providers::ProviderResponse, AppError> {
unimplemented!()
}
async fn chat_completion_stream(
&self,
_req: crate::models::UnifiedRequest,
) -> Result<futures::stream::BoxStream<'static, Result<ProviderStreamChunk, AppError>>, AppError> {
unimplemented!()
}
fn estimate_tokens(&self, _req: &crate::models::UnifiedRequest) -> Result<u32> {
Ok(10)
}
fn calculate_cost(&self, _model: &str, _p: u32, _c: u32, _cr: u32, _cw: u32, _r: &crate::models::registry::ModelRegistry) -> f64 {
0.05
}
}
#[tokio::test]
async fn test_aggregating_stream() {
let chunks = vec![
Ok(ProviderStreamChunk {
content: "Hello".to_string(),
reasoning_content: None,
finish_reason: None,
tool_calls: None,
model: "test".to_string(),
usage: None,
}),
Ok(ProviderStreamChunk {
content: " World".to_string(),
reasoning_content: None,
finish_reason: Some("stop".to_string()),
tool_calls: None,
model: "test".to_string(),
usage: None,
}),
];
let inner_stream = stream::iter(chunks);
let pool = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap();
let (dashboard_tx, _) = tokio::sync::broadcast::channel(16);
let logger = Arc::new(RequestLogger::new(pool.clone(), dashboard_tx));
let registry = Arc::new(crate::models::registry::ModelRegistry {
providers: std::collections::HashMap::new(),
});
let mut agg_stream = AggregatingStream::new(
inner_stream,
StreamConfig {
client_id: "client_1".to_string(),
provider: Arc::new(MockProvider),
model: "test".to_string(),
prompt_tokens: 10,
has_images: false,
logger,
model_registry: registry,
model_config_cache: ModelConfigCache::new(pool.clone()),
},
);
while let Some(item) = agg_stream.next().await {
assert!(item.is_ok());
}
assert_eq!(agg_stream.accumulated_content, "Hello World");
assert!(agg_stream.has_logged);
}
}

View File

@@ -1,69 +0,0 @@
use crate::models::UnifiedRequest;
use tiktoken_rs::get_bpe_from_model;
/// Count tokens for a given model and text
pub fn count_tokens(model: &str, text: &str) -> u32 {
// If we can't get the bpe for the model, fallback to a safe default (cl100k_base for GPT-4/o1)
let bpe = get_bpe_from_model(model)
.unwrap_or_else(|_| tiktoken_rs::cl100k_base().expect("Failed to get cl100k_base encoding"));
bpe.encode_with_special_tokens(text).len() as u32
}
/// Estimate tokens for a unified request.
/// Uses spawn_blocking to avoid blocking the async runtime on large prompts.
pub fn estimate_request_tokens(model: &str, request: &UnifiedRequest) -> u32 {
let mut total_text = String::new();
let msg_count = request.messages.len();
// Base tokens per message for OpenAI (approximate)
let tokens_per_message: u32 = 3;
for msg in &request.messages {
for part in &msg.content {
match part {
crate::models::ContentPart::Text { text } => {
total_text.push_str(text);
total_text.push('\n');
}
crate::models::ContentPart::Image { .. } => {
// Vision models usually have a fixed cost or calculation based on size
}
}
}
}
// Quick heuristic for small inputs (< 1KB) — avoid spawn_blocking overhead
if total_text.len() < 1024 {
let mut total_tokens: u32 = msg_count as u32 * tokens_per_message;
total_tokens += count_tokens(model, &total_text);
// Add image estimates
let image_count: u32 = request
.messages
.iter()
.flat_map(|m| m.content.iter())
.filter(|p| matches!(p, crate::models::ContentPart::Image { .. }))
.count() as u32;
total_tokens += image_count * 1000;
total_tokens += 3; // assistant reply header
return total_tokens;
}
// For large inputs, use the fast heuristic (chars / 4) to avoid blocking
// the async runtime. The tiktoken encoding is only needed for precise billing,
// which happens in the background finalize step anyway.
let estimated_text_tokens = (total_text.len() as u32) / 4;
let image_count: u32 = request
.messages
.iter()
.flat_map(|m| m.content.iter())
.filter(|p| matches!(p, crate::models::ContentPart::Image { .. }))
.count() as u32;
(msg_count as u32 * tokens_per_message) + estimated_text_tokens + (image_count * 1000) + 3
}
/// Estimate tokens for completion text
pub fn estimate_completion_tokens(text: &str, model: &str) -> u32 {
count_tokens(model, text)
}

View File

@@ -1,38 +0,0 @@
#!/bin/bash
# Test script for LLM Proxy Dashboard
echo "Building LLM Proxy Gateway..."
cargo build --release
echo ""
echo "Starting server in background..."
./target/release/llm-proxy &
SERVER_PID=$!
# Wait for server to start
sleep 3
echo ""
echo "Testing dashboard endpoints..."
# Test health endpoint
echo "1. Testing health endpoint:"
curl -s http://localhost:8080/health
echo ""
echo "2. Testing dashboard static files:"
curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/
echo ""
echo "3. Testing API endpoints:"
curl -s http://localhost:8080/api/auth/status | jq . 2>/dev/null || echo "JSON response received"
echo ""
echo "Dashboard should be available at: http://localhost:8080"
echo "Default login: admin / admin"
echo ""
echo "Press Ctrl+C to stop the server"
# Keep script running
wait $SERVER_PID

View File

@@ -1,75 +0,0 @@
#!/bin/bash
# Test script for LLM Proxy Gateway
echo "Building LLM Proxy Gateway..."
cargo build --release
if [ $? -ne 0 ]; then
echo "Build failed!"
exit 1
fi
echo "Build successful!"
echo ""
echo "Project Structure Summary:"
echo "=========================="
echo "Core Components:"
echo " - main.rs: Application entry point with server setup"
echo " - config/: Configuration management"
echo " - server/: API route handlers"
echo " - auth/: Bearer token authentication"
echo " - database/: SQLite database setup"
echo " - models/: Data structures (OpenAI-compatible)"
echo " - providers/: LLM provider implementations (OpenAI, Gemini, DeepSeek, Grok)"
echo " - errors/: Custom error types"
echo " - dashboard/: Admin dashboard with WebSocket support"
echo " - logging/: Request logging middleware"
echo " - state/: Shared application state"
echo " - multimodal/: Image processing support (basic structure)"
echo ""
echo "Key Features Implemented:"
echo "=========================="
echo "✓ OpenAI-compatible API endpoint (/v1/chat/completions)"
echo "✓ Bearer token authentication"
echo "✓ SQLite database for request tracking"
echo "✓ Request logging with token/cost calculation"
echo "✓ Provider abstraction layer"
echo "✓ Admin dashboard with real-time monitoring"
echo "✓ WebSocket support for live updates"
echo "✓ Configuration management (config.toml, .env, env vars)"
echo "✓ Multimodal support structure (images)"
echo "✓ Error handling with proper HTTP status codes"
echo ""
echo "Next Steps Needed:"
echo "=================="
echo "1. Add API keys to .env file:"
echo " OPENAI_API_KEY=your_key_here"
echo " GEMINI_API_KEY=your_key_here"
echo " DEEPSEEK_API_KEY=your_key_here"
echo " GROK_API_KEY=your_key_here (optional)"
echo ""
echo "2. Create config.toml for custom configuration (optional)"
echo ""
echo "3. Run the server:"
echo " cargo run"
echo ""
echo "4. Access dashboard at: http://localhost:8080"
echo ""
echo "5. Test API with curl:"
echo " curl -X POST http://localhost:8080/v1/chat/completions \\"
echo " -H 'Authorization: Bearer your_token' \\"
echo " -H 'Content-Type: application/json' \\"
echo " -d '{\"model\": \"gpt-4\", \"messages\": [{\"role\": \"user\", \"content\": \"Hello\"}]}'"
echo ""
echo "Deployment Notes:"
echo "================="
echo "Memory: Designed for 512MB RAM (LXC container)"
echo "Database: SQLite (./data/llm_proxy.db)"
echo "Port: 8080 (configurable)"
echo "Authentication: Single Bearer token (configurable)"
echo "Providers: OpenAI, Gemini, DeepSeek, Grok (disabled by default)"